glam/scripts/migrate_qdrant_to_minilm.py
2025-12-16 09:02:52 +01:00

332 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Migrate Qdrant collection from OpenAI embeddings to local MiniLM embeddings.
This script:
1. Creates a new collection with 384-dim vectors (MiniLM)
2. Scrolls through all existing documents
3. Re-embeds with MiniLM and inserts into new collection
4. Tests the new collection performance
5. Optionally swaps collection names
Expected performance improvement: 300-1300ms → 22-25ms per embedding
Usage:
python migrate_qdrant_to_minilm.py --dry-run # Test without creating
python migrate_qdrant_to_minilm.py --migrate # Run migration
python migrate_qdrant_to_minilm.py --swap # Swap collections after verification
"""
import argparse
import logging
import time
from typing import Any
import sys
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Configuration
QDRANT_HOST = "localhost"
QDRANT_PORT = 6333
SOURCE_COLLECTION = "heritage_custodians"
TARGET_COLLECTION = "heritage_custodians_minilm"
BATCH_SIZE = 100 # Documents per batch
MINILM_DIM = 384
def get_qdrant_client():
"""Get Qdrant client."""
from qdrant_client import QdrantClient
return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
def load_minilm_model():
"""Load MiniLM sentence transformer model."""
from sentence_transformers import SentenceTransformer
logger.info("Loading MiniLM model...")
start = time.time()
model = SentenceTransformer("all-MiniLM-L6-v2")
logger.info(f"Model loaded in {time.time() - start:.1f}s")
return model
def create_target_collection(client) -> bool:
"""Create target collection with MiniLM vector configuration."""
from qdrant_client.http.models import Distance, VectorParams
# Check if already exists
collections = client.get_collections().collections
existing = [c.name for c in collections]
if TARGET_COLLECTION in existing:
logger.warning(f"Collection '{TARGET_COLLECTION}' already exists!")
info = client.get_collection(TARGET_COLLECTION)
logger.info(f" Points: {info.points_count}")
return False
# Create with MiniLM dimensions
client.create_collection(
collection_name=TARGET_COLLECTION,
vectors_config=VectorParams(
size=MINILM_DIM,
distance=Distance.COSINE,
),
)
logger.info(f"Created collection '{TARGET_COLLECTION}' with {MINILM_DIM}-dim vectors")
return True
def migrate_documents(client, model, limit: int | None = None) -> dict:
"""Migrate documents from source to target collection."""
from qdrant_client.http.models import PointStruct
stats = {
"total_scrolled": 0,
"total_indexed": 0,
"batches": 0,
"embedding_time_ms": 0,
"indexing_time_ms": 0,
}
# Scroll through source collection
offset = None
batch_num = 0
while True:
# Get batch of documents
result = client.scroll(
collection_name=SOURCE_COLLECTION,
limit=BATCH_SIZE,
offset=offset,
with_payload=True,
with_vectors=False, # Don't need old vectors
)
points = result[0]
offset = result[1]
if not points:
break
batch_num += 1
stats["batches"] = batch_num
stats["total_scrolled"] += len(points)
# Extract texts for embedding
texts = []
valid_points = []
for point in points:
text = point.payload.get("text", "")
if text:
texts.append(text)
valid_points.append(point)
if not texts:
logger.warning(f"Batch {batch_num}: No valid texts found")
continue
# Generate embeddings with MiniLM
embed_start = time.time()
embeddings = model.encode(texts, show_progress_bar=False)
embed_time = (time.time() - embed_start) * 1000
stats["embedding_time_ms"] += embed_time
# Create points for target collection
new_points = [
PointStruct(
id=str(point.id),
vector=embedding.tolist(),
payload=point.payload,
)
for point, embedding in zip(valid_points, embeddings)
]
# Index into target collection
index_start = time.time()
client.upsert(
collection_name=TARGET_COLLECTION,
points=new_points,
)
index_time = (time.time() - index_start) * 1000
stats["indexing_time_ms"] += index_time
stats["total_indexed"] += len(new_points)
# Progress log
avg_embed = stats["embedding_time_ms"] / stats["batches"]
logger.info(
f"Batch {batch_num}: {len(points)} docs, "
f"embed={embed_time:.0f}ms ({embed_time/len(texts):.1f}ms/doc), "
f"index={index_time:.0f}ms, "
f"total={stats['total_indexed']}"
)
# Check limit
if limit and stats["total_scrolled"] >= limit:
logger.info(f"Reached limit of {limit} documents")
break
# No more documents
if offset is None:
break
return stats
def test_search_performance(client, model, queries: list[str]) -> dict:
"""Test search performance on new collection."""
results = []
for query in queries:
# Embed query
embed_start = time.time()
query_vector = model.encode(query).tolist()
embed_time = (time.time() - embed_start) * 1000
# Search
search_start = time.time()
search_results = client.search(
collection_name=TARGET_COLLECTION,
query_vector=query_vector,
limit=10,
)
search_time = (time.time() - search_start) * 1000
total_time = embed_time + search_time
results.append({
"query": query,
"embed_ms": embed_time,
"search_ms": search_time,
"total_ms": total_time,
"num_results": len(search_results),
"top_result": search_results[0].payload.get("name") if search_results else None,
})
logger.info(
f"Query: '{query}'"
f"embed={embed_time:.0f}ms, search={search_time:.0f}ms, total={total_time:.0f}ms, "
f"top='{results[-1]['top_result']}'"
)
return results
def swap_collections(client):
"""Swap source and target collection names."""
# Rename source to backup
backup_name = f"{SOURCE_COLLECTION}_openai_backup"
# Check if backup already exists
collections = [c.name for c in client.get_collections().collections]
if backup_name in collections:
logger.error(f"Backup collection '{backup_name}' already exists!")
return False
logger.info(f"Renaming '{SOURCE_COLLECTION}''{backup_name}'")
client.update_collection_aliases(
change_aliases_operations=[
{"create_alias": {"alias_name": backup_name, "collection_name": SOURCE_COLLECTION}}
]
)
# This is tricky - Qdrant doesn't have direct rename
# We'll need to recreate with aliases or just use the new collection name
logger.warning("NOTE: Qdrant doesn't support direct collection rename.")
logger.warning(f"Update your code to use '{TARGET_COLLECTION}' instead of '{SOURCE_COLLECTION}'")
logger.warning("Or update the retriever to use the new collection name.")
return True
def main():
parser = argparse.ArgumentParser(description="Migrate Qdrant to MiniLM embeddings")
parser.add_argument("--dry-run", action="store_true", help="Test without creating")
parser.add_argument("--migrate", action="store_true", help="Run full migration")
parser.add_argument("--test-only", action="store_true", help="Test existing target collection")
parser.add_argument("--limit", type=int, help="Limit number of documents to migrate")
parser.add_argument("--swap", action="store_true", help="Swap collections (after verification)")
args = parser.parse_args()
client = get_qdrant_client()
# Get source stats
source_info = client.get_collection(SOURCE_COLLECTION)
logger.info(f"Source collection '{SOURCE_COLLECTION}':")
logger.info(f" Points: {source_info.points_count}")
logger.info(f" Vector size: {source_info.config.params.vectors.size}")
if args.dry_run:
logger.info("\n=== DRY RUN ===")
logger.info(f"Would migrate {source_info.points_count} documents")
logger.info(f"Target: {TARGET_COLLECTION} with {MINILM_DIM}-dim vectors")
# Test MiniLM loading and embedding
model = load_minilm_model()
test_text = "archives in Utrecht"
start = time.time()
embedding = model.encode(test_text)
logger.info(f"Test embedding: {(time.time()-start)*1000:.0f}ms, dim={len(embedding)}")
return
if args.test_only:
logger.info("\n=== TESTING TARGET COLLECTION ===")
model = load_minilm_model()
test_queries = [
"archives in Utrecht",
"museums in Amsterdam",
"library Belgium",
"Japanese art museum",
]
test_search_performance(client, model, test_queries)
return
if args.migrate:
logger.info("\n=== STARTING MIGRATION ===")
model = load_minilm_model()
# Create target collection
created = create_target_collection(client)
if not created:
response = input("Collection exists. Continue migration anyway? [y/N] ")
if response.lower() != 'y':
logger.info("Aborted.")
return
# Migrate documents
start_time = time.time()
stats = migrate_documents(client, model, limit=args.limit)
total_time = time.time() - start_time
logger.info("\n=== MIGRATION COMPLETE ===")
logger.info(f"Total documents: {stats['total_indexed']}")
logger.info(f"Total time: {total_time:.1f}s")
logger.info(f"Embedding time: {stats['embedding_time_ms']/1000:.1f}s")
logger.info(f"Indexing time: {stats['indexing_time_ms']/1000:.1f}s")
logger.info(f"Avg per doc: {total_time*1000/stats['total_indexed']:.1f}ms")
# Test performance
logger.info("\n=== TESTING PERFORMANCE ===")
test_queries = [
"archives in Utrecht",
"museums in Amsterdam",
"library Belgium",
]
test_search_performance(client, model, test_queries)
return
if args.swap:
logger.info("\n=== SWAPPING COLLECTIONS ===")
swap_collections(client)
return
parser.print_help()
if __name__ == "__main__":
main()