glam/scripts/index_institutions_direct.py
2025-12-21 00:01:54 +01:00

449 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Index Heritage Institutions in Qdrant using Direct HTTP API
This script bypasses the qdrant-client library which has issues with reverse proxy URLs.
Uses requests library directly for reliable operation.
Usage:
python scripts/index_institutions_direct.py --data-dir /tmp/dutch_custodians
"""
import argparse
import logging
import os
import sys
import uuid
from pathlib import Path
from typing import Any
import requests
import yaml
from openai import OpenAI
PROJECT_ROOT = Path(__file__).parent.parent
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)
# Qdrant configuration
QDRANT_BASE_URL = "https://bronhouder.nl/qdrant"
COLLECTION_NAME = "heritage_custodians"
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIM = 1536
def load_yaml_file(filepath: Path) -> dict[str, Any] | None:
"""Load a YAML file and return its contents."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
except Exception as e:
logger.warning(f"Failed to load {filepath}: {e}")
return None
def extract_institution_text(data: dict[str, Any]) -> str:
"""Extract searchable text from institution data."""
parts = []
original = data.get("original_entry", {})
# Name
name = (
data.get("custodian_name", {}).get("claim_value") or
data.get("custodian_name", {}).get("emic_name") or
original.get("name") or
data.get("name", "")
)
if name:
parts.append(f"Name: {name}")
# Alternative names from Wikidata
wikidata = data.get("wikidata_enrichment", {})
labels = wikidata.get("wikidata_labels", {})
if labels:
alt_names = [v for k, v in labels.items() if v and v != name][:5]
if alt_names:
parts.append(f"Also known as: {', '.join(set(alt_names))}")
# Description
description = wikidata.get("wikidata_description_en", "")
if not description:
descriptions = wikidata.get("wikidata_descriptions", {})
description = descriptions.get("en", "")
if description:
parts.append(description)
# Institution type
inst_type = original.get("institution_type") or data.get("institution_type", "")
if inst_type:
parts.append(f"Type: {inst_type}")
# Location
locations = original.get("locations", []) or data.get("locations", [])
location = data.get("location", {})
def safe_str(val):
if val is None:
return None
if isinstance(val, str):
return val
if isinstance(val, dict):
return val.get("name") or val.get("label") or str(val)
return str(val)
city = None
region = None
country = None
if locations and isinstance(locations, list) and len(locations) > 0:
loc = locations[0]
if isinstance(loc, dict):
city = safe_str(loc.get("city"))
region = safe_str(loc.get("region"))
country = safe_str(loc.get("country"))
if not city:
city = safe_str(location.get("city"))
if not region:
region = safe_str(location.get("region"))
if not country:
country = safe_str(location.get("country"))
# Try wikidata location
if not city and not region:
wikidata_loc = wikidata.get("located_in", {})
if wikidata_loc:
city = safe_str(wikidata_loc.get("label"))
location_parts = [p for p in [city, region, country] if p]
if location_parts:
parts.append(f"Location: {', '.join(location_parts)}")
return "\n".join(parts)
def extract_metadata(data: dict[str, Any], filepath: Path) -> dict[str, Any]:
"""Extract metadata for Qdrant payload."""
metadata = {}
original = data.get("original_entry", {})
# GHCID
ghcid = data.get("ghcid", {})
ghcid_current = ghcid.get("ghcid_current") or original.get("ghcid") or filepath.stem
metadata["ghcid"] = ghcid_current
# Name
name = (
data.get("custodian_name", {}).get("claim_value") or
data.get("custodian_name", {}).get("emic_name") or
original.get("name") or
data.get("name", "")
)
if name:
metadata["name"] = name
# Institution type
inst_type = original.get("institution_type") or data.get("institution_type", "")
if inst_type:
metadata["institution_type"] = inst_type
# Location
locations = original.get("locations", []) or data.get("locations", [])
location = data.get("location", {})
def safe_str(val):
if val is None:
return None
if isinstance(val, str):
return val
if isinstance(val, dict):
return val.get("name") or val.get("label") or str(val)
return str(val)
if locations and isinstance(locations, list) and len(locations) > 0:
loc = locations[0]
if isinstance(loc, dict):
if loc.get("city"):
metadata["city"] = safe_str(loc["city"])
# Use region_code (ISO 3166-2) for filtering, fallback to region name
if loc.get("region_code"):
metadata["region"] = loc["region_code"] # e.g., "NH" not "Noord-Holland"
elif loc.get("region"):
metadata["region"] = safe_str(loc["region"])
if loc.get("country"):
metadata["country"] = safe_str(loc["country"])
# Fallback to location dict
if "city" not in metadata and location.get("city"):
metadata["city"] = safe_str(location["city"])
if "region" not in metadata:
if location.get("region_code"):
metadata["region"] = location["region_code"]
elif location.get("region"):
metadata["region"] = safe_str(location["region"])
if "country" not in metadata and location.get("country"):
metadata["country"] = safe_str(location["country"])
# Coordinates
lat = data.get("latitude")
lon = data.get("longitude")
if lat is None or lon is None:
google_maps = data.get("google_maps_enrichment", {})
coords = google_maps.get("coordinates", {})
if coords:
lat = lat or coords.get("latitude")
lon = lon or coords.get("longitude")
if lat is None or lon is None:
wikidata = data.get("wikidata_enrichment", {})
wikidata_coords = wikidata.get("wikidata_coordinates", {})
if wikidata_coords:
lat = lat or wikidata_coords.get("latitude")
lon = lon or wikidata_coords.get("longitude")
if lat is not None:
try:
metadata["latitude"] = float(lat)
except (ValueError, TypeError):
pass
if lon is not None:
try:
metadata["longitude"] = float(lon)
except (ValueError, TypeError):
pass
# Wikidata ID
wikidata_id = (
original.get("wikidata_id") or
data.get("wikidata_enrichment", {}).get("wikidata_entity_id", "")
)
if wikidata_id:
metadata["wikidata_id"] = wikidata_id
return metadata
def find_institution_files(data_dir: Path) -> list[Path]:
"""Find all institution YAML files in the data directory."""
files = []
excluded_patterns = ["_schema", "_config", "_template", "test_", "example_"]
def is_valid_file(name: str) -> bool:
if not name.endswith(('.yaml', '.yml')):
return False
if name.startswith('.'):
return False
name_lower = name.lower()
return not any(excl in name_lower for excl in excluded_patterns)
try:
for name in os.listdir(data_dir):
if is_valid_file(name):
filepath = data_dir / name
if filepath.is_file():
files.append(filepath)
except PermissionError:
logger.warning(f"Permission denied accessing {data_dir}")
return sorted(files)
def create_collection():
"""Create the Qdrant collection with proper settings."""
# Check if collection exists
resp = requests.get(f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}", timeout=30)
if resp.status_code == 200:
logger.info(f"Collection {COLLECTION_NAME} already exists")
return True
# Create collection
create_data = {
"vectors": {
"size": EMBEDDING_DIM,
"distance": "Cosine"
}
}
resp = requests.put(
f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}",
json=create_data,
timeout=30
)
if resp.status_code in (200, 201):
logger.info(f"Created collection {COLLECTION_NAME}")
return True
else:
logger.error(f"Failed to create collection: {resp.status_code} - {resp.text}")
return False
def get_embeddings(texts: list[str], client: OpenAI) -> list[list[float]]:
"""Get embeddings for a batch of texts."""
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=texts
)
return [item.embedding for item in response.data]
def upsert_points(points: list[dict], timeout: int = 120):
"""Upsert points to Qdrant collection."""
resp = requests.put(
f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}/points",
json={"points": points},
timeout=timeout
)
if resp.status_code in (200, 201):
return True
else:
logger.error(f"Failed to upsert points: {resp.status_code} - {resp.text}")
return False
def main():
parser = argparse.ArgumentParser(
description="Index heritage institutions in Qdrant using direct HTTP API"
)
parser.add_argument(
"--data-dir",
type=Path,
default=PROJECT_ROOT / "data" / "custodian",
help="Directory containing institution YAML files"
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Number of documents to index per batch"
)
parser.add_argument(
"--recreate",
action="store_true",
help="Delete and recreate the collection"
)
parser.add_argument(
"--limit",
type=int,
default=None,
help="Limit number of files to process (for testing)"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Parse files but don't index"
)
args = parser.parse_args()
# Check API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.error("OPENAI_API_KEY environment variable is required")
sys.exit(1)
# Check data directory
if not args.data_dir.exists():
logger.error(f"Data directory not found: {args.data_dir}")
sys.exit(1)
# Find files
logger.info(f"Scanning for institution files in {args.data_dir}")
files = find_institution_files(args.data_dir)
logger.info(f"Found {len(files)} institution files")
if args.limit:
files = files[:args.limit]
logger.info(f"Limited to {len(files)} files")
if not files:
logger.warning("No institution files found")
sys.exit(0)
# Prepare documents
documents = []
for filepath in files:
data = load_yaml_file(filepath)
if not data:
continue
text = extract_institution_text(data)
if not text or len(text) < 20:
continue
metadata = extract_metadata(data, filepath)
documents.append({
"text": text,
"metadata": metadata,
})
logger.info(f"Prepared {len(documents)} documents for indexing")
if args.dry_run:
logger.info("Dry run - not indexing")
for doc in documents[:5]:
logger.info(f" - {doc['metadata'].get('name', 'Unknown')}: {len(doc['text'])} chars")
logger.info(f" Region: {doc['metadata'].get('region', 'N/A')}")
sys.exit(0)
# Handle collection
if args.recreate:
logger.info(f"Deleting collection {COLLECTION_NAME}")
resp = requests.delete(f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}", timeout=30)
logger.info(f"Delete result: {resp.status_code}")
if not create_collection():
sys.exit(1)
# Initialize OpenAI client
client = OpenAI(api_key=api_key)
# Index in batches
total_indexed = 0
for i in range(0, len(documents), args.batch_size):
batch = documents[i:i + args.batch_size]
texts = [doc["text"] for doc in batch]
logger.info(f"Processing batch {i // args.batch_size + 1}/{(len(documents) + args.batch_size - 1) // args.batch_size} ({len(batch)} docs)")
# Get embeddings
embeddings = get_embeddings(texts, client)
# Prepare points
points = []
for j, (doc, embedding) in enumerate(zip(batch, embeddings)):
point_id = str(uuid.uuid4())
points.append({
"id": point_id,
"vector": embedding,
"payload": doc["metadata"]
})
# Upsert
if upsert_points(points):
total_indexed += len(points)
logger.info(f"Indexed {total_indexed}/{len(documents)} documents")
else:
logger.error(f"Failed to index batch starting at {i}")
# Final stats
resp = requests.get(f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}", timeout=30)
if resp.status_code == 200:
info = resp.json().get("result", {})
vectors_count = info.get("vectors_count", 0)
logger.info(f"Indexing complete! Collection has {vectors_count} vectors")
logger.info(f"Total documents indexed: {total_indexed}")
if __name__ == "__main__":
main()