glam/scripts/index_institutions_qdrant.py
2025-12-14 17:09:55 +01:00

464 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Index Heritage Institutions in Qdrant
This script reads heritage custodian YAML files and indexes them in Qdrant
for semantic search and RAG-enhanced SPARQL generation.
Usage:
python scripts/index_institutions_qdrant.py [--data-dir DATA_DIR] [--host HOST] [--port PORT]
Examples:
# Index all institutions from default data directory
python scripts/index_institutions_qdrant.py
# Index from specific directory
python scripts/index_institutions_qdrant.py --data-dir data/custodian/
# Connect to remote Qdrant
python scripts/index_institutions_qdrant.py --host 91.98.224.44 --port 6333
"""
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
import yaml
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT / "src"))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)
def load_yaml_file(filepath: Path) -> dict[str, Any] | None:
"""Load a YAML file and return its contents."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
except Exception as e:
logger.warning(f"Failed to load {filepath}: {e}")
return None
def extract_institution_text(data: dict[str, Any]) -> str:
"""Extract searchable text from institution data."""
parts = []
# Get original_entry if present (our actual file format)
original = data.get("original_entry", {})
# Name (primary identifier) - try multiple locations
name = (
data.get("custodian_name", {}).get("claim_value") or # New format
data.get("custodian_name", {}).get("emic_name") or
original.get("name") or
data.get("name", "")
)
if name:
parts.append(f"Name: {name}")
# Alternative names from Wikidata enrichment
wikidata = data.get("wikidata_enrichment", {})
labels = wikidata.get("wikidata_labels", {})
if labels:
alt_names = [v for k, v in labels.items() if v and v != name][:5]
if alt_names:
parts.append(f"Also known as: {', '.join(set(alt_names))}")
# Description from Wikidata
description = wikidata.get("wikidata_description_en", "")
if not description:
descriptions = wikidata.get("wikidata_descriptions", {})
description = descriptions.get("en", "")
if description:
parts.append(description)
# Institution type
inst_type = original.get("institution_type") or data.get("institution_type", "")
if inst_type:
parts.append(f"Type: {inst_type}")
# Location - try multiple locations
locations = original.get("locations", []) or data.get("locations", [])
location = data.get("location", {})
def safe_str(val):
"""Safely convert value to string, handling dicts and None."""
if val is None:
return None
if isinstance(val, str):
return val
if isinstance(val, dict):
# Try common keys for nested location dicts
return val.get("name") or val.get("label") or val.get("value") or str(val)
return str(val)
if locations:
for loc in locations[:1]: # Just first location
if not isinstance(loc, dict):
continue
loc_parts = []
city = safe_str(loc.get("city"))
region = safe_str(loc.get("region"))
country = safe_str(loc.get("country"))
if city:
loc_parts.append(city)
if region:
loc_parts.append(region)
if country:
loc_parts.append(country)
if loc_parts:
parts.append(f"Location: {', '.join(loc_parts)}")
elif location and isinstance(location, dict):
loc_parts = []
city = safe_str(location.get("city"))
region = safe_str(location.get("region"))
country = safe_str(location.get("country"))
if city:
loc_parts.append(city)
if region:
loc_parts.append(region)
if country:
loc_parts.append(country)
if loc_parts:
parts.append(f"Location: {', '.join(loc_parts)}")
# GHCID for reference
ghcid = data.get("ghcid", {}).get("ghcid_current", "")
if ghcid:
parts.append(f"GHCID: {ghcid}")
# Collections
collections = data.get("collections", [])
if collections and isinstance(collections, list):
for coll in collections[:3]: # Limit to first 3
if isinstance(coll, dict):
coll_name = coll.get("name", coll.get("collection_name", ""))
coll_desc = coll.get("description", "")
if coll_name:
parts.append(f"Collection: {coll_name}")
if coll_desc:
parts.append(coll_desc[:200]) # Truncate long descriptions
# MoW inscriptions
mow = original.get("mow_inscriptions", [])
if mow and isinstance(mow, list):
for inscription in mow[:2]:
if isinstance(inscription, dict):
mow_name = inscription.get("name", "")
if mow_name:
parts.append(f"Memory of the World: {mow_name}")
# Combine all parts
return "\n".join(parts)
def extract_metadata(data: dict[str, Any], filepath: Path) -> dict[str, Any]:
"""Extract metadata for filtering from institution data."""
metadata = {
"filename": filepath.name,
}
# Get original_entry if present (our actual file format)
original = data.get("original_entry", {})
# GHCID
ghcid = data.get("ghcid", {}).get("ghcid_current", "")
if ghcid:
metadata["ghcid"] = ghcid
# Name - try multiple locations
name = (
data.get("custodian_name", {}).get("claim_value") or
data.get("custodian_name", {}).get("emic_name") or
original.get("name") or
data.get("name", "")
)
if name:
metadata["name"] = name
# Institution type
inst_type = original.get("institution_type") or data.get("institution_type", "")
if inst_type:
metadata["institution_type"] = inst_type
# Location - try multiple locations
locations = original.get("locations", []) or data.get("locations", [])
location = data.get("location", {})
if locations:
loc = locations[0] # First location
if loc.get("country"):
metadata["country"] = loc["country"]
if loc.get("city"):
metadata["city"] = loc["city"]
if loc.get("region"):
metadata["region"] = loc["region"]
elif location:
if location.get("country"):
metadata["country"] = location["country"]
if location.get("city"):
metadata["city"] = location["city"]
if location.get("region"):
metadata["region"] = location["region"]
# Also extract country from GHCID if not found elsewhere
if not metadata.get("country") and ghcid:
parts = ghcid.split("-")
if parts:
metadata["country"] = parts[0]
# Extract coordinates (latitude/longitude) from multiple sources
# Priority: 1. Top-level lat/lon, 2. Google Maps, 3. Wikidata, 4. locations array
lat = None
lon = None
# 1. Top-level latitude/longitude (most reliable - merged from enrichment)
lat = data.get("latitude")
lon = data.get("longitude")
# 2. Google Maps enrichment
if lat is None or lon is None:
google_maps = data.get("google_maps_enrichment", {})
coords = google_maps.get("coordinates", {})
if coords:
lat = lat or coords.get("latitude")
lon = lon or coords.get("longitude")
# 3. Wikidata enrichment (headquarter_location.coordinates or wikidata_coordinates)
if lat is None or lon is None:
wikidata = data.get("wikidata_enrichment", {})
# Try wikidata_coordinates first (YAML anchor reference)
wikidata_coords = wikidata.get("wikidata_coordinates", {})
if wikidata_coords:
lat = lat or wikidata_coords.get("latitude")
lon = lon or wikidata_coords.get("longitude")
# Then try headquarter_location.coordinates
hq_location = wikidata.get("headquarter_location", {})
hq_coords = hq_location.get("coordinates", {})
if hq_coords:
lat = lat or hq_coords.get("latitude")
lon = lon or hq_coords.get("longitude")
# 4. Locations array
if lat is None or lon is None:
if locations and isinstance(locations[0], dict):
loc = locations[0]
lat = lat or loc.get("latitude") or loc.get("lat")
lon = lon or loc.get("longitude") or loc.get("lon")
# Store coordinates if found
if lat is not None:
try:
metadata["latitude"] = float(lat)
except (ValueError, TypeError):
pass
if lon is not None:
try:
metadata["longitude"] = float(lon)
except (ValueError, TypeError):
pass
# Wikidata ID
wikidata_id = (
original.get("wikidata_id") or
data.get("wikidata_enrichment", {}).get("wikidata_entity_id", "")
)
if wikidata_id:
metadata["wikidata_id"] = wikidata_id
# Identifiers
identifiers = data.get("identifiers", [])
for ident in identifiers:
if isinstance(ident, dict):
scheme = ident.get("identifier_scheme", "").lower()
value = ident.get("identifier_value", "")
if scheme and value and scheme not in ["ghcid", "ghcid_uuid", "ghcid_uuid_sha256", "ghcid_numeric", "record_id"]:
metadata[f"id_{scheme}"] = str(value)
return metadata
def find_institution_files(data_dir: Path) -> list[Path]:
"""Find all institution YAML files in the data directory."""
files = []
# Look for YAML files in common patterns
patterns = [
"*.yaml",
"*.yml",
"**/*.yaml",
"**/*.yml",
]
for pattern in patterns:
files.extend(data_dir.glob(pattern))
# Deduplicate
files = list(set(files))
# Filter out non-institution files
excluded_patterns = [
"_schema",
"_config",
"_template",
"test_",
"example_",
]
filtered = []
for f in files:
if not any(excl in f.name.lower() for excl in excluded_patterns):
filtered.append(f)
return sorted(filtered)
def main():
parser = argparse.ArgumentParser(
description="Index heritage institutions in Qdrant for semantic search"
)
parser.add_argument(
"--data-dir",
type=Path,
default=PROJECT_ROOT / "data" / "custodian",
help="Directory containing institution YAML files"
)
parser.add_argument(
"--host",
default=os.getenv("QDRANT_HOST", "localhost"),
help="Qdrant server hostname"
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("QDRANT_PORT", "6333")),
help="Qdrant REST API port"
)
parser.add_argument(
"--url",
default=os.getenv("QDRANT_URL", ""),
help="Full Qdrant URL (e.g., https://bronhouder.nl/qdrant). Overrides host/port."
)
parser.add_argument(
"--collection",
default="heritage_custodians",
help="Qdrant collection name"
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Number of documents to index per batch"
)
parser.add_argument(
"--recreate",
action="store_true",
help="Delete and recreate the collection"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Parse files but don't index"
)
args = parser.parse_args()
# Check data directory exists
if not args.data_dir.exists():
logger.error(f"Data directory not found: {args.data_dir}")
sys.exit(1)
# Find institution files
logger.info(f"Scanning for institution files in {args.data_dir}")
files = find_institution_files(args.data_dir)
logger.info(f"Found {len(files)} institution files")
if not files:
logger.warning("No institution files found")
sys.exit(0)
# Prepare documents
documents = []
for filepath in files:
data = load_yaml_file(filepath)
if not data:
continue
text = extract_institution_text(data)
if not text or len(text) < 20:
logger.debug(f"Skipping {filepath.name}: insufficient text")
continue
metadata = extract_metadata(data, filepath)
documents.append({
"text": text,
"metadata": metadata,
})
logger.info(f"Prepared {len(documents)} documents for indexing")
if args.dry_run:
logger.info("Dry run - not indexing")
for doc in documents[:5]:
logger.info(f" - {doc['metadata'].get('name', 'Unknown')}: {len(doc['text'])} chars")
sys.exit(0)
# Import and create retriever
try:
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever
except ImportError as e:
logger.error(f"Failed to import retriever: {e}")
logger.error("Make sure qdrant-client and openai are installed:")
logger.error(" pip install qdrant-client openai")
sys.exit(1)
# Check for OpenAI API key
if not os.getenv("OPENAI_API_KEY"):
logger.error("OPENAI_API_KEY environment variable is required for embeddings")
sys.exit(1)
# Connect to Qdrant
if args.url:
logger.info(f"Connecting to Qdrant at {args.url}")
retriever = HeritageCustodianRetriever(url=args.url)
else:
logger.info(f"Connecting to Qdrant at {args.host}:{args.port}")
retriever = HeritageCustodianRetriever(
host=args.host,
port=args.port,
)
# Optionally recreate collection
if args.recreate:
logger.warning(f"Deleting collection: {retriever.collection_name}")
retriever.delete_collection()
# Ensure collection exists
retriever.ensure_collection()
# Index documents
logger.info(f"Indexing {len(documents)} documents...")
indexed = retriever.add_documents(documents, batch_size=args.batch_size)
# Report results
info = retriever.get_collection_info()
logger.info("Indexing complete!")
logger.info(f" Documents indexed: {indexed}")
logger.info(f" Collection status: {info.get('status', 'unknown')}")
logger.info(f" Total vectors: {info.get('vectors_count', 0)}")
if __name__ == "__main__":
main()