""" Multi-Embedding Retriever for Heritage Data Supports multiple embedding models using Qdrant's named vectors feature. This enables: - A/B testing different embedding models - Cost optimization (cheap local embeddings vs paid API embeddings) - Gradual migration between embedding models - Fallback when one model is unavailable Supported Embedding Models: - openai_1536: text-embedding-3-small (1536-dim, $0.02/1M tokens) - minilm_384: all-MiniLM-L6-v2 (384-dim, free/local) - bge_768: bge-base-en-v1.5 (768-dim, free/local, high quality) Collection Architecture: Each collection has named vectors for each embedding model: heritage_custodians: vectors: "openai_1536": VectorParams(size=1536) "minilm_384": VectorParams(size=384) payload: {name, ghcid, institution_type, ...} heritage_persons: vectors: "openai_1536": VectorParams(size=1536) "minilm_384": VectorParams(size=384) payload: {name, headline, custodian_name, ...} Usage: retriever = MultiEmbeddingRetriever() # Search with default model (auto-select based on availability) results = retriever.search("museums in Amsterdam") # Search with specific model results = retriever.search("museums in Amsterdam", using="minilm_384") # A/B test comparison comparison = retriever.compare_models("museums in Amsterdam") """ import hashlib import logging import os from dataclasses import dataclass, field from enum import Enum from typing import Any, Literal logger = logging.getLogger(__name__) class EmbeddingModel(str, Enum): """Supported embedding models with their configurations.""" OPENAI_1536 = "openai_1536" MINILM_384 = "minilm_384" BGE_768 = "bge_768" @property def dimension(self) -> int: """Get the vector dimension for this model.""" dims = { "openai_1536": 1536, "minilm_384": 384, "bge_768": 768, } return dims[self.value] @property def model_name(self) -> str: """Get the actual model name for loading.""" names = { "openai_1536": "text-embedding-3-small", "minilm_384": "all-MiniLM-L6-v2", "bge_768": "BAAI/bge-base-en-v1.5", } return names[self.value] @property def is_local(self) -> bool: """Check if this model runs locally (no API calls).""" return self.value in ("minilm_384", "bge_768") @property def cost_per_1m_tokens(self) -> float: """Approximate cost per 1M tokens (0 for local models).""" costs = { "openai_1536": 0.02, "minilm_384": 0.0, "bge_768": 0.0, } return costs[self.value] @dataclass class MultiEmbeddingConfig: """Configuration for multi-embedding retriever.""" # Qdrant connection qdrant_host: str = "localhost" qdrant_port: int = 6333 qdrant_https: bool = False qdrant_prefix: str | None = None # API keys openai_api_key: str | None = None # Default embedding model preference order # First available model is used if no explicit model is specified model_preference: list[EmbeddingModel] = field(default_factory=lambda: [ EmbeddingModel.MINILM_384, # Free, fast, good quality EmbeddingModel.OPENAI_1536, # Higher quality, paid EmbeddingModel.BGE_768, # Free, high quality, slower ]) # Collection names institutions_collection: str = "heritage_custodians" persons_collection: str = "heritage_persons" # Search defaults default_k: int = 10 class MultiEmbeddingRetriever: """Retriever supporting multiple embedding models via Qdrant named vectors. This class manages multiple embedding models and allows searching with any available model. It handles: - Model lazy-loading - Automatic model selection based on availability - Named vector creation and search - A/B testing between models """ def __init__(self, config: MultiEmbeddingConfig | None = None): """Initialize multi-embedding retriever. Args: config: Configuration options. If None, uses environment variables. """ self.config = config or self._config_from_env() # Lazy-loaded clients self._qdrant_client = None self._openai_client = None self._st_models: dict[str, Any] = {} # Sentence transformer models # Track available models per collection self._available_models: dict[str, set[EmbeddingModel]] = {} # Track whether each collection uses named vectors (vs single unnamed vector) self._uses_named_vectors: dict[str, bool] = {} logger.info(f"MultiEmbeddingRetriever initialized with preference: {[m.value for m in self.config.model_preference]}") @staticmethod def _config_from_env() -> MultiEmbeddingConfig: """Create configuration from environment variables.""" use_production = os.getenv("QDRANT_USE_PRODUCTION", "false").lower() == "true" if use_production: return MultiEmbeddingConfig( qdrant_host=os.getenv("QDRANT_PROD_HOST", "bronhouder.nl"), qdrant_port=443, qdrant_https=True, qdrant_prefix=os.getenv("QDRANT_PROD_PREFIX", "qdrant"), openai_api_key=os.getenv("OPENAI_API_KEY"), ) else: return MultiEmbeddingConfig( qdrant_host=os.getenv("QDRANT_HOST", "localhost"), qdrant_port=int(os.getenv("QDRANT_PORT", "6333")), openai_api_key=os.getenv("OPENAI_API_KEY"), ) @property def qdrant_client(self): """Lazy-load Qdrant client.""" if self._qdrant_client is None: from qdrant_client import QdrantClient if self.config.qdrant_https: self._qdrant_client = QdrantClient( host=self.config.qdrant_host, port=self.config.qdrant_port, https=True, prefix=self.config.qdrant_prefix, prefer_grpc=False, timeout=30, ) logger.info(f"Connected to Qdrant: https://{self.config.qdrant_host}/{self.config.qdrant_prefix or ''}") else: self._qdrant_client = QdrantClient( host=self.config.qdrant_host, port=self.config.qdrant_port, ) logger.info(f"Connected to Qdrant: {self.config.qdrant_host}:{self.config.qdrant_port}") return self._qdrant_client @property def openai_client(self): """Lazy-load OpenAI client.""" if self._openai_client is None: if not self.config.openai_api_key: raise RuntimeError("OpenAI API key not configured") import openai self._openai_client = openai.OpenAI(api_key=self.config.openai_api_key) return self._openai_client def _load_sentence_transformer(self, model: EmbeddingModel) -> Any: """Lazy-load a sentence-transformers model. Args: model: The embedding model to load Returns: Loaded SentenceTransformer model """ if model.value not in self._st_models: try: from sentence_transformers import SentenceTransformer self._st_models[model.value] = SentenceTransformer(model.model_name) logger.info(f"Loaded sentence-transformers model: {model.model_name}") except ImportError: raise RuntimeError( "sentence-transformers not installed. Run: pip install sentence-transformers" ) return self._st_models[model.value] def get_embedding(self, text: str, model: EmbeddingModel) -> list[float]: """Get embedding vector for text using specified model. Args: text: Text to embed model: Embedding model to use Returns: Embedding vector as list of floats """ if model == EmbeddingModel.OPENAI_1536: response = self.openai_client.embeddings.create( input=text, model=model.model_name, ) return response.data[0].embedding elif model in (EmbeddingModel.MINILM_384, EmbeddingModel.BGE_768): st_model = self._load_sentence_transformer(model) embedding = st_model.encode(text) return embedding.tolist() else: raise ValueError(f"Unknown embedding model: {model}") def get_embeddings_batch( self, texts: list[str], model: EmbeddingModel, batch_size: int = 32, ) -> list[list[float]]: """Get embedding vectors for multiple texts. Args: texts: List of texts to embed model: Embedding model to use batch_size: Batch size for processing Returns: List of embedding vectors """ if not texts: return [] if model == EmbeddingModel.OPENAI_1536: # OpenAI batch API (max 2048 per request) all_embeddings = [] for i in range(0, len(texts), 2048): batch = texts[i:i + 2048] response = self.openai_client.embeddings.create( input=batch, model=model.model_name, ) batch_embeddings = [item.embedding for item in sorted(response.data, key=lambda x: x.index)] all_embeddings.extend(batch_embeddings) return all_embeddings elif model in (EmbeddingModel.MINILM_384, EmbeddingModel.BGE_768): st_model = self._load_sentence_transformer(model) embeddings = st_model.encode(texts, batch_size=batch_size, show_progress_bar=len(texts) > 100) return embeddings.tolist() else: raise ValueError(f"Unknown embedding model: {model}") def get_available_models(self, collection_name: str) -> set[EmbeddingModel]: """Get the embedding models available for a collection. Checks which named vectors exist in the collection. For single-vector collections, returns models matching the dimension. Args: collection_name: Name of the Qdrant collection Returns: Set of available EmbeddingModel values """ if collection_name in self._available_models: return self._available_models[collection_name] try: info = self.qdrant_client.get_collection(collection_name) vectors_config = info.config.params.vectors available = set() uses_named_vectors = False # Check for named vectors (dict of vector configs) if isinstance(vectors_config, dict): # Named vectors - each key is a vector name uses_named_vectors = True for vector_name in vectors_config.keys(): try: model = EmbeddingModel(vector_name) available.add(model) except ValueError: logger.warning(f"Unknown vector name in collection: {vector_name}") else: # Single unnamed vector - check dimension to find compatible model # Note: This doesn't mean we can use `using=model.value` in queries uses_named_vectors = False if hasattr(vectors_config, 'size'): dim = vectors_config.size for model in EmbeddingModel: if model.dimension == dim: available.add(model) # Store both available models and whether named vectors are used self._available_models[collection_name] = available self._uses_named_vectors[collection_name] = uses_named_vectors if uses_named_vectors: logger.info(f"Collection '{collection_name}' uses named vectors: {[m.value for m in available]}") else: logger.info(f"Collection '{collection_name}' uses single vector (compatible with: {[m.value for m in available]})") return available except Exception as e: logger.warning(f"Could not get available models for {collection_name}: {e}") return set() def uses_named_vectors(self, collection_name: str) -> bool: """Check if a collection uses named vectors (vs single unnamed vector). Args: collection_name: Name of the Qdrant collection Returns: True if collection has named vectors, False for single-vector collections """ # Ensure models are loaded (populates _uses_named_vectors) self.get_available_models(collection_name) return self._uses_named_vectors.get(collection_name, False) def select_model( self, collection_name: str, preferred: EmbeddingModel | None = None, ) -> EmbeddingModel | None: """Select the best available embedding model for a collection. Args: collection_name: Name of the collection preferred: Preferred model (used if available) Returns: Selected EmbeddingModel or None if none available """ available = self.get_available_models(collection_name) if not available: # No named vectors - check if we can use any model # This happens for legacy single-vector collections try: info = self.qdrant_client.get_collection(collection_name) vectors_config = info.config.params.vectors # Get vector dimension dim = None if hasattr(vectors_config, 'size'): dim = vectors_config.size elif isinstance(vectors_config, dict): # Get first vector config first_config = next(iter(vectors_config.values()), None) if first_config and hasattr(first_config, 'size'): dim = first_config.size if dim: for model in self.config.model_preference: if model.dimension == dim: return model except Exception: pass return None # If preferred model is available, use it if preferred and preferred in available: return preferred # Otherwise, follow preference order for model in self.config.model_preference: if model in available: # Check if model is usable (has API key if needed) if model == EmbeddingModel.OPENAI_1536 and not self.config.openai_api_key: continue return model return None def search( self, query: str, collection_name: str | None = None, k: int | None = None, using: EmbeddingModel | str | None = None, filter_conditions: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: """Search for similar documents using specified or auto-selected model. Args: query: Search query text collection_name: Collection to search (default: institutions) k: Number of results using: Embedding model to use (auto-selected if None) filter_conditions: Optional Qdrant filter conditions Returns: List of results with scores and payloads """ collection_name = collection_name or self.config.institutions_collection k = k or self.config.default_k # Resolve model if using is not None: if isinstance(using, str): model = EmbeddingModel(using) else: model = using else: model = self.select_model(collection_name) if model is None: raise RuntimeError(f"No compatible embedding model for collection '{collection_name}'") logger.info(f"Searching '{collection_name}' with {model.value}: {query[:50]}...") # Get query embedding query_vector = self.get_embedding(query, model) # Build filter from qdrant_client.http import models query_filter = None if filter_conditions: query_filter = models.Filter( must=[ models.FieldCondition( key=key, match=models.MatchValue(value=value), ) for key, value in filter_conditions.items() ] ) # Check if collection uses named vectors (not just single unnamed vector) # Only pass `using=model.value` if collection has actual named vectors use_named_vector = self.uses_named_vectors(collection_name) # Search if use_named_vector: results = self.qdrant_client.query_points( collection_name=collection_name, query=query_vector, using=model.value, limit=k, with_payload=True, query_filter=query_filter, ) else: # Legacy single-vector search results = self.qdrant_client.query_points( collection_name=collection_name, query=query_vector, limit=k, with_payload=True, query_filter=query_filter, ) return [ { "id": str(point.id), "score": point.score, "model": model.value, "payload": point.payload or {}, } for point in results.points ] def search_persons( self, query: str, k: int | None = None, using: EmbeddingModel | str | None = None, filter_custodian: str | None = None, only_heritage_relevant: bool = False, only_wcms: bool = False, ) -> list[dict[str, Any]]: """Search for persons/staff in the heritage_persons collection. Args: query: Search query text k: Number of results using: Embedding model to use filter_custodian: Optional custodian slug to filter by only_heritage_relevant: Only return heritage-relevant staff only_wcms: Only return WCMS-registered profiles (heritage sector users) Returns: List of person results with scores """ k = k or self.config.default_k # Build filters filters = {} if filter_custodian: filters["custodian_slug"] = filter_custodian if only_wcms: filters["has_wcms"] = True # Search with over-fetch for post-filtering results = self.search( query=query, collection_name=self.config.persons_collection, k=k * 2, using=using, filter_conditions=filters if filters else None, ) # Post-filter for heritage_relevant if needed if only_heritage_relevant: results = [r for r in results if r.get("payload", {}).get("heritage_relevant", False)] # Format results formatted = [] for r in results[:k]: payload = r.get("payload", {}) formatted.append({ "person_id": payload.get("staff_id", "") or hashlib.md5( f"{payload.get('custodian_slug', '')}:{payload.get('name', '')}".encode() ).hexdigest()[:16], "name": payload.get("name", ""), "headline": payload.get("headline"), "custodian_name": payload.get("custodian_name"), "custodian_slug": payload.get("custodian_slug"), "location": payload.get("location"), "heritage_relevant": payload.get("heritage_relevant", False), "heritage_type": payload.get("heritage_type"), "linkedin_url": payload.get("linkedin_url"), "score": r["score"], "model": r["model"], }) return formatted def compare_models( self, query: str, collection_name: str | None = None, k: int = 10, models: list[EmbeddingModel] | None = None, ) -> dict[str, Any]: """A/B test comparison of multiple embedding models. Args: query: Search query collection_name: Collection to search k: Number of results per model models: Models to compare (default: all available) Returns: Dict with results per model and overlap analysis """ collection_name = collection_name or self.config.institutions_collection # Determine which models to compare available = self.get_available_models(collection_name) if models: models_to_test = [m for m in models if m in available] else: models_to_test = list(available) if not models_to_test: return {"error": "No models available for comparison"} results = {} all_ids = {} for model in models_to_test: try: model_results = self.search( query=query, collection_name=collection_name, k=k, using=model, ) results[model.value] = model_results all_ids[model.value] = {r["id"] for r in model_results} except Exception as e: results[model.value] = {"error": str(e)} all_ids[model.value] = set() # Calculate overlap between models overlap = {} model_values = list(all_ids.keys()) for i, m1 in enumerate(model_values): for m2 in model_values[i + 1:]: if all_ids[m1] and all_ids[m2]: intersection = all_ids[m1] & all_ids[m2] union = all_ids[m1] | all_ids[m2] jaccard = len(intersection) / len(union) if union else 0 overlap[f"{m1}_vs_{m2}"] = { "jaccard_similarity": round(jaccard, 3), "common_results": len(intersection), "total_unique": len(union), } return { "query": query, "collection": collection_name, "k": k, "results": results, "overlap_analysis": overlap, } def create_multi_embedding_collection( self, collection_name: str, models: list[EmbeddingModel] | None = None, ) -> bool: """Create a new collection with named vectors for multiple embedding models. Args: collection_name: Name for the new collection models: Embedding models to support (default: all) Returns: True if created successfully """ from qdrant_client.http.models import Distance, VectorParams models = models or list(EmbeddingModel) vectors_config = { model.value: VectorParams( size=model.dimension, distance=Distance.COSINE, ) for model in models } try: self.qdrant_client.create_collection( collection_name=collection_name, vectors_config=vectors_config, ) logger.info(f"Created multi-embedding collection '{collection_name}' with {[m.value for m in models]}") # Clear cache self._available_models.pop(collection_name, None) return True except Exception as e: logger.error(f"Failed to create collection: {e}") return False def add_documents_multi_embedding( self, documents: list[dict[str, Any]], collection_name: str, models: list[EmbeddingModel] | None = None, batch_size: int = 100, ) -> int: """Add documents with embeddings from multiple models. Args: documents: List of documents with 'text' and optional 'metadata' fields collection_name: Target collection models: Models to generate embeddings for (default: all available) batch_size: Batch size for processing Returns: Number of documents added """ from qdrant_client.http import models as qmodels # Determine which models to use available = self.get_available_models(collection_name) if models: models_to_use = [m for m in models if m in available] else: models_to_use = list(available) if not models_to_use: raise RuntimeError(f"No embedding models available for collection '{collection_name}'") # Filter valid documents valid_docs = [d for d in documents if d.get("text")] total_indexed = 0 for i in range(0, len(valid_docs), batch_size): batch = valid_docs[i:i + batch_size] texts = [d["text"] for d in batch] # Generate embeddings for each model embeddings_by_model = {} for model in models_to_use: try: embeddings_by_model[model] = self.get_embeddings_batch(texts, model) except Exception as e: logger.warning(f"Failed to get {model.value} embeddings: {e}") if not embeddings_by_model: continue # Create points with named vectors points = [] for j, doc in enumerate(batch): text = doc["text"] metadata = doc.get("metadata", {}) point_id = doc.get("id") or hashlib.md5(text.encode()).hexdigest() # Build named vectors dict vectors = {} for model, model_embeddings in embeddings_by_model.items(): vectors[model.value] = model_embeddings[j] points.append(qmodels.PointStruct( id=point_id, vector=vectors, payload={ "text": text, **metadata, } )) # Upsert batch self.qdrant_client.upsert( collection_name=collection_name, points=points, ) total_indexed += len(points) logger.info(f"Indexed {total_indexed}/{len(valid_docs)} documents with {len(models_to_use)} models") return total_indexed def get_stats(self) -> dict[str, Any]: """Get statistics about collections and available models. Returns: Dict with collection stats and model availability """ stats = { "config": { "qdrant_host": self.config.qdrant_host, "qdrant_port": self.config.qdrant_port, "model_preference": [m.value for m in self.config.model_preference], "openai_available": bool(self.config.openai_api_key), }, "collections": {}, } for collection_name in [self.config.institutions_collection, self.config.persons_collection]: try: info = self.qdrant_client.get_collection(collection_name) available_models = self.get_available_models(collection_name) selected_model = self.select_model(collection_name) stats["collections"][collection_name] = { "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value if info.status else "unknown", "available_models": [m.value for m in available_models], "selected_model": selected_model.value if selected_model else None, } except Exception as e: stats["collections"][collection_name] = {"error": str(e)} return stats def close(self): """Close all connections.""" if self._qdrant_client: self._qdrant_client.close() self._qdrant_client = None self._st_models.clear() self._available_models.clear() self._uses_named_vectors.clear() def create_multi_embedding_retriever(use_production: bool | None = None) -> MultiEmbeddingRetriever: """Factory function to create a MultiEmbeddingRetriever. Args: use_production: If True, connect to production Qdrant. Defaults to QDRANT_USE_PRODUCTION env var. Returns: Configured MultiEmbeddingRetriever instance """ if use_production is None: use_production = os.getenv("QDRANT_USE_PRODUCTION", "").lower() in ("true", "1", "yes") if use_production: config = MultiEmbeddingConfig( qdrant_host=os.getenv("QDRANT_PROD_HOST", "bronhouder.nl"), qdrant_port=443, qdrant_https=True, qdrant_prefix=os.getenv("QDRANT_PROD_PREFIX", "qdrant"), openai_api_key=os.getenv("OPENAI_API_KEY"), ) else: config = MultiEmbeddingConfig( qdrant_host=os.getenv("QDRANT_HOST", "localhost"), qdrant_port=int(os.getenv("QDRANT_PORT", "6333")), openai_api_key=os.getenv("OPENAI_API_KEY"), ) return MultiEmbeddingRetriever(config)