476 lines
17 KiB
Python
476 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Qdrant Sync Module - Sync custodian YAML files to Qdrant vector database.
|
|
|
|
This module syncs all custodian YAML files to Qdrant for semantic search
|
|
and RAG-enhanced SPARQL generation.
|
|
|
|
Usage:
|
|
python -m scripts.sync.qdrant_sync [--dry-run] [--limit N] [--host HOST] [--port PORT]
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
import yaml
|
|
|
|
# Add project root to path
|
|
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from scripts.sync import BaseSyncer, SyncResult, SyncStatus, DEFAULT_CUSTODIAN_DIR
|
|
|
|
# Configuration
|
|
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
|
|
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
|
QDRANT_URL = os.getenv("QDRANT_URL", "")
|
|
COLLECTION_NAME = "heritage_custodians"
|
|
BATCH_SIZE = 50
|
|
|
|
|
|
def extract_institution_text(data: dict[str, Any]) -> str:
|
|
"""Extract searchable text from institution data."""
|
|
parts = []
|
|
|
|
original = data.get("original_entry", {})
|
|
|
|
# Name
|
|
name = (
|
|
data.get("custodian_name", {}).get("claim_value") or
|
|
data.get("custodian_name", {}).get("emic_name") or
|
|
original.get("name") or
|
|
data.get("name", "")
|
|
)
|
|
if name:
|
|
parts.append(f"Name: {name}")
|
|
|
|
# Alternative names from Wikidata
|
|
wikidata = data.get("wikidata_enrichment", {})
|
|
labels = wikidata.get("wikidata_labels", {})
|
|
if labels:
|
|
alt_names = [v for k, v in labels.items() if v and v != name][:5]
|
|
if alt_names:
|
|
parts.append(f"Also known as: {', '.join(set(alt_names))}")
|
|
|
|
# Description
|
|
description = wikidata.get("wikidata_description_en", "")
|
|
if not description:
|
|
descriptions = wikidata.get("wikidata_descriptions", {})
|
|
description = descriptions.get("en", "")
|
|
if description:
|
|
parts.append(description)
|
|
|
|
# Institution type
|
|
inst_type = original.get("institution_type") or data.get("institution_type", "")
|
|
if inst_type:
|
|
parts.append(f"Type: {inst_type}")
|
|
|
|
# Location
|
|
locations = original.get("locations", []) or data.get("locations", [])
|
|
location = data.get("location", {})
|
|
|
|
def safe_str(val):
|
|
if val is None:
|
|
return None
|
|
if isinstance(val, str):
|
|
return val
|
|
if isinstance(val, dict):
|
|
return val.get("name") or val.get("label") or val.get("value") or str(val)
|
|
return str(val)
|
|
|
|
if locations:
|
|
for loc in locations[:1]:
|
|
if not isinstance(loc, dict):
|
|
continue
|
|
loc_parts = []
|
|
city = safe_str(loc.get("city"))
|
|
region = safe_str(loc.get("region"))
|
|
country = safe_str(loc.get("country"))
|
|
if city:
|
|
loc_parts.append(city)
|
|
if region:
|
|
loc_parts.append(region)
|
|
if country:
|
|
loc_parts.append(country)
|
|
if loc_parts:
|
|
parts.append(f"Location: {', '.join(loc_parts)}")
|
|
elif location and isinstance(location, dict):
|
|
loc_parts = []
|
|
city = safe_str(location.get("city"))
|
|
region = safe_str(location.get("region"))
|
|
country = safe_str(location.get("country"))
|
|
if city:
|
|
loc_parts.append(city)
|
|
if region:
|
|
loc_parts.append(region)
|
|
if country:
|
|
loc_parts.append(country)
|
|
if loc_parts:
|
|
parts.append(f"Location: {', '.join(loc_parts)}")
|
|
|
|
# GHCID
|
|
ghcid = data.get("ghcid", {}).get("ghcid_current", "")
|
|
if ghcid:
|
|
parts.append(f"GHCID: {ghcid}")
|
|
|
|
# Collections
|
|
collections = data.get("collections", [])
|
|
if collections and isinstance(collections, list):
|
|
for coll in collections[:3]:
|
|
if isinstance(coll, dict):
|
|
coll_name = coll.get("name", coll.get("collection_name", ""))
|
|
coll_desc = coll.get("description", "")
|
|
if coll_name:
|
|
parts.append(f"Collection: {coll_name}")
|
|
if coll_desc:
|
|
parts.append(coll_desc[:200])
|
|
|
|
# MoW inscriptions
|
|
mow = original.get("mow_inscriptions", [])
|
|
if mow and isinstance(mow, list):
|
|
for inscription in mow[:2]:
|
|
if isinstance(inscription, dict):
|
|
mow_name = inscription.get("name", "")
|
|
if mow_name:
|
|
parts.append(f"Memory of the World: {mow_name}")
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
def extract_metadata(data: dict[str, Any], filepath: Path) -> dict[str, Any]:
|
|
"""Extract metadata for filtering from institution data."""
|
|
metadata = {
|
|
"filename": filepath.name,
|
|
}
|
|
|
|
original = data.get("original_entry", {})
|
|
|
|
# GHCID
|
|
ghcid = data.get("ghcid", {}).get("ghcid_current", "")
|
|
if ghcid:
|
|
metadata["ghcid"] = ghcid
|
|
|
|
# Name
|
|
name = (
|
|
data.get("custodian_name", {}).get("claim_value") or
|
|
data.get("custodian_name", {}).get("emic_name") or
|
|
original.get("name") or
|
|
data.get("name", "")
|
|
)
|
|
if name:
|
|
metadata["name"] = name
|
|
|
|
# Institution type
|
|
inst_type = original.get("institution_type") or data.get("institution_type", "")
|
|
if inst_type:
|
|
metadata["institution_type"] = inst_type
|
|
|
|
# Location
|
|
locations = original.get("locations", []) or data.get("locations", [])
|
|
location = data.get("location", {})
|
|
|
|
if locations:
|
|
loc = locations[0]
|
|
if loc.get("country"):
|
|
metadata["country"] = loc["country"]
|
|
if loc.get("city"):
|
|
metadata["city"] = loc["city"]
|
|
# Use region_code (ISO 3166-2) for filtering, fallback to region name
|
|
if loc.get("region_code"):
|
|
metadata["region"] = loc["region_code"] # e.g., "NH" not "Noord-Holland"
|
|
elif loc.get("region"):
|
|
metadata["region"] = loc["region"]
|
|
elif location:
|
|
if location.get("country"):
|
|
metadata["country"] = location["country"]
|
|
if location.get("city"):
|
|
metadata["city"] = location["city"]
|
|
# Use region_code (ISO 3166-2) for filtering, fallback to region name
|
|
if location.get("region_code"):
|
|
metadata["region"] = location["region_code"] # e.g., "NH" not "Noord-Holland"
|
|
elif location.get("region"):
|
|
metadata["region"] = location["region"]
|
|
|
|
# Country from GHCID
|
|
if not metadata.get("country") and ghcid:
|
|
parts = ghcid.split("-")
|
|
if parts:
|
|
metadata["country"] = parts[0]
|
|
|
|
# Wikidata ID
|
|
wikidata_id = (
|
|
original.get("wikidata_id") or
|
|
data.get("wikidata_enrichment", {}).get("wikidata_entity_id", "")
|
|
)
|
|
if wikidata_id:
|
|
metadata["wikidata_id"] = wikidata_id
|
|
|
|
return metadata
|
|
|
|
|
|
class QdrantSyncer(BaseSyncer):
|
|
"""Sync custodian YAML files to Qdrant vector database."""
|
|
|
|
database_name = "qdrant"
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = QDRANT_HOST,
|
|
port: int = QDRANT_PORT,
|
|
url: str = QDRANT_URL,
|
|
collection: str = COLLECTION_NAME,
|
|
batch_size: int = BATCH_SIZE,
|
|
**kwargs
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.host = host
|
|
self.port = port
|
|
self.url = url
|
|
self.collection = collection
|
|
self.batch_size = batch_size
|
|
self._retriever = None
|
|
self._client = None
|
|
|
|
def _get_client(self):
|
|
"""Lazy-load the Qdrant client (lightweight, for connection checks)."""
|
|
if self._client is None:
|
|
try:
|
|
from qdrant_client import QdrantClient
|
|
|
|
if self.url:
|
|
self._client = QdrantClient(url=self.url, timeout=10)
|
|
else:
|
|
self._client = QdrantClient(
|
|
host=self.host,
|
|
port=self.port,
|
|
timeout=10,
|
|
)
|
|
except ImportError as e:
|
|
self.logger.error(f"Cannot import qdrant_client: {e}")
|
|
raise
|
|
return self._client
|
|
|
|
def _get_retriever(self):
|
|
"""Lazy-load the retriever (heavyweight, for indexing with embeddings)."""
|
|
if self._retriever is None:
|
|
try:
|
|
sys.path.insert(0, str(PROJECT_ROOT / "src"))
|
|
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever
|
|
|
|
if self.url:
|
|
self._retriever = HeritageCustodianRetriever(url=self.url)
|
|
else:
|
|
self._retriever = HeritageCustodianRetriever(
|
|
host=self.host,
|
|
port=self.port,
|
|
)
|
|
except ImportError as e:
|
|
self.logger.error(f"Cannot import retriever: {e}")
|
|
raise
|
|
return self._retriever
|
|
|
|
def check_connection(self) -> bool:
|
|
"""Check if Qdrant is available (uses lightweight client)."""
|
|
try:
|
|
client = self._get_client()
|
|
# Simple health check - just list collections
|
|
collections = client.get_collections()
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f"Qdrant connection failed: {e}")
|
|
return False
|
|
|
|
def get_status(self) -> dict:
|
|
"""Get Qdrant status (uses lightweight client)."""
|
|
try:
|
|
client = self._get_client()
|
|
collections = client.get_collections().collections
|
|
collection_names = [c.name for c in collections]
|
|
|
|
# Check if our collection exists
|
|
if self.collection in collection_names:
|
|
info = client.get_collection(self.collection)
|
|
return {
|
|
"status": info.status.value if info.status else "unknown",
|
|
"vectors_count": info.vectors_count or 0,
|
|
"points_count": info.points_count or 0,
|
|
"collection_exists": True,
|
|
}
|
|
else:
|
|
return {
|
|
"status": "ready",
|
|
"vectors_count": 0,
|
|
"points_count": 0,
|
|
"collection_exists": False,
|
|
"available_collections": collection_names,
|
|
}
|
|
except Exception as e:
|
|
return {"status": "unavailable", "error": str(e)}
|
|
|
|
def sync(self, limit: Optional[int] = None, dry_run: bool = False, recreate: bool = True) -> SyncResult:
|
|
"""Sync all YAML files to Qdrant."""
|
|
result = SyncResult(
|
|
database="qdrant",
|
|
status=SyncStatus.IN_PROGRESS,
|
|
start_time=datetime.now(timezone.utc),
|
|
)
|
|
|
|
# Check OpenAI API key (required for embeddings)
|
|
if not dry_run and not os.getenv("OPENAI_API_KEY"):
|
|
result.status = SyncStatus.FAILED
|
|
result.error_message = "OPENAI_API_KEY environment variable is required for embeddings"
|
|
result.end_time = datetime.now(timezone.utc)
|
|
return result
|
|
|
|
# Load YAML files
|
|
yaml_files = self._list_yaml_files()
|
|
if limit:
|
|
yaml_files = yaml_files[:limit]
|
|
|
|
self.progress.total_files = len(yaml_files)
|
|
self.progress.current_database = "qdrant"
|
|
|
|
self.logger.info(f"Processing {len(yaml_files)} YAML files...")
|
|
|
|
# Prepare documents
|
|
documents = []
|
|
for i, filepath in enumerate(yaml_files):
|
|
self.progress.processed_files = i + 1
|
|
self.progress.current_file = filepath.name
|
|
self._report_progress()
|
|
|
|
try:
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
if not data:
|
|
continue
|
|
|
|
text = extract_institution_text(data)
|
|
if not text or len(text) < 20:
|
|
continue
|
|
|
|
metadata = extract_metadata(data, filepath)
|
|
|
|
documents.append({
|
|
"text": text,
|
|
"metadata": metadata,
|
|
})
|
|
result.records_succeeded += 1
|
|
except Exception as e:
|
|
self.logger.warning(f"Error processing {filepath.name}: {e}")
|
|
result.records_failed += 1
|
|
self.progress.errors.append(f"{filepath.name}: {str(e)}")
|
|
|
|
result.records_processed = len(yaml_files)
|
|
result.details["documents_prepared"] = len(documents)
|
|
|
|
if dry_run:
|
|
self.logger.info(f"[DRY RUN] Would index {len(documents)} documents to Qdrant")
|
|
result.status = SyncStatus.SUCCESS
|
|
result.details["dry_run"] = True
|
|
result.end_time = datetime.now(timezone.utc)
|
|
return result
|
|
|
|
# Check connection
|
|
try:
|
|
retriever = self._get_retriever()
|
|
except Exception as e:
|
|
result.status = SyncStatus.FAILED
|
|
result.error_message = str(e)
|
|
result.end_time = datetime.now(timezone.utc)
|
|
return result
|
|
|
|
if not self.check_connection():
|
|
result.status = SyncStatus.FAILED
|
|
result.error_message = f"Cannot connect to Qdrant at {self.url or f'{self.host}:{self.port}'}"
|
|
result.end_time = datetime.now(timezone.utc)
|
|
return result
|
|
|
|
# Recreate collection
|
|
if recreate:
|
|
self.logger.info(f"Deleting collection: {self.collection}")
|
|
retriever.delete_collection()
|
|
|
|
retriever.ensure_collection()
|
|
|
|
# Index documents
|
|
self.logger.info(f"Indexing {len(documents)} documents...")
|
|
try:
|
|
indexed = retriever.add_documents(documents, batch_size=self.batch_size)
|
|
result.details["indexed_count"] = indexed
|
|
|
|
info = retriever.get_collection_info()
|
|
result.details["final_vectors_count"] = info.get("vectors_count", 0)
|
|
|
|
result.status = SyncStatus.SUCCESS
|
|
except Exception as e:
|
|
self.logger.error(f"Indexing failed: {e}")
|
|
result.status = SyncStatus.FAILED
|
|
result.error_message = str(e)
|
|
|
|
result.end_time = datetime.now(timezone.utc)
|
|
return result
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Sync custodian YAML files to Qdrant")
|
|
parser.add_argument("--dry-run", action="store_true", help="Parse files but don't index")
|
|
parser.add_argument("--limit", type=int, help="Limit number of files to process")
|
|
parser.add_argument("--host", default=QDRANT_HOST, help="Qdrant server hostname")
|
|
parser.add_argument("--port", type=int, default=QDRANT_PORT, help="Qdrant REST API port")
|
|
parser.add_argument("--url", default=QDRANT_URL, help="Full Qdrant URL (overrides host/port)")
|
|
parser.add_argument("--collection", default=COLLECTION_NAME, help="Qdrant collection name")
|
|
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Documents per batch")
|
|
parser.add_argument("--no-recreate", action="store_true", help="Don't recreate collection")
|
|
args = parser.parse_args()
|
|
|
|
import logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
)
|
|
|
|
syncer = QdrantSyncer(
|
|
host=args.host,
|
|
port=args.port,
|
|
url=args.url,
|
|
collection=args.collection,
|
|
batch_size=args.batch_size,
|
|
)
|
|
|
|
print("=" * 60)
|
|
print("Qdrant Sync")
|
|
print("=" * 60)
|
|
|
|
if not args.dry_run:
|
|
location = args.url or f"{args.host}:{args.port}"
|
|
print(f"Checking connection to {location}...")
|
|
try:
|
|
status = syncer.get_status()
|
|
print(f" Status: {status.get('status', 'unknown')}")
|
|
if status.get('vectors_count'):
|
|
print(f" Vectors: {status['vectors_count']:,}")
|
|
except Exception as e:
|
|
print(f" Connection check failed: {e}")
|
|
|
|
result = syncer.sync(limit=args.limit, dry_run=args.dry_run, recreate=not args.no_recreate)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"Sync Result: {result.status.value.upper()}")
|
|
print(f" Processed: {result.records_processed}")
|
|
print(f" Succeeded: {result.records_succeeded}")
|
|
print(f" Failed: {result.records_failed}")
|
|
print(f" Documents: {result.details.get('documents_prepared', 0)}")
|
|
print(f" Duration: {result.duration_seconds:.2f}s")
|
|
if result.error_message:
|
|
print(f" Error: {result.error_message}")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|