glam/backend/rag/multi_embedding_retriever.py
kempersc 80eb3d969c Add new slots for heritage custodian ontology
- Introduced `has_api_version`, `has_appellation_language`, `has_appellation_type`, `has_appellation_value`, `has_applicable_country`, `has_application_deadline`, `has_application_opening_date`, `has_appraisal_note`, `has_approval_date`, `has_archdiocese_name`, `has_architectural_style`, `has_archival_reference`, `has_archive_description`, `has_archive_memento_uri`, `has_archive_name`, `has_archive_path`, `has_archive_search_score`, `has_arrangement`, `has_arrangement_level`, `has_arrangement_note`, `has_articles_archival_stage`, `has_articles_document_format`, `has_articles_document_url`, `has_articles_of_association`, `has_or_had_altitude`, `has_or_had_annotation`, `has_or_had_arrangement`, `has_or_had_document`, `has_or_had_reason`, `has_or_had_style`, `is_or_was_amended_through`, `is_or_was_approved_on`, `is_or_was_archived_as`, `is_or_was_due_on`, `is_or_was_opened_on`, and `is_or_was_used_in` slots.
- Each slot includes detailed descriptions, range specifications, and appropriate mappings to existing ontologies.
2026-01-27 10:07:16 +01:00

846 lines
30 KiB
Python

"""
Multi-Embedding Retriever for Heritage Data
Supports multiple embedding models using Qdrant's named vectors feature.
This enables:
- A/B testing different embedding models
- Cost optimization (cheap local embeddings vs paid API embeddings)
- Gradual migration between embedding models
- Fallback when one model is unavailable
Supported Embedding Models:
- openai_1536: text-embedding-3-small (1536-dim, $0.02/1M tokens)
- minilm_384: all-MiniLM-L6-v2 (384-dim, free/local)
- bge_768: bge-base-en-v1.5 (768-dim, free/local, high quality)
Collection Architecture:
Each collection has named vectors for each embedding model:
heritage_custodians:
vectors:
"openai_1536": VectorParams(size=1536)
"minilm_384": VectorParams(size=384)
payload: {name, ghcid, institution_type, ...}
heritage_persons:
vectors:
"openai_1536": VectorParams(size=1536)
"minilm_384": VectorParams(size=384)
payload: {name, headline, custodian_name, ...}
Usage:
retriever = MultiEmbeddingRetriever()
# Search with default model (auto-select based on availability)
results = retriever.search("museums in Amsterdam")
# Search with specific model
results = retriever.search("museums in Amsterdam", using="minilm_384")
# A/B test comparison
comparison = retriever.compare_models("museums in Amsterdam")
"""
import hashlib
import logging
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Literal
logger = logging.getLogger(__name__)
class EmbeddingModel(str, Enum):
"""Supported embedding models with their configurations."""
OPENAI_1536 = "openai_1536"
MINILM_384 = "minilm_384"
BGE_768 = "bge_768"
@property
def dimension(self) -> int:
"""Get the vector dimension for this model."""
dims = {
"openai_1536": 1536,
"minilm_384": 384,
"bge_768": 768,
}
return dims[self.value]
@property
def model_name(self) -> str:
"""Get the actual model name for loading."""
names = {
"openai_1536": "text-embedding-3-small",
"minilm_384": "all-MiniLM-L6-v2",
"bge_768": "BAAI/bge-base-en-v1.5",
}
return names[self.value]
@property
def is_local(self) -> bool:
"""Check if this model runs locally (no API calls)."""
return self.value in ("minilm_384", "bge_768")
@property
def cost_per_1m_tokens(self) -> float:
"""Approximate cost per 1M tokens (0 for local models)."""
costs = {
"openai_1536": 0.02,
"minilm_384": 0.0,
"bge_768": 0.0,
}
return costs[self.value]
@dataclass
class MultiEmbeddingConfig:
"""Configuration for multi-embedding retriever."""
# Qdrant connection
qdrant_host: str = "localhost"
qdrant_port: int = 6333
qdrant_https: bool = False
qdrant_prefix: str | None = None
# API keys
openai_api_key: str | None = None
# Default embedding model preference order
# First available model is used if no explicit model is specified
model_preference: list[EmbeddingModel] = field(default_factory=lambda: [
EmbeddingModel.MINILM_384, # Free, fast, good quality
EmbeddingModel.OPENAI_1536, # Higher quality, paid
EmbeddingModel.BGE_768, # Free, high quality, slower
])
# Collection names
institutions_collection: str = "heritage_custodians"
persons_collection: str = "heritage_persons"
# Search defaults
default_k: int = 10
class MultiEmbeddingRetriever:
"""Retriever supporting multiple embedding models via Qdrant named vectors.
This class manages multiple embedding models and allows searching with
any available model. It handles:
- Model lazy-loading
- Automatic model selection based on availability
- Named vector creation and search
- A/B testing between models
"""
def __init__(self, config: MultiEmbeddingConfig | None = None):
"""Initialize multi-embedding retriever.
Args:
config: Configuration options. If None, uses environment variables.
"""
self.config = config or self._config_from_env()
# Lazy-loaded clients
self._qdrant_client = None
self._openai_client = None
self._st_models: dict[str, Any] = {} # Sentence transformer models
# Track available models per collection
self._available_models: dict[str, set[EmbeddingModel]] = {}
# Track whether each collection uses named vectors (vs single unnamed vector)
self._uses_named_vectors: dict[str, bool] = {}
logger.info(f"MultiEmbeddingRetriever initialized with preference: {[m.value for m in self.config.model_preference]}")
@staticmethod
def _config_from_env() -> MultiEmbeddingConfig:
"""Create configuration from environment variables."""
use_production = os.getenv("QDRANT_USE_PRODUCTION", "false").lower() == "true"
if use_production:
return MultiEmbeddingConfig(
qdrant_host=os.getenv("QDRANT_PROD_HOST", "bronhouder.nl"),
qdrant_port=443,
qdrant_https=True,
qdrant_prefix=os.getenv("QDRANT_PROD_PREFIX", "qdrant"),
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
else:
return MultiEmbeddingConfig(
qdrant_host=os.getenv("QDRANT_HOST", "localhost"),
qdrant_port=int(os.getenv("QDRANT_PORT", "6333")),
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
@property
def qdrant_client(self):
"""Lazy-load Qdrant client."""
if self._qdrant_client is None:
from qdrant_client import QdrantClient
if self.config.qdrant_https:
self._qdrant_client = QdrantClient(
host=self.config.qdrant_host,
port=self.config.qdrant_port,
https=True,
prefix=self.config.qdrant_prefix,
prefer_grpc=False,
timeout=30,
)
logger.info(f"Connected to Qdrant: https://{self.config.qdrant_host}/{self.config.qdrant_prefix or ''}")
else:
self._qdrant_client = QdrantClient(
host=self.config.qdrant_host,
port=self.config.qdrant_port,
)
logger.info(f"Connected to Qdrant: {self.config.qdrant_host}:{self.config.qdrant_port}")
return self._qdrant_client
@property
def openai_client(self):
"""Lazy-load OpenAI client."""
if self._openai_client is None:
if not self.config.openai_api_key:
raise RuntimeError("OpenAI API key not configured")
import openai
self._openai_client = openai.OpenAI(api_key=self.config.openai_api_key)
return self._openai_client
def _load_sentence_transformer(self, model: EmbeddingModel) -> Any:
"""Lazy-load a sentence-transformers model.
Args:
model: The embedding model to load
Returns:
Loaded SentenceTransformer model
"""
if model.value not in self._st_models:
try:
from sentence_transformers import SentenceTransformer
self._st_models[model.value] = SentenceTransformer(model.model_name)
logger.info(f"Loaded sentence-transformers model: {model.model_name}")
except ImportError:
raise RuntimeError(
"sentence-transformers not installed. Run: pip install sentence-transformers"
)
return self._st_models[model.value]
def get_embedding(self, text: str, model: EmbeddingModel) -> list[float]:
"""Get embedding vector for text using specified model.
Args:
text: Text to embed
model: Embedding model to use
Returns:
Embedding vector as list of floats
"""
if model == EmbeddingModel.OPENAI_1536:
response = self.openai_client.embeddings.create(
input=text,
model=model.model_name,
)
return response.data[0].embedding
elif model in (EmbeddingModel.MINILM_384, EmbeddingModel.BGE_768):
st_model = self._load_sentence_transformer(model)
embedding = st_model.encode(text)
return embedding.tolist()
else:
raise ValueError(f"Unknown embedding model: {model}")
def get_embeddings_batch(
self,
texts: list[str],
model: EmbeddingModel,
batch_size: int = 32,
) -> list[list[float]]:
"""Get embedding vectors for multiple texts.
Args:
texts: List of texts to embed
model: Embedding model to use
batch_size: Batch size for processing
Returns:
List of embedding vectors
"""
if not texts:
return []
if model == EmbeddingModel.OPENAI_1536:
# OpenAI batch API (max 2048 per request)
all_embeddings = []
for i in range(0, len(texts), 2048):
batch = texts[i:i + 2048]
response = self.openai_client.embeddings.create(
input=batch,
model=model.model_name,
)
batch_embeddings = [item.embedding for item in sorted(response.data, key=lambda x: x.index)]
all_embeddings.extend(batch_embeddings)
return all_embeddings
elif model in (EmbeddingModel.MINILM_384, EmbeddingModel.BGE_768):
st_model = self._load_sentence_transformer(model)
embeddings = st_model.encode(texts, batch_size=batch_size, show_progress_bar=len(texts) > 100)
return embeddings.tolist()
else:
raise ValueError(f"Unknown embedding model: {model}")
def get_available_models(self, collection_name: str) -> set[EmbeddingModel]:
"""Get the embedding models available for a collection.
Checks which named vectors exist in the collection.
For single-vector collections, returns models matching the dimension.
Args:
collection_name: Name of the Qdrant collection
Returns:
Set of available EmbeddingModel values
"""
if collection_name in self._available_models:
return self._available_models[collection_name]
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_config = info.config.params.vectors
available = set()
uses_named_vectors = False
# Check for named vectors (dict of vector configs)
if isinstance(vectors_config, dict):
# Named vectors - each key is a vector name
uses_named_vectors = True
for vector_name in vectors_config.keys():
try:
model = EmbeddingModel(vector_name)
available.add(model)
except ValueError:
logger.warning(f"Unknown vector name in collection: {vector_name}")
else:
# Single unnamed vector - check dimension to find compatible model
# Note: This doesn't mean we can use `using=model.value` in queries
uses_named_vectors = False
if hasattr(vectors_config, 'size'):
dim = vectors_config.size
for model in EmbeddingModel:
if model.dimension == dim:
available.add(model)
# Store both available models and whether named vectors are used
self._available_models[collection_name] = available
self._uses_named_vectors[collection_name] = uses_named_vectors
if uses_named_vectors:
logger.info(f"Collection '{collection_name}' uses named vectors: {[m.value for m in available]}")
else:
logger.info(f"Collection '{collection_name}' uses single vector (compatible with: {[m.value for m in available]})")
return available
except Exception as e:
logger.warning(f"Could not get available models for {collection_name}: {e}")
return set()
def uses_named_vectors(self, collection_name: str) -> bool:
"""Check if a collection uses named vectors (vs single unnamed vector).
Args:
collection_name: Name of the Qdrant collection
Returns:
True if collection has named vectors, False for single-vector collections
"""
# Ensure models are loaded (populates _uses_named_vectors)
self.get_available_models(collection_name)
return self._uses_named_vectors.get(collection_name, False)
def select_model(
self,
collection_name: str,
preferred: EmbeddingModel | None = None,
) -> EmbeddingModel | None:
"""Select the best available embedding model for a collection.
Args:
collection_name: Name of the collection
preferred: Preferred model (used if available)
Returns:
Selected EmbeddingModel or None if none available
"""
available = self.get_available_models(collection_name)
if not available:
# No named vectors - check if we can use any model
# This happens for legacy single-vector collections
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_config = info.config.params.vectors
# Get vector dimension
dim = None
if hasattr(vectors_config, 'size'):
dim = vectors_config.size
elif isinstance(vectors_config, dict):
# Get first vector config
first_config = next(iter(vectors_config.values()), None)
if first_config and hasattr(first_config, 'size'):
dim = first_config.size
if dim:
for model in self.config.model_preference:
if model.dimension == dim:
return model
except Exception:
pass
return None
# If preferred model is available, use it
if preferred and preferred in available:
return preferred
# Otherwise, follow preference order
for model in self.config.model_preference:
if model in available:
# Check if model is usable (has API key if needed)
if model == EmbeddingModel.OPENAI_1536 and not self.config.openai_api_key:
continue
return model
return None
def search(
self,
query: str,
collection_name: str | None = None,
k: int | None = None,
using: EmbeddingModel | str | None = None,
filter_conditions: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""Search for similar documents using specified or auto-selected model.
Args:
query: Search query text
collection_name: Collection to search (default: institutions)
k: Number of results
using: Embedding model to use (auto-selected if None)
filter_conditions: Optional Qdrant filter conditions
Returns:
List of results with scores and payloads
"""
collection_name = collection_name or self.config.institutions_collection
k = k or self.config.default_k
# Resolve model
if using is not None:
if isinstance(using, str):
model = EmbeddingModel(using)
else:
model = using
else:
model = self.select_model(collection_name)
if model is None:
raise RuntimeError(f"No compatible embedding model for collection '{collection_name}'")
logger.info(f"Searching '{collection_name}' with {model.value}: {query[:50]}...")
# Get query embedding
query_vector = self.get_embedding(query, model)
# Build filter
from qdrant_client.http import models
query_filter = None
if filter_conditions:
query_filter = models.Filter(
must=[
models.FieldCondition(
key=key,
match=models.MatchValue(value=value),
)
for key, value in filter_conditions.items()
]
)
# Check if collection uses named vectors (not just single unnamed vector)
# Only pass `using=model.value` if collection has actual named vectors
use_named_vector = self.uses_named_vectors(collection_name)
# Search
if use_named_vector:
results = self.qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using=model.value,
limit=k,
with_payload=True,
query_filter=query_filter,
)
else:
# Legacy single-vector search
results = self.qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
limit=k,
with_payload=True,
query_filter=query_filter,
)
return [
{
"id": str(point.id),
"score": point.score,
"model": model.value,
"payload": point.payload or {},
}
for point in results.points
]
def search_persons(
self,
query: str,
k: int | None = None,
using: EmbeddingModel | str | None = None,
filter_custodian: str | None = None,
only_heritage_relevant: bool = False,
only_wcms: bool = False,
) -> list[dict[str, Any]]:
"""Search for persons/staff in the heritage_persons collection.
Args:
query: Search query text
k: Number of results
using: Embedding model to use
filter_custodian: Optional custodian slug to filter by
only_heritage_relevant: Only return heritage-relevant staff
only_wcms: Only return WCMS-registered profiles (heritage sector users)
Returns:
List of person results with scores
"""
k = k or self.config.default_k
# Build filters
filters = {}
if filter_custodian:
filters["custodian_slug"] = filter_custodian
if only_wcms:
filters["has_wcms"] = True
# Search with over-fetch for post-filtering
results = self.search(
query=query,
collection_name=self.config.persons_collection,
k=k * 2,
using=using,
filter_conditions=filters if filters else None,
)
# Post-filter for heritage_relevant if needed
if only_heritage_relevant:
results = [r for r in results if r.get("payload", {}).get("heritage_relevant", False)]
# Format results
formatted = []
for r in results[:k]:
payload = r.get("payload", {})
formatted.append({
"person_id": payload.get("staff_id", "") or hashlib.md5(
f"{payload.get('custodian_slug', '')}:{payload.get('name', '')}".encode()
).hexdigest()[:16],
"name": payload.get("name", ""),
"headline": payload.get("headline"),
"custodian_name": payload.get("custodian_name"),
"custodian_slug": payload.get("custodian_slug"),
"location": payload.get("location"),
"heritage_relevant": payload.get("heritage_relevant", False),
"heritage_type": payload.get("heritage_type"),
"linkedin_url": payload.get("linkedin_url"),
"score": r["score"],
"model": r["model"],
})
return formatted
def compare_models(
self,
query: str,
collection_name: str | None = None,
k: int = 10,
models: list[EmbeddingModel] | None = None,
) -> dict[str, Any]:
"""A/B test comparison of multiple embedding models.
Args:
query: Search query
collection_name: Collection to search
k: Number of results per model
models: Models to compare (default: all available)
Returns:
Dict with results per model and overlap analysis
"""
collection_name = collection_name or self.config.institutions_collection
# Determine which models to compare
available = self.get_available_models(collection_name)
if models:
models_to_test = [m for m in models if m in available]
else:
models_to_test = list(available)
if not models_to_test:
return {"error": "No models available for comparison"}
results = {}
all_ids = {}
for model in models_to_test:
try:
model_results = self.search(
query=query,
collection_name=collection_name,
k=k,
using=model,
)
results[model.value] = model_results
all_ids[model.value] = {r["id"] for r in model_results}
except Exception as e:
results[model.value] = {"error": str(e)}
all_ids[model.value] = set()
# Calculate overlap between models
overlap = {}
model_values = list(all_ids.keys())
for i, m1 in enumerate(model_values):
for m2 in model_values[i + 1:]:
if all_ids[m1] and all_ids[m2]:
intersection = all_ids[m1] & all_ids[m2]
union = all_ids[m1] | all_ids[m2]
jaccard = len(intersection) / len(union) if union else 0
overlap[f"{m1}_vs_{m2}"] = {
"jaccard_similarity": round(jaccard, 3),
"common_results": len(intersection),
"total_unique": len(union),
}
return {
"query": query,
"collection": collection_name,
"k": k,
"results": results,
"overlap_analysis": overlap,
}
def create_multi_embedding_collection(
self,
collection_name: str,
models: list[EmbeddingModel] | None = None,
) -> bool:
"""Create a new collection with named vectors for multiple embedding models.
Args:
collection_name: Name for the new collection
models: Embedding models to support (default: all)
Returns:
True if created successfully
"""
from qdrant_client.http.models import Distance, VectorParams
models = models or list(EmbeddingModel)
vectors_config = {
model.value: VectorParams(
size=model.dimension,
distance=Distance.COSINE,
)
for model in models
}
try:
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
logger.info(f"Created multi-embedding collection '{collection_name}' with {[m.value for m in models]}")
# Clear cache
self._available_models.pop(collection_name, None)
return True
except Exception as e:
logger.error(f"Failed to create collection: {e}")
return False
def add_documents_multi_embedding(
self,
documents: list[dict[str, Any]],
collection_name: str,
models: list[EmbeddingModel] | None = None,
batch_size: int = 100,
) -> int:
"""Add documents with embeddings from multiple models.
Args:
documents: List of documents with 'text' and optional 'metadata' fields
collection_name: Target collection
models: Models to generate embeddings for (default: all available)
batch_size: Batch size for processing
Returns:
Number of documents added
"""
from qdrant_client.http import models as qmodels
# Determine which models to use
available = self.get_available_models(collection_name)
if models:
models_to_use = [m for m in models if m in available]
else:
models_to_use = list(available)
if not models_to_use:
raise RuntimeError(f"No embedding models available for collection '{collection_name}'")
# Filter valid documents
valid_docs = [d for d in documents if d.get("text")]
total_indexed = 0
for i in range(0, len(valid_docs), batch_size):
batch = valid_docs[i:i + batch_size]
texts = [d["text"] for d in batch]
# Generate embeddings for each model
embeddings_by_model = {}
for model in models_to_use:
try:
embeddings_by_model[model] = self.get_embeddings_batch(texts, model)
except Exception as e:
logger.warning(f"Failed to get {model.value} embeddings: {e}")
if not embeddings_by_model:
continue
# Create points with named vectors
points = []
for j, doc in enumerate(batch):
text = doc["text"]
metadata = doc.get("metadata", {})
point_id = doc.get("id") or hashlib.md5(text.encode()).hexdigest()
# Build named vectors dict
vectors = {}
for model, model_embeddings in embeddings_by_model.items():
vectors[model.value] = model_embeddings[j]
points.append(qmodels.PointStruct(
id=point_id,
vector=vectors,
payload={
"text": text,
**metadata,
}
))
# Upsert batch
self.qdrant_client.upsert(
collection_name=collection_name,
points=points,
)
total_indexed += len(points)
logger.info(f"Indexed {total_indexed}/{len(valid_docs)} documents with {len(models_to_use)} models")
return total_indexed
def get_stats(self) -> dict[str, Any]:
"""Get statistics about collections and available models.
Returns:
Dict with collection stats and model availability
"""
stats = {
"config": {
"qdrant_host": self.config.qdrant_host,
"qdrant_port": self.config.qdrant_port,
"model_preference": [m.value for m in self.config.model_preference],
"openai_available": bool(self.config.openai_api_key),
},
"collections": {},
}
for collection_name in [self.config.institutions_collection, self.config.persons_collection]:
try:
info = self.qdrant_client.get_collection(collection_name)
available_models = self.get_available_models(collection_name)
selected_model = self.select_model(collection_name)
stats["collections"][collection_name] = {
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value if info.status else "unknown",
"available_models": [m.value for m in available_models],
"selected_model": selected_model.value if selected_model else None,
}
except Exception as e:
stats["collections"][collection_name] = {"error": str(e)}
return stats
def close(self):
"""Close all connections."""
if self._qdrant_client:
self._qdrant_client.close()
self._qdrant_client = None
self._st_models.clear()
self._available_models.clear()
self._uses_named_vectors.clear()
def create_multi_embedding_retriever(use_production: bool | None = None) -> MultiEmbeddingRetriever:
"""Factory function to create a MultiEmbeddingRetriever.
Args:
use_production: If True, connect to production Qdrant.
Defaults to QDRANT_USE_PRODUCTION env var.
Returns:
Configured MultiEmbeddingRetriever instance
"""
if use_production is None:
use_production = os.getenv("QDRANT_USE_PRODUCTION", "").lower() in ("true", "1", "yes")
if use_production:
config = MultiEmbeddingConfig(
qdrant_host=os.getenv("QDRANT_PROD_HOST", "bronhouder.nl"),
qdrant_port=443,
qdrant_https=True,
qdrant_prefix=os.getenv("QDRANT_PROD_PREFIX", "qdrant"),
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
else:
config = MultiEmbeddingConfig(
qdrant_host=os.getenv("QDRANT_HOST", "localhost"),
qdrant_port=int(os.getenv("QDRANT_PORT", "6333")),
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
return MultiEmbeddingRetriever(config)