488 lines
16 KiB
Python
488 lines
16 KiB
Python
"""
|
|
Qdrant Vector Store Retriever for DSPy
|
|
|
|
Provides semantic search over heritage institution data using Qdrant vector database.
|
|
Implements DSPy retriever interface for RAG-enhanced SPARQL generation.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models
|
|
from qdrant_client.http.models import Distance, VectorParams
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default configuration
|
|
DEFAULT_QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
|
|
DEFAULT_QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
|
DEFAULT_COLLECTION_NAME = "heritage_institutions"
|
|
DEFAULT_EMBEDDING_DIM = 1536 # OpenAI text-embedding-3-small
|
|
|
|
|
|
class QdrantRetriever:
|
|
"""Qdrant-based retriever for heritage institution semantic search.
|
|
|
|
Implements a retriever compatible with DSPy's retrieval module pattern.
|
|
Uses Qdrant vector database for efficient similarity search.
|
|
|
|
Example usage with DSPy:
|
|
retriever = QdrantRetriever()
|
|
# Use in RAG pipeline
|
|
results = retriever("museums in Amsterdam")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = DEFAULT_QDRANT_HOST,
|
|
port: int = DEFAULT_QDRANT_PORT,
|
|
collection_name: str = DEFAULT_COLLECTION_NAME,
|
|
embedding_model: str = "text-embedding-3-small",
|
|
embedding_dim: int = DEFAULT_EMBEDDING_DIM,
|
|
k: int = 5,
|
|
api_key: str | None = None,
|
|
url: str | None = None,
|
|
https: bool = False,
|
|
prefix: str | None = None,
|
|
) -> None:
|
|
"""Initialize Qdrant retriever.
|
|
|
|
Args:
|
|
host: Qdrant server hostname
|
|
port: Qdrant REST API port
|
|
collection_name: Name of the Qdrant collection
|
|
embedding_model: OpenAI embedding model name
|
|
embedding_dim: Dimension of embedding vectors
|
|
k: Number of results to retrieve
|
|
api_key: OpenAI API key for embeddings
|
|
url: Full URL to Qdrant (deprecated, use host/port/prefix instead)
|
|
https: Use HTTPS for connection
|
|
prefix: URL path prefix (e.g., 'qdrant' for /qdrant/*)
|
|
"""
|
|
self.host = host
|
|
self.port = port
|
|
self.collection_name = collection_name
|
|
self.embedding_model = embedding_model
|
|
self.embedding_dim = embedding_dim
|
|
self.k = k
|
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
|
|
# Initialize Qdrant client
|
|
if url:
|
|
# Legacy: Direct URL connection (may not work with path prefixes)
|
|
self.client = QdrantClient(url=url, prefer_grpc=False, timeout=60)
|
|
logger.info(f"Initialized QdrantRetriever: {url}/{collection_name}")
|
|
elif https or port == 443:
|
|
# HTTPS connection with optional path prefix
|
|
# This is the correct way to connect via reverse proxy
|
|
self.client = QdrantClient(
|
|
host=host,
|
|
port=port,
|
|
https=True,
|
|
prefix=prefix,
|
|
prefer_grpc=False,
|
|
timeout=30
|
|
)
|
|
prefix_str = f"/{prefix}" if prefix else ""
|
|
logger.info(f"Initialized QdrantRetriever: https://{host}:{port}{prefix_str}/{collection_name}")
|
|
else:
|
|
# Standard host/port connection (local/SSH tunnel)
|
|
self.client = QdrantClient(host=host, port=port)
|
|
logger.info(f"Initialized QdrantRetriever: {host}:{port}/{collection_name}")
|
|
|
|
# Lazy-load OpenAI client
|
|
self._openai_client = None
|
|
|
|
@property
|
|
def openai_client(self) -> Any:
|
|
"""Lazy-load OpenAI client."""
|
|
if self._openai_client is None:
|
|
try:
|
|
import openai
|
|
self._openai_client = openai.OpenAI(api_key=self.api_key)
|
|
except ImportError:
|
|
raise ImportError("openai package required for embeddings. Install with: pip install openai")
|
|
return self._openai_client
|
|
|
|
def _get_embedding(self, text: str) -> list[float]:
|
|
"""Get embedding vector for text using OpenAI.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector as list of floats
|
|
"""
|
|
response = self.openai_client.embeddings.create(
|
|
input=text,
|
|
model=self.embedding_model
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
def _get_embeddings_batch(self, texts: list[str]) -> list[list[float]]:
|
|
"""Get embedding vectors for multiple texts using OpenAI batch API.
|
|
|
|
Args:
|
|
texts: List of texts to embed (max 2048 per batch)
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
response = self.openai_client.embeddings.create(
|
|
input=texts,
|
|
model=self.embedding_model
|
|
)
|
|
# Return embeddings in same order as input
|
|
return [item.embedding for item in sorted(response.data, key=lambda x: x.index)]
|
|
|
|
def _text_to_id(self, text: str) -> str:
|
|
"""Generate deterministic ID from text.
|
|
|
|
Args:
|
|
text: Text to hash
|
|
|
|
Returns:
|
|
Hex string ID
|
|
"""
|
|
return hashlib.md5(text.encode()).hexdigest()
|
|
|
|
def ensure_collection(self) -> None:
|
|
"""Ensure the collection exists, create if not."""
|
|
collections = self.client.get_collections().collections
|
|
collection_names = [c.name for c in collections]
|
|
|
|
if self.collection_name not in collection_names:
|
|
logger.info(f"Creating collection: {self.collection_name}")
|
|
self.client.create_collection(
|
|
collection_name=self.collection_name,
|
|
vectors_config=VectorParams(
|
|
size=self.embedding_dim,
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
|
|
def add_documents(
|
|
self,
|
|
documents: list[dict[str, Any]],
|
|
batch_size: int = 100
|
|
) -> int:
|
|
"""Add documents to the collection using batch embeddings.
|
|
|
|
Args:
|
|
documents: List of documents with 'text' and optional 'metadata' fields
|
|
batch_size: Number of documents to embed and upsert per batch
|
|
|
|
Returns:
|
|
Number of documents added
|
|
"""
|
|
self.ensure_collection()
|
|
|
|
# Filter documents with valid text
|
|
valid_docs = [d for d in documents if d.get("text")]
|
|
|
|
total_indexed = 0
|
|
|
|
# Process in batches
|
|
for i in range(0, len(valid_docs), batch_size):
|
|
batch = valid_docs[i:i + batch_size]
|
|
texts = [d["text"] for d in batch]
|
|
|
|
# Get embeddings in batch (much faster than one at a time)
|
|
embeddings = self._get_embeddings_batch(texts)
|
|
|
|
# Create points
|
|
points = []
|
|
for doc, embedding in zip(batch, embeddings):
|
|
text = doc["text"]
|
|
metadata = doc.get("metadata", {})
|
|
point_id = self._text_to_id(text)
|
|
|
|
points.append(models.PointStruct(
|
|
id=point_id,
|
|
vector=embedding,
|
|
payload={
|
|
"text": text,
|
|
**metadata
|
|
}
|
|
))
|
|
|
|
# Upsert batch
|
|
self.client.upsert(
|
|
collection_name=self.collection_name,
|
|
points=points
|
|
)
|
|
total_indexed += len(points)
|
|
logger.info(f"Indexed {total_indexed}/{len(valid_docs)} documents")
|
|
|
|
return total_indexed
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
k: int | None = None,
|
|
filter_conditions: dict[str, Any] | None = None
|
|
) -> list[dict[str, Any]]:
|
|
"""Search for similar documents.
|
|
|
|
Args:
|
|
query: Search query text
|
|
k: Number of results (defaults to self.k)
|
|
filter_conditions: Optional Qdrant filter conditions
|
|
|
|
Returns:
|
|
List of matching documents with scores
|
|
"""
|
|
k = k or self.k
|
|
|
|
query_vector = self._get_embedding(query)
|
|
|
|
query_filter = None
|
|
if filter_conditions:
|
|
query_filter = models.Filter(
|
|
must=[
|
|
models.FieldCondition(
|
|
key=key,
|
|
match=models.MatchValue(value=value)
|
|
)
|
|
for key, value in filter_conditions.items()
|
|
]
|
|
)
|
|
|
|
# Use query_points API (qdrant-client >= 1.7)
|
|
results = self.client.query_points(
|
|
collection_name=self.collection_name,
|
|
query=query_vector,
|
|
limit=k,
|
|
with_payload=True,
|
|
query_filter=query_filter
|
|
)
|
|
|
|
return [
|
|
{
|
|
"id": str(r.id),
|
|
"score": r.score,
|
|
"text": r.payload.get("text", "") if r.payload else "",
|
|
"metadata": {k: v for k, v in r.payload.items() if k != "text"} if r.payload else {}
|
|
}
|
|
for r in results.points
|
|
]
|
|
|
|
def __call__(self, query: str, k: int | None = None) -> list[str]:
|
|
"""DSPy-compatible interface for retrieval.
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results
|
|
|
|
Returns:
|
|
List of passage texts
|
|
"""
|
|
results = self.search(query, k=k)
|
|
return [r["text"] for r in results]
|
|
|
|
def get_collection_info(self) -> dict[str, Any]:
|
|
"""Get collection statistics.
|
|
|
|
Returns:
|
|
Collection info dict
|
|
"""
|
|
try:
|
|
info = self.client.get_collection(self.collection_name)
|
|
return {
|
|
"name": self.collection_name,
|
|
"vectors_count": info.vectors_count,
|
|
"points_count": info.points_count,
|
|
"status": info.status.value if info.status else "unknown"
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get collection info: {e}")
|
|
return {
|
|
"name": self.collection_name,
|
|
"error": str(e)
|
|
}
|
|
|
|
def delete_collection(self) -> bool:
|
|
"""Delete the collection.
|
|
|
|
Returns:
|
|
True if deleted successfully
|
|
"""
|
|
try:
|
|
self.client.delete_collection(self.collection_name)
|
|
logger.info(f"Deleted collection: {self.collection_name}")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete collection: {e}")
|
|
return False
|
|
|
|
|
|
class HeritageCustodianRetriever(QdrantRetriever):
|
|
"""Specialized retriever for heritage custodian institution data.
|
|
|
|
Provides domain-specific search functionality for GLAM institutions.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = DEFAULT_QDRANT_HOST,
|
|
port: int = DEFAULT_QDRANT_PORT,
|
|
**kwargs
|
|
) -> None:
|
|
"""Initialize heritage custodian retriever."""
|
|
super().__init__(
|
|
host=host,
|
|
port=port,
|
|
collection_name="heritage_custodians",
|
|
**kwargs
|
|
)
|
|
|
|
def search_by_type(
|
|
self,
|
|
query: str,
|
|
institution_type: str,
|
|
k: int = 5
|
|
) -> list[dict[str, Any]]:
|
|
"""Search institutions filtered by type.
|
|
|
|
Args:
|
|
query: Search query
|
|
institution_type: Institution type (MUSEUM, LIBRARY, ARCHIVE, etc.)
|
|
k: Number of results
|
|
|
|
Returns:
|
|
List of matching institutions
|
|
"""
|
|
return self.search(
|
|
query=query,
|
|
k=k,
|
|
filter_conditions={"institution_type": institution_type}
|
|
)
|
|
|
|
def search_by_country(
|
|
self,
|
|
query: str,
|
|
country_code: str,
|
|
k: int = 5
|
|
) -> list[dict[str, Any]]:
|
|
"""Search institutions filtered by country.
|
|
|
|
Args:
|
|
query: Search query
|
|
country_code: ISO 3166-1 alpha-2 country code
|
|
k: Number of results
|
|
|
|
Returns:
|
|
List of matching institutions
|
|
"""
|
|
return self.search(
|
|
query=query,
|
|
k=k,
|
|
filter_conditions={"country": country_code}
|
|
)
|
|
|
|
def add_institution(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
institution_type: str,
|
|
country: str,
|
|
city: str | None = None,
|
|
ghcid: str | None = None,
|
|
**extra_metadata
|
|
) -> None:
|
|
"""Add a single institution to the index.
|
|
|
|
Args:
|
|
name: Institution name
|
|
description: Institution description
|
|
institution_type: Type (MUSEUM, LIBRARY, ARCHIVE, etc.)
|
|
country: ISO 3166-1 alpha-2 country code
|
|
city: City name (optional)
|
|
ghcid: Global Heritage Custodian ID (optional)
|
|
**extra_metadata: Additional metadata fields
|
|
"""
|
|
# Create searchable text combining name and description
|
|
text = f"{name}. {description}"
|
|
|
|
metadata = {
|
|
"name": name,
|
|
"institution_type": institution_type,
|
|
"country": country,
|
|
**extra_metadata
|
|
}
|
|
|
|
if city:
|
|
metadata["city"] = city
|
|
if ghcid:
|
|
metadata["ghcid"] = ghcid
|
|
|
|
self.add_documents([{"text": text, "metadata": metadata}])
|
|
|
|
|
|
def create_retriever(
|
|
host: str | None = None,
|
|
port: int | None = None,
|
|
collection_type: str = "heritage_custodians",
|
|
use_production: bool | None = None
|
|
) -> QdrantRetriever:
|
|
"""Factory function to create a retriever instance.
|
|
|
|
Args:
|
|
host: Qdrant server hostname (defaults to env var or localhost)
|
|
port: Qdrant REST API port (defaults to env var or 6333)
|
|
collection_type: Type of collection ('heritage_custodians' or 'general')
|
|
use_production: If True, connect to production server via HTTPS.
|
|
Defaults to QDRANT_USE_PRODUCTION env var or False.
|
|
|
|
Returns:
|
|
Configured retriever instance
|
|
|
|
Examples:
|
|
# Local development (SSH tunnel or local Docker)
|
|
retriever = create_retriever()
|
|
|
|
# Production via HTTPS proxy
|
|
retriever = create_retriever(use_production=True)
|
|
|
|
# Or set environment variable:
|
|
# export QDRANT_USE_PRODUCTION=true
|
|
retriever = create_retriever()
|
|
"""
|
|
# Check if production mode requested
|
|
if use_production is None:
|
|
use_production = os.getenv("QDRANT_USE_PRODUCTION", "").lower() in ("true", "1", "yes")
|
|
|
|
if use_production:
|
|
# Production: connect via HTTPS reverse proxy
|
|
prod_host = os.getenv("QDRANT_PROD_HOST", "bronhouder.nl")
|
|
prod_prefix = os.getenv("QDRANT_PROD_PREFIX", "qdrant")
|
|
|
|
if collection_type == "heritage_custodians":
|
|
return HeritageCustodianRetriever(
|
|
host=prod_host,
|
|
port=443,
|
|
https=True,
|
|
prefix=prod_prefix
|
|
)
|
|
else:
|
|
return QdrantRetriever(
|
|
host=prod_host,
|
|
port=443,
|
|
https=True,
|
|
prefix=prod_prefix
|
|
)
|
|
else:
|
|
# Local development
|
|
host = host or DEFAULT_QDRANT_HOST
|
|
port = port or DEFAULT_QDRANT_PORT
|
|
|
|
if collection_type == "heritage_custodians":
|
|
return HeritageCustodianRetriever(host=host, port=port)
|
|
else:
|
|
return QdrantRetriever(host=host, port=port)
|