glam/scripts/sync/qdrant_sync.py
2025-12-17 11:58:40 +01:00

476 lines
17 KiB
Python

#!/usr/bin/env python3
"""
Qdrant Sync Module - Sync custodian YAML files to Qdrant vector database.
This module syncs all custodian YAML files to Qdrant for semantic search
and RAG-enhanced SPARQL generation.
Usage:
python -m scripts.sync.qdrant_sync [--dry-run] [--limit N] [--host HOST] [--port PORT]
"""
import argparse
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
import yaml
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from scripts.sync import BaseSyncer, SyncResult, SyncStatus, DEFAULT_CUSTODIAN_DIR
# Configuration
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
QDRANT_URL = os.getenv("QDRANT_URL", "")
COLLECTION_NAME = "heritage_custodians"
BATCH_SIZE = 50
def extract_institution_text(data: dict[str, Any]) -> str:
"""Extract searchable text from institution data."""
parts = []
original = data.get("original_entry", {})
# Name
name = (
data.get("custodian_name", {}).get("claim_value") or
data.get("custodian_name", {}).get("emic_name") or
original.get("name") or
data.get("name", "")
)
if name:
parts.append(f"Name: {name}")
# Alternative names from Wikidata
wikidata = data.get("wikidata_enrichment", {})
labels = wikidata.get("wikidata_labels", {})
if labels:
alt_names = [v for k, v in labels.items() if v and v != name][:5]
if alt_names:
parts.append(f"Also known as: {', '.join(set(alt_names))}")
# Description
description = wikidata.get("wikidata_description_en", "")
if not description:
descriptions = wikidata.get("wikidata_descriptions", {})
description = descriptions.get("en", "")
if description:
parts.append(description)
# Institution type
inst_type = original.get("institution_type") or data.get("institution_type", "")
if inst_type:
parts.append(f"Type: {inst_type}")
# Location
locations = original.get("locations", []) or data.get("locations", [])
location = data.get("location", {})
def safe_str(val):
if val is None:
return None
if isinstance(val, str):
return val
if isinstance(val, dict):
return val.get("name") or val.get("label") or val.get("value") or str(val)
return str(val)
if locations:
for loc in locations[:1]:
if not isinstance(loc, dict):
continue
loc_parts = []
city = safe_str(loc.get("city"))
region = safe_str(loc.get("region"))
country = safe_str(loc.get("country"))
if city:
loc_parts.append(city)
if region:
loc_parts.append(region)
if country:
loc_parts.append(country)
if loc_parts:
parts.append(f"Location: {', '.join(loc_parts)}")
elif location and isinstance(location, dict):
loc_parts = []
city = safe_str(location.get("city"))
region = safe_str(location.get("region"))
country = safe_str(location.get("country"))
if city:
loc_parts.append(city)
if region:
loc_parts.append(region)
if country:
loc_parts.append(country)
if loc_parts:
parts.append(f"Location: {', '.join(loc_parts)}")
# GHCID
ghcid = data.get("ghcid", {}).get("ghcid_current", "")
if ghcid:
parts.append(f"GHCID: {ghcid}")
# Collections
collections = data.get("collections", [])
if collections and isinstance(collections, list):
for coll in collections[:3]:
if isinstance(coll, dict):
coll_name = coll.get("name", coll.get("collection_name", ""))
coll_desc = coll.get("description", "")
if coll_name:
parts.append(f"Collection: {coll_name}")
if coll_desc:
parts.append(coll_desc[:200])
# MoW inscriptions
mow = original.get("mow_inscriptions", [])
if mow and isinstance(mow, list):
for inscription in mow[:2]:
if isinstance(inscription, dict):
mow_name = inscription.get("name", "")
if mow_name:
parts.append(f"Memory of the World: {mow_name}")
return "\n".join(parts)
def extract_metadata(data: dict[str, Any], filepath: Path) -> dict[str, Any]:
"""Extract metadata for filtering from institution data."""
metadata = {
"filename": filepath.name,
}
original = data.get("original_entry", {})
# GHCID
ghcid = data.get("ghcid", {}).get("ghcid_current", "")
if ghcid:
metadata["ghcid"] = ghcid
# Name
name = (
data.get("custodian_name", {}).get("claim_value") or
data.get("custodian_name", {}).get("emic_name") or
original.get("name") or
data.get("name", "")
)
if name:
metadata["name"] = name
# Institution type
inst_type = original.get("institution_type") or data.get("institution_type", "")
if inst_type:
metadata["institution_type"] = inst_type
# Location
locations = original.get("locations", []) or data.get("locations", [])
location = data.get("location", {})
if locations:
loc = locations[0]
if loc.get("country"):
metadata["country"] = loc["country"]
if loc.get("city"):
metadata["city"] = loc["city"]
# Use region_code (ISO 3166-2) for filtering, fallback to region name
if loc.get("region_code"):
metadata["region"] = loc["region_code"] # e.g., "NH" not "Noord-Holland"
elif loc.get("region"):
metadata["region"] = loc["region"]
elif location:
if location.get("country"):
metadata["country"] = location["country"]
if location.get("city"):
metadata["city"] = location["city"]
# Use region_code (ISO 3166-2) for filtering, fallback to region name
if location.get("region_code"):
metadata["region"] = location["region_code"] # e.g., "NH" not "Noord-Holland"
elif location.get("region"):
metadata["region"] = location["region"]
# Country from GHCID
if not metadata.get("country") and ghcid:
parts = ghcid.split("-")
if parts:
metadata["country"] = parts[0]
# Wikidata ID
wikidata_id = (
original.get("wikidata_id") or
data.get("wikidata_enrichment", {}).get("wikidata_entity_id", "")
)
if wikidata_id:
metadata["wikidata_id"] = wikidata_id
return metadata
class QdrantSyncer(BaseSyncer):
"""Sync custodian YAML files to Qdrant vector database."""
database_name = "qdrant"
def __init__(
self,
host: str = QDRANT_HOST,
port: int = QDRANT_PORT,
url: str = QDRANT_URL,
collection: str = COLLECTION_NAME,
batch_size: int = BATCH_SIZE,
**kwargs
):
super().__init__(**kwargs)
self.host = host
self.port = port
self.url = url
self.collection = collection
self.batch_size = batch_size
self._retriever = None
self._client = None
def _get_client(self):
"""Lazy-load the Qdrant client (lightweight, for connection checks)."""
if self._client is None:
try:
from qdrant_client import QdrantClient
if self.url:
self._client = QdrantClient(url=self.url, timeout=10)
else:
self._client = QdrantClient(
host=self.host,
port=self.port,
timeout=10,
)
except ImportError as e:
self.logger.error(f"Cannot import qdrant_client: {e}")
raise
return self._client
def _get_retriever(self):
"""Lazy-load the retriever (heavyweight, for indexing with embeddings)."""
if self._retriever is None:
try:
sys.path.insert(0, str(PROJECT_ROOT / "src"))
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever
if self.url:
self._retriever = HeritageCustodianRetriever(url=self.url)
else:
self._retriever = HeritageCustodianRetriever(
host=self.host,
port=self.port,
)
except ImportError as e:
self.logger.error(f"Cannot import retriever: {e}")
raise
return self._retriever
def check_connection(self) -> bool:
"""Check if Qdrant is available (uses lightweight client)."""
try:
client = self._get_client()
# Simple health check - just list collections
collections = client.get_collections()
return True
except Exception as e:
self.logger.error(f"Qdrant connection failed: {e}")
return False
def get_status(self) -> dict:
"""Get Qdrant status (uses lightweight client)."""
try:
client = self._get_client()
collections = client.get_collections().collections
collection_names = [c.name for c in collections]
# Check if our collection exists
if self.collection in collection_names:
info = client.get_collection(self.collection)
return {
"status": info.status.value if info.status else "unknown",
"vectors_count": info.vectors_count or 0,
"points_count": info.points_count or 0,
"collection_exists": True,
}
else:
return {
"status": "ready",
"vectors_count": 0,
"points_count": 0,
"collection_exists": False,
"available_collections": collection_names,
}
except Exception as e:
return {"status": "unavailable", "error": str(e)}
def sync(self, limit: Optional[int] = None, dry_run: bool = False, recreate: bool = True) -> SyncResult:
"""Sync all YAML files to Qdrant."""
result = SyncResult(
database="qdrant",
status=SyncStatus.IN_PROGRESS,
start_time=datetime.now(timezone.utc),
)
# Check OpenAI API key (required for embeddings)
if not dry_run and not os.getenv("OPENAI_API_KEY"):
result.status = SyncStatus.FAILED
result.error_message = "OPENAI_API_KEY environment variable is required for embeddings"
result.end_time = datetime.now(timezone.utc)
return result
# Load YAML files
yaml_files = self._list_yaml_files()
if limit:
yaml_files = yaml_files[:limit]
self.progress.total_files = len(yaml_files)
self.progress.current_database = "qdrant"
self.logger.info(f"Processing {len(yaml_files)} YAML files...")
# Prepare documents
documents = []
for i, filepath in enumerate(yaml_files):
self.progress.processed_files = i + 1
self.progress.current_file = filepath.name
self._report_progress()
try:
with open(filepath, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not data:
continue
text = extract_institution_text(data)
if not text or len(text) < 20:
continue
metadata = extract_metadata(data, filepath)
documents.append({
"text": text,
"metadata": metadata,
})
result.records_succeeded += 1
except Exception as e:
self.logger.warning(f"Error processing {filepath.name}: {e}")
result.records_failed += 1
self.progress.errors.append(f"{filepath.name}: {str(e)}")
result.records_processed = len(yaml_files)
result.details["documents_prepared"] = len(documents)
if dry_run:
self.logger.info(f"[DRY RUN] Would index {len(documents)} documents to Qdrant")
result.status = SyncStatus.SUCCESS
result.details["dry_run"] = True
result.end_time = datetime.now(timezone.utc)
return result
# Check connection
try:
retriever = self._get_retriever()
except Exception as e:
result.status = SyncStatus.FAILED
result.error_message = str(e)
result.end_time = datetime.now(timezone.utc)
return result
if not self.check_connection():
result.status = SyncStatus.FAILED
result.error_message = f"Cannot connect to Qdrant at {self.url or f'{self.host}:{self.port}'}"
result.end_time = datetime.now(timezone.utc)
return result
# Recreate collection
if recreate:
self.logger.info(f"Deleting collection: {self.collection}")
retriever.delete_collection()
retriever.ensure_collection()
# Index documents
self.logger.info(f"Indexing {len(documents)} documents...")
try:
indexed = retriever.add_documents(documents, batch_size=self.batch_size)
result.details["indexed_count"] = indexed
info = retriever.get_collection_info()
result.details["final_vectors_count"] = info.get("vectors_count", 0)
result.status = SyncStatus.SUCCESS
except Exception as e:
self.logger.error(f"Indexing failed: {e}")
result.status = SyncStatus.FAILED
result.error_message = str(e)
result.end_time = datetime.now(timezone.utc)
return result
def main():
parser = argparse.ArgumentParser(description="Sync custodian YAML files to Qdrant")
parser.add_argument("--dry-run", action="store_true", help="Parse files but don't index")
parser.add_argument("--limit", type=int, help="Limit number of files to process")
parser.add_argument("--host", default=QDRANT_HOST, help="Qdrant server hostname")
parser.add_argument("--port", type=int, default=QDRANT_PORT, help="Qdrant REST API port")
parser.add_argument("--url", default=QDRANT_URL, help="Full Qdrant URL (overrides host/port)")
parser.add_argument("--collection", default=COLLECTION_NAME, help="Qdrant collection name")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Documents per batch")
parser.add_argument("--no-recreate", action="store_true", help="Don't recreate collection")
args = parser.parse_args()
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
syncer = QdrantSyncer(
host=args.host,
port=args.port,
url=args.url,
collection=args.collection,
batch_size=args.batch_size,
)
print("=" * 60)
print("Qdrant Sync")
print("=" * 60)
if not args.dry_run:
location = args.url or f"{args.host}:{args.port}"
print(f"Checking connection to {location}...")
try:
status = syncer.get_status()
print(f" Status: {status.get('status', 'unknown')}")
if status.get('vectors_count'):
print(f" Vectors: {status['vectors_count']:,}")
except Exception as e:
print(f" Connection check failed: {e}")
result = syncer.sync(limit=args.limit, dry_run=args.dry_run, recreate=not args.no_recreate)
print("\n" + "=" * 60)
print(f"Sync Result: {result.status.value.upper()}")
print(f" Processed: {result.records_processed}")
print(f" Succeeded: {result.records_succeeded}")
print(f" Failed: {result.records_failed}")
print(f" Documents: {result.details.get('documents_prepared', 0)}")
print(f" Duration: {result.duration_seconds:.2f}s")
if result.error_message:
print(f" Error: {result.error_message}")
print("=" * 60)
if __name__ == "__main__":
main()