332 lines
11 KiB
Python
332 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Migrate Qdrant collection from OpenAI embeddings to local MiniLM embeddings.
|
|
|
|
This script:
|
|
1. Creates a new collection with 384-dim vectors (MiniLM)
|
|
2. Scrolls through all existing documents
|
|
3. Re-embeds with MiniLM and inserts into new collection
|
|
4. Tests the new collection performance
|
|
5. Optionally swaps collection names
|
|
|
|
Expected performance improvement: 300-1300ms → 22-25ms per embedding
|
|
|
|
Usage:
|
|
python migrate_qdrant_to_minilm.py --dry-run # Test without creating
|
|
python migrate_qdrant_to_minilm.py --migrate # Run migration
|
|
python migrate_qdrant_to_minilm.py --swap # Swap collections after verification
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import time
|
|
from typing import Any
|
|
import sys
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration
|
|
QDRANT_HOST = "localhost"
|
|
QDRANT_PORT = 6333
|
|
SOURCE_COLLECTION = "heritage_custodians"
|
|
TARGET_COLLECTION = "heritage_custodians_minilm"
|
|
BATCH_SIZE = 100 # Documents per batch
|
|
MINILM_DIM = 384
|
|
|
|
|
|
def get_qdrant_client():
|
|
"""Get Qdrant client."""
|
|
from qdrant_client import QdrantClient
|
|
return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
|
|
|
|
|
def load_minilm_model():
|
|
"""Load MiniLM sentence transformer model."""
|
|
from sentence_transformers import SentenceTransformer
|
|
logger.info("Loading MiniLM model...")
|
|
start = time.time()
|
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
logger.info(f"Model loaded in {time.time() - start:.1f}s")
|
|
return model
|
|
|
|
|
|
def create_target_collection(client) -> bool:
|
|
"""Create target collection with MiniLM vector configuration."""
|
|
from qdrant_client.http.models import Distance, VectorParams
|
|
|
|
# Check if already exists
|
|
collections = client.get_collections().collections
|
|
existing = [c.name for c in collections]
|
|
|
|
if TARGET_COLLECTION in existing:
|
|
logger.warning(f"Collection '{TARGET_COLLECTION}' already exists!")
|
|
info = client.get_collection(TARGET_COLLECTION)
|
|
logger.info(f" Points: {info.points_count}")
|
|
return False
|
|
|
|
# Create with MiniLM dimensions
|
|
client.create_collection(
|
|
collection_name=TARGET_COLLECTION,
|
|
vectors_config=VectorParams(
|
|
size=MINILM_DIM,
|
|
distance=Distance.COSINE,
|
|
),
|
|
)
|
|
logger.info(f"Created collection '{TARGET_COLLECTION}' with {MINILM_DIM}-dim vectors")
|
|
return True
|
|
|
|
|
|
def migrate_documents(client, model, limit: int | None = None) -> dict:
|
|
"""Migrate documents from source to target collection."""
|
|
from qdrant_client.http.models import PointStruct
|
|
|
|
stats = {
|
|
"total_scrolled": 0,
|
|
"total_indexed": 0,
|
|
"batches": 0,
|
|
"embedding_time_ms": 0,
|
|
"indexing_time_ms": 0,
|
|
}
|
|
|
|
# Scroll through source collection
|
|
offset = None
|
|
batch_num = 0
|
|
|
|
while True:
|
|
# Get batch of documents
|
|
result = client.scroll(
|
|
collection_name=SOURCE_COLLECTION,
|
|
limit=BATCH_SIZE,
|
|
offset=offset,
|
|
with_payload=True,
|
|
with_vectors=False, # Don't need old vectors
|
|
)
|
|
|
|
points = result[0]
|
|
offset = result[1]
|
|
|
|
if not points:
|
|
break
|
|
|
|
batch_num += 1
|
|
stats["batches"] = batch_num
|
|
stats["total_scrolled"] += len(points)
|
|
|
|
# Extract texts for embedding
|
|
texts = []
|
|
valid_points = []
|
|
for point in points:
|
|
text = point.payload.get("text", "")
|
|
if text:
|
|
texts.append(text)
|
|
valid_points.append(point)
|
|
|
|
if not texts:
|
|
logger.warning(f"Batch {batch_num}: No valid texts found")
|
|
continue
|
|
|
|
# Generate embeddings with MiniLM
|
|
embed_start = time.time()
|
|
embeddings = model.encode(texts, show_progress_bar=False)
|
|
embed_time = (time.time() - embed_start) * 1000
|
|
stats["embedding_time_ms"] += embed_time
|
|
|
|
# Create points for target collection
|
|
new_points = [
|
|
PointStruct(
|
|
id=str(point.id),
|
|
vector=embedding.tolist(),
|
|
payload=point.payload,
|
|
)
|
|
for point, embedding in zip(valid_points, embeddings)
|
|
]
|
|
|
|
# Index into target collection
|
|
index_start = time.time()
|
|
client.upsert(
|
|
collection_name=TARGET_COLLECTION,
|
|
points=new_points,
|
|
)
|
|
index_time = (time.time() - index_start) * 1000
|
|
stats["indexing_time_ms"] += index_time
|
|
|
|
stats["total_indexed"] += len(new_points)
|
|
|
|
# Progress log
|
|
avg_embed = stats["embedding_time_ms"] / stats["batches"]
|
|
logger.info(
|
|
f"Batch {batch_num}: {len(points)} docs, "
|
|
f"embed={embed_time:.0f}ms ({embed_time/len(texts):.1f}ms/doc), "
|
|
f"index={index_time:.0f}ms, "
|
|
f"total={stats['total_indexed']}"
|
|
)
|
|
|
|
# Check limit
|
|
if limit and stats["total_scrolled"] >= limit:
|
|
logger.info(f"Reached limit of {limit} documents")
|
|
break
|
|
|
|
# No more documents
|
|
if offset is None:
|
|
break
|
|
|
|
return stats
|
|
|
|
|
|
def test_search_performance(client, model, queries: list[str]) -> dict:
|
|
"""Test search performance on new collection."""
|
|
results = []
|
|
|
|
for query in queries:
|
|
# Embed query
|
|
embed_start = time.time()
|
|
query_vector = model.encode(query).tolist()
|
|
embed_time = (time.time() - embed_start) * 1000
|
|
|
|
# Search
|
|
search_start = time.time()
|
|
search_results = client.search(
|
|
collection_name=TARGET_COLLECTION,
|
|
query_vector=query_vector,
|
|
limit=10,
|
|
)
|
|
search_time = (time.time() - search_start) * 1000
|
|
|
|
total_time = embed_time + search_time
|
|
|
|
results.append({
|
|
"query": query,
|
|
"embed_ms": embed_time,
|
|
"search_ms": search_time,
|
|
"total_ms": total_time,
|
|
"num_results": len(search_results),
|
|
"top_result": search_results[0].payload.get("name") if search_results else None,
|
|
})
|
|
|
|
logger.info(
|
|
f"Query: '{query}' → "
|
|
f"embed={embed_time:.0f}ms, search={search_time:.0f}ms, total={total_time:.0f}ms, "
|
|
f"top='{results[-1]['top_result']}'"
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
def swap_collections(client):
|
|
"""Swap source and target collection names."""
|
|
# Rename source to backup
|
|
backup_name = f"{SOURCE_COLLECTION}_openai_backup"
|
|
|
|
# Check if backup already exists
|
|
collections = [c.name for c in client.get_collections().collections]
|
|
if backup_name in collections:
|
|
logger.error(f"Backup collection '{backup_name}' already exists!")
|
|
return False
|
|
|
|
logger.info(f"Renaming '{SOURCE_COLLECTION}' → '{backup_name}'")
|
|
client.update_collection_aliases(
|
|
change_aliases_operations=[
|
|
{"create_alias": {"alias_name": backup_name, "collection_name": SOURCE_COLLECTION}}
|
|
]
|
|
)
|
|
|
|
# This is tricky - Qdrant doesn't have direct rename
|
|
# We'll need to recreate with aliases or just use the new collection name
|
|
logger.warning("NOTE: Qdrant doesn't support direct collection rename.")
|
|
logger.warning(f"Update your code to use '{TARGET_COLLECTION}' instead of '{SOURCE_COLLECTION}'")
|
|
logger.warning("Or update the retriever to use the new collection name.")
|
|
|
|
return True
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Migrate Qdrant to MiniLM embeddings")
|
|
parser.add_argument("--dry-run", action="store_true", help="Test without creating")
|
|
parser.add_argument("--migrate", action="store_true", help="Run full migration")
|
|
parser.add_argument("--test-only", action="store_true", help="Test existing target collection")
|
|
parser.add_argument("--limit", type=int, help="Limit number of documents to migrate")
|
|
parser.add_argument("--swap", action="store_true", help="Swap collections (after verification)")
|
|
args = parser.parse_args()
|
|
|
|
client = get_qdrant_client()
|
|
|
|
# Get source stats
|
|
source_info = client.get_collection(SOURCE_COLLECTION)
|
|
logger.info(f"Source collection '{SOURCE_COLLECTION}':")
|
|
logger.info(f" Points: {source_info.points_count}")
|
|
logger.info(f" Vector size: {source_info.config.params.vectors.size}")
|
|
|
|
if args.dry_run:
|
|
logger.info("\n=== DRY RUN ===")
|
|
logger.info(f"Would migrate {source_info.points_count} documents")
|
|
logger.info(f"Target: {TARGET_COLLECTION} with {MINILM_DIM}-dim vectors")
|
|
|
|
# Test MiniLM loading and embedding
|
|
model = load_minilm_model()
|
|
test_text = "archives in Utrecht"
|
|
start = time.time()
|
|
embedding = model.encode(test_text)
|
|
logger.info(f"Test embedding: {(time.time()-start)*1000:.0f}ms, dim={len(embedding)}")
|
|
return
|
|
|
|
if args.test_only:
|
|
logger.info("\n=== TESTING TARGET COLLECTION ===")
|
|
model = load_minilm_model()
|
|
|
|
test_queries = [
|
|
"archives in Utrecht",
|
|
"museums in Amsterdam",
|
|
"library Belgium",
|
|
"Japanese art museum",
|
|
]
|
|
test_search_performance(client, model, test_queries)
|
|
return
|
|
|
|
if args.migrate:
|
|
logger.info("\n=== STARTING MIGRATION ===")
|
|
model = load_minilm_model()
|
|
|
|
# Create target collection
|
|
created = create_target_collection(client)
|
|
if not created:
|
|
response = input("Collection exists. Continue migration anyway? [y/N] ")
|
|
if response.lower() != 'y':
|
|
logger.info("Aborted.")
|
|
return
|
|
|
|
# Migrate documents
|
|
start_time = time.time()
|
|
stats = migrate_documents(client, model, limit=args.limit)
|
|
total_time = time.time() - start_time
|
|
|
|
logger.info("\n=== MIGRATION COMPLETE ===")
|
|
logger.info(f"Total documents: {stats['total_indexed']}")
|
|
logger.info(f"Total time: {total_time:.1f}s")
|
|
logger.info(f"Embedding time: {stats['embedding_time_ms']/1000:.1f}s")
|
|
logger.info(f"Indexing time: {stats['indexing_time_ms']/1000:.1f}s")
|
|
logger.info(f"Avg per doc: {total_time*1000/stats['total_indexed']:.1f}ms")
|
|
|
|
# Test performance
|
|
logger.info("\n=== TESTING PERFORMANCE ===")
|
|
test_queries = [
|
|
"archives in Utrecht",
|
|
"museums in Amsterdam",
|
|
"library Belgium",
|
|
]
|
|
test_search_performance(client, model, test_queries)
|
|
return
|
|
|
|
if args.swap:
|
|
logger.info("\n=== SWAPPING COLLECTIONS ===")
|
|
swap_collections(client)
|
|
return
|
|
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|