#!/usr/bin/env python3 """ Migrate Qdrant collection from OpenAI embeddings to local MiniLM embeddings. This script: 1. Creates a new collection with 384-dim vectors (MiniLM) 2. Scrolls through all existing documents 3. Re-embeds with MiniLM and inserts into new collection 4. Tests the new collection performance 5. Optionally swaps collection names Expected performance improvement: 300-1300ms → 22-25ms per embedding Usage: python migrate_qdrant_to_minilm.py --dry-run # Test without creating python migrate_qdrant_to_minilm.py --migrate # Run migration python migrate_qdrant_to_minilm.py --swap # Swap collections after verification """ import argparse import logging import time from typing import Any import sys logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Configuration QDRANT_HOST = "localhost" QDRANT_PORT = 6333 SOURCE_COLLECTION = "heritage_custodians" TARGET_COLLECTION = "heritage_custodians_minilm" BATCH_SIZE = 100 # Documents per batch MINILM_DIM = 384 def get_qdrant_client(): """Get Qdrant client.""" from qdrant_client import QdrantClient return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) def load_minilm_model(): """Load MiniLM sentence transformer model.""" from sentence_transformers import SentenceTransformer logger.info("Loading MiniLM model...") start = time.time() model = SentenceTransformer("all-MiniLM-L6-v2") logger.info(f"Model loaded in {time.time() - start:.1f}s") return model def create_target_collection(client) -> bool: """Create target collection with MiniLM vector configuration.""" from qdrant_client.http.models import Distance, VectorParams # Check if already exists collections = client.get_collections().collections existing = [c.name for c in collections] if TARGET_COLLECTION in existing: logger.warning(f"Collection '{TARGET_COLLECTION}' already exists!") info = client.get_collection(TARGET_COLLECTION) logger.info(f" Points: {info.points_count}") return False # Create with MiniLM dimensions client.create_collection( collection_name=TARGET_COLLECTION, vectors_config=VectorParams( size=MINILM_DIM, distance=Distance.COSINE, ), ) logger.info(f"Created collection '{TARGET_COLLECTION}' with {MINILM_DIM}-dim vectors") return True def migrate_documents(client, model, limit: int | None = None) -> dict: """Migrate documents from source to target collection.""" from qdrant_client.http.models import PointStruct stats = { "total_scrolled": 0, "total_indexed": 0, "batches": 0, "embedding_time_ms": 0, "indexing_time_ms": 0, } # Scroll through source collection offset = None batch_num = 0 while True: # Get batch of documents result = client.scroll( collection_name=SOURCE_COLLECTION, limit=BATCH_SIZE, offset=offset, with_payload=True, with_vectors=False, # Don't need old vectors ) points = result[0] offset = result[1] if not points: break batch_num += 1 stats["batches"] = batch_num stats["total_scrolled"] += len(points) # Extract texts for embedding texts = [] valid_points = [] for point in points: text = point.payload.get("text", "") if text: texts.append(text) valid_points.append(point) if not texts: logger.warning(f"Batch {batch_num}: No valid texts found") continue # Generate embeddings with MiniLM embed_start = time.time() embeddings = model.encode(texts, show_progress_bar=False) embed_time = (time.time() - embed_start) * 1000 stats["embedding_time_ms"] += embed_time # Create points for target collection new_points = [ PointStruct( id=str(point.id), vector=embedding.tolist(), payload=point.payload, ) for point, embedding in zip(valid_points, embeddings) ] # Index into target collection index_start = time.time() client.upsert( collection_name=TARGET_COLLECTION, points=new_points, ) index_time = (time.time() - index_start) * 1000 stats["indexing_time_ms"] += index_time stats["total_indexed"] += len(new_points) # Progress log avg_embed = stats["embedding_time_ms"] / stats["batches"] logger.info( f"Batch {batch_num}: {len(points)} docs, " f"embed={embed_time:.0f}ms ({embed_time/len(texts):.1f}ms/doc), " f"index={index_time:.0f}ms, " f"total={stats['total_indexed']}" ) # Check limit if limit and stats["total_scrolled"] >= limit: logger.info(f"Reached limit of {limit} documents") break # No more documents if offset is None: break return stats def test_search_performance(client, model, queries: list[str]) -> dict: """Test search performance on new collection.""" results = [] for query in queries: # Embed query embed_start = time.time() query_vector = model.encode(query).tolist() embed_time = (time.time() - embed_start) * 1000 # Search search_start = time.time() search_results = client.search( collection_name=TARGET_COLLECTION, query_vector=query_vector, limit=10, ) search_time = (time.time() - search_start) * 1000 total_time = embed_time + search_time results.append({ "query": query, "embed_ms": embed_time, "search_ms": search_time, "total_ms": total_time, "num_results": len(search_results), "top_result": search_results[0].payload.get("name") if search_results else None, }) logger.info( f"Query: '{query}' → " f"embed={embed_time:.0f}ms, search={search_time:.0f}ms, total={total_time:.0f}ms, " f"top='{results[-1]['top_result']}'" ) return results def swap_collections(client): """Swap source and target collection names.""" # Rename source to backup backup_name = f"{SOURCE_COLLECTION}_openai_backup" # Check if backup already exists collections = [c.name for c in client.get_collections().collections] if backup_name in collections: logger.error(f"Backup collection '{backup_name}' already exists!") return False logger.info(f"Renaming '{SOURCE_COLLECTION}' → '{backup_name}'") client.update_collection_aliases( change_aliases_operations=[ {"create_alias": {"alias_name": backup_name, "collection_name": SOURCE_COLLECTION}} ] ) # This is tricky - Qdrant doesn't have direct rename # We'll need to recreate with aliases or just use the new collection name logger.warning("NOTE: Qdrant doesn't support direct collection rename.") logger.warning(f"Update your code to use '{TARGET_COLLECTION}' instead of '{SOURCE_COLLECTION}'") logger.warning("Or update the retriever to use the new collection name.") return True def main(): parser = argparse.ArgumentParser(description="Migrate Qdrant to MiniLM embeddings") parser.add_argument("--dry-run", action="store_true", help="Test without creating") parser.add_argument("--migrate", action="store_true", help="Run full migration") parser.add_argument("--test-only", action="store_true", help="Test existing target collection") parser.add_argument("--limit", type=int, help="Limit number of documents to migrate") parser.add_argument("--swap", action="store_true", help="Swap collections (after verification)") args = parser.parse_args() client = get_qdrant_client() # Get source stats source_info = client.get_collection(SOURCE_COLLECTION) logger.info(f"Source collection '{SOURCE_COLLECTION}':") logger.info(f" Points: {source_info.points_count}") logger.info(f" Vector size: {source_info.config.params.vectors.size}") if args.dry_run: logger.info("\n=== DRY RUN ===") logger.info(f"Would migrate {source_info.points_count} documents") logger.info(f"Target: {TARGET_COLLECTION} with {MINILM_DIM}-dim vectors") # Test MiniLM loading and embedding model = load_minilm_model() test_text = "archives in Utrecht" start = time.time() embedding = model.encode(test_text) logger.info(f"Test embedding: {(time.time()-start)*1000:.0f}ms, dim={len(embedding)}") return if args.test_only: logger.info("\n=== TESTING TARGET COLLECTION ===") model = load_minilm_model() test_queries = [ "archives in Utrecht", "museums in Amsterdam", "library Belgium", "Japanese art museum", ] test_search_performance(client, model, test_queries) return if args.migrate: logger.info("\n=== STARTING MIGRATION ===") model = load_minilm_model() # Create target collection created = create_target_collection(client) if not created: response = input("Collection exists. Continue migration anyway? [y/N] ") if response.lower() != 'y': logger.info("Aborted.") return # Migrate documents start_time = time.time() stats = migrate_documents(client, model, limit=args.limit) total_time = time.time() - start_time logger.info("\n=== MIGRATION COMPLETE ===") logger.info(f"Total documents: {stats['total_indexed']}") logger.info(f"Total time: {total_time:.1f}s") logger.info(f"Embedding time: {stats['embedding_time_ms']/1000:.1f}s") logger.info(f"Indexing time: {stats['indexing_time_ms']/1000:.1f}s") logger.info(f"Avg per doc: {total_time*1000/stats['total_indexed']:.1f}ms") # Test performance logger.info("\n=== TESTING PERFORMANCE ===") test_queries = [ "archives in Utrecht", "museums in Amsterdam", "library Belgium", ] test_search_performance(client, model, test_queries) return if args.swap: logger.info("\n=== SWAPPING COLLECTIONS ===") swap_collections(client) return parser.print_help() if __name__ == "__main__": main()