""" Qdrant Vector Store Retriever for DSPy Provides semantic search over heritage institution data using Qdrant vector database. Implements DSPy retriever interface for RAG-enhanced SPARQL generation. """ import hashlib import logging import os from typing import Any from qdrant_client import QdrantClient from qdrant_client.http import models from qdrant_client.http.models import Distance, VectorParams logger = logging.getLogger(__name__) # Default configuration DEFAULT_QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost") DEFAULT_QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) DEFAULT_COLLECTION_NAME = "heritage_institutions" DEFAULT_EMBEDDING_DIM = 1536 # OpenAI text-embedding-3-small class QdrantRetriever: """Qdrant-based retriever for heritage institution semantic search. Implements a retriever compatible with DSPy's retrieval module pattern. Uses Qdrant vector database for efficient similarity search. Example usage with DSPy: retriever = QdrantRetriever() # Use in RAG pipeline results = retriever("museums in Amsterdam") """ def __init__( self, host: str = DEFAULT_QDRANT_HOST, port: int = DEFAULT_QDRANT_PORT, collection_name: str = DEFAULT_COLLECTION_NAME, embedding_model: str = "text-embedding-3-small", embedding_dim: int = DEFAULT_EMBEDDING_DIM, k: int = 5, api_key: str | None = None, url: str | None = None, https: bool = False, prefix: str | None = None, ) -> None: """Initialize Qdrant retriever. Args: host: Qdrant server hostname port: Qdrant REST API port collection_name: Name of the Qdrant collection embedding_model: OpenAI embedding model name embedding_dim: Dimension of embedding vectors k: Number of results to retrieve api_key: OpenAI API key for embeddings url: Full URL to Qdrant (deprecated, use host/port/prefix instead) https: Use HTTPS for connection prefix: URL path prefix (e.g., 'qdrant' for /qdrant/*) """ self.host = host self.port = port self.collection_name = collection_name self.embedding_model = embedding_model self.embedding_dim = embedding_dim self.k = k self.api_key = api_key or os.getenv("OPENAI_API_KEY") # Initialize Qdrant client if url: # Legacy: Direct URL connection (may not work with path prefixes) self.client = QdrantClient(url=url, prefer_grpc=False, timeout=60) logger.info(f"Initialized QdrantRetriever: {url}/{collection_name}") elif https or port == 443: # HTTPS connection with optional path prefix # This is the correct way to connect via reverse proxy self.client = QdrantClient( host=host, port=port, https=True, prefix=prefix, prefer_grpc=False, timeout=30 ) prefix_str = f"/{prefix}" if prefix else "" logger.info(f"Initialized QdrantRetriever: https://{host}:{port}{prefix_str}/{collection_name}") else: # Standard host/port connection (local/SSH tunnel) self.client = QdrantClient(host=host, port=port) logger.info(f"Initialized QdrantRetriever: {host}:{port}/{collection_name}") # Lazy-load OpenAI client self._openai_client = None @property def openai_client(self) -> Any: """Lazy-load OpenAI client.""" if self._openai_client is None: try: import openai self._openai_client = openai.OpenAI(api_key=self.api_key) except ImportError: raise ImportError("openai package required for embeddings. Install with: pip install openai") return self._openai_client def _get_embedding(self, text: str) -> list[float]: """Get embedding vector for text using OpenAI. Args: text: Text to embed Returns: Embedding vector as list of floats """ response = self.openai_client.embeddings.create( input=text, model=self.embedding_model ) return response.data[0].embedding def _get_embeddings_batch(self, texts: list[str]) -> list[list[float]]: """Get embedding vectors for multiple texts using OpenAI batch API. Args: texts: List of texts to embed (max 2048 per batch) Returns: List of embedding vectors """ if not texts: return [] response = self.openai_client.embeddings.create( input=texts, model=self.embedding_model ) # Return embeddings in same order as input return [item.embedding for item in sorted(response.data, key=lambda x: x.index)] def _text_to_id(self, text: str) -> str: """Generate deterministic ID from text. Args: text: Text to hash Returns: Hex string ID """ return hashlib.md5(text.encode()).hexdigest() def ensure_collection(self) -> None: """Ensure the collection exists, create if not.""" collections = self.client.get_collections().collections collection_names = [c.name for c in collections] if self.collection_name not in collection_names: logger.info(f"Creating collection: {self.collection_name}") self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams( size=self.embedding_dim, distance=Distance.COSINE ) ) def add_documents( self, documents: list[dict[str, Any]], batch_size: int = 100 ) -> int: """Add documents to the collection using batch embeddings. Args: documents: List of documents with 'text' and optional 'metadata' fields batch_size: Number of documents to embed and upsert per batch Returns: Number of documents added """ self.ensure_collection() # Filter documents with valid text valid_docs = [d for d in documents if d.get("text")] total_indexed = 0 # Process in batches for i in range(0, len(valid_docs), batch_size): batch = valid_docs[i:i + batch_size] texts = [d["text"] for d in batch] # Get embeddings in batch (much faster than one at a time) embeddings = self._get_embeddings_batch(texts) # Create points points = [] for doc, embedding in zip(batch, embeddings): text = doc["text"] metadata = doc.get("metadata", {}) point_id = self._text_to_id(text) points.append(models.PointStruct( id=point_id, vector=embedding, payload={ "text": text, **metadata } )) # Upsert batch self.client.upsert( collection_name=self.collection_name, points=points ) total_indexed += len(points) logger.info(f"Indexed {total_indexed}/{len(valid_docs)} documents") return total_indexed def search( self, query: str, k: int | None = None, filter_conditions: dict[str, Any] | None = None ) -> list[dict[str, Any]]: """Search for similar documents. Args: query: Search query text k: Number of results (defaults to self.k) filter_conditions: Optional Qdrant filter conditions Returns: List of matching documents with scores """ k = k or self.k query_vector = self._get_embedding(query) query_filter = None if filter_conditions: query_filter = models.Filter( must=[ models.FieldCondition( key=key, match=models.MatchValue(value=value) ) for key, value in filter_conditions.items() ] ) # Use query_points API (qdrant-client >= 1.7) results = self.client.query_points( collection_name=self.collection_name, query=query_vector, limit=k, with_payload=True, query_filter=query_filter ) return [ { "id": str(r.id), "score": r.score, "text": r.payload.get("text", "") if r.payload else "", "metadata": {k: v for k, v in r.payload.items() if k != "text"} if r.payload else {} } for r in results.points ] def __call__(self, query: str, k: int | None = None) -> list[str]: """DSPy-compatible interface for retrieval. Args: query: Search query k: Number of results Returns: List of passage texts """ results = self.search(query, k=k) return [r["text"] for r in results] def get_collection_info(self) -> dict[str, Any]: """Get collection statistics. Returns: Collection info dict """ try: info = self.client.get_collection(self.collection_name) return { "name": self.collection_name, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value if info.status else "unknown" } except Exception as e: logger.warning(f"Failed to get collection info: {e}") return { "name": self.collection_name, "error": str(e) } def delete_collection(self) -> bool: """Delete the collection. Returns: True if deleted successfully """ try: self.client.delete_collection(self.collection_name) logger.info(f"Deleted collection: {self.collection_name}") return True except Exception as e: logger.warning(f"Failed to delete collection: {e}") return False class HeritageCustodianRetriever(QdrantRetriever): """Specialized retriever for heritage custodian institution data. Provides domain-specific search functionality for GLAM institutions. """ def __init__( self, host: str = DEFAULT_QDRANT_HOST, port: int = DEFAULT_QDRANT_PORT, **kwargs ) -> None: """Initialize heritage custodian retriever.""" super().__init__( host=host, port=port, collection_name="heritage_custodians", **kwargs ) def search_by_type( self, query: str, institution_type: str, k: int = 5 ) -> list[dict[str, Any]]: """Search institutions filtered by type. Args: query: Search query institution_type: Institution type (MUSEUM, LIBRARY, ARCHIVE, etc.) k: Number of results Returns: List of matching institutions """ return self.search( query=query, k=k, filter_conditions={"institution_type": institution_type} ) def search_by_country( self, query: str, country_code: str, k: int = 5 ) -> list[dict[str, Any]]: """Search institutions filtered by country. Args: query: Search query country_code: ISO 3166-1 alpha-2 country code k: Number of results Returns: List of matching institutions """ return self.search( query=query, k=k, filter_conditions={"country": country_code} ) def add_institution( self, name: str, description: str, institution_type: str, country: str, city: str | None = None, ghcid: str | None = None, **extra_metadata ) -> None: """Add a single institution to the index. Args: name: Institution name description: Institution description institution_type: Type (MUSEUM, LIBRARY, ARCHIVE, etc.) country: ISO 3166-1 alpha-2 country code city: City name (optional) ghcid: Global Heritage Custodian ID (optional) **extra_metadata: Additional metadata fields """ # Create searchable text combining name and description text = f"{name}. {description}" metadata = { "name": name, "institution_type": institution_type, "country": country, **extra_metadata } if city: metadata["city"] = city if ghcid: metadata["ghcid"] = ghcid self.add_documents([{"text": text, "metadata": metadata}]) def create_retriever( host: str | None = None, port: int | None = None, collection_type: str = "heritage_custodians", use_production: bool | None = None ) -> QdrantRetriever: """Factory function to create a retriever instance. Args: host: Qdrant server hostname (defaults to env var or localhost) port: Qdrant REST API port (defaults to env var or 6333) collection_type: Type of collection ('heritage_custodians' or 'general') use_production: If True, connect to production server via HTTPS. Defaults to QDRANT_USE_PRODUCTION env var or False. Returns: Configured retriever instance Examples: # Local development (SSH tunnel or local Docker) retriever = create_retriever() # Production via HTTPS proxy retriever = create_retriever(use_production=True) # Or set environment variable: # export QDRANT_USE_PRODUCTION=true retriever = create_retriever() """ # Check if production mode requested if use_production is None: use_production = os.getenv("QDRANT_USE_PRODUCTION", "").lower() in ("true", "1", "yes") if use_production: # Production: connect via HTTPS reverse proxy prod_host = os.getenv("QDRANT_PROD_HOST", "bronhouder.nl") prod_prefix = os.getenv("QDRANT_PROD_PREFIX", "qdrant") if collection_type == "heritage_custodians": return HeritageCustodianRetriever( host=prod_host, port=443, https=True, prefix=prod_prefix ) else: return QdrantRetriever( host=prod_host, port=443, https=True, prefix=prod_prefix ) else: # Local development host = host or DEFAULT_QDRANT_HOST port = port or DEFAULT_QDRANT_PORT if collection_type == "heritage_custodians": return HeritageCustodianRetriever(host=host, port=port) else: return QdrantRetriever(host=host, port=port)