glam/scripts/sync/migrate_to_multi_embedding.py
2025-12-14 17:09:55 +01:00

506 lines
17 KiB
Python

#!/usr/bin/env python3
"""
Migration Script: Add Multi-Embedding Support to Existing Qdrant Collections
This script migrates existing single-vector Qdrant collections to support
multiple embedding models using Qdrant's named vectors feature.
The migration process:
1. Create a new collection with named vectors for each embedding model
2. Copy existing data from old collection
3. Generate embeddings for new models (local sentence-transformers)
4. Rename collections (old -> backup, new -> original name)
Supported Embedding Models:
- minilm_384: all-MiniLM-L6-v2 (384-dim, free/local)
- openai_1536: text-embedding-3-small (1536-dim, API)
- bge_768: bge-base-en-v1.5 (768-dim, free/local)
Usage:
# Dry run to see what would be migrated
python -m scripts.sync.migrate_to_multi_embedding --dry-run
# Migrate heritage_persons collection
python -m scripts.sync.migrate_to_multi_embedding --collection heritage_persons
# Migrate with specific models
python -m scripts.sync.migrate_to_multi_embedding --models minilm_384,openai_1536
# Use production Qdrant
QDRANT_USE_PRODUCTION=true python -m scripts.sync.migrate_to_multi_embedding
"""
import argparse
import logging
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Any
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT / "src"))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Configuration
DEFAULT_COLLECTIONS = ["heritage_persons", "heritage_custodians"]
BATCH_SIZE = 100
def get_qdrant_client(use_production: bool = False):
"""Get Qdrant client for local or production."""
from qdrant_client import QdrantClient
if use_production:
host = os.getenv("QDRANT_PROD_HOST", "bronhouder.nl")
port = 443
prefix = os.getenv("QDRANT_PROD_PREFIX", "qdrant")
client = QdrantClient(
host=host,
port=port,
https=True,
prefix=prefix,
prefer_grpc=False,
timeout=60,
)
logger.info(f"Connected to production Qdrant: https://{host}/{prefix}")
else:
host = os.getenv("QDRANT_HOST", "localhost")
port = int(os.getenv("QDRANT_PORT", "6333"))
client = QdrantClient(host=host, port=port)
logger.info(f"Connected to local Qdrant: {host}:{port}")
return client
def get_collection_info(client, collection_name: str) -> dict[str, Any] | None:
"""Get collection information including vector configuration."""
try:
info = client.get_collection(collection_name)
vectors_config = info.config.params.vectors
# Determine vector configuration
if isinstance(vectors_config, dict):
# Already has named vectors
named_vectors = {}
for name, config in vectors_config.items():
named_vectors[name] = {
"size": config.size,
"distance": str(config.distance),
}
return {
"name": collection_name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"vector_type": "named",
"named_vectors": named_vectors,
}
else:
# Single vector config
return {
"name": collection_name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"vector_type": "single",
"vector_size": vectors_config.size if hasattr(vectors_config, 'size') else None,
"distance": str(vectors_config.distance) if hasattr(vectors_config, 'distance') else None,
}
except Exception as e:
logger.warning(f"Could not get collection info for '{collection_name}': {e}")
return None
def detect_existing_model(vector_size: int) -> str | None:
"""Detect which embedding model was used based on vector size."""
size_to_model = {
384: "minilm_384",
768: "bge_768",
1536: "openai_1536",
}
return size_to_model.get(vector_size)
def create_multi_embedding_collection(
client,
collection_name: str,
models: list[str],
backup_suffix: str = "_backup",
) -> bool:
"""Create a new collection with named vectors for multiple embedding models.
Args:
client: Qdrant client
collection_name: Name for the new collection
models: List of model names to support
backup_suffix: Suffix for backup collection
Returns:
True if created successfully
"""
from qdrant_client.http.models import Distance, VectorParams
# Model configurations
model_configs = {
"minilm_384": {"size": 384, "name": "all-MiniLM-L6-v2"},
"openai_1536": {"size": 1536, "name": "text-embedding-3-small"},
"bge_768": {"size": 768, "name": "BAAI/bge-base-en-v1.5"},
}
# Build vectors config
vectors_config = {}
for model in models:
if model not in model_configs:
logger.warning(f"Unknown model: {model}, skipping")
continue
config = model_configs[model]
vectors_config[model] = VectorParams(
size=config["size"],
distance=Distance.COSINE,
)
if not vectors_config:
logger.error("No valid models specified")
return False
# Create new collection with temporary name
temp_name = f"{collection_name}_multi_new"
try:
# Delete temp collection if exists
try:
client.delete_collection(temp_name)
except Exception:
pass
client.create_collection(
collection_name=temp_name,
vectors_config=vectors_config,
)
logger.info(f"Created collection '{temp_name}' with models: {list(vectors_config.keys())}")
return True
except Exception as e:
logger.error(f"Failed to create collection: {e}")
return False
def migrate_collection(
client,
collection_name: str,
target_models: list[str],
batch_size: int = BATCH_SIZE,
dry_run: bool = False,
openai_api_key: str | None = None,
) -> dict[str, Any]:
"""Migrate a collection to multi-embedding format.
Args:
client: Qdrant client
collection_name: Collection to migrate
target_models: Target embedding models
batch_size: Batch size for processing
dry_run: If True, only show what would be done
openai_api_key: OpenAI API key for openai_1536 model
Returns:
Migration result dict
"""
from qdrant_client.http import models as qmodels
result = {
"collection": collection_name,
"status": "pending",
"points_migrated": 0,
"errors": [],
}
# Get current collection info
info = get_collection_info(client, collection_name)
if not info:
result["status"] = "error"
result["errors"].append(f"Collection '{collection_name}' not found")
return result
logger.info(f"Collection '{collection_name}': {info['vectors_count']} vectors, type={info['vector_type']}")
# Check if already has named vectors
if info["vector_type"] == "named":
existing_models = set(info["named_vectors"].keys())
missing_models = set(target_models) - existing_models
if not missing_models:
logger.info(f"Collection already has all target models: {target_models}")
result["status"] = "already_migrated"
return result
logger.info(f"Collection has {existing_models}, adding {missing_models}")
# TODO: Implement adding vectors to existing named vector collection
result["status"] = "partial"
result["existing_models"] = list(existing_models)
result["missing_models"] = list(missing_models)
return result
# Single vector collection - full migration needed
existing_size = info.get("vector_size")
existing_model = detect_existing_model(existing_size) if existing_size else None
logger.info(f"Detected existing model: {existing_model} ({existing_size}-dim)")
if dry_run:
logger.info(f"[DRY RUN] Would migrate {info['points_count']} points to models: {target_models}")
result["status"] = "dry_run"
result["would_migrate"] = info["points_count"]
result["existing_model"] = existing_model
result["target_models"] = target_models
return result
# Load embedding models
from sentence_transformers import SentenceTransformer
st_models = {}
for model in target_models:
if model == "minilm_384":
st_models[model] = SentenceTransformer("all-MiniLM-L6-v2")
logger.info(f"Loaded {model}")
elif model == "bge_768":
st_models[model] = SentenceTransformer("BAAI/bge-base-en-v1.5")
logger.info(f"Loaded {model}")
elif model == "openai_1536":
if not openai_api_key:
logger.warning("OpenAI API key not provided, skipping openai_1536")
continue
# Will use OpenAI client
logger.info(f"Will use OpenAI API for {model}")
if existing_model and existing_model not in target_models:
target_models = [existing_model] + [m for m in target_models if m != existing_model]
logger.info(f"Including existing model, final models: {target_models}")
# Create new collection
temp_name = f"{collection_name}_multi_new"
if not create_multi_embedding_collection(client, temp_name, target_models):
result["status"] = "error"
result["errors"].append("Failed to create new collection")
return result
# Migrate data in batches
offset = None
total_migrated = 0
while True:
# Scroll through existing collection
records, next_offset = client.scroll(
collection_name=collection_name,
limit=batch_size,
offset=offset,
with_payload=True,
with_vectors=True,
)
if not records:
break
# Process batch
points_to_upsert = []
texts_to_embed = []
for record in records:
payload = record.payload or {}
text = payload.get("text", "")
texts_to_embed.append(text)
# Generate embeddings for each model
embeddings_by_model = {}
for model in target_models:
if model == existing_model:
# Use existing vectors
embeddings_by_model[model] = [
record.vector if not isinstance(record.vector, dict)
else record.vector.get(model, [])
for record in records
]
elif model in st_models:
# Generate with sentence-transformers
embeddings = st_models[model].encode(texts_to_embed, show_progress_bar=False)
embeddings_by_model[model] = embeddings.tolist()
elif model == "openai_1536" and openai_api_key:
# Generate with OpenAI
import openai
client_openai = openai.OpenAI(api_key=openai_api_key)
response = client_openai.embeddings.create(
input=texts_to_embed,
model="text-embedding-3-small",
)
embeddings_by_model[model] = [
item.embedding for item in sorted(response.data, key=lambda x: x.index)
]
# Create points with named vectors
for i, record in enumerate(records):
vectors = {}
for model in target_models:
if model in embeddings_by_model and i < len(embeddings_by_model[model]):
vectors[model] = embeddings_by_model[model][i]
if vectors:
points_to_upsert.append(qmodels.PointStruct(
id=record.id,
vector=vectors,
payload=record.payload,
))
# Upsert to new collection
if points_to_upsert:
client.upsert(
collection_name=temp_name,
points=points_to_upsert,
)
total_migrated += len(points_to_upsert)
logger.info(f"Migrated {total_migrated} points...")
if next_offset is None:
break
offset = next_offset
# Swap collections
backup_name = f"{collection_name}_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
try:
# Rename old collection to backup
# Note: Qdrant doesn't have rename, so we need to create alias or document the mapping
logger.info(f"Migration complete. New collection: {temp_name}")
logger.info(f"To complete migration, manually:")
logger.info(f" 1. Verify new collection: {temp_name}")
logger.info(f" 2. Delete old collection: {collection_name}")
logger.info(f" 3. Recreate {collection_name} with data from {temp_name}")
result["status"] = "success"
result["points_migrated"] = total_migrated
result["new_collection"] = temp_name
result["models"] = target_models
except Exception as e:
result["status"] = "error"
result["errors"].append(f"Failed during swap: {e}")
return result
def main():
parser = argparse.ArgumentParser(
description="Migrate Qdrant collections to multi-embedding format"
)
parser.add_argument(
"--collection", "-c",
type=str,
default=None,
help="Collection to migrate (default: all default collections)"
)
parser.add_argument(
"--models", "-m",
type=str,
default="minilm_384",
help="Comma-separated list of models: minilm_384,openai_1536,bge_768"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be migrated without making changes"
)
parser.add_argument(
"--production",
action="store_true",
help="Use production Qdrant"
)
parser.add_argument(
"--batch-size",
type=int,
default=BATCH_SIZE,
help=f"Batch size for processing (default: {BATCH_SIZE})"
)
parser.add_argument(
"--info",
action="store_true",
help="Only show collection information"
)
args = parser.parse_args()
# Parse models
models = [m.strip() for m in args.models.split(",")]
# Use production if env var set
use_production = args.production or os.getenv("QDRANT_USE_PRODUCTION", "").lower() in ("true", "1", "yes")
# Get Qdrant client
client = get_qdrant_client(use_production)
# Determine collections to process
if args.collection:
collections = [args.collection]
else:
collections = DEFAULT_COLLECTIONS
# Info mode - just show collection status
if args.info:
print("\n=== Collection Information ===\n")
for coll in collections:
info = get_collection_info(client, coll)
if info:
print(f"Collection: {info['name']}")
print(f" Vectors: {info['vectors_count']}")
print(f" Points: {info['points_count']}")
print(f" Type: {info['vector_type']}")
if info['vector_type'] == 'single':
model = detect_existing_model(info['vector_size'])
print(f" Vector size: {info['vector_size']} ({model or 'unknown'})")
else:
print(f" Named vectors: {list(info['named_vectors'].keys())}")
print()
else:
print(f"Collection: {coll} - NOT FOUND\n")
return
# Migration mode
print(f"\n=== Multi-Embedding Migration ===")
print(f"Target models: {models}")
print(f"Collections: {collections}")
print(f"Dry run: {args.dry_run}")
print()
openai_key = os.getenv("OPENAI_API_KEY")
results = []
for coll in collections:
print(f"\n--- Migrating {coll} ---")
result = migrate_collection(
client,
coll,
models,
batch_size=args.batch_size,
dry_run=args.dry_run,
openai_api_key=openai_key,
)
results.append(result)
print(f"Result: {result['status']}")
if result.get("errors"):
for err in result["errors"]:
print(f" Error: {err}")
print("\n=== Migration Summary ===")
for r in results:
print(f"{r['collection']}: {r['status']} ({r.get('points_migrated', 0)} points)")
if __name__ == "__main__":
main()