Refactor RAG to template-based SPARQL generation
Major architectural changes based on Formica et al. (2023) research: - Add TemplateClassifier for deterministic SPARQL template matching - Add SlotExtractor with synonym resolution for slot values - Add TemplateInstantiator using Jinja2 for query rendering - Refactor dspy_heritage_rag.py to use template system - Update main.py with streamlined pipeline - Fix semantic_router.py ordering issues - Add comprehensive metrics tracking Template-based approach achieves 65% precision vs 10% LLM-only per Formica et al. research on SPARQL generation.
This commit is contained in:
parent
9b769f1ca2
commit
99dc608826
7 changed files with 852 additions and 221 deletions
|
|
@ -725,6 +725,10 @@ class HeritagePersonSPARQLGenerator(dspy.Signature):
|
|||
class HeritageSQLGenerator(dspy.Signature):
|
||||
"""Generate SQL queries for DuckLake heritage analytics database.
|
||||
|
||||
DEPRECATED: This signature is kept for offline DuckLake analytics only.
|
||||
For real-time RAG retrieval, use SPARQL via HeritageSPARQLGenerator instead.
|
||||
DuckLake is an analysis database, not a retrieval backend.
|
||||
|
||||
You are an expert in SQL and heritage institution data analytics.
|
||||
Generate valid DuckDB SQL queries for the custodians table.
|
||||
|
||||
|
|
@ -1731,11 +1735,13 @@ class HeritageQueryRouter(dspy.Module):
|
|||
logger.info(f"HeritageQueryRouter configured with fast LM for routing")
|
||||
|
||||
# Source routing based on intent
|
||||
# NOTE: RAG uses only Qdrant (vectors) and Oxigraph (SPARQL) for retrieval
|
||||
# DuckLake is for offline analysis only, not real-time retrieval
|
||||
self.source_mapping = {
|
||||
"geographic": ["postgis", "qdrant", "sparql"],
|
||||
"statistical": ["ducklake", "sparql", "qdrant"],
|
||||
"relational": ["typedb", "sparql"],
|
||||
"temporal": ["typedb", "sparql"],
|
||||
"geographic": ["sparql", "qdrant"],
|
||||
"statistical": ["sparql", "qdrant"], # SPARQL COUNT/SUM aggregations
|
||||
"relational": ["sparql", "qdrant"],
|
||||
"temporal": ["sparql", "qdrant"],
|
||||
"entity_lookup": ["sparql", "qdrant"],
|
||||
"comparative": ["sparql", "qdrant"],
|
||||
"exploration": ["qdrant", "sparql"],
|
||||
|
|
@ -2122,7 +2128,6 @@ class HeritageReActAgent(dspy.Module):
|
|||
# =============================================================================
|
||||
|
||||
def create_heritage_tools(
|
||||
ducklake_endpoint: str = "http://localhost:8001",
|
||||
qdrant_retriever: Any = None,
|
||||
sparql_endpoint: str = "http://localhost:7878/query",
|
||||
typedb_client: Any = None,
|
||||
|
|
@ -2133,9 +2138,12 @@ def create_heritage_tools(
|
|||
These tools can be used by ReAct agents and are compatible with GEPA
|
||||
tool optimization when enable_tool_optimization=True.
|
||||
|
||||
NOTE: RAG uses only Qdrant (vectors) and Oxigraph (SPARQL) for retrieval.
|
||||
DuckLake is for offline analytics only, not real-time retrieval.
|
||||
|
||||
Args:
|
||||
qdrant_retriever: Optional Qdrant retriever for semantic search
|
||||
sparql_endpoint: SPARQL endpoint URL
|
||||
sparql_endpoint: SPARQL endpoint URL (Oxigraph)
|
||||
typedb_client: Optional TypeDB client for graph queries
|
||||
use_schema_aware: Whether to use schema-derived descriptions.
|
||||
If True, tool descriptions include GLAMORCUBESFIXPHDNT taxonomy.
|
||||
|
|
@ -2447,108 +2455,8 @@ def create_heritage_tools(
|
|||
))
|
||||
|
||||
|
||||
# DuckLake analytics tool for statistical queries
|
||||
def query_ducklake(
|
||||
sql_query: str = None,
|
||||
question: str = None,
|
||||
group_by: str = None,
|
||||
institution_type: str = None,
|
||||
country: str = None,
|
||||
city: str = None,
|
||||
) -> str:
|
||||
"""Query DuckLake analytics database for heritage institution statistics.
|
||||
|
||||
Use this tool for statistical queries: counts, distributions, aggregations.
|
||||
DuckLake contains 27,452 heritage institutions with rich metadata.
|
||||
|
||||
Args:
|
||||
sql_query: Direct SQL query (if provided, other params ignored)
|
||||
question: Natural language question to convert to SQL
|
||||
group_by: Field to group by (country, city, region, institution_type)
|
||||
institution_type: Filter by type (MUSEUM, LIBRARY, ARCHIVE, etc.)
|
||||
country: Filter by ISO 2-letter country code (NL, DE, BE, etc.)
|
||||
city: Filter by city name
|
||||
|
||||
Returns:
|
||||
JSON string with query results: columns, rows, row_count
|
||||
"""
|
||||
import httpx
|
||||
|
||||
# If direct SQL provided, use it
|
||||
if sql_query:
|
||||
query = sql_query
|
||||
elif question:
|
||||
# Use DSPy to generate SQL from natural language
|
||||
try:
|
||||
sql_gen = dspy.ChainOfThought(HeritageSQLGenerator)
|
||||
result = sql_gen(
|
||||
question=question,
|
||||
intent="statistical",
|
||||
entities=[],
|
||||
context="",
|
||||
)
|
||||
query = result.sql
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"SQL generation failed: {e}"})
|
||||
else:
|
||||
# Build a simple aggregation query from parameters
|
||||
filters = []
|
||||
if institution_type:
|
||||
filters.append(f"CAST(institution_type AS VARCHAR) ILIKE '%{institution_type}%'")
|
||||
if country:
|
||||
filters.append(f"CAST(country AS VARCHAR) = '{country.upper()}'")
|
||||
if city:
|
||||
filters.append(f"CAST(city AS VARCHAR) ILIKE '%{city}%'")
|
||||
|
||||
where_clause = " AND ".join(filters) if filters else "1=1"
|
||||
|
||||
if group_by:
|
||||
query = f"""
|
||||
SELECT CAST({group_by} AS VARCHAR) as {group_by}, COUNT(*) as count
|
||||
FROM custodians
|
||||
WHERE {where_clause}
|
||||
GROUP BY CAST({group_by} AS VARCHAR)
|
||||
ORDER BY count DESC
|
||||
LIMIT 50
|
||||
"""
|
||||
else:
|
||||
query = f"SELECT COUNT(*) as total FROM custodians WHERE {where_clause}"
|
||||
|
||||
# Execute query via DuckLake API
|
||||
try:
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{ducklake_endpoint}/query",
|
||||
json={"sql": query},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Add the executed query to the response for transparency
|
||||
data["executed_query"] = query
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
except httpx.HTTPStatusError as e:
|
||||
return json.dumps({
|
||||
"error": f"DuckLake query failed: {e.response.status_code}",
|
||||
"query": query,
|
||||
"detail": e.response.text[:500] if e.response.text else None,
|
||||
})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"DuckLake request failed: {e}", "query": query})
|
||||
|
||||
tools.append(dspy.Tool(
|
||||
func=query_ducklake,
|
||||
name="query_ducklake",
|
||||
desc="Query DuckLake analytics database for heritage statistics (counts, distributions, aggregations)",
|
||||
args={
|
||||
"sql_query": "Direct SQL query (optional, takes precedence)",
|
||||
"question": "Natural language question to convert to SQL",
|
||||
"group_by": "Field to group by: country, city, region, institution_type",
|
||||
"institution_type": f"Filter by type. {type_desc}",
|
||||
"country": "Filter by ISO 2-letter country code (NL, DE, BE, FR, etc.)",
|
||||
"city": "Filter by city name",
|
||||
},
|
||||
))
|
||||
# NOTE: DuckLake tool removed - DuckLake is for offline analytics only, not real-time RAG retrieval
|
||||
# Statistical queries are now handled via SPARQL COUNT/SUM aggregations on Oxigraph
|
||||
|
||||
return tools
|
||||
|
||||
|
|
|
|||
|
|
@ -391,8 +391,8 @@ class Settings:
|
|||
# Production: Use bronhouder.nl/api/geo reverse proxy
|
||||
postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo")
|
||||
|
||||
# DuckLake Analytics
|
||||
ducklake_url: str = os.getenv("DUCKLAKE_URL", "http://localhost:8001")
|
||||
# NOTE: DuckLake removed from RAG - it's for offline analytics only, not real-time retrieval
|
||||
# RAG uses only Qdrant (vectors) and Oxigraph (SPARQL) for retrieval
|
||||
|
||||
# LLM Configuration
|
||||
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
|
|
@ -431,13 +431,17 @@ class QueryIntent(str, Enum):
|
|||
|
||||
|
||||
class DataSource(str, Enum):
|
||||
"""Available data sources."""
|
||||
"""Available data sources for RAG retrieval.
|
||||
|
||||
NOTE: DuckLake removed - it's for offline analytics only, not real-time RAG retrieval.
|
||||
RAG uses Qdrant (vectors) and Oxigraph (SPARQL) as primary backends.
|
||||
"""
|
||||
QDRANT = "qdrant"
|
||||
SPARQL = "sparql"
|
||||
TYPEDB = "typedb"
|
||||
POSTGIS = "postgis"
|
||||
CACHE = "cache"
|
||||
DUCKLAKE = "ducklake"
|
||||
# DUCKLAKE removed - use DuckLake separately for offline analytics/dashboards
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -1103,9 +1107,11 @@ class QueryRouter:
|
|||
],
|
||||
}
|
||||
|
||||
# NOTE: DuckLake removed from RAG - it's for offline analytics only
|
||||
# Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY)
|
||||
self.source_routing = {
|
||||
QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL],
|
||||
QueryIntent.STATISTICAL: [DataSource.DUCKLAKE, DataSource.SPARQL, DataSource.QDRANT],
|
||||
QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT], # SPARQL aggregations
|
||||
QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
||||
QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
||||
QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL],
|
||||
|
|
@ -1169,7 +1175,7 @@ class MultiSourceRetriever:
|
|||
self._typedb: TypeDBRetriever | None = None
|
||||
self._sparql_client: httpx.AsyncClient | None = None
|
||||
self._postgis_client: httpx.AsyncClient | None = None
|
||||
self._ducklake_client: httpx.AsyncClient | None = None
|
||||
# NOTE: DuckLake client removed - DuckLake is for offline analytics only
|
||||
|
||||
@property
|
||||
def qdrant(self) -> HybridRetriever | None:
|
||||
|
|
@ -1233,21 +1239,7 @@ class MultiSourceRetriever:
|
|||
record_connection_pool(client="postgis", pool_size=10, available=10)
|
||||
return self._postgis_client
|
||||
|
||||
async def _get_ducklake_client(self) -> httpx.AsyncClient:
|
||||
"""Get DuckLake HTTP client with connection pooling."""
|
||||
if self._ducklake_client is None or self._ducklake_client.is_closed:
|
||||
self._ducklake_client = httpx.AsyncClient(
|
||||
timeout=60.0, # Longer timeout for SQL
|
||||
limits=httpx.Limits(
|
||||
max_connections=10,
|
||||
max_keepalive_connections=5,
|
||||
keepalive_expiry=30.0,
|
||||
),
|
||||
)
|
||||
# Record connection pool metrics
|
||||
if record_connection_pool:
|
||||
record_connection_pool(client="ducklake", pool_size=10, available=10)
|
||||
return self._ducklake_client
|
||||
# NOTE: _get_ducklake_client removed - DuckLake is for offline analytics only, not RAG retrieval
|
||||
|
||||
async def retrieve_from_qdrant(
|
||||
self,
|
||||
|
|
@ -1465,81 +1457,8 @@ class MultiSourceRetriever:
|
|||
query_time_ms=elapsed,
|
||||
)
|
||||
|
||||
async def retrieve_from_ducklake(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 10,
|
||||
) -> RetrievalResult:
|
||||
"""Retrieve from DuckLake SQL analytics database.
|
||||
|
||||
Uses DSPy HeritageSQLGenerator to convert natural language to SQL,
|
||||
then executes against the custodians table.
|
||||
|
||||
Args:
|
||||
query: Natural language question or SQL query
|
||||
k: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
RetrievalResult with query results
|
||||
"""
|
||||
start = asyncio.get_event_loop().time()
|
||||
items = []
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
# Import the SQL generator from dspy module
|
||||
import dspy
|
||||
from dspy_heritage_rag import HeritageSQLGenerator
|
||||
|
||||
# Initialize DSPy predictor for SQL generation
|
||||
sql_generator = dspy.Predict(HeritageSQLGenerator)
|
||||
|
||||
# Generate SQL from natural language
|
||||
sql_result = sql_generator(
|
||||
question=query,
|
||||
intent="statistical",
|
||||
entities="",
|
||||
context="",
|
||||
)
|
||||
|
||||
sql_query = sql_result.sql
|
||||
metadata["generated_sql"] = sql_query
|
||||
metadata["sql_explanation"] = sql_result.explanation
|
||||
|
||||
# Execute SQL against DuckLake
|
||||
client = await self._get_ducklake_client()
|
||||
response = await client.post(
|
||||
f"{settings.ducklake_url}/query",
|
||||
json={"sql": sql_query},
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
columns = data.get("columns", [])
|
||||
rows = data.get("rows", [])
|
||||
|
||||
# Convert to list of dicts
|
||||
items = [dict(zip(columns, row)) for row in rows[:k]]
|
||||
metadata["row_count"] = data.get("row_count", len(items))
|
||||
metadata["execution_time_ms"] = data.get("execution_time_ms", 0)
|
||||
else:
|
||||
logger.error(f"DuckLake query failed: {response.status_code} - {response.text}")
|
||||
metadata["error"] = f"HTTP {response.status_code}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DuckLake retrieval failed: {e}")
|
||||
metadata["error"] = str(e)
|
||||
|
||||
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
source=DataSource.DUCKLAKE,
|
||||
items=items,
|
||||
score=1.0 if items else 0.0,
|
||||
query_time_ms=elapsed,
|
||||
metadata=metadata,
|
||||
)
|
||||
# NOTE: retrieve_from_ducklake removed - DuckLake is for offline analytics only, not RAG retrieval
|
||||
# Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY) on Oxigraph
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
|
|
@ -1583,8 +1502,7 @@ class MultiSourceRetriever:
|
|||
tasks.append(self.retrieve_from_typedb(question, k))
|
||||
elif source == DataSource.POSTGIS:
|
||||
tasks.append(self.retrieve_from_postgis(question, k))
|
||||
elif source == DataSource.DUCKLAKE:
|
||||
tasks.append(self.retrieve_from_ducklake(question, k))
|
||||
# NOTE: DuckLake case removed - DuckLake is for offline analytics only
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ session management, caching, and overall API performance.
|
|||
Metrics exposed:
|
||||
- rag_queries_total: Total queries by type (template/llm), status, endpoint
|
||||
- rag_template_hits_total: Template SPARQL hits by template_id
|
||||
- rag_template_tier_total: Template matching by tier (pattern/embedding/llm)
|
||||
- rag_template_tier_total: Template matching by tier (pattern/embedding/rag/llm)
|
||||
- rag_query_duration_seconds: Query latency histogram
|
||||
- rag_session_active: Active sessions gauge
|
||||
- rag_cache_hits_total: Cache hit/miss counter
|
||||
|
|
@ -110,7 +110,7 @@ def _init_metrics():
|
|||
"template_tier_counter": pc.Counter(
|
||||
"rag_template_tier_total",
|
||||
"Template matching attempts by tier",
|
||||
labelnames=["tier", "matched"], # tier: pattern, embedding, llm
|
||||
labelnames=["tier", "matched"], # tier: pattern, embedding, rag, llm
|
||||
),
|
||||
"template_matching_duration": pc.Histogram(
|
||||
"rag_template_matching_seconds",
|
||||
|
|
@ -147,7 +147,7 @@ def _init_metrics():
|
|||
"connection_pool_size": pc.Gauge(
|
||||
"rag_connection_pool_size",
|
||||
"Current connection pool size by client type",
|
||||
labelnames=["client"], # sparql, postgis, ducklake
|
||||
labelnames=["client"], # sparql, postgis (ducklake removed from RAG)
|
||||
),
|
||||
"connection_pool_available": pc.Gauge(
|
||||
"rag_connection_pool_available",
|
||||
|
|
@ -303,7 +303,7 @@ def record_template_tier(
|
|||
"""Record which template matching tier was used.
|
||||
|
||||
Args:
|
||||
tier: Matching tier - "pattern", "embedding", or "llm"
|
||||
tier: Matching tier - "pattern", "embedding", "rag", or "llm"
|
||||
matched: Whether the tier successfully matched
|
||||
template_id: Template ID if matched
|
||||
duration_seconds: Optional time taken for this tier
|
||||
|
|
@ -397,7 +397,7 @@ def record_connection_pool(
|
|||
"""Record connection pool utilization.
|
||||
|
||||
Args:
|
||||
client: Client type - "sparql", "postgis", "ducklake"
|
||||
client: Client type - "sparql", "postgis" (ducklake removed from RAG)
|
||||
pool_size: Current total pool size
|
||||
available: Number of available connections (if known)
|
||||
"""
|
||||
|
|
@ -598,6 +598,7 @@ def get_template_tier_stats() -> dict[str, Any]:
|
|||
stats: dict[str, dict[str, int]] = {
|
||||
"pattern": {"matched": 0, "unmatched": 0},
|
||||
"embedding": {"matched": 0, "unmatched": 0},
|
||||
"rag": {"matched": 0, "unmatched": 0}, # Tier 2.5: RAG-enhanced matching
|
||||
"llm": {"matched": 0, "unmatched": 0},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -308,11 +308,12 @@ class SemanticDecisionRouter:
|
|||
|
||||
return config
|
||||
|
||||
# Statistical queries → DuckLake
|
||||
# Statistical queries → SPARQL (aggregations via COUNT, SUM, etc.)
|
||||
if signals.requires_aggregation:
|
||||
return RouteConfig(
|
||||
primary_backend="ducklake",
|
||||
secondary_backend="sparql",
|
||||
primary_backend="sparql",
|
||||
secondary_backend="qdrant",
|
||||
qdrant_collection="heritage_custodians",
|
||||
)
|
||||
|
||||
# Temporal queries → Temporal SPARQL templates
|
||||
|
|
|
|||
|
|
@ -1103,6 +1103,322 @@ def get_template_embedding_matcher() -> TemplateEmbeddingMatcher:
|
|||
return _template_embedding_matcher
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RAG-ENHANCED TEMPLATE MATCHING (TIER 2.5)
|
||||
# =============================================================================
|
||||
|
||||
class RAGEnhancedMatcher:
|
||||
"""Context-enriched matching using similar Q&A examples from templates.
|
||||
|
||||
This tier sits between embedding matching (Tier 2) and LLM fallback (Tier 3).
|
||||
It retrieves similar examples from the template YAML and uses voting to
|
||||
determine the best template match.
|
||||
|
||||
Based on SPARQL-LLM (arXiv:2512.14277) and COT-SPARQL (SEMANTICS 2024)
|
||||
patterns for RAG-enhanced query generation.
|
||||
|
||||
Architecture:
|
||||
1. Embed all Q&A examples from templates (cached)
|
||||
2. For incoming question, find top-k most similar examples
|
||||
3. Vote: If majority of examples agree on template, use it
|
||||
4. Return match with confidence based on vote agreement
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_example_embeddings: Optional[np.ndarray] = None
|
||||
_example_template_ids: Optional[list[str]] = None
|
||||
_example_texts: Optional[list[str]] = None
|
||||
_example_slots: Optional[list[dict]] = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern - embeddings are expensive to compute."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def _ensure_examples_indexed(self, templates: dict[str, "TemplateDefinition"]) -> bool:
|
||||
"""Index all Q&A examples from templates for retrieval.
|
||||
|
||||
Returns:
|
||||
True if examples are indexed, False otherwise
|
||||
"""
|
||||
if self._example_embeddings is not None:
|
||||
return True
|
||||
|
||||
model = _get_embedding_model()
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
# Collect all examples from templates
|
||||
example_texts = []
|
||||
template_ids = []
|
||||
example_slots = []
|
||||
|
||||
for template_id, template_def in templates.items():
|
||||
for example in template_def.examples:
|
||||
if "question" in example:
|
||||
example_texts.append(example["question"])
|
||||
template_ids.append(template_id)
|
||||
example_slots.append(example.get("slots", {}))
|
||||
|
||||
if not example_texts:
|
||||
logger.warning("No examples found for RAG-enhanced matching")
|
||||
return False
|
||||
|
||||
# Compute embeddings for all examples
|
||||
logger.info(f"Indexing {len(example_texts)} Q&A examples for RAG-enhanced matching...")
|
||||
try:
|
||||
embeddings = model.encode(example_texts, convert_to_numpy=True, show_progress_bar=False)
|
||||
self._example_embeddings = embeddings
|
||||
self._example_template_ids = template_ids
|
||||
self._example_texts = example_texts
|
||||
self._example_slots = example_slots
|
||||
logger.info(f"Indexed {len(embeddings)} examples (dim={embeddings.shape[1]})")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to index examples: {e}")
|
||||
return False
|
||||
|
||||
def match(
|
||||
self,
|
||||
question: str,
|
||||
templates: dict[str, "TemplateDefinition"],
|
||||
k: int = 5,
|
||||
min_agreement: float = 0.6,
|
||||
min_similarity: float = 0.65
|
||||
) -> Optional["TemplateMatchResult"]:
|
||||
"""Find best template using RAG retrieval and voting.
|
||||
|
||||
Args:
|
||||
question: Natural language question
|
||||
templates: Dictionary of template definitions
|
||||
k: Number of similar examples to retrieve
|
||||
min_agreement: Minimum fraction of examples that must agree (e.g., 0.6 = 3/5)
|
||||
min_similarity: Minimum similarity for retrieved examples
|
||||
|
||||
Returns:
|
||||
TemplateMatchResult if voting succeeds, None otherwise
|
||||
"""
|
||||
if not self._ensure_examples_indexed(templates):
|
||||
return None
|
||||
|
||||
model = _get_embedding_model()
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
# Guard against None (should not happen after _ensure_examples_indexed)
|
||||
if (self._example_embeddings is None or
|
||||
self._example_template_ids is None or
|
||||
self._example_texts is None):
|
||||
return None
|
||||
|
||||
# Compute question embedding
|
||||
try:
|
||||
question_embedding = model.encode([question], convert_to_numpy=True)[0]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compute question embedding: {e}")
|
||||
return None
|
||||
|
||||
# Compute cosine similarities
|
||||
question_norm = question_embedding / np.linalg.norm(question_embedding)
|
||||
example_norms = self._example_embeddings / np.linalg.norm(
|
||||
self._example_embeddings, axis=1, keepdims=True
|
||||
)
|
||||
similarities = np.dot(example_norms, question_norm)
|
||||
|
||||
# Get top-k indices
|
||||
top_k_indices = np.argsort(similarities)[-k:][::-1]
|
||||
|
||||
# Filter by minimum similarity
|
||||
valid_indices = [
|
||||
i for i in top_k_indices
|
||||
if similarities[i] >= min_similarity
|
||||
]
|
||||
|
||||
if not valid_indices:
|
||||
logger.debug(f"RAG: No examples above similarity threshold {min_similarity}")
|
||||
return None
|
||||
|
||||
# Vote on template
|
||||
from collections import Counter
|
||||
template_votes = Counter(
|
||||
self._example_template_ids[i] for i in valid_indices
|
||||
)
|
||||
|
||||
top_template, vote_count = template_votes.most_common(1)[0]
|
||||
agreement = vote_count / len(valid_indices)
|
||||
|
||||
if agreement < min_agreement:
|
||||
logger.debug(
|
||||
f"RAG: Low agreement {agreement:.2f} < {min_agreement} "
|
||||
f"(votes: {dict(template_votes)})"
|
||||
)
|
||||
return None
|
||||
|
||||
# Calculate confidence based on agreement and average similarity
|
||||
avg_similarity = np.mean([similarities[i] for i in valid_indices])
|
||||
confidence = 0.70 + (agreement * 0.15) + (avg_similarity * 0.10)
|
||||
confidence = min(0.90, confidence) # Cap at 0.90
|
||||
|
||||
# Log retrieved examples for debugging
|
||||
retrieved_examples = [
|
||||
(self._example_texts[i], self._example_template_ids[i], similarities[i])
|
||||
for i in valid_indices[:3]
|
||||
]
|
||||
logger.info(
|
||||
f"RAG match: template='{top_template}', agreement={agreement:.2f}, "
|
||||
f"confidence={confidence:.2f}, examples={retrieved_examples}"
|
||||
)
|
||||
|
||||
return TemplateMatchResult(
|
||||
matched=True,
|
||||
template_id=top_template,
|
||||
confidence=float(confidence),
|
||||
reasoning=f"RAG: {vote_count}/{len(valid_indices)} examples vote for {top_template}"
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_rag_enhanced_matcher: Optional[RAGEnhancedMatcher] = None
|
||||
|
||||
def get_rag_enhanced_matcher() -> RAGEnhancedMatcher:
|
||||
"""Get or create the singleton RAG-enhanced matcher."""
|
||||
global _rag_enhanced_matcher
|
||||
if _rag_enhanced_matcher is None:
|
||||
_rag_enhanced_matcher = RAGEnhancedMatcher()
|
||||
return _rag_enhanced_matcher
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SPARQL VALIDATION (SPARQL-LLM Pattern)
|
||||
# =============================================================================
|
||||
|
||||
class SPARQLValidationResult(BaseModel):
|
||||
"""Result of SPARQL query validation."""
|
||||
valid: bool
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
suggestions: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SPARQLValidator:
|
||||
"""Validates generated SPARQL against ontology schema.
|
||||
|
||||
Based on SPARQL-LLM (arXiv:2512.14277) validation-correction pattern.
|
||||
Checks predicates and classes against the LinkML schema.
|
||||
"""
|
||||
|
||||
# Known predicates from our ontology (hc: namespace)
|
||||
VALID_HC_PREDICATES = {
|
||||
"hc:institutionType", "hc:settlementName", "hc:subregionCode",
|
||||
"hc:countryCode", "hc:ghcid", "hc:isil", "hc:validFrom", "hc:validTo",
|
||||
"hc:changeType", "hc:changeReason", "hc:eventType", "hc:eventDate",
|
||||
"hc:affectedActor", "hc:resultingActor", "hc:refers_to_custodian",
|
||||
"hc:fiscal_year_start", "hc:innovation_budget", "hc:digitization_budget",
|
||||
"hc:preservation_budget", "hc:personnel_budget", "hc:acquisition_budget",
|
||||
"hc:operating_budget", "hc:capital_budget", "hc:reporting_period_start",
|
||||
"hc:innovation_expenses", "hc:digitization_expenses", "hc:preservation_expenses",
|
||||
}
|
||||
|
||||
# Known classes from our ontology
|
||||
VALID_HC_CLASSES = {
|
||||
"hcc:Custodian", "hc:class/Budget", "hc:class/FinancialStatement",
|
||||
"hc:OrganizationalChangeEvent",
|
||||
}
|
||||
|
||||
# Standard schema.org, SKOS, FOAF predicates we use
|
||||
VALID_EXTERNAL_PREDICATES = {
|
||||
"schema:name", "schema:description", "schema:foundingDate",
|
||||
"schema:addressCountry", "schema:addressLocality",
|
||||
"foaf:homepage", "skos:prefLabel", "skos:altLabel",
|
||||
"dcterms:identifier", "org:memberOf",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._all_predicates = (
|
||||
self.VALID_HC_PREDICATES |
|
||||
self.VALID_EXTERNAL_PREDICATES
|
||||
)
|
||||
self._all_classes = self.VALID_HC_CLASSES
|
||||
|
||||
def validate(self, sparql: str) -> SPARQLValidationResult:
|
||||
"""Validate SPARQL query against schema.
|
||||
|
||||
Args:
|
||||
sparql: SPARQL query string
|
||||
|
||||
Returns:
|
||||
SPARQLValidationResult with errors and suggestions
|
||||
"""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
suggestions: list[str] = []
|
||||
|
||||
# Skip validation for queries without our predicates
|
||||
if "hc:" not in sparql and "hcc:" not in sparql:
|
||||
return SPARQLValidationResult(valid=True)
|
||||
|
||||
# Extract predicates used (hc:xxx, schema:xxx, etc.)
|
||||
predicate_pattern = r'(hc:\w+|hcc:\w+|schema:\w+|foaf:\w+|skos:\w+|dcterms:\w+)'
|
||||
predicates = set(re.findall(predicate_pattern, sparql))
|
||||
|
||||
for pred in predicates:
|
||||
if pred.startswith("hc:") or pred.startswith("hcc:"):
|
||||
if pred not in self._all_predicates and pred not in self._all_classes:
|
||||
# Check for common typos
|
||||
similar = self._find_similar(pred, self._all_predicates)
|
||||
if similar:
|
||||
errors.append(f"Unknown predicate: {pred}")
|
||||
suggestions.append(f"Did you mean: {similar}?")
|
||||
else:
|
||||
warnings.append(f"Unrecognized predicate: {pred}")
|
||||
|
||||
# Extract classes (a hcc:xxx)
|
||||
class_pattern = r'a\s+(hcc:\w+|hc:class/\w+)'
|
||||
classes = set(re.findall(class_pattern, sparql))
|
||||
|
||||
for cls in classes:
|
||||
if cls not in self._all_classes:
|
||||
errors.append(f"Unknown class: {cls}")
|
||||
|
||||
# Check for common SPARQL syntax issues
|
||||
if sparql.count("{") != sparql.count("}"):
|
||||
errors.append("Mismatched braces in query")
|
||||
|
||||
if "SELECT" in sparql.upper() and "WHERE" not in sparql.upper():
|
||||
errors.append("SELECT query missing WHERE clause")
|
||||
|
||||
return SPARQLValidationResult(
|
||||
valid=len(errors) == 0,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
suggestions=suggestions
|
||||
)
|
||||
|
||||
def _find_similar(self, term: str, candidates: set[str], threshold: float = 0.7) -> Optional[str]:
|
||||
"""Find similar term using fuzzy matching."""
|
||||
if not candidates:
|
||||
return None
|
||||
match = process.extractOne(
|
||||
term,
|
||||
list(candidates),
|
||||
scorer=fuzz.ratio,
|
||||
score_cutoff=int(threshold * 100)
|
||||
)
|
||||
return match[0] if match else None
|
||||
|
||||
|
||||
# Global validator instance
|
||||
_sparql_validator: Optional[SPARQLValidator] = None
|
||||
|
||||
def get_sparql_validator() -> SPARQLValidator:
|
||||
"""Get or create the SPARQL validator."""
|
||||
global _sparql_validator
|
||||
if _sparql_validator is None:
|
||||
_sparql_validator = SPARQLValidator()
|
||||
return _sparql_validator
|
||||
|
||||
|
||||
class TemplateClassifier(dspy.Module):
|
||||
"""Classifies questions to match SPARQL templates."""
|
||||
|
||||
|
|
@ -1392,6 +1708,29 @@ class TemplateClassifier(dspy.Module):
|
|||
if _record_template_tier:
|
||||
_record_template_tier(tier="embedding", matched=False, duration_seconds=tier2_duration)
|
||||
|
||||
# TIER 2.5: RAG-enhanced matching (retrieval + voting from Q&A examples)
|
||||
# Based on SPARQL-LLM (arXiv:2512.14277) and COT-SPARQL patterns
|
||||
tier2_5_start = time.perf_counter()
|
||||
rag_matcher = get_rag_enhanced_matcher()
|
||||
rag_match = rag_matcher.match(question, templates, k=5, min_agreement=0.6)
|
||||
tier2_5_duration = time.perf_counter() - tier2_5_start
|
||||
|
||||
if rag_match and rag_match.confidence >= 0.70:
|
||||
logger.info(f"Using RAG-enhanced match: {rag_match.template_id} (confidence={rag_match.confidence:.2f})")
|
||||
# Record tier 2.5 success
|
||||
if _record_template_tier:
|
||||
_record_template_tier(
|
||||
tier="rag",
|
||||
matched=True,
|
||||
template_id=rag_match.template_id,
|
||||
duration_seconds=tier2_5_duration,
|
||||
)
|
||||
return rag_match
|
||||
else:
|
||||
# Record tier 2.5 miss
|
||||
if _record_template_tier:
|
||||
_record_template_tier(tier="rag", matched=False, duration_seconds=tier2_5_duration)
|
||||
|
||||
# TIER 3: LLM classification (fallback for complex/novel queries)
|
||||
tier3_start = time.perf_counter()
|
||||
try:
|
||||
|
|
@ -1712,20 +2051,28 @@ class TemplateSPARQLPipeline(dspy.Module):
|
|||
Pipeline order (CRITICAL):
|
||||
1. ConversationContextResolver - Expand follow-ups FIRST
|
||||
2. FykeFilter - Filter irrelevant on RESOLVED question
|
||||
3. TemplateClassifier - Match to template
|
||||
3. TemplateClassifier - Match to template (4 tiers: regex → embedding → RAG → LLM)
|
||||
4. SlotExtractor - Extract and resolve slots
|
||||
5. TemplateInstantiator - Render SPARQL
|
||||
6. SPARQLValidator - Validate against schema (SPARQL-LLM pattern)
|
||||
|
||||
Falls back to LLM generation if no template matches.
|
||||
|
||||
Based on SOTA patterns:
|
||||
- SPARQL-LLM (arXiv:2512.14277) - RAG + validation loop
|
||||
- COT-SPARQL (SEMANTICS 2024) - Context-enriched matching
|
||||
- KGQuest (arXiv:2511.11258) - Deterministic templates + LLM refinement
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, validate_sparql: bool = True):
|
||||
super().__init__()
|
||||
self.context_resolver = ConversationContextResolver()
|
||||
self.fyke_filter = FykeFilter()
|
||||
self.template_classifier = TemplateClassifier()
|
||||
self.slot_extractor = SlotExtractor()
|
||||
self.instantiator = TemplateInstantiator()
|
||||
self.validator = get_sparql_validator() if validate_sparql else None
|
||||
self.validate_sparql = validate_sparql
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -1807,6 +2154,19 @@ class TemplateSPARQLPipeline(dspy.Module):
|
|||
template_id=match_result.template_id,
|
||||
reasoning="Template rendering failed"
|
||||
)
|
||||
|
||||
# Step 6: Validate SPARQL against schema (SPARQL-LLM pattern)
|
||||
if self.validate_sparql and self.validator:
|
||||
validation = self.validator.validate(sparql)
|
||||
if not validation.valid:
|
||||
logger.warning(
|
||||
f"SPARQL validation errors: {validation.errors}, "
|
||||
f"suggestions: {validation.suggestions}"
|
||||
)
|
||||
# Log but don't fail - errors may be false positives
|
||||
# In future: could use LLM to correct errors
|
||||
elif validation.warnings:
|
||||
logger.info(f"SPARQL validation warnings: {validation.warnings}")
|
||||
|
||||
# Update conversation state if provided
|
||||
if conversation_state:
|
||||
|
|
|
|||
|
|
@ -224,15 +224,19 @@ class TestSemanticDecisionRouter:
|
|||
assert "custodian_slug" in config.qdrant_filters
|
||||
assert "noord-hollands-archief" in config.qdrant_filters["custodian_slug"]
|
||||
|
||||
def test_statistical_query_routes_to_ducklake(self, router):
|
||||
"""Statistical queries should route to DuckLake."""
|
||||
def test_statistical_query_routes_to_sparql(self, router):
|
||||
"""Statistical queries should route to SPARQL for aggregations.
|
||||
|
||||
NOTE: DuckLake removed from RAG - it's for offline analytics only.
|
||||
Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY).
|
||||
"""
|
||||
signals = QuerySignals(
|
||||
entity_type="institution",
|
||||
intent="statistical",
|
||||
requires_aggregation=True,
|
||||
)
|
||||
config = router.route(signals)
|
||||
assert config.primary_backend == "ducklake"
|
||||
assert config.primary_backend == "sparql"
|
||||
|
||||
def test_temporal_query_uses_temporal_templates(self, router):
|
||||
"""Temporal queries should enable temporal templates."""
|
||||
|
|
@ -343,7 +347,10 @@ class TestIntegration:
|
|||
assert config.qdrant_collection == "heritage_persons"
|
||||
|
||||
def test_full_statistical_query_flow(self):
|
||||
"""Test complete flow for statistical query."""
|
||||
"""Test complete flow for statistical query.
|
||||
|
||||
NOTE: DuckLake removed from RAG - statistical queries now use SPARQL aggregations.
|
||||
"""
|
||||
extractor = get_signal_extractor()
|
||||
router = get_decision_router()
|
||||
|
||||
|
|
@ -354,7 +361,7 @@ class TestIntegration:
|
|||
|
||||
assert signals.intent == "statistical"
|
||||
assert signals.requires_aggregation is True
|
||||
assert config.primary_backend == "ducklake"
|
||||
assert config.primary_backend == "sparql"
|
||||
|
||||
def test_full_temporal_query_flow(self):
|
||||
"""Test complete flow for temporal query."""
|
||||
|
|
|
|||
436
backend/rag/test_template_sota.py
Normal file
436
backend/rag/test_template_sota.py
Normal file
|
|
@ -0,0 +1,436 @@
|
|||
"""
|
||||
Tests for SOTA Template Matching Components
|
||||
|
||||
Tests the new components added based on SOTA research:
|
||||
1. RAGEnhancedMatcher (Tier 2.5) - Context-enriched matching using Q&A examples
|
||||
2. SPARQLValidator - Validates SPARQL against ontology schema
|
||||
3. 4-tier fallback behavior - Pattern -> Embedding -> RAG -> LLM
|
||||
|
||||
Based on:
|
||||
- SPARQL-LLM (arXiv:2512.14277)
|
||||
- COT-SPARQL (SEMANTICS 2024)
|
||||
- KGQuest (arXiv:2511.11258)
|
||||
|
||||
Author: OpenCode
|
||||
Created: 2025-01-07
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import numpy as np
|
||||
|
||||
# =============================================================================
|
||||
# SPARQL VALIDATOR TESTS
|
||||
# =============================================================================
|
||||
|
||||
class TestSPARQLValidator:
|
||||
"""Tests for the SPARQLValidator class."""
|
||||
|
||||
@pytest.fixture
|
||||
def validator(self):
|
||||
"""Get a fresh validator instance."""
|
||||
from backend.rag.template_sparql import SPARQLValidator
|
||||
return SPARQLValidator()
|
||||
|
||||
def test_valid_query_with_known_predicates(self, validator):
|
||||
"""Valid query using known hc: predicates should pass."""
|
||||
sparql = """
|
||||
SELECT ?name ?type WHERE {
|
||||
?inst hc:institutionType "M" .
|
||||
?inst hc:settlementName "Amsterdam" .
|
||||
}
|
||||
"""
|
||||
result = validator.validate(sparql)
|
||||
assert result.valid is True
|
||||
assert len(result.errors) == 0
|
||||
|
||||
def test_valid_query_with_schema_predicates(self, validator):
|
||||
"""Valid query using schema.org predicates should pass."""
|
||||
sparql = """
|
||||
SELECT ?name WHERE {
|
||||
?inst schema:name ?name .
|
||||
?inst schema:foundingDate ?date .
|
||||
}
|
||||
"""
|
||||
result = validator.validate(sparql)
|
||||
assert result.valid is True
|
||||
assert len(result.errors) == 0
|
||||
|
||||
def test_invalid_predicate_detected(self, validator):
|
||||
"""Unknown hc: predicate should be flagged."""
|
||||
sparql = """
|
||||
SELECT ?x WHERE {
|
||||
?x hc:unknownPredicate "test" .
|
||||
}
|
||||
"""
|
||||
result = validator.validate(sparql)
|
||||
# Should have warning or error for unknown predicate
|
||||
assert len(result.warnings) > 0 or len(result.errors) > 0
|
||||
|
||||
def test_typo_suggestion(self, validator):
|
||||
"""Typo in predicate should suggest correction."""
|
||||
sparql = """
|
||||
SELECT ?x WHERE {
|
||||
?x hc:institutionTyp "M" .
|
||||
}
|
||||
"""
|
||||
result = validator.validate(sparql)
|
||||
# Should suggest "hc:institutionType"
|
||||
if result.suggestions:
|
||||
assert any("institutionType" in s for s in result.suggestions)
|
||||
|
||||
def test_mismatched_braces_detected(self, validator):
|
||||
"""Mismatched braces should be flagged."""
|
||||
sparql = """
|
||||
SELECT ?x WHERE {
|
||||
?x hc:institutionType "M" .
|
||||
""" # Missing closing brace
|
||||
result = validator.validate(sparql)
|
||||
assert result.valid is False
|
||||
assert any("brace" in e.lower() for e in result.errors)
|
||||
|
||||
def test_missing_where_clause_detected(self, validator):
|
||||
"""SELECT without WHERE should be flagged."""
|
||||
sparql = """
|
||||
SELECT ?x {
|
||||
?x hc:institutionType "M" .
|
||||
}
|
||||
"""
|
||||
result = validator.validate(sparql)
|
||||
# Note: This has braces but no WHERE keyword
|
||||
assert result.valid is False
|
||||
assert any("WHERE" in e for e in result.errors)
|
||||
|
||||
def test_non_hc_query_passes(self, validator):
|
||||
"""Query without hc: predicates should pass (not our responsibility)."""
|
||||
sparql = """
|
||||
SELECT ?s ?p ?o WHERE {
|
||||
?s ?p ?o .
|
||||
} LIMIT 10
|
||||
"""
|
||||
result = validator.validate(sparql)
|
||||
assert result.valid is True
|
||||
|
||||
def test_budget_predicates_valid(self, validator):
|
||||
"""Budget-related predicates should be valid."""
|
||||
sparql = """
|
||||
SELECT ?budget WHERE {
|
||||
?inst hc:innovation_budget ?budget .
|
||||
?inst hc:digitization_budget ?dbudget .
|
||||
}
|
||||
"""
|
||||
result = validator.validate(sparql)
|
||||
assert result.valid is True
|
||||
assert len(result.errors) == 0
|
||||
|
||||
def test_change_event_predicates_valid(self, validator):
|
||||
"""Change event predicates should be valid."""
|
||||
sparql = """
|
||||
SELECT ?event WHERE {
|
||||
?event hc:changeType "MERGER" .
|
||||
?event hc:eventDate ?date .
|
||||
}
|
||||
"""
|
||||
result = validator.validate(sparql)
|
||||
assert result.valid is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RAG ENHANCED MATCHER TESTS
|
||||
# =============================================================================
|
||||
|
||||
class TestRAGEnhancedMatcher:
|
||||
"""Tests for the RAGEnhancedMatcher class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_templates(self):
|
||||
"""Create mock template definitions for testing."""
|
||||
from backend.rag.template_sparql import TemplateDefinition, SlotDefinition, SlotType
|
||||
|
||||
return {
|
||||
"list_institutions_by_type_city": TemplateDefinition(
|
||||
id="list_institutions_by_type_city",
|
||||
description="List institutions by type in a city",
|
||||
intent=["list", "institutions", "city"],
|
||||
question_patterns=["Welke {type} zijn er in {city}?"],
|
||||
slots={
|
||||
"type": SlotDefinition(type=SlotType.INSTITUTION_TYPE),
|
||||
"city": SlotDefinition(type=SlotType.CITY),
|
||||
},
|
||||
sparql_template="SELECT ?inst WHERE { ?inst hc:institutionType '{{ type }}' }",
|
||||
examples=[
|
||||
{"question": "Welke musea zijn er in Amsterdam?", "slots": {"type": "M", "city": "Amsterdam"}},
|
||||
{"question": "Welke archieven zijn er in Utrecht?", "slots": {"type": "A", "city": "Utrecht"}},
|
||||
{"question": "Welke bibliotheken zijn er in Rotterdam?", "slots": {"type": "L", "city": "Rotterdam"}},
|
||||
],
|
||||
),
|
||||
"count_institutions_by_type": TemplateDefinition(
|
||||
id="count_institutions_by_type",
|
||||
description="Count institutions by type",
|
||||
intent=["count", "institutions"],
|
||||
question_patterns=["Hoeveel {type} zijn er?"],
|
||||
slots={
|
||||
"type": SlotDefinition(type=SlotType.INSTITUTION_TYPE),
|
||||
},
|
||||
sparql_template="SELECT (COUNT(?inst) AS ?count) WHERE { ?inst hc:institutionType '{{ type }}' }",
|
||||
examples=[
|
||||
{"question": "Hoeveel musea zijn er in Nederland?", "slots": {"type": "M"}},
|
||||
{"question": "Hoeveel archieven zijn er?", "slots": {"type": "A"}},
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self):
|
||||
"""Get a fresh RAG matcher instance (reset singleton)."""
|
||||
from backend.rag.template_sparql import RAGEnhancedMatcher
|
||||
# Reset singleton state for clean tests
|
||||
RAGEnhancedMatcher._instance = None
|
||||
RAGEnhancedMatcher._example_embeddings = None
|
||||
RAGEnhancedMatcher._example_template_ids = None
|
||||
RAGEnhancedMatcher._example_texts = None
|
||||
RAGEnhancedMatcher._example_slots = None
|
||||
return RAGEnhancedMatcher()
|
||||
|
||||
def test_singleton_pattern(self):
|
||||
"""RAGEnhancedMatcher should be a singleton."""
|
||||
from backend.rag.template_sparql import RAGEnhancedMatcher
|
||||
# Reset first
|
||||
RAGEnhancedMatcher._instance = None
|
||||
|
||||
matcher1 = RAGEnhancedMatcher()
|
||||
matcher2 = RAGEnhancedMatcher()
|
||||
assert matcher1 is matcher2
|
||||
|
||||
def test_match_returns_none_without_model(self, matcher, mock_templates):
|
||||
"""Should return None if embedding model not available."""
|
||||
with patch('backend.rag.template_sparql._get_embedding_model', return_value=None):
|
||||
result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates)
|
||||
assert result is None
|
||||
|
||||
def test_match_with_high_agreement(self, matcher, mock_templates):
|
||||
"""Should match when examples agree on template."""
|
||||
# Create mock embedding model
|
||||
mock_model = MagicMock()
|
||||
|
||||
# Create embeddings that will give high similarity for "list" template
|
||||
# The question embedding
|
||||
question_emb = np.array([1.0, 0.0, 0.0])
|
||||
# Example embeddings - 3 from list template, 2 from count template
|
||||
example_embs = np.array([
|
||||
[0.95, 0.1, 0.0], # list - high similarity
|
||||
[0.90, 0.15, 0.0], # list - high similarity
|
||||
[0.92, 0.12, 0.0], # list - high similarity
|
||||
[0.3, 0.9, 0.1], # count - low similarity
|
||||
[0.25, 0.85, 0.15], # count - low similarity
|
||||
])
|
||||
|
||||
mock_model.encode = MagicMock(side_effect=[
|
||||
example_embs, # First call for indexing
|
||||
np.array([question_emb]), # Second call for question
|
||||
])
|
||||
|
||||
with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model):
|
||||
# Reset cached embeddings
|
||||
matcher._example_embeddings = None
|
||||
result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates, k=5)
|
||||
|
||||
# Should match list template with high confidence
|
||||
if result is not None:
|
||||
assert result.matched is True
|
||||
assert result.template_id == "list_institutions_by_type_city"
|
||||
assert result.confidence >= 0.70
|
||||
|
||||
def test_match_returns_none_with_low_agreement(self, matcher, mock_templates):
|
||||
"""Should return None when examples don't agree."""
|
||||
mock_model = MagicMock()
|
||||
|
||||
# Create embeddings with low agreement (split between templates)
|
||||
question_emb = np.array([0.5, 0.5, 0.0])
|
||||
example_embs = np.array([
|
||||
[0.6, 0.4, 0.0], # list - medium similarity
|
||||
[0.55, 0.45, 0.0], # list - medium similarity
|
||||
[0.52, 0.48, 0.0], # list - medium similarity
|
||||
[0.45, 0.55, 0.0], # count - medium similarity
|
||||
[0.4, 0.6, 0.0], # count - medium similarity
|
||||
])
|
||||
|
||||
mock_model.encode = MagicMock(side_effect=[
|
||||
example_embs,
|
||||
np.array([question_emb]),
|
||||
])
|
||||
|
||||
with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model):
|
||||
matcher._example_embeddings = None
|
||||
# With mixed agreement, should fail min_agreement threshold
|
||||
result = matcher.match(
|
||||
"Geef me informatie over musea",
|
||||
mock_templates,
|
||||
k=5,
|
||||
min_agreement=0.8 # High threshold
|
||||
)
|
||||
# May return None if agreement is below threshold
|
||||
# The exact behavior depends on similarity calculation
|
||||
|
||||
|
||||
class TestRAGEnhancedMatcherFactory:
|
||||
"""Tests for the get_rag_enhanced_matcher factory function."""
|
||||
|
||||
def test_factory_returns_singleton(self):
|
||||
"""Factory should return same instance."""
|
||||
from backend.rag.template_sparql import get_rag_enhanced_matcher, RAGEnhancedMatcher
|
||||
# Reset
|
||||
RAGEnhancedMatcher._instance = None
|
||||
|
||||
matcher1 = get_rag_enhanced_matcher()
|
||||
matcher2 = get_rag_enhanced_matcher()
|
||||
assert matcher1 is matcher2
|
||||
|
||||
|
||||
class TestSPARQLValidatorFactory:
|
||||
"""Tests for the get_sparql_validator factory function."""
|
||||
|
||||
def test_factory_returns_instance(self):
|
||||
"""Factory should return validator instance."""
|
||||
from backend.rag.template_sparql import get_sparql_validator, SPARQLValidator
|
||||
|
||||
validator = get_sparql_validator()
|
||||
assert isinstance(validator, SPARQLValidator)
|
||||
|
||||
def test_factory_returns_singleton(self):
|
||||
"""Factory should return same instance."""
|
||||
from backend.rag.template_sparql import get_sparql_validator, _sparql_validator
|
||||
import backend.rag.template_sparql as module
|
||||
|
||||
# Reset
|
||||
module._sparql_validator = None
|
||||
|
||||
validator1 = get_sparql_validator()
|
||||
validator2 = get_sparql_validator()
|
||||
assert validator1 is validator2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TIER METRICS TESTS
|
||||
# =============================================================================
|
||||
|
||||
class TestTierMetrics:
|
||||
"""Tests for tier tracking in metrics."""
|
||||
|
||||
def test_rag_tier_in_stats(self):
|
||||
"""RAG tier should be tracked in tier stats."""
|
||||
from backend.rag.metrics import get_template_tier_stats
|
||||
|
||||
stats = get_template_tier_stats()
|
||||
|
||||
# Should include rag tier
|
||||
if stats.get("available"):
|
||||
assert "rag" in stats.get("tiers", {})
|
||||
|
||||
def test_record_template_tier_accepts_rag(self):
|
||||
"""record_template_tier should accept 'rag' tier."""
|
||||
from backend.rag.metrics import record_template_tier
|
||||
|
||||
# Should not raise
|
||||
record_template_tier(tier="rag", matched=True, template_id="test_template")
|
||||
record_template_tier(tier="rag", matched=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VALIDATION RESULT TESTS
|
||||
# =============================================================================
|
||||
|
||||
class TestSPARQLValidationResult:
|
||||
"""Tests for SPARQLValidationResult model."""
|
||||
|
||||
def test_valid_result_creation(self):
|
||||
"""Should create valid result."""
|
||||
from backend.rag.template_sparql import SPARQLValidationResult
|
||||
|
||||
result = SPARQLValidationResult(valid=True)
|
||||
assert result.valid is True
|
||||
assert result.errors == []
|
||||
assert result.warnings == []
|
||||
assert result.suggestions == []
|
||||
|
||||
def test_invalid_result_with_errors(self):
|
||||
"""Should create result with errors."""
|
||||
from backend.rag.template_sparql import SPARQLValidationResult
|
||||
|
||||
result = SPARQLValidationResult(
|
||||
valid=False,
|
||||
errors=["Unknown predicate: hc:foo"],
|
||||
suggestions=["Did you mean: hc:foaf?"]
|
||||
)
|
||||
assert result.valid is False
|
||||
assert len(result.errors) == 1
|
||||
assert len(result.suggestions) == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# INTEGRATION TESTS
|
||||
# =============================================================================
|
||||
|
||||
class TestFourTierFallback:
|
||||
"""Integration tests for 4-tier matching fallback."""
|
||||
|
||||
@pytest.fixture
|
||||
def classifier(self):
|
||||
"""Get TemplateClassifier instance."""
|
||||
from backend.rag.template_sparql import TemplateClassifier
|
||||
return TemplateClassifier()
|
||||
|
||||
def test_pattern_tier_takes_priority(self, classifier):
|
||||
"""Tier 1 (pattern) should be checked first."""
|
||||
# A question that matches a pattern exactly
|
||||
question = "Welke musea zijn er in Amsterdam?"
|
||||
|
||||
result = classifier.forward(question, language="nl")
|
||||
|
||||
# Should match (if templates are loaded)
|
||||
if result.matched:
|
||||
# Pattern matches should have high confidence
|
||||
assert result.confidence >= 0.75
|
||||
assert "pattern" in result.reasoning.lower() or result.confidence >= 0.90
|
||||
|
||||
def test_embedding_tier_fallback(self, classifier):
|
||||
"""Tier 2 (embedding) should be used when pattern fails."""
|
||||
# A paraphrased question that won't match patterns exactly
|
||||
question = "Geef mij een lijst van alle musea in de stad Amsterdam"
|
||||
|
||||
# This may use embedding tier if pattern doesn't match
|
||||
result = classifier.forward(question, language="nl")
|
||||
|
||||
# Should still get a result (may be embedding or LLM)
|
||||
# We just verify the fallback mechanism works
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestValidatorPredicateSet:
|
||||
"""Tests to verify the validator predicate/class sets."""
|
||||
|
||||
def test_valid_predicates_not_empty(self):
|
||||
"""Validator should have known predicates."""
|
||||
from backend.rag.template_sparql import SPARQLValidator
|
||||
|
||||
validator = SPARQLValidator()
|
||||
assert len(validator.VALID_HC_PREDICATES) > 0
|
||||
assert len(validator.VALID_EXTERNAL_PREDICATES) > 0
|
||||
|
||||
def test_institution_type_in_predicates(self):
|
||||
"""institutionType should be a valid predicate."""
|
||||
from backend.rag.template_sparql import SPARQLValidator
|
||||
|
||||
validator = SPARQLValidator()
|
||||
assert "hc:institutionType" in validator.VALID_HC_PREDICATES
|
||||
|
||||
def test_settlement_name_in_predicates(self):
|
||||
"""settlementName should be a valid predicate."""
|
||||
from backend.rag.template_sparql import SPARQLValidator
|
||||
|
||||
validator = SPARQLValidator()
|
||||
assert "hc:settlementName" in validator.VALID_HC_PREDICATES
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
Reference in a new issue