449 lines
14 KiB
Python
449 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Index Heritage Institutions in Qdrant using Direct HTTP API
|
|
|
|
This script bypasses the qdrant-client library which has issues with reverse proxy URLs.
|
|
Uses requests library directly for reliable operation.
|
|
|
|
Usage:
|
|
python scripts/index_institutions_direct.py --data-dir /tmp/dutch_custodians
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import requests
|
|
import yaml
|
|
from openai import OpenAI
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Qdrant configuration
|
|
QDRANT_BASE_URL = "https://bronhouder.nl/qdrant"
|
|
COLLECTION_NAME = "heritage_custodians"
|
|
EMBEDDING_MODEL = "text-embedding-3-small"
|
|
EMBEDDING_DIM = 1536
|
|
|
|
|
|
def load_yaml_file(filepath: Path) -> dict[str, Any] | None:
|
|
"""Load a YAML file and return its contents."""
|
|
try:
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
return yaml.safe_load(f)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load {filepath}: {e}")
|
|
return None
|
|
|
|
|
|
def extract_institution_text(data: dict[str, Any]) -> str:
|
|
"""Extract searchable text from institution data."""
|
|
parts = []
|
|
|
|
original = data.get("original_entry", {})
|
|
|
|
# Name
|
|
name = (
|
|
data.get("custodian_name", {}).get("claim_value") or
|
|
data.get("custodian_name", {}).get("emic_name") or
|
|
original.get("name") or
|
|
data.get("name", "")
|
|
)
|
|
if name:
|
|
parts.append(f"Name: {name}")
|
|
|
|
# Alternative names from Wikidata
|
|
wikidata = data.get("wikidata_enrichment", {})
|
|
labels = wikidata.get("wikidata_labels", {})
|
|
if labels:
|
|
alt_names = [v for k, v in labels.items() if v and v != name][:5]
|
|
if alt_names:
|
|
parts.append(f"Also known as: {', '.join(set(alt_names))}")
|
|
|
|
# Description
|
|
description = wikidata.get("wikidata_description_en", "")
|
|
if not description:
|
|
descriptions = wikidata.get("wikidata_descriptions", {})
|
|
description = descriptions.get("en", "")
|
|
if description:
|
|
parts.append(description)
|
|
|
|
# Institution type
|
|
inst_type = original.get("institution_type") or data.get("institution_type", "")
|
|
if inst_type:
|
|
parts.append(f"Type: {inst_type}")
|
|
|
|
# Location
|
|
locations = original.get("locations", []) or data.get("locations", [])
|
|
location = data.get("location", {})
|
|
|
|
def safe_str(val):
|
|
if val is None:
|
|
return None
|
|
if isinstance(val, str):
|
|
return val
|
|
if isinstance(val, dict):
|
|
return val.get("name") or val.get("label") or str(val)
|
|
return str(val)
|
|
|
|
city = None
|
|
region = None
|
|
country = None
|
|
|
|
if locations and isinstance(locations, list) and len(locations) > 0:
|
|
loc = locations[0]
|
|
if isinstance(loc, dict):
|
|
city = safe_str(loc.get("city"))
|
|
region = safe_str(loc.get("region"))
|
|
country = safe_str(loc.get("country"))
|
|
|
|
if not city:
|
|
city = safe_str(location.get("city"))
|
|
if not region:
|
|
region = safe_str(location.get("region"))
|
|
if not country:
|
|
country = safe_str(location.get("country"))
|
|
|
|
# Try wikidata location
|
|
if not city and not region:
|
|
wikidata_loc = wikidata.get("located_in", {})
|
|
if wikidata_loc:
|
|
city = safe_str(wikidata_loc.get("label"))
|
|
|
|
location_parts = [p for p in [city, region, country] if p]
|
|
if location_parts:
|
|
parts.append(f"Location: {', '.join(location_parts)}")
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
def extract_metadata(data: dict[str, Any], filepath: Path) -> dict[str, Any]:
|
|
"""Extract metadata for Qdrant payload."""
|
|
metadata = {}
|
|
|
|
original = data.get("original_entry", {})
|
|
|
|
# GHCID
|
|
ghcid = data.get("ghcid", {})
|
|
ghcid_current = ghcid.get("ghcid_current") or original.get("ghcid") or filepath.stem
|
|
metadata["ghcid"] = ghcid_current
|
|
|
|
# Name
|
|
name = (
|
|
data.get("custodian_name", {}).get("claim_value") or
|
|
data.get("custodian_name", {}).get("emic_name") or
|
|
original.get("name") or
|
|
data.get("name", "")
|
|
)
|
|
if name:
|
|
metadata["name"] = name
|
|
|
|
# Institution type
|
|
inst_type = original.get("institution_type") or data.get("institution_type", "")
|
|
if inst_type:
|
|
metadata["institution_type"] = inst_type
|
|
|
|
# Location
|
|
locations = original.get("locations", []) or data.get("locations", [])
|
|
location = data.get("location", {})
|
|
|
|
def safe_str(val):
|
|
if val is None:
|
|
return None
|
|
if isinstance(val, str):
|
|
return val
|
|
if isinstance(val, dict):
|
|
return val.get("name") or val.get("label") or str(val)
|
|
return str(val)
|
|
|
|
if locations and isinstance(locations, list) and len(locations) > 0:
|
|
loc = locations[0]
|
|
if isinstance(loc, dict):
|
|
if loc.get("city"):
|
|
metadata["city"] = safe_str(loc["city"])
|
|
# Use region_code (ISO 3166-2) for filtering, fallback to region name
|
|
if loc.get("region_code"):
|
|
metadata["region"] = loc["region_code"] # e.g., "NH" not "Noord-Holland"
|
|
elif loc.get("region"):
|
|
metadata["region"] = safe_str(loc["region"])
|
|
if loc.get("country"):
|
|
metadata["country"] = safe_str(loc["country"])
|
|
|
|
# Fallback to location dict
|
|
if "city" not in metadata and location.get("city"):
|
|
metadata["city"] = safe_str(location["city"])
|
|
if "region" not in metadata:
|
|
if location.get("region_code"):
|
|
metadata["region"] = location["region_code"]
|
|
elif location.get("region"):
|
|
metadata["region"] = safe_str(location["region"])
|
|
if "country" not in metadata and location.get("country"):
|
|
metadata["country"] = safe_str(location["country"])
|
|
|
|
# Coordinates
|
|
lat = data.get("latitude")
|
|
lon = data.get("longitude")
|
|
|
|
if lat is None or lon is None:
|
|
google_maps = data.get("google_maps_enrichment", {})
|
|
coords = google_maps.get("coordinates", {})
|
|
if coords:
|
|
lat = lat or coords.get("latitude")
|
|
lon = lon or coords.get("longitude")
|
|
|
|
if lat is None or lon is None:
|
|
wikidata = data.get("wikidata_enrichment", {})
|
|
wikidata_coords = wikidata.get("wikidata_coordinates", {})
|
|
if wikidata_coords:
|
|
lat = lat or wikidata_coords.get("latitude")
|
|
lon = lon or wikidata_coords.get("longitude")
|
|
|
|
if lat is not None:
|
|
try:
|
|
metadata["latitude"] = float(lat)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if lon is not None:
|
|
try:
|
|
metadata["longitude"] = float(lon)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# Wikidata ID
|
|
wikidata_id = (
|
|
original.get("wikidata_id") or
|
|
data.get("wikidata_enrichment", {}).get("wikidata_entity_id", "")
|
|
)
|
|
if wikidata_id:
|
|
metadata["wikidata_id"] = wikidata_id
|
|
|
|
return metadata
|
|
|
|
|
|
def find_institution_files(data_dir: Path) -> list[Path]:
|
|
"""Find all institution YAML files in the data directory."""
|
|
files = []
|
|
|
|
excluded_patterns = ["_schema", "_config", "_template", "test_", "example_"]
|
|
|
|
def is_valid_file(name: str) -> bool:
|
|
if not name.endswith(('.yaml', '.yml')):
|
|
return False
|
|
if name.startswith('.'):
|
|
return False
|
|
name_lower = name.lower()
|
|
return not any(excl in name_lower for excl in excluded_patterns)
|
|
|
|
try:
|
|
for name in os.listdir(data_dir):
|
|
if is_valid_file(name):
|
|
filepath = data_dir / name
|
|
if filepath.is_file():
|
|
files.append(filepath)
|
|
except PermissionError:
|
|
logger.warning(f"Permission denied accessing {data_dir}")
|
|
|
|
return sorted(files)
|
|
|
|
|
|
def create_collection():
|
|
"""Create the Qdrant collection with proper settings."""
|
|
# Check if collection exists
|
|
resp = requests.get(f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}", timeout=30)
|
|
if resp.status_code == 200:
|
|
logger.info(f"Collection {COLLECTION_NAME} already exists")
|
|
return True
|
|
|
|
# Create collection
|
|
create_data = {
|
|
"vectors": {
|
|
"size": EMBEDDING_DIM,
|
|
"distance": "Cosine"
|
|
}
|
|
}
|
|
resp = requests.put(
|
|
f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}",
|
|
json=create_data,
|
|
timeout=30
|
|
)
|
|
|
|
if resp.status_code in (200, 201):
|
|
logger.info(f"Created collection {COLLECTION_NAME}")
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to create collection: {resp.status_code} - {resp.text}")
|
|
return False
|
|
|
|
|
|
def get_embeddings(texts: list[str], client: OpenAI) -> list[list[float]]:
|
|
"""Get embeddings for a batch of texts."""
|
|
response = client.embeddings.create(
|
|
model=EMBEDDING_MODEL,
|
|
input=texts
|
|
)
|
|
return [item.embedding for item in response.data]
|
|
|
|
|
|
def upsert_points(points: list[dict], timeout: int = 120):
|
|
"""Upsert points to Qdrant collection."""
|
|
resp = requests.put(
|
|
f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}/points",
|
|
json={"points": points},
|
|
timeout=timeout
|
|
)
|
|
|
|
if resp.status_code in (200, 201):
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to upsert points: {resp.status_code} - {resp.text}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Index heritage institutions in Qdrant using direct HTTP API"
|
|
)
|
|
parser.add_argument(
|
|
"--data-dir",
|
|
type=Path,
|
|
default=PROJECT_ROOT / "data" / "custodian",
|
|
help="Directory containing institution YAML files"
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=50,
|
|
help="Number of documents to index per batch"
|
|
)
|
|
parser.add_argument(
|
|
"--recreate",
|
|
action="store_true",
|
|
help="Delete and recreate the collection"
|
|
)
|
|
parser.add_argument(
|
|
"--limit",
|
|
type=int,
|
|
default=None,
|
|
help="Limit number of files to process (for testing)"
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Parse files but don't index"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check API key
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
logger.error("OPENAI_API_KEY environment variable is required")
|
|
sys.exit(1)
|
|
|
|
# Check data directory
|
|
if not args.data_dir.exists():
|
|
logger.error(f"Data directory not found: {args.data_dir}")
|
|
sys.exit(1)
|
|
|
|
# Find files
|
|
logger.info(f"Scanning for institution files in {args.data_dir}")
|
|
files = find_institution_files(args.data_dir)
|
|
logger.info(f"Found {len(files)} institution files")
|
|
|
|
if args.limit:
|
|
files = files[:args.limit]
|
|
logger.info(f"Limited to {len(files)} files")
|
|
|
|
if not files:
|
|
logger.warning("No institution files found")
|
|
sys.exit(0)
|
|
|
|
# Prepare documents
|
|
documents = []
|
|
for filepath in files:
|
|
data = load_yaml_file(filepath)
|
|
if not data:
|
|
continue
|
|
|
|
text = extract_institution_text(data)
|
|
if not text or len(text) < 20:
|
|
continue
|
|
|
|
metadata = extract_metadata(data, filepath)
|
|
documents.append({
|
|
"text": text,
|
|
"metadata": metadata,
|
|
})
|
|
|
|
logger.info(f"Prepared {len(documents)} documents for indexing")
|
|
|
|
if args.dry_run:
|
|
logger.info("Dry run - not indexing")
|
|
for doc in documents[:5]:
|
|
logger.info(f" - {doc['metadata'].get('name', 'Unknown')}: {len(doc['text'])} chars")
|
|
logger.info(f" Region: {doc['metadata'].get('region', 'N/A')}")
|
|
sys.exit(0)
|
|
|
|
# Handle collection
|
|
if args.recreate:
|
|
logger.info(f"Deleting collection {COLLECTION_NAME}")
|
|
resp = requests.delete(f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}", timeout=30)
|
|
logger.info(f"Delete result: {resp.status_code}")
|
|
|
|
if not create_collection():
|
|
sys.exit(1)
|
|
|
|
# Initialize OpenAI client
|
|
client = OpenAI(api_key=api_key)
|
|
|
|
# Index in batches
|
|
total_indexed = 0
|
|
for i in range(0, len(documents), args.batch_size):
|
|
batch = documents[i:i + args.batch_size]
|
|
texts = [doc["text"] for doc in batch]
|
|
|
|
logger.info(f"Processing batch {i // args.batch_size + 1}/{(len(documents) + args.batch_size - 1) // args.batch_size} ({len(batch)} docs)")
|
|
|
|
# Get embeddings
|
|
embeddings = get_embeddings(texts, client)
|
|
|
|
# Prepare points
|
|
points = []
|
|
for j, (doc, embedding) in enumerate(zip(batch, embeddings)):
|
|
point_id = str(uuid.uuid4())
|
|
points.append({
|
|
"id": point_id,
|
|
"vector": embedding,
|
|
"payload": doc["metadata"]
|
|
})
|
|
|
|
# Upsert
|
|
if upsert_points(points):
|
|
total_indexed += len(points)
|
|
logger.info(f"Indexed {total_indexed}/{len(documents)} documents")
|
|
else:
|
|
logger.error(f"Failed to index batch starting at {i}")
|
|
|
|
# Final stats
|
|
resp = requests.get(f"{QDRANT_BASE_URL}/collections/{COLLECTION_NAME}", timeout=30)
|
|
if resp.status_code == 200:
|
|
info = resp.json().get("result", {})
|
|
vectors_count = info.get("vectors_count", 0)
|
|
logger.info(f"Indexing complete! Collection has {vectors_count} vectors")
|
|
|
|
logger.info(f"Total documents indexed: {total_indexed}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|