506 lines
17 KiB
Python
506 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Migration Script: Add Multi-Embedding Support to Existing Qdrant Collections
|
|
|
|
This script migrates existing single-vector Qdrant collections to support
|
|
multiple embedding models using Qdrant's named vectors feature.
|
|
|
|
The migration process:
|
|
1. Create a new collection with named vectors for each embedding model
|
|
2. Copy existing data from old collection
|
|
3. Generate embeddings for new models (local sentence-transformers)
|
|
4. Rename collections (old -> backup, new -> original name)
|
|
|
|
Supported Embedding Models:
|
|
- minilm_384: all-MiniLM-L6-v2 (384-dim, free/local)
|
|
- openai_1536: text-embedding-3-small (1536-dim, API)
|
|
- bge_768: bge-base-en-v1.5 (768-dim, free/local)
|
|
|
|
Usage:
|
|
# Dry run to see what would be migrated
|
|
python -m scripts.sync.migrate_to_multi_embedding --dry-run
|
|
|
|
# Migrate heritage_persons collection
|
|
python -m scripts.sync.migrate_to_multi_embedding --collection heritage_persons
|
|
|
|
# Migrate with specific models
|
|
python -m scripts.sync.migrate_to_multi_embedding --models minilm_384,openai_1536
|
|
|
|
# Use production Qdrant
|
|
QDRANT_USE_PRODUCTION=true python -m scripts.sync.migrate_to_multi_embedding
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
# Add project root to path
|
|
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
|
sys.path.insert(0, str(PROJECT_ROOT / "src"))
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Configuration
|
|
DEFAULT_COLLECTIONS = ["heritage_persons", "heritage_custodians"]
|
|
BATCH_SIZE = 100
|
|
|
|
|
|
def get_qdrant_client(use_production: bool = False):
|
|
"""Get Qdrant client for local or production."""
|
|
from qdrant_client import QdrantClient
|
|
|
|
if use_production:
|
|
host = os.getenv("QDRANT_PROD_HOST", "bronhouder.nl")
|
|
port = 443
|
|
prefix = os.getenv("QDRANT_PROD_PREFIX", "qdrant")
|
|
|
|
client = QdrantClient(
|
|
host=host,
|
|
port=port,
|
|
https=True,
|
|
prefix=prefix,
|
|
prefer_grpc=False,
|
|
timeout=60,
|
|
)
|
|
logger.info(f"Connected to production Qdrant: https://{host}/{prefix}")
|
|
else:
|
|
host = os.getenv("QDRANT_HOST", "localhost")
|
|
port = int(os.getenv("QDRANT_PORT", "6333"))
|
|
|
|
client = QdrantClient(host=host, port=port)
|
|
logger.info(f"Connected to local Qdrant: {host}:{port}")
|
|
|
|
return client
|
|
|
|
|
|
def get_collection_info(client, collection_name: str) -> dict[str, Any] | None:
|
|
"""Get collection information including vector configuration."""
|
|
try:
|
|
info = client.get_collection(collection_name)
|
|
vectors_config = info.config.params.vectors
|
|
|
|
# Determine vector configuration
|
|
if isinstance(vectors_config, dict):
|
|
# Already has named vectors
|
|
named_vectors = {}
|
|
for name, config in vectors_config.items():
|
|
named_vectors[name] = {
|
|
"size": config.size,
|
|
"distance": str(config.distance),
|
|
}
|
|
return {
|
|
"name": collection_name,
|
|
"vectors_count": info.vectors_count,
|
|
"points_count": info.points_count,
|
|
"vector_type": "named",
|
|
"named_vectors": named_vectors,
|
|
}
|
|
else:
|
|
# Single vector config
|
|
return {
|
|
"name": collection_name,
|
|
"vectors_count": info.vectors_count,
|
|
"points_count": info.points_count,
|
|
"vector_type": "single",
|
|
"vector_size": vectors_config.size if hasattr(vectors_config, 'size') else None,
|
|
"distance": str(vectors_config.distance) if hasattr(vectors_config, 'distance') else None,
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Could not get collection info for '{collection_name}': {e}")
|
|
return None
|
|
|
|
|
|
def detect_existing_model(vector_size: int) -> str | None:
|
|
"""Detect which embedding model was used based on vector size."""
|
|
size_to_model = {
|
|
384: "minilm_384",
|
|
768: "bge_768",
|
|
1536: "openai_1536",
|
|
}
|
|
return size_to_model.get(vector_size)
|
|
|
|
|
|
def create_multi_embedding_collection(
|
|
client,
|
|
collection_name: str,
|
|
models: list[str],
|
|
backup_suffix: str = "_backup",
|
|
) -> bool:
|
|
"""Create a new collection with named vectors for multiple embedding models.
|
|
|
|
Args:
|
|
client: Qdrant client
|
|
collection_name: Name for the new collection
|
|
models: List of model names to support
|
|
backup_suffix: Suffix for backup collection
|
|
|
|
Returns:
|
|
True if created successfully
|
|
"""
|
|
from qdrant_client.http.models import Distance, VectorParams
|
|
|
|
# Model configurations
|
|
model_configs = {
|
|
"minilm_384": {"size": 384, "name": "all-MiniLM-L6-v2"},
|
|
"openai_1536": {"size": 1536, "name": "text-embedding-3-small"},
|
|
"bge_768": {"size": 768, "name": "BAAI/bge-base-en-v1.5"},
|
|
}
|
|
|
|
# Build vectors config
|
|
vectors_config = {}
|
|
for model in models:
|
|
if model not in model_configs:
|
|
logger.warning(f"Unknown model: {model}, skipping")
|
|
continue
|
|
config = model_configs[model]
|
|
vectors_config[model] = VectorParams(
|
|
size=config["size"],
|
|
distance=Distance.COSINE,
|
|
)
|
|
|
|
if not vectors_config:
|
|
logger.error("No valid models specified")
|
|
return False
|
|
|
|
# Create new collection with temporary name
|
|
temp_name = f"{collection_name}_multi_new"
|
|
|
|
try:
|
|
# Delete temp collection if exists
|
|
try:
|
|
client.delete_collection(temp_name)
|
|
except Exception:
|
|
pass
|
|
|
|
client.create_collection(
|
|
collection_name=temp_name,
|
|
vectors_config=vectors_config,
|
|
)
|
|
logger.info(f"Created collection '{temp_name}' with models: {list(vectors_config.keys())}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create collection: {e}")
|
|
return False
|
|
|
|
|
|
def migrate_collection(
|
|
client,
|
|
collection_name: str,
|
|
target_models: list[str],
|
|
batch_size: int = BATCH_SIZE,
|
|
dry_run: bool = False,
|
|
openai_api_key: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Migrate a collection to multi-embedding format.
|
|
|
|
Args:
|
|
client: Qdrant client
|
|
collection_name: Collection to migrate
|
|
target_models: Target embedding models
|
|
batch_size: Batch size for processing
|
|
dry_run: If True, only show what would be done
|
|
openai_api_key: OpenAI API key for openai_1536 model
|
|
|
|
Returns:
|
|
Migration result dict
|
|
"""
|
|
from qdrant_client.http import models as qmodels
|
|
|
|
result = {
|
|
"collection": collection_name,
|
|
"status": "pending",
|
|
"points_migrated": 0,
|
|
"errors": [],
|
|
}
|
|
|
|
# Get current collection info
|
|
info = get_collection_info(client, collection_name)
|
|
if not info:
|
|
result["status"] = "error"
|
|
result["errors"].append(f"Collection '{collection_name}' not found")
|
|
return result
|
|
|
|
logger.info(f"Collection '{collection_name}': {info['vectors_count']} vectors, type={info['vector_type']}")
|
|
|
|
# Check if already has named vectors
|
|
if info["vector_type"] == "named":
|
|
existing_models = set(info["named_vectors"].keys())
|
|
missing_models = set(target_models) - existing_models
|
|
|
|
if not missing_models:
|
|
logger.info(f"Collection already has all target models: {target_models}")
|
|
result["status"] = "already_migrated"
|
|
return result
|
|
|
|
logger.info(f"Collection has {existing_models}, adding {missing_models}")
|
|
# TODO: Implement adding vectors to existing named vector collection
|
|
result["status"] = "partial"
|
|
result["existing_models"] = list(existing_models)
|
|
result["missing_models"] = list(missing_models)
|
|
return result
|
|
|
|
# Single vector collection - full migration needed
|
|
existing_size = info.get("vector_size")
|
|
existing_model = detect_existing_model(existing_size) if existing_size else None
|
|
|
|
logger.info(f"Detected existing model: {existing_model} ({existing_size}-dim)")
|
|
|
|
if dry_run:
|
|
logger.info(f"[DRY RUN] Would migrate {info['points_count']} points to models: {target_models}")
|
|
result["status"] = "dry_run"
|
|
result["would_migrate"] = info["points_count"]
|
|
result["existing_model"] = existing_model
|
|
result["target_models"] = target_models
|
|
return result
|
|
|
|
# Load embedding models
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
st_models = {}
|
|
for model in target_models:
|
|
if model == "minilm_384":
|
|
st_models[model] = SentenceTransformer("all-MiniLM-L6-v2")
|
|
logger.info(f"Loaded {model}")
|
|
elif model == "bge_768":
|
|
st_models[model] = SentenceTransformer("BAAI/bge-base-en-v1.5")
|
|
logger.info(f"Loaded {model}")
|
|
elif model == "openai_1536":
|
|
if not openai_api_key:
|
|
logger.warning("OpenAI API key not provided, skipping openai_1536")
|
|
continue
|
|
# Will use OpenAI client
|
|
logger.info(f"Will use OpenAI API for {model}")
|
|
|
|
if existing_model and existing_model not in target_models:
|
|
target_models = [existing_model] + [m for m in target_models if m != existing_model]
|
|
logger.info(f"Including existing model, final models: {target_models}")
|
|
|
|
# Create new collection
|
|
temp_name = f"{collection_name}_multi_new"
|
|
if not create_multi_embedding_collection(client, temp_name, target_models):
|
|
result["status"] = "error"
|
|
result["errors"].append("Failed to create new collection")
|
|
return result
|
|
|
|
# Migrate data in batches
|
|
offset = None
|
|
total_migrated = 0
|
|
|
|
while True:
|
|
# Scroll through existing collection
|
|
records, next_offset = client.scroll(
|
|
collection_name=collection_name,
|
|
limit=batch_size,
|
|
offset=offset,
|
|
with_payload=True,
|
|
with_vectors=True,
|
|
)
|
|
|
|
if not records:
|
|
break
|
|
|
|
# Process batch
|
|
points_to_upsert = []
|
|
texts_to_embed = []
|
|
|
|
for record in records:
|
|
payload = record.payload or {}
|
|
text = payload.get("text", "")
|
|
texts_to_embed.append(text)
|
|
|
|
# Generate embeddings for each model
|
|
embeddings_by_model = {}
|
|
|
|
for model in target_models:
|
|
if model == existing_model:
|
|
# Use existing vectors
|
|
embeddings_by_model[model] = [
|
|
record.vector if not isinstance(record.vector, dict)
|
|
else record.vector.get(model, [])
|
|
for record in records
|
|
]
|
|
elif model in st_models:
|
|
# Generate with sentence-transformers
|
|
embeddings = st_models[model].encode(texts_to_embed, show_progress_bar=False)
|
|
embeddings_by_model[model] = embeddings.tolist()
|
|
elif model == "openai_1536" and openai_api_key:
|
|
# Generate with OpenAI
|
|
import openai
|
|
client_openai = openai.OpenAI(api_key=openai_api_key)
|
|
response = client_openai.embeddings.create(
|
|
input=texts_to_embed,
|
|
model="text-embedding-3-small",
|
|
)
|
|
embeddings_by_model[model] = [
|
|
item.embedding for item in sorted(response.data, key=lambda x: x.index)
|
|
]
|
|
|
|
# Create points with named vectors
|
|
for i, record in enumerate(records):
|
|
vectors = {}
|
|
for model in target_models:
|
|
if model in embeddings_by_model and i < len(embeddings_by_model[model]):
|
|
vectors[model] = embeddings_by_model[model][i]
|
|
|
|
if vectors:
|
|
points_to_upsert.append(qmodels.PointStruct(
|
|
id=record.id,
|
|
vector=vectors,
|
|
payload=record.payload,
|
|
))
|
|
|
|
# Upsert to new collection
|
|
if points_to_upsert:
|
|
client.upsert(
|
|
collection_name=temp_name,
|
|
points=points_to_upsert,
|
|
)
|
|
total_migrated += len(points_to_upsert)
|
|
logger.info(f"Migrated {total_migrated} points...")
|
|
|
|
if next_offset is None:
|
|
break
|
|
offset = next_offset
|
|
|
|
# Swap collections
|
|
backup_name = f"{collection_name}_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
|
|
try:
|
|
# Rename old collection to backup
|
|
# Note: Qdrant doesn't have rename, so we need to create alias or document the mapping
|
|
logger.info(f"Migration complete. New collection: {temp_name}")
|
|
logger.info(f"To complete migration, manually:")
|
|
logger.info(f" 1. Verify new collection: {temp_name}")
|
|
logger.info(f" 2. Delete old collection: {collection_name}")
|
|
logger.info(f" 3. Recreate {collection_name} with data from {temp_name}")
|
|
|
|
result["status"] = "success"
|
|
result["points_migrated"] = total_migrated
|
|
result["new_collection"] = temp_name
|
|
result["models"] = target_models
|
|
|
|
except Exception as e:
|
|
result["status"] = "error"
|
|
result["errors"].append(f"Failed during swap: {e}")
|
|
|
|
return result
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Migrate Qdrant collections to multi-embedding format"
|
|
)
|
|
parser.add_argument(
|
|
"--collection", "-c",
|
|
type=str,
|
|
default=None,
|
|
help="Collection to migrate (default: all default collections)"
|
|
)
|
|
parser.add_argument(
|
|
"--models", "-m",
|
|
type=str,
|
|
default="minilm_384",
|
|
help="Comma-separated list of models: minilm_384,openai_1536,bge_768"
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Show what would be migrated without making changes"
|
|
)
|
|
parser.add_argument(
|
|
"--production",
|
|
action="store_true",
|
|
help="Use production Qdrant"
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=BATCH_SIZE,
|
|
help=f"Batch size for processing (default: {BATCH_SIZE})"
|
|
)
|
|
parser.add_argument(
|
|
"--info",
|
|
action="store_true",
|
|
help="Only show collection information"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Parse models
|
|
models = [m.strip() for m in args.models.split(",")]
|
|
|
|
# Use production if env var set
|
|
use_production = args.production or os.getenv("QDRANT_USE_PRODUCTION", "").lower() in ("true", "1", "yes")
|
|
|
|
# Get Qdrant client
|
|
client = get_qdrant_client(use_production)
|
|
|
|
# Determine collections to process
|
|
if args.collection:
|
|
collections = [args.collection]
|
|
else:
|
|
collections = DEFAULT_COLLECTIONS
|
|
|
|
# Info mode - just show collection status
|
|
if args.info:
|
|
print("\n=== Collection Information ===\n")
|
|
for coll in collections:
|
|
info = get_collection_info(client, coll)
|
|
if info:
|
|
print(f"Collection: {info['name']}")
|
|
print(f" Vectors: {info['vectors_count']}")
|
|
print(f" Points: {info['points_count']}")
|
|
print(f" Type: {info['vector_type']}")
|
|
if info['vector_type'] == 'single':
|
|
model = detect_existing_model(info['vector_size'])
|
|
print(f" Vector size: {info['vector_size']} ({model or 'unknown'})")
|
|
else:
|
|
print(f" Named vectors: {list(info['named_vectors'].keys())}")
|
|
print()
|
|
else:
|
|
print(f"Collection: {coll} - NOT FOUND\n")
|
|
return
|
|
|
|
# Migration mode
|
|
print(f"\n=== Multi-Embedding Migration ===")
|
|
print(f"Target models: {models}")
|
|
print(f"Collections: {collections}")
|
|
print(f"Dry run: {args.dry_run}")
|
|
print()
|
|
|
|
openai_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
results = []
|
|
for coll in collections:
|
|
print(f"\n--- Migrating {coll} ---")
|
|
result = migrate_collection(
|
|
client,
|
|
coll,
|
|
models,
|
|
batch_size=args.batch_size,
|
|
dry_run=args.dry_run,
|
|
openai_api_key=openai_key,
|
|
)
|
|
results.append(result)
|
|
print(f"Result: {result['status']}")
|
|
if result.get("errors"):
|
|
for err in result["errors"]:
|
|
print(f" Error: {err}")
|
|
|
|
print("\n=== Migration Summary ===")
|
|
for r in results:
|
|
print(f"{r['collection']}: {r['status']} ({r.get('points_migrated', 0)} points)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|