- Introduced `has_api_version`, `has_appellation_language`, `has_appellation_type`, `has_appellation_value`, `has_applicable_country`, `has_application_deadline`, `has_application_opening_date`, `has_appraisal_note`, `has_approval_date`, `has_archdiocese_name`, `has_architectural_style`, `has_archival_reference`, `has_archive_description`, `has_archive_memento_uri`, `has_archive_name`, `has_archive_path`, `has_archive_search_score`, `has_arrangement`, `has_arrangement_level`, `has_arrangement_note`, `has_articles_archival_stage`, `has_articles_document_format`, `has_articles_document_url`, `has_articles_of_association`, `has_or_had_altitude`, `has_or_had_annotation`, `has_or_had_arrangement`, `has_or_had_document`, `has_or_had_reason`, `has_or_had_style`, `is_or_was_amended_through`, `is_or_was_approved_on`, `is_or_was_archived_as`, `is_or_was_due_on`, `is_or_was_opened_on`, and `is_or_was_used_in` slots. - Each slot includes detailed descriptions, range specifications, and appropriate mappings to existing ontologies.
846 lines
30 KiB
Python
846 lines
30 KiB
Python
"""
|
|
Multi-Embedding Retriever for Heritage Data
|
|
|
|
Supports multiple embedding models using Qdrant's named vectors feature.
|
|
This enables:
|
|
- A/B testing different embedding models
|
|
- Cost optimization (cheap local embeddings vs paid API embeddings)
|
|
- Gradual migration between embedding models
|
|
- Fallback when one model is unavailable
|
|
|
|
Supported Embedding Models:
|
|
- openai_1536: text-embedding-3-small (1536-dim, $0.02/1M tokens)
|
|
- minilm_384: all-MiniLM-L6-v2 (384-dim, free/local)
|
|
- bge_768: bge-base-en-v1.5 (768-dim, free/local, high quality)
|
|
|
|
Collection Architecture:
|
|
Each collection has named vectors for each embedding model:
|
|
|
|
heritage_custodians:
|
|
vectors:
|
|
"openai_1536": VectorParams(size=1536)
|
|
"minilm_384": VectorParams(size=384)
|
|
payload: {name, ghcid, institution_type, ...}
|
|
|
|
heritage_persons:
|
|
vectors:
|
|
"openai_1536": VectorParams(size=1536)
|
|
"minilm_384": VectorParams(size=384)
|
|
payload: {name, headline, custodian_name, ...}
|
|
|
|
Usage:
|
|
retriever = MultiEmbeddingRetriever()
|
|
|
|
# Search with default model (auto-select based on availability)
|
|
results = retriever.search("museums in Amsterdam")
|
|
|
|
# Search with specific model
|
|
results = retriever.search("museums in Amsterdam", using="minilm_384")
|
|
|
|
# A/B test comparison
|
|
comparison = retriever.compare_models("museums in Amsterdam")
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any, Literal
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingModel(str, Enum):
|
|
"""Supported embedding models with their configurations."""
|
|
|
|
OPENAI_1536 = "openai_1536"
|
|
MINILM_384 = "minilm_384"
|
|
BGE_768 = "bge_768"
|
|
|
|
@property
|
|
def dimension(self) -> int:
|
|
"""Get the vector dimension for this model."""
|
|
dims = {
|
|
"openai_1536": 1536,
|
|
"minilm_384": 384,
|
|
"bge_768": 768,
|
|
}
|
|
return dims[self.value]
|
|
|
|
@property
|
|
def model_name(self) -> str:
|
|
"""Get the actual model name for loading."""
|
|
names = {
|
|
"openai_1536": "text-embedding-3-small",
|
|
"minilm_384": "all-MiniLM-L6-v2",
|
|
"bge_768": "BAAI/bge-base-en-v1.5",
|
|
}
|
|
return names[self.value]
|
|
|
|
@property
|
|
def is_local(self) -> bool:
|
|
"""Check if this model runs locally (no API calls)."""
|
|
return self.value in ("minilm_384", "bge_768")
|
|
|
|
@property
|
|
def cost_per_1m_tokens(self) -> float:
|
|
"""Approximate cost per 1M tokens (0 for local models)."""
|
|
costs = {
|
|
"openai_1536": 0.02,
|
|
"minilm_384": 0.0,
|
|
"bge_768": 0.0,
|
|
}
|
|
return costs[self.value]
|
|
|
|
|
|
@dataclass
|
|
class MultiEmbeddingConfig:
|
|
"""Configuration for multi-embedding retriever."""
|
|
|
|
# Qdrant connection
|
|
qdrant_host: str = "localhost"
|
|
qdrant_port: int = 6333
|
|
qdrant_https: bool = False
|
|
qdrant_prefix: str | None = None
|
|
|
|
# API keys
|
|
openai_api_key: str | None = None
|
|
|
|
# Default embedding model preference order
|
|
# First available model is used if no explicit model is specified
|
|
model_preference: list[EmbeddingModel] = field(default_factory=lambda: [
|
|
EmbeddingModel.MINILM_384, # Free, fast, good quality
|
|
EmbeddingModel.OPENAI_1536, # Higher quality, paid
|
|
EmbeddingModel.BGE_768, # Free, high quality, slower
|
|
])
|
|
|
|
# Collection names
|
|
institutions_collection: str = "heritage_custodians"
|
|
persons_collection: str = "heritage_persons"
|
|
|
|
# Search defaults
|
|
default_k: int = 10
|
|
|
|
|
|
class MultiEmbeddingRetriever:
|
|
"""Retriever supporting multiple embedding models via Qdrant named vectors.
|
|
|
|
This class manages multiple embedding models and allows searching with
|
|
any available model. It handles:
|
|
- Model lazy-loading
|
|
- Automatic model selection based on availability
|
|
- Named vector creation and search
|
|
- A/B testing between models
|
|
"""
|
|
|
|
def __init__(self, config: MultiEmbeddingConfig | None = None):
|
|
"""Initialize multi-embedding retriever.
|
|
|
|
Args:
|
|
config: Configuration options. If None, uses environment variables.
|
|
"""
|
|
self.config = config or self._config_from_env()
|
|
|
|
# Lazy-loaded clients
|
|
self._qdrant_client = None
|
|
self._openai_client = None
|
|
self._st_models: dict[str, Any] = {} # Sentence transformer models
|
|
|
|
# Track available models per collection
|
|
self._available_models: dict[str, set[EmbeddingModel]] = {}
|
|
|
|
# Track whether each collection uses named vectors (vs single unnamed vector)
|
|
self._uses_named_vectors: dict[str, bool] = {}
|
|
|
|
logger.info(f"MultiEmbeddingRetriever initialized with preference: {[m.value for m in self.config.model_preference]}")
|
|
|
|
@staticmethod
|
|
def _config_from_env() -> MultiEmbeddingConfig:
|
|
"""Create configuration from environment variables."""
|
|
use_production = os.getenv("QDRANT_USE_PRODUCTION", "false").lower() == "true"
|
|
|
|
if use_production:
|
|
return MultiEmbeddingConfig(
|
|
qdrant_host=os.getenv("QDRANT_PROD_HOST", "bronhouder.nl"),
|
|
qdrant_port=443,
|
|
qdrant_https=True,
|
|
qdrant_prefix=os.getenv("QDRANT_PROD_PREFIX", "qdrant"),
|
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
|
)
|
|
else:
|
|
return MultiEmbeddingConfig(
|
|
qdrant_host=os.getenv("QDRANT_HOST", "localhost"),
|
|
qdrant_port=int(os.getenv("QDRANT_PORT", "6333")),
|
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
|
)
|
|
|
|
@property
|
|
def qdrant_client(self):
|
|
"""Lazy-load Qdrant client."""
|
|
if self._qdrant_client is None:
|
|
from qdrant_client import QdrantClient
|
|
|
|
if self.config.qdrant_https:
|
|
self._qdrant_client = QdrantClient(
|
|
host=self.config.qdrant_host,
|
|
port=self.config.qdrant_port,
|
|
https=True,
|
|
prefix=self.config.qdrant_prefix,
|
|
prefer_grpc=False,
|
|
timeout=30,
|
|
)
|
|
logger.info(f"Connected to Qdrant: https://{self.config.qdrant_host}/{self.config.qdrant_prefix or ''}")
|
|
else:
|
|
self._qdrant_client = QdrantClient(
|
|
host=self.config.qdrant_host,
|
|
port=self.config.qdrant_port,
|
|
)
|
|
logger.info(f"Connected to Qdrant: {self.config.qdrant_host}:{self.config.qdrant_port}")
|
|
|
|
return self._qdrant_client
|
|
|
|
@property
|
|
def openai_client(self):
|
|
"""Lazy-load OpenAI client."""
|
|
if self._openai_client is None:
|
|
if not self.config.openai_api_key:
|
|
raise RuntimeError("OpenAI API key not configured")
|
|
|
|
import openai
|
|
self._openai_client = openai.OpenAI(api_key=self.config.openai_api_key)
|
|
|
|
return self._openai_client
|
|
|
|
def _load_sentence_transformer(self, model: EmbeddingModel) -> Any:
|
|
"""Lazy-load a sentence-transformers model.
|
|
|
|
Args:
|
|
model: The embedding model to load
|
|
|
|
Returns:
|
|
Loaded SentenceTransformer model
|
|
"""
|
|
if model.value not in self._st_models:
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
self._st_models[model.value] = SentenceTransformer(model.model_name)
|
|
logger.info(f"Loaded sentence-transformers model: {model.model_name}")
|
|
except ImportError:
|
|
raise RuntimeError(
|
|
"sentence-transformers not installed. Run: pip install sentence-transformers"
|
|
)
|
|
|
|
return self._st_models[model.value]
|
|
|
|
def get_embedding(self, text: str, model: EmbeddingModel) -> list[float]:
|
|
"""Get embedding vector for text using specified model.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
model: Embedding model to use
|
|
|
|
Returns:
|
|
Embedding vector as list of floats
|
|
"""
|
|
if model == EmbeddingModel.OPENAI_1536:
|
|
response = self.openai_client.embeddings.create(
|
|
input=text,
|
|
model=model.model_name,
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
elif model in (EmbeddingModel.MINILM_384, EmbeddingModel.BGE_768):
|
|
st_model = self._load_sentence_transformer(model)
|
|
embedding = st_model.encode(text)
|
|
return embedding.tolist()
|
|
|
|
else:
|
|
raise ValueError(f"Unknown embedding model: {model}")
|
|
|
|
def get_embeddings_batch(
|
|
self,
|
|
texts: list[str],
|
|
model: EmbeddingModel,
|
|
batch_size: int = 32,
|
|
) -> list[list[float]]:
|
|
"""Get embedding vectors for multiple texts.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
model: Embedding model to use
|
|
batch_size: Batch size for processing
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
if model == EmbeddingModel.OPENAI_1536:
|
|
# OpenAI batch API (max 2048 per request)
|
|
all_embeddings = []
|
|
for i in range(0, len(texts), 2048):
|
|
batch = texts[i:i + 2048]
|
|
response = self.openai_client.embeddings.create(
|
|
input=batch,
|
|
model=model.model_name,
|
|
)
|
|
batch_embeddings = [item.embedding for item in sorted(response.data, key=lambda x: x.index)]
|
|
all_embeddings.extend(batch_embeddings)
|
|
return all_embeddings
|
|
|
|
elif model in (EmbeddingModel.MINILM_384, EmbeddingModel.BGE_768):
|
|
st_model = self._load_sentence_transformer(model)
|
|
embeddings = st_model.encode(texts, batch_size=batch_size, show_progress_bar=len(texts) > 100)
|
|
return embeddings.tolist()
|
|
|
|
else:
|
|
raise ValueError(f"Unknown embedding model: {model}")
|
|
|
|
def get_available_models(self, collection_name: str) -> set[EmbeddingModel]:
|
|
"""Get the embedding models available for a collection.
|
|
|
|
Checks which named vectors exist in the collection.
|
|
For single-vector collections, returns models matching the dimension.
|
|
|
|
Args:
|
|
collection_name: Name of the Qdrant collection
|
|
|
|
Returns:
|
|
Set of available EmbeddingModel values
|
|
"""
|
|
if collection_name in self._available_models:
|
|
return self._available_models[collection_name]
|
|
|
|
try:
|
|
info = self.qdrant_client.get_collection(collection_name)
|
|
vectors_config = info.config.params.vectors
|
|
|
|
available = set()
|
|
uses_named_vectors = False
|
|
|
|
# Check for named vectors (dict of vector configs)
|
|
if isinstance(vectors_config, dict):
|
|
# Named vectors - each key is a vector name
|
|
uses_named_vectors = True
|
|
for vector_name in vectors_config.keys():
|
|
try:
|
|
model = EmbeddingModel(vector_name)
|
|
available.add(model)
|
|
except ValueError:
|
|
logger.warning(f"Unknown vector name in collection: {vector_name}")
|
|
else:
|
|
# Single unnamed vector - check dimension to find compatible model
|
|
# Note: This doesn't mean we can use `using=model.value` in queries
|
|
uses_named_vectors = False
|
|
if hasattr(vectors_config, 'size'):
|
|
dim = vectors_config.size
|
|
for model in EmbeddingModel:
|
|
if model.dimension == dim:
|
|
available.add(model)
|
|
|
|
# Store both available models and whether named vectors are used
|
|
self._available_models[collection_name] = available
|
|
self._uses_named_vectors[collection_name] = uses_named_vectors
|
|
|
|
if uses_named_vectors:
|
|
logger.info(f"Collection '{collection_name}' uses named vectors: {[m.value for m in available]}")
|
|
else:
|
|
logger.info(f"Collection '{collection_name}' uses single vector (compatible with: {[m.value for m in available]})")
|
|
|
|
return available
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Could not get available models for {collection_name}: {e}")
|
|
return set()
|
|
|
|
def uses_named_vectors(self, collection_name: str) -> bool:
|
|
"""Check if a collection uses named vectors (vs single unnamed vector).
|
|
|
|
Args:
|
|
collection_name: Name of the Qdrant collection
|
|
|
|
Returns:
|
|
True if collection has named vectors, False for single-vector collections
|
|
"""
|
|
# Ensure models are loaded (populates _uses_named_vectors)
|
|
self.get_available_models(collection_name)
|
|
return self._uses_named_vectors.get(collection_name, False)
|
|
|
|
def select_model(
|
|
self,
|
|
collection_name: str,
|
|
preferred: EmbeddingModel | None = None,
|
|
) -> EmbeddingModel | None:
|
|
"""Select the best available embedding model for a collection.
|
|
|
|
Args:
|
|
collection_name: Name of the collection
|
|
preferred: Preferred model (used if available)
|
|
|
|
Returns:
|
|
Selected EmbeddingModel or None if none available
|
|
"""
|
|
available = self.get_available_models(collection_name)
|
|
|
|
if not available:
|
|
# No named vectors - check if we can use any model
|
|
# This happens for legacy single-vector collections
|
|
try:
|
|
info = self.qdrant_client.get_collection(collection_name)
|
|
vectors_config = info.config.params.vectors
|
|
|
|
# Get vector dimension
|
|
dim = None
|
|
if hasattr(vectors_config, 'size'):
|
|
dim = vectors_config.size
|
|
elif isinstance(vectors_config, dict):
|
|
# Get first vector config
|
|
first_config = next(iter(vectors_config.values()), None)
|
|
if first_config and hasattr(first_config, 'size'):
|
|
dim = first_config.size
|
|
|
|
if dim:
|
|
for model in self.config.model_preference:
|
|
if model.dimension == dim:
|
|
return model
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
# If preferred model is available, use it
|
|
if preferred and preferred in available:
|
|
return preferred
|
|
|
|
# Otherwise, follow preference order
|
|
for model in self.config.model_preference:
|
|
if model in available:
|
|
# Check if model is usable (has API key if needed)
|
|
if model == EmbeddingModel.OPENAI_1536 and not self.config.openai_api_key:
|
|
continue
|
|
return model
|
|
|
|
return None
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
collection_name: str | None = None,
|
|
k: int | None = None,
|
|
using: EmbeddingModel | str | None = None,
|
|
filter_conditions: dict[str, Any] | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Search for similar documents using specified or auto-selected model.
|
|
|
|
Args:
|
|
query: Search query text
|
|
collection_name: Collection to search (default: institutions)
|
|
k: Number of results
|
|
using: Embedding model to use (auto-selected if None)
|
|
filter_conditions: Optional Qdrant filter conditions
|
|
|
|
Returns:
|
|
List of results with scores and payloads
|
|
"""
|
|
collection_name = collection_name or self.config.institutions_collection
|
|
k = k or self.config.default_k
|
|
|
|
# Resolve model
|
|
if using is not None:
|
|
if isinstance(using, str):
|
|
model = EmbeddingModel(using)
|
|
else:
|
|
model = using
|
|
else:
|
|
model = self.select_model(collection_name)
|
|
|
|
if model is None:
|
|
raise RuntimeError(f"No compatible embedding model for collection '{collection_name}'")
|
|
|
|
logger.info(f"Searching '{collection_name}' with {model.value}: {query[:50]}...")
|
|
|
|
# Get query embedding
|
|
query_vector = self.get_embedding(query, model)
|
|
|
|
# Build filter
|
|
from qdrant_client.http import models
|
|
|
|
query_filter = None
|
|
if filter_conditions:
|
|
query_filter = models.Filter(
|
|
must=[
|
|
models.FieldCondition(
|
|
key=key,
|
|
match=models.MatchValue(value=value),
|
|
)
|
|
for key, value in filter_conditions.items()
|
|
]
|
|
)
|
|
|
|
# Check if collection uses named vectors (not just single unnamed vector)
|
|
# Only pass `using=model.value` if collection has actual named vectors
|
|
use_named_vector = self.uses_named_vectors(collection_name)
|
|
|
|
# Search
|
|
if use_named_vector:
|
|
results = self.qdrant_client.query_points(
|
|
collection_name=collection_name,
|
|
query=query_vector,
|
|
using=model.value,
|
|
limit=k,
|
|
with_payload=True,
|
|
query_filter=query_filter,
|
|
)
|
|
else:
|
|
# Legacy single-vector search
|
|
results = self.qdrant_client.query_points(
|
|
collection_name=collection_name,
|
|
query=query_vector,
|
|
limit=k,
|
|
with_payload=True,
|
|
query_filter=query_filter,
|
|
)
|
|
|
|
return [
|
|
{
|
|
"id": str(point.id),
|
|
"score": point.score,
|
|
"model": model.value,
|
|
"payload": point.payload or {},
|
|
}
|
|
for point in results.points
|
|
]
|
|
|
|
def search_persons(
|
|
self,
|
|
query: str,
|
|
k: int | None = None,
|
|
using: EmbeddingModel | str | None = None,
|
|
filter_custodian: str | None = None,
|
|
only_heritage_relevant: bool = False,
|
|
only_wcms: bool = False,
|
|
) -> list[dict[str, Any]]:
|
|
"""Search for persons/staff in the heritage_persons collection.
|
|
|
|
Args:
|
|
query: Search query text
|
|
k: Number of results
|
|
using: Embedding model to use
|
|
filter_custodian: Optional custodian slug to filter by
|
|
only_heritage_relevant: Only return heritage-relevant staff
|
|
only_wcms: Only return WCMS-registered profiles (heritage sector users)
|
|
|
|
Returns:
|
|
List of person results with scores
|
|
"""
|
|
k = k or self.config.default_k
|
|
|
|
# Build filters
|
|
filters = {}
|
|
if filter_custodian:
|
|
filters["custodian_slug"] = filter_custodian
|
|
if only_wcms:
|
|
filters["has_wcms"] = True
|
|
|
|
# Search with over-fetch for post-filtering
|
|
results = self.search(
|
|
query=query,
|
|
collection_name=self.config.persons_collection,
|
|
k=k * 2,
|
|
using=using,
|
|
filter_conditions=filters if filters else None,
|
|
)
|
|
|
|
# Post-filter for heritage_relevant if needed
|
|
if only_heritage_relevant:
|
|
results = [r for r in results if r.get("payload", {}).get("heritage_relevant", False)]
|
|
|
|
# Format results
|
|
formatted = []
|
|
for r in results[:k]:
|
|
payload = r.get("payload", {})
|
|
formatted.append({
|
|
"person_id": payload.get("staff_id", "") or hashlib.md5(
|
|
f"{payload.get('custodian_slug', '')}:{payload.get('name', '')}".encode()
|
|
).hexdigest()[:16],
|
|
"name": payload.get("name", ""),
|
|
"headline": payload.get("headline"),
|
|
"custodian_name": payload.get("custodian_name"),
|
|
"custodian_slug": payload.get("custodian_slug"),
|
|
"location": payload.get("location"),
|
|
"heritage_relevant": payload.get("heritage_relevant", False),
|
|
"heritage_type": payload.get("heritage_type"),
|
|
"linkedin_url": payload.get("linkedin_url"),
|
|
"score": r["score"],
|
|
"model": r["model"],
|
|
})
|
|
|
|
return formatted
|
|
|
|
def compare_models(
|
|
self,
|
|
query: str,
|
|
collection_name: str | None = None,
|
|
k: int = 10,
|
|
models: list[EmbeddingModel] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""A/B test comparison of multiple embedding models.
|
|
|
|
Args:
|
|
query: Search query
|
|
collection_name: Collection to search
|
|
k: Number of results per model
|
|
models: Models to compare (default: all available)
|
|
|
|
Returns:
|
|
Dict with results per model and overlap analysis
|
|
"""
|
|
collection_name = collection_name or self.config.institutions_collection
|
|
|
|
# Determine which models to compare
|
|
available = self.get_available_models(collection_name)
|
|
if models:
|
|
models_to_test = [m for m in models if m in available]
|
|
else:
|
|
models_to_test = list(available)
|
|
|
|
if not models_to_test:
|
|
return {"error": "No models available for comparison"}
|
|
|
|
results = {}
|
|
all_ids = {}
|
|
|
|
for model in models_to_test:
|
|
try:
|
|
model_results = self.search(
|
|
query=query,
|
|
collection_name=collection_name,
|
|
k=k,
|
|
using=model,
|
|
)
|
|
results[model.value] = model_results
|
|
all_ids[model.value] = {r["id"] for r in model_results}
|
|
except Exception as e:
|
|
results[model.value] = {"error": str(e)}
|
|
all_ids[model.value] = set()
|
|
|
|
# Calculate overlap between models
|
|
overlap = {}
|
|
model_values = list(all_ids.keys())
|
|
for i, m1 in enumerate(model_values):
|
|
for m2 in model_values[i + 1:]:
|
|
if all_ids[m1] and all_ids[m2]:
|
|
intersection = all_ids[m1] & all_ids[m2]
|
|
union = all_ids[m1] | all_ids[m2]
|
|
jaccard = len(intersection) / len(union) if union else 0
|
|
overlap[f"{m1}_vs_{m2}"] = {
|
|
"jaccard_similarity": round(jaccard, 3),
|
|
"common_results": len(intersection),
|
|
"total_unique": len(union),
|
|
}
|
|
|
|
return {
|
|
"query": query,
|
|
"collection": collection_name,
|
|
"k": k,
|
|
"results": results,
|
|
"overlap_analysis": overlap,
|
|
}
|
|
|
|
def create_multi_embedding_collection(
|
|
self,
|
|
collection_name: str,
|
|
models: list[EmbeddingModel] | None = None,
|
|
) -> bool:
|
|
"""Create a new collection with named vectors for multiple embedding models.
|
|
|
|
Args:
|
|
collection_name: Name for the new collection
|
|
models: Embedding models to support (default: all)
|
|
|
|
Returns:
|
|
True if created successfully
|
|
"""
|
|
from qdrant_client.http.models import Distance, VectorParams
|
|
|
|
models = models or list(EmbeddingModel)
|
|
|
|
vectors_config = {
|
|
model.value: VectorParams(
|
|
size=model.dimension,
|
|
distance=Distance.COSINE,
|
|
)
|
|
for model in models
|
|
}
|
|
|
|
try:
|
|
self.qdrant_client.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=vectors_config,
|
|
)
|
|
logger.info(f"Created multi-embedding collection '{collection_name}' with {[m.value for m in models]}")
|
|
|
|
# Clear cache
|
|
self._available_models.pop(collection_name, None)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create collection: {e}")
|
|
return False
|
|
|
|
def add_documents_multi_embedding(
|
|
self,
|
|
documents: list[dict[str, Any]],
|
|
collection_name: str,
|
|
models: list[EmbeddingModel] | None = None,
|
|
batch_size: int = 100,
|
|
) -> int:
|
|
"""Add documents with embeddings from multiple models.
|
|
|
|
Args:
|
|
documents: List of documents with 'text' and optional 'metadata' fields
|
|
collection_name: Target collection
|
|
models: Models to generate embeddings for (default: all available)
|
|
batch_size: Batch size for processing
|
|
|
|
Returns:
|
|
Number of documents added
|
|
"""
|
|
from qdrant_client.http import models as qmodels
|
|
|
|
# Determine which models to use
|
|
available = self.get_available_models(collection_name)
|
|
if models:
|
|
models_to_use = [m for m in models if m in available]
|
|
else:
|
|
models_to_use = list(available)
|
|
|
|
if not models_to_use:
|
|
raise RuntimeError(f"No embedding models available for collection '{collection_name}'")
|
|
|
|
# Filter valid documents
|
|
valid_docs = [d for d in documents if d.get("text")]
|
|
total_indexed = 0
|
|
|
|
for i in range(0, len(valid_docs), batch_size):
|
|
batch = valid_docs[i:i + batch_size]
|
|
texts = [d["text"] for d in batch]
|
|
|
|
# Generate embeddings for each model
|
|
embeddings_by_model = {}
|
|
for model in models_to_use:
|
|
try:
|
|
embeddings_by_model[model] = self.get_embeddings_batch(texts, model)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get {model.value} embeddings: {e}")
|
|
|
|
if not embeddings_by_model:
|
|
continue
|
|
|
|
# Create points with named vectors
|
|
points = []
|
|
for j, doc in enumerate(batch):
|
|
text = doc["text"]
|
|
metadata = doc.get("metadata", {})
|
|
point_id = doc.get("id") or hashlib.md5(text.encode()).hexdigest()
|
|
|
|
# Build named vectors dict
|
|
vectors = {}
|
|
for model, model_embeddings in embeddings_by_model.items():
|
|
vectors[model.value] = model_embeddings[j]
|
|
|
|
points.append(qmodels.PointStruct(
|
|
id=point_id,
|
|
vector=vectors,
|
|
payload={
|
|
"text": text,
|
|
**metadata,
|
|
}
|
|
))
|
|
|
|
# Upsert batch
|
|
self.qdrant_client.upsert(
|
|
collection_name=collection_name,
|
|
points=points,
|
|
)
|
|
total_indexed += len(points)
|
|
logger.info(f"Indexed {total_indexed}/{len(valid_docs)} documents with {len(models_to_use)} models")
|
|
|
|
return total_indexed
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Get statistics about collections and available models.
|
|
|
|
Returns:
|
|
Dict with collection stats and model availability
|
|
"""
|
|
stats = {
|
|
"config": {
|
|
"qdrant_host": self.config.qdrant_host,
|
|
"qdrant_port": self.config.qdrant_port,
|
|
"model_preference": [m.value for m in self.config.model_preference],
|
|
"openai_available": bool(self.config.openai_api_key),
|
|
},
|
|
"collections": {},
|
|
}
|
|
|
|
for collection_name in [self.config.institutions_collection, self.config.persons_collection]:
|
|
try:
|
|
info = self.qdrant_client.get_collection(collection_name)
|
|
available_models = self.get_available_models(collection_name)
|
|
selected_model = self.select_model(collection_name)
|
|
|
|
stats["collections"][collection_name] = {
|
|
"vectors_count": info.vectors_count,
|
|
"points_count": info.points_count,
|
|
"status": info.status.value if info.status else "unknown",
|
|
"available_models": [m.value for m in available_models],
|
|
"selected_model": selected_model.value if selected_model else None,
|
|
}
|
|
except Exception as e:
|
|
stats["collections"][collection_name] = {"error": str(e)}
|
|
|
|
return stats
|
|
|
|
def close(self):
|
|
"""Close all connections."""
|
|
if self._qdrant_client:
|
|
self._qdrant_client.close()
|
|
self._qdrant_client = None
|
|
self._st_models.clear()
|
|
self._available_models.clear()
|
|
self._uses_named_vectors.clear()
|
|
|
|
|
|
def create_multi_embedding_retriever(use_production: bool | None = None) -> MultiEmbeddingRetriever:
|
|
"""Factory function to create a MultiEmbeddingRetriever.
|
|
|
|
Args:
|
|
use_production: If True, connect to production Qdrant.
|
|
Defaults to QDRANT_USE_PRODUCTION env var.
|
|
|
|
Returns:
|
|
Configured MultiEmbeddingRetriever instance
|
|
"""
|
|
if use_production is None:
|
|
use_production = os.getenv("QDRANT_USE_PRODUCTION", "").lower() in ("true", "1", "yes")
|
|
|
|
if use_production:
|
|
config = MultiEmbeddingConfig(
|
|
qdrant_host=os.getenv("QDRANT_PROD_HOST", "bronhouder.nl"),
|
|
qdrant_port=443,
|
|
qdrant_https=True,
|
|
qdrant_prefix=os.getenv("QDRANT_PROD_PREFIX", "qdrant"),
|
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
|
)
|
|
else:
|
|
config = MultiEmbeddingConfig(
|
|
qdrant_host=os.getenv("QDRANT_HOST", "localhost"),
|
|
qdrant_port=int(os.getenv("QDRANT_PORT", "6333")),
|
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
|
)
|
|
|
|
return MultiEmbeddingRetriever(config)
|