glam/src/glam_extractor/api/qdrant_retriever.py
2025-12-21 00:01:54 +01:00

488 lines
16 KiB
Python

"""
Qdrant Vector Store Retriever for DSPy
Provides semantic search over heritage institution data using Qdrant vector database.
Implements DSPy retriever interface for RAG-enhanced SPARQL generation.
"""
import hashlib
import logging
import os
from typing import Any
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams
logger = logging.getLogger(__name__)
# Default configuration
DEFAULT_QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
DEFAULT_QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
DEFAULT_COLLECTION_NAME = "heritage_institutions"
DEFAULT_EMBEDDING_DIM = 1536 # OpenAI text-embedding-3-small
class QdrantRetriever:
"""Qdrant-based retriever for heritage institution semantic search.
Implements a retriever compatible with DSPy's retrieval module pattern.
Uses Qdrant vector database for efficient similarity search.
Example usage with DSPy:
retriever = QdrantRetriever()
# Use in RAG pipeline
results = retriever("museums in Amsterdam")
"""
def __init__(
self,
host: str = DEFAULT_QDRANT_HOST,
port: int = DEFAULT_QDRANT_PORT,
collection_name: str = DEFAULT_COLLECTION_NAME,
embedding_model: str = "text-embedding-3-small",
embedding_dim: int = DEFAULT_EMBEDDING_DIM,
k: int = 5,
api_key: str | None = None,
url: str | None = None,
https: bool = False,
prefix: str | None = None,
) -> None:
"""Initialize Qdrant retriever.
Args:
host: Qdrant server hostname
port: Qdrant REST API port
collection_name: Name of the Qdrant collection
embedding_model: OpenAI embedding model name
embedding_dim: Dimension of embedding vectors
k: Number of results to retrieve
api_key: OpenAI API key for embeddings
url: Full URL to Qdrant (deprecated, use host/port/prefix instead)
https: Use HTTPS for connection
prefix: URL path prefix (e.g., 'qdrant' for /qdrant/*)
"""
self.host = host
self.port = port
self.collection_name = collection_name
self.embedding_model = embedding_model
self.embedding_dim = embedding_dim
self.k = k
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
# Initialize Qdrant client
if url:
# Legacy: Direct URL connection (may not work with path prefixes)
self.client = QdrantClient(url=url, prefer_grpc=False, timeout=60)
logger.info(f"Initialized QdrantRetriever: {url}/{collection_name}")
elif https or port == 443:
# HTTPS connection with optional path prefix
# This is the correct way to connect via reverse proxy
self.client = QdrantClient(
host=host,
port=port,
https=True,
prefix=prefix,
prefer_grpc=False,
timeout=30
)
prefix_str = f"/{prefix}" if prefix else ""
logger.info(f"Initialized QdrantRetriever: https://{host}:{port}{prefix_str}/{collection_name}")
else:
# Standard host/port connection (local/SSH tunnel)
self.client = QdrantClient(host=host, port=port)
logger.info(f"Initialized QdrantRetriever: {host}:{port}/{collection_name}")
# Lazy-load OpenAI client
self._openai_client = None
@property
def openai_client(self) -> Any:
"""Lazy-load OpenAI client."""
if self._openai_client is None:
try:
import openai
self._openai_client = openai.OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError("openai package required for embeddings. Install with: pip install openai")
return self._openai_client
def _get_embedding(self, text: str) -> list[float]:
"""Get embedding vector for text using OpenAI.
Args:
text: Text to embed
Returns:
Embedding vector as list of floats
"""
response = self.openai_client.embeddings.create(
input=text,
model=self.embedding_model
)
return response.data[0].embedding
def _get_embeddings_batch(self, texts: list[str]) -> list[list[float]]:
"""Get embedding vectors for multiple texts using OpenAI batch API.
Args:
texts: List of texts to embed (max 2048 per batch)
Returns:
List of embedding vectors
"""
if not texts:
return []
response = self.openai_client.embeddings.create(
input=texts,
model=self.embedding_model
)
# Return embeddings in same order as input
return [item.embedding for item in sorted(response.data, key=lambda x: x.index)]
def _text_to_id(self, text: str) -> str:
"""Generate deterministic ID from text.
Args:
text: Text to hash
Returns:
Hex string ID
"""
return hashlib.md5(text.encode()).hexdigest()
def ensure_collection(self) -> None:
"""Ensure the collection exists, create if not."""
collections = self.client.get_collections().collections
collection_names = [c.name for c in collections]
if self.collection_name not in collection_names:
logger.info(f"Creating collection: {self.collection_name}")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.embedding_dim,
distance=Distance.COSINE
)
)
def add_documents(
self,
documents: list[dict[str, Any]],
batch_size: int = 100
) -> int:
"""Add documents to the collection using batch embeddings.
Args:
documents: List of documents with 'text' and optional 'metadata' fields
batch_size: Number of documents to embed and upsert per batch
Returns:
Number of documents added
"""
self.ensure_collection()
# Filter documents with valid text
valid_docs = [d for d in documents if d.get("text")]
total_indexed = 0
# Process in batches
for i in range(0, len(valid_docs), batch_size):
batch = valid_docs[i:i + batch_size]
texts = [d["text"] for d in batch]
# Get embeddings in batch (much faster than one at a time)
embeddings = self._get_embeddings_batch(texts)
# Create points
points = []
for doc, embedding in zip(batch, embeddings):
text = doc["text"]
metadata = doc.get("metadata", {})
point_id = self._text_to_id(text)
points.append(models.PointStruct(
id=point_id,
vector=embedding,
payload={
"text": text,
**metadata
}
))
# Upsert batch
self.client.upsert(
collection_name=self.collection_name,
points=points
)
total_indexed += len(points)
logger.info(f"Indexed {total_indexed}/{len(valid_docs)} documents")
return total_indexed
def search(
self,
query: str,
k: int | None = None,
filter_conditions: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
"""Search for similar documents.
Args:
query: Search query text
k: Number of results (defaults to self.k)
filter_conditions: Optional Qdrant filter conditions
Returns:
List of matching documents with scores
"""
k = k or self.k
query_vector = self._get_embedding(query)
query_filter = None
if filter_conditions:
query_filter = models.Filter(
must=[
models.FieldCondition(
key=key,
match=models.MatchValue(value=value)
)
for key, value in filter_conditions.items()
]
)
# Use query_points API (qdrant-client >= 1.7)
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=k,
with_payload=True,
query_filter=query_filter
)
return [
{
"id": str(r.id),
"score": r.score,
"text": r.payload.get("text", "") if r.payload else "",
"metadata": {k: v for k, v in r.payload.items() if k != "text"} if r.payload else {}
}
for r in results.points
]
def __call__(self, query: str, k: int | None = None) -> list[str]:
"""DSPy-compatible interface for retrieval.
Args:
query: Search query
k: Number of results
Returns:
List of passage texts
"""
results = self.search(query, k=k)
return [r["text"] for r in results]
def get_collection_info(self) -> dict[str, Any]:
"""Get collection statistics.
Returns:
Collection info dict
"""
try:
info = self.client.get_collection(self.collection_name)
return {
"name": self.collection_name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value if info.status else "unknown"
}
except Exception as e:
logger.warning(f"Failed to get collection info: {e}")
return {
"name": self.collection_name,
"error": str(e)
}
def delete_collection(self) -> bool:
"""Delete the collection.
Returns:
True if deleted successfully
"""
try:
self.client.delete_collection(self.collection_name)
logger.info(f"Deleted collection: {self.collection_name}")
return True
except Exception as e:
logger.warning(f"Failed to delete collection: {e}")
return False
class HeritageCustodianRetriever(QdrantRetriever):
"""Specialized retriever for heritage custodian institution data.
Provides domain-specific search functionality for GLAM institutions.
"""
def __init__(
self,
host: str = DEFAULT_QDRANT_HOST,
port: int = DEFAULT_QDRANT_PORT,
**kwargs
) -> None:
"""Initialize heritage custodian retriever."""
super().__init__(
host=host,
port=port,
collection_name="heritage_custodians",
**kwargs
)
def search_by_type(
self,
query: str,
institution_type: str,
k: int = 5
) -> list[dict[str, Any]]:
"""Search institutions filtered by type.
Args:
query: Search query
institution_type: Institution type (MUSEUM, LIBRARY, ARCHIVE, etc.)
k: Number of results
Returns:
List of matching institutions
"""
return self.search(
query=query,
k=k,
filter_conditions={"institution_type": institution_type}
)
def search_by_country(
self,
query: str,
country_code: str,
k: int = 5
) -> list[dict[str, Any]]:
"""Search institutions filtered by country.
Args:
query: Search query
country_code: ISO 3166-1 alpha-2 country code
k: Number of results
Returns:
List of matching institutions
"""
return self.search(
query=query,
k=k,
filter_conditions={"country": country_code}
)
def add_institution(
self,
name: str,
description: str,
institution_type: str,
country: str,
city: str | None = None,
ghcid: str | None = None,
**extra_metadata
) -> None:
"""Add a single institution to the index.
Args:
name: Institution name
description: Institution description
institution_type: Type (MUSEUM, LIBRARY, ARCHIVE, etc.)
country: ISO 3166-1 alpha-2 country code
city: City name (optional)
ghcid: Global Heritage Custodian ID (optional)
**extra_metadata: Additional metadata fields
"""
# Create searchable text combining name and description
text = f"{name}. {description}"
metadata = {
"name": name,
"institution_type": institution_type,
"country": country,
**extra_metadata
}
if city:
metadata["city"] = city
if ghcid:
metadata["ghcid"] = ghcid
self.add_documents([{"text": text, "metadata": metadata}])
def create_retriever(
host: str | None = None,
port: int | None = None,
collection_type: str = "heritage_custodians",
use_production: bool | None = None
) -> QdrantRetriever:
"""Factory function to create a retriever instance.
Args:
host: Qdrant server hostname (defaults to env var or localhost)
port: Qdrant REST API port (defaults to env var or 6333)
collection_type: Type of collection ('heritage_custodians' or 'general')
use_production: If True, connect to production server via HTTPS.
Defaults to QDRANT_USE_PRODUCTION env var or False.
Returns:
Configured retriever instance
Examples:
# Local development (SSH tunnel or local Docker)
retriever = create_retriever()
# Production via HTTPS proxy
retriever = create_retriever(use_production=True)
# Or set environment variable:
# export QDRANT_USE_PRODUCTION=true
retriever = create_retriever()
"""
# Check if production mode requested
if use_production is None:
use_production = os.getenv("QDRANT_USE_PRODUCTION", "").lower() in ("true", "1", "yes")
if use_production:
# Production: connect via HTTPS reverse proxy
prod_host = os.getenv("QDRANT_PROD_HOST", "bronhouder.nl")
prod_prefix = os.getenv("QDRANT_PROD_PREFIX", "qdrant")
if collection_type == "heritage_custodians":
return HeritageCustodianRetriever(
host=prod_host,
port=443,
https=True,
prefix=prod_prefix
)
else:
return QdrantRetriever(
host=prod_host,
port=443,
https=True,
prefix=prod_prefix
)
else:
# Local development
host = host or DEFAULT_QDRANT_HOST
port = port or DEFAULT_QDRANT_PORT
if collection_type == "heritage_custodians":
return HeritageCustodianRetriever(host=host, port=port)
else:
return QdrantRetriever(host=host, port=port)