- Add 'databases' field to TemplateDefinition and TemplateMatchResult - Support values: 'oxigraph' (SPARQL/KG), 'qdrant' (vector search) - Add helper methods use_oxigraph() and use_qdrant() - Default to both databases for backward compatibility - Allows templates to skip vector search for factual/geographic queries
3312 lines
133 KiB
Python
3312 lines
133 KiB
Python
"""
|
|
Template-Based SPARQL Query Generation System
|
|
|
|
This module implements a template-based approach to SPARQL query generation,
|
|
replacing error-prone LLM-generated queries with deterministic, validated templates.
|
|
|
|
Architecture (CRITICAL ORDERING):
|
|
=================================
|
|
1. ConversationContextResolver (DSPy) - Resolves elliptical follow-ups FIRST
|
|
"En in Enschede?" → "Welke archieven zijn er in Enschede?"
|
|
|
|
2. FykeFilter (DSPy) - Filters irrelevant questions on RESOLVED input
|
|
⚠️ MUST operate on resolved question, not raw input!
|
|
|
|
3. TemplateClassifier (DSPy) - Matches to SPARQL template
|
|
4. SlotExtractor (DSPy) - Extracts slot values with synonym resolution
|
|
5. TemplateInstantiator (Jinja2) - Renders final SPARQL query
|
|
|
|
Based on:
|
|
- docs/plan/prompt-query_template_mapping/
|
|
- Formica et al. (2023) - Template-based SPARQL achieves 65% precision vs 10% LLM-only
|
|
- DSPy 2.6+ GEPA optimization
|
|
|
|
Author: OpenCode
|
|
Created: 2025-01-06
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Literal, Optional
|
|
|
|
import dspy
|
|
import numpy as np
|
|
from dspy import History
|
|
from jinja2 import Environment, BaseLoader
|
|
from pydantic import BaseModel, Field
|
|
from rapidfuzz import fuzz, process
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# =============================================================================
|
|
# METRICS INTEGRATION (Optional - No-Op if Not Available)
|
|
# =============================================================================
|
|
|
|
_record_template_tier = None
|
|
|
|
try:
|
|
from metrics import record_template_tier as _record_tier
|
|
_record_template_tier = _record_tier
|
|
logger.debug("Metrics module available for template tier tracking")
|
|
except ImportError:
|
|
logger.debug("Metrics module not available - template tier tracking disabled")
|
|
|
|
# =============================================================================
|
|
# CONFIGURATION
|
|
# =============================================================================
|
|
|
|
# Lazy-loaded sentence transformer model
|
|
_embedding_model = None
|
|
_embedding_model_name = "paraphrase-multilingual-MiniLM-L12-v2" # Multilingual, 384-dim
|
|
|
|
def _get_embedding_model():
|
|
"""Lazy-load the sentence transformer model."""
|
|
global _embedding_model
|
|
if _embedding_model is None:
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
logger.info(f"Loading embedding model: {_embedding_model_name}")
|
|
_embedding_model = SentenceTransformer(_embedding_model_name)
|
|
logger.info("Embedding model loaded successfully")
|
|
except ImportError:
|
|
logger.warning("sentence-transformers not installed, embedding matching disabled")
|
|
return None
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load embedding model: {e}")
|
|
return None
|
|
return _embedding_model
|
|
|
|
def _find_data_path(filename: str) -> Path:
|
|
"""Find data file in multiple possible locations.
|
|
|
|
Supports both local development (backend/rag/ → data/) and
|
|
server deployment (e.g., /opt/glam-backend/rag/data/).
|
|
"""
|
|
# Try relative to module location (local dev: backend/rag → glam/data)
|
|
module_dir = Path(__file__).parent
|
|
candidates = [
|
|
module_dir.parent.parent / "data" / filename, # Local: glam/data/
|
|
module_dir / "data" / filename, # Server: rag/data/
|
|
Path("/opt/glam-backend/rag/data") / filename, # Server explicit path
|
|
]
|
|
|
|
for candidate in candidates:
|
|
if candidate.exists():
|
|
return candidate
|
|
|
|
# Return first candidate (will report as missing in logs)
|
|
return candidates[0]
|
|
|
|
TEMPLATES_PATH = _find_data_path("sparql_templates.yaml")
|
|
VALIDATION_RULES_PATH = _find_data_path("validation/sparql_validation_rules.json")
|
|
|
|
# LinkML schema path for dynamic ontology loading
|
|
LINKML_SCHEMA_PATH = Path(__file__).parent.parent.parent / "schemas" / "20251121" / "linkml"
|
|
|
|
# Oxigraph SPARQL endpoint for KG queries
|
|
SPARQL_ENDPOINT = "http://localhost:7878/query"
|
|
|
|
|
|
# =============================================================================
|
|
# ONTOLOGY LOADER (Dynamic Schema Loading)
|
|
# =============================================================================
|
|
|
|
class OntologyLoader:
|
|
"""Dynamically loads predicates and valid values from LinkML schema and Knowledge Graph.
|
|
|
|
This class eliminates hardcoded heuristics by:
|
|
1. Loading slot_uri definitions from LinkML YAML files
|
|
2. Loading enum definitions from validation rules JSON
|
|
3. Querying the Knowledge Graph for valid enum values
|
|
4. Caching results for performance with TTL-based expiration
|
|
|
|
Architecture:
|
|
LinkML Schema → slot_uri predicates → SPARQLValidator
|
|
Validation Rules JSON → enums, mappings → SynonymResolver
|
|
Knowledge Graph → SPARQL queries → valid slot values
|
|
|
|
Caching:
|
|
- KG query results are cached with configurable TTL (default: 5 minutes)
|
|
- Use refresh_kg_cache() to force reload of KG data
|
|
- Use clear_cache() to reset all cached data
|
|
"""
|
|
|
|
_instance = None
|
|
_predicates: set[str] = set()
|
|
_external_predicates: set[str] = set()
|
|
_classes: set[str] = set()
|
|
_slot_values: dict[str, set[str]] = {}
|
|
_synonyms: dict[str, dict[str, str]] = {}
|
|
_enums: dict[str, dict] = {}
|
|
_institution_type_codes: set[str] = set()
|
|
_institution_type_mappings: dict[str, str] = {}
|
|
_subregion_mappings: dict[str, str] = {}
|
|
_country_mappings: dict[str, str] = {}
|
|
_loaded: bool = False
|
|
|
|
# TTL-based caching for KG queries
|
|
_kg_cache: dict[str, set[str]] = {} # query_hash → result set
|
|
_kg_cache_timestamps: dict[str, float] = {} # query_hash → timestamp
|
|
_kg_cache_ttl: float = 300.0 # 5 minutes default TTL
|
|
_kg_values_last_refresh: float = 0.0 # timestamp of last KG values refresh
|
|
|
|
def __new__(cls):
|
|
"""Singleton pattern."""
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def _load_from_validation_rules(self) -> None:
|
|
"""Load enums, mappings and predicates from validation rules JSON."""
|
|
if not VALIDATION_RULES_PATH.exists():
|
|
logger.warning(f"Validation rules not found: {VALIDATION_RULES_PATH}")
|
|
return
|
|
|
|
try:
|
|
with open(VALIDATION_RULES_PATH) as f:
|
|
rules = json.load(f)
|
|
|
|
# Load enums (HeritageTypeEnum, etc.)
|
|
self._enums = rules.get("enums", {})
|
|
|
|
# Extract valid institution type codes from HeritageTypeEnum
|
|
heritage_enum = self._enums.get("HeritageTypeEnum", {})
|
|
self._institution_type_codes = set(heritage_enum.get("values", []))
|
|
|
|
# Load institution type mappings (case-insensitive lookup)
|
|
for k, v in rules.get("institution_type_mappings", {}).items():
|
|
self._institution_type_mappings[k.lower()] = v
|
|
|
|
# Load subregion mappings
|
|
for k, v in rules.get("subregion_mappings", {}).items():
|
|
self._subregion_mappings[k.lower()] = v
|
|
|
|
# Load country mappings
|
|
for k, v in rules.get("country_mappings", {}).items():
|
|
self._country_mappings[k.lower()] = v
|
|
|
|
# Load property mappings to extract predicates
|
|
property_mappings = rules.get("property_mappings", {})
|
|
for prop_name, prop_def in property_mappings.items():
|
|
if isinstance(prop_def, dict) and "error" not in prop_def:
|
|
# Add the property name as a valid predicate
|
|
self._predicates.add(prop_name)
|
|
|
|
# Load namespace prefixes to build external predicates
|
|
namespaces = rules.get("namespaces", {})
|
|
|
|
logger.info(
|
|
f"Loaded from validation rules: {len(self._enums)} enums, "
|
|
f"{len(self._institution_type_codes)} type codes, "
|
|
f"{len(self._institution_type_mappings)} type mappings, "
|
|
f"{len(self._subregion_mappings)} subregion mappings"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load validation rules: {e}")
|
|
|
|
def _load_from_linkml(self) -> None:
|
|
"""Load predicates from LinkML schema YAML files."""
|
|
import yaml
|
|
|
|
slots_dir = LINKML_SCHEMA_PATH / "modules" / "slots"
|
|
if not slots_dir.exists():
|
|
logger.warning(f"LinkML slots directory not found: {slots_dir}")
|
|
return
|
|
|
|
# Scan all slot YAML files for slot_uri definitions
|
|
for yaml_file in slots_dir.glob("*.yaml"):
|
|
try:
|
|
with open(yaml_file) as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
if not data or "slots" not in data:
|
|
continue
|
|
|
|
for slot_name, slot_def in data.get("slots", {}).items():
|
|
if isinstance(slot_def, dict) and "slot_uri" in slot_def:
|
|
uri = slot_def["slot_uri"]
|
|
self._predicates.add(uri)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error loading {yaml_file}: {e}")
|
|
|
|
# Load classes from classes directory
|
|
classes_dir = LINKML_SCHEMA_PATH / "modules" / "classes"
|
|
if classes_dir.exists():
|
|
for yaml_file in classes_dir.glob("*.yaml"):
|
|
try:
|
|
with open(yaml_file) as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
if not data or "classes" not in data:
|
|
continue
|
|
|
|
for class_name, class_def in data.get("classes", {}).items():
|
|
if isinstance(class_def, dict) and "class_uri" in class_def:
|
|
uri = class_def["class_uri"]
|
|
self._classes.add(uri)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error loading {yaml_file}: {e}")
|
|
|
|
logger.info(f"Loaded {len(self._predicates)} predicates and {len(self._classes)} classes from LinkML")
|
|
|
|
def _query_kg_for_values(self, sparql_query: str, use_cache: bool = True) -> set[str]:
|
|
"""Execute SPARQL query against the Knowledge Graph with TTL-based caching.
|
|
|
|
Args:
|
|
sparql_query: SPARQL query to execute
|
|
use_cache: Whether to use cached results (default: True)
|
|
|
|
Returns:
|
|
Set of values from the query results, or empty set on failure.
|
|
|
|
Caching:
|
|
Results are cached using query hash as key. Cached results are
|
|
returned if within TTL window, otherwise a fresh query is made.
|
|
"""
|
|
import hashlib
|
|
import urllib.request
|
|
import urllib.parse
|
|
|
|
# Generate cache key from query hash
|
|
query_hash = hashlib.md5(sparql_query.encode()).hexdigest()
|
|
|
|
# Check cache if enabled
|
|
if use_cache and query_hash in self._kg_cache:
|
|
cache_time = self._kg_cache_timestamps.get(query_hash, 0)
|
|
if time.time() - cache_time < self._kg_cache_ttl:
|
|
logger.debug(f"KG cache hit for query (age: {time.time() - cache_time:.1f}s)")
|
|
return self._kg_cache[query_hash]
|
|
else:
|
|
logger.debug(f"KG cache expired for query (age: {time.time() - cache_time:.1f}s)")
|
|
|
|
try:
|
|
# Encode query
|
|
params = urllib.parse.urlencode({"query": sparql_query})
|
|
url = f"{SPARQL_ENDPOINT}?{params}"
|
|
|
|
req = urllib.request.Request(url)
|
|
req.add_header("Accept", "application/sparql-results+json")
|
|
|
|
with urllib.request.urlopen(req, timeout=5) as response:
|
|
result = json.loads(response.read().decode())
|
|
|
|
values = set()
|
|
for binding in result.get("results", {}).get("bindings", []):
|
|
for var_name, var_data in binding.items():
|
|
values.add(var_data.get("value", ""))
|
|
|
|
# Cache the results
|
|
self._kg_cache[query_hash] = values
|
|
self._kg_cache_timestamps[query_hash] = time.time()
|
|
|
|
return values
|
|
|
|
except Exception as e:
|
|
logger.debug(f"KG query failed (using fallback): {e}")
|
|
# Return cached value if available, even if expired
|
|
if query_hash in self._kg_cache:
|
|
logger.debug("Returning stale cached KG results due to query failure")
|
|
return self._kg_cache[query_hash]
|
|
return set()
|
|
|
|
def _load_institution_types_from_kg(self) -> None:
|
|
"""Load valid institution types from the Knowledge Graph."""
|
|
# Query for distinct institution types
|
|
query = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
SELECT DISTINCT ?type WHERE {
|
|
?s hc:institutionType ?type .
|
|
}
|
|
"""
|
|
|
|
values = self._query_kg_for_values(query)
|
|
if values:
|
|
self._slot_values["institution_type"] = values
|
|
logger.info(f"Loaded {len(values)} institution types from KG")
|
|
|
|
def _load_subregions_from_kg(self) -> None:
|
|
"""Load valid subregion codes from the Knowledge Graph."""
|
|
query = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
SELECT DISTINCT ?code WHERE {
|
|
?s hc:subregionCode ?code .
|
|
}
|
|
"""
|
|
|
|
values = self._query_kg_for_values(query)
|
|
if values:
|
|
self._slot_values["subregion"] = values
|
|
logger.info(f"Loaded {len(values)} subregion codes from KG")
|
|
|
|
def _load_countries_from_kg(self) -> None:
|
|
"""Load valid country codes from the Knowledge Graph."""
|
|
query = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
SELECT DISTINCT ?code WHERE {
|
|
?s hc:countryCode ?code .
|
|
}
|
|
"""
|
|
|
|
values = self._query_kg_for_values(query)
|
|
if values:
|
|
self._slot_values["country"] = values
|
|
logger.info(f"Loaded {len(values)} country codes from KG")
|
|
|
|
def _load_cities_from_kg(self) -> None:
|
|
"""Load valid city names from the Knowledge Graph."""
|
|
query = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
SELECT DISTINCT ?city WHERE {
|
|
?s hc:settlementName ?city .
|
|
}
|
|
"""
|
|
|
|
values = self._query_kg_for_values(query)
|
|
if values:
|
|
self._slot_values["city"] = values
|
|
logger.info(f"Loaded {len(values)} cities from KG")
|
|
|
|
def load(self) -> None:
|
|
"""Load all ontology data from LinkML, validation rules, and KG."""
|
|
if self._loaded:
|
|
return
|
|
|
|
logger.info("Loading ontology from LinkML schema, validation rules, and Knowledge Graph...")
|
|
|
|
# Load from validation rules JSON (enums, mappings)
|
|
self._load_from_validation_rules()
|
|
|
|
# Load predicates from LinkML schema YAML files
|
|
self._load_from_linkml()
|
|
|
|
# Load valid values from Knowledge Graph
|
|
self._load_institution_types_from_kg()
|
|
self._load_subregions_from_kg()
|
|
self._load_countries_from_kg()
|
|
self._load_cities_from_kg()
|
|
|
|
self._loaded = True
|
|
|
|
logger.info(
|
|
f"Ontology loaded: {len(self._predicates)} predicates, "
|
|
f"{len(self._classes)} classes, "
|
|
f"{len(self._slot_values)} slot value sets, "
|
|
f"{len(self._institution_type_codes)} institution type codes"
|
|
)
|
|
|
|
def get_predicates(self) -> set[str]:
|
|
"""Get all valid predicates from the ontology."""
|
|
self.load()
|
|
return self._predicates
|
|
|
|
def get_classes(self) -> set[str]:
|
|
"""Get all valid classes from the ontology."""
|
|
self.load()
|
|
return self._classes
|
|
|
|
def get_valid_values(self, slot_name: str) -> set[str]:
|
|
"""Get valid values for a slot from the Knowledge Graph."""
|
|
self.load()
|
|
return self._slot_values.get(slot_name, set())
|
|
|
|
def is_valid_value(self, slot_name: str, value: str) -> bool:
|
|
"""Check if a value is valid for a slot."""
|
|
valid_values = self.get_valid_values(slot_name)
|
|
if not valid_values:
|
|
return True # No KG data, assume valid
|
|
return value in valid_values or value.upper() in valid_values
|
|
|
|
def get_institution_type_codes(self) -> set[str]:
|
|
"""Get valid single-letter institution type codes from HeritageTypeEnum.
|
|
|
|
Returns:
|
|
Set of valid codes: {"G", "L", "A", "M", "O", "R", "C", "U", "B", "E", "S", "F", "I", "X", "P", "H", "D", "N", "T"}
|
|
"""
|
|
self.load()
|
|
return self._institution_type_codes
|
|
|
|
def get_institution_type_mappings(self) -> dict[str, str]:
|
|
"""Get institution type mappings (name → code).
|
|
|
|
Returns:
|
|
Dict mapping type names/synonyms to single-letter codes.
|
|
Example: {"museum": "M", "library": "L", "archive": "A", ...}
|
|
"""
|
|
self.load()
|
|
return self._institution_type_mappings
|
|
|
|
def get_subregion_mappings(self) -> dict[str, str]:
|
|
"""Get subregion mappings (name → ISO code).
|
|
|
|
Returns:
|
|
Dict mapping region names to ISO 3166-2 codes.
|
|
Example: {"noord-holland": "NL-NH", "limburg": "NL-LI", ...}
|
|
"""
|
|
self.load()
|
|
return self._subregion_mappings
|
|
|
|
def get_country_mappings(self) -> dict[str, str]:
|
|
"""Get country mappings (ISO code → Wikidata ID).
|
|
|
|
Returns:
|
|
Dict mapping ISO country codes to Wikidata entity IRIs.
|
|
Example: {"nl": "wd:Q55", "de": "wd:Q183", ...}
|
|
"""
|
|
self.load()
|
|
return self._country_mappings
|
|
|
|
def get_enum_values(self, enum_name: str) -> list[str]:
|
|
"""Get valid values for a specific enum from the schema.
|
|
|
|
Args:
|
|
enum_name: Name of the enum (e.g., "HeritageTypeEnum", "OrganizationalChangeEventTypeEnum")
|
|
|
|
Returns:
|
|
List of valid enum values, or empty list if not found.
|
|
"""
|
|
self.load()
|
|
enum_def = self._enums.get(enum_name, {})
|
|
return enum_def.get("values", [])
|
|
|
|
def set_kg_cache_ttl(self, ttl_seconds: float) -> None:
|
|
"""Set the TTL for KG query cache.
|
|
|
|
Args:
|
|
ttl_seconds: Time-to-live in seconds for cached KG query results.
|
|
Default is 300 seconds (5 minutes).
|
|
"""
|
|
self._kg_cache_ttl = ttl_seconds
|
|
logger.info(f"KG cache TTL set to {ttl_seconds} seconds")
|
|
|
|
def get_kg_cache_ttl(self) -> float:
|
|
"""Get the current TTL for KG query cache."""
|
|
return self._kg_cache_ttl
|
|
|
|
def clear_kg_cache(self) -> None:
|
|
"""Clear the KG query cache, forcing fresh queries on next access."""
|
|
self._kg_cache.clear()
|
|
self._kg_cache_timestamps.clear()
|
|
logger.info("KG query cache cleared")
|
|
|
|
def refresh_kg_values(self) -> None:
|
|
"""Force refresh of KG-loaded slot values (institution types, subregions, etc.).
|
|
|
|
This clears the KG cache and reloads all slot values from the Knowledge Graph.
|
|
Useful when KG data has been updated and you need fresh values.
|
|
"""
|
|
# Clear KG cache
|
|
self.clear_kg_cache()
|
|
|
|
# Clear slot values loaded from KG
|
|
kg_slots = ["institution_type", "subregion", "country", "city"]
|
|
for slot in kg_slots:
|
|
if slot in self._slot_values:
|
|
del self._slot_values[slot]
|
|
|
|
# Reload from KG
|
|
self._load_institution_types_from_kg()
|
|
self._load_subregions_from_kg()
|
|
self._load_countries_from_kg()
|
|
self._load_cities_from_kg()
|
|
|
|
self._kg_values_last_refresh = time.time()
|
|
logger.info("KG slot values refreshed")
|
|
|
|
def get_kg_cache_stats(self) -> dict[str, Any]:
|
|
"""Get statistics about the KG query cache.
|
|
|
|
Returns:
|
|
Dict with cache statistics including size, age of entries,
|
|
and last refresh time.
|
|
"""
|
|
now = time.time()
|
|
stats = {
|
|
"cache_size": len(self._kg_cache),
|
|
"ttl_seconds": self._kg_cache_ttl,
|
|
"last_kg_refresh": self._kg_values_last_refresh,
|
|
"entries": {}
|
|
}
|
|
|
|
for query_hash, timestamp in self._kg_cache_timestamps.items():
|
|
age = now - timestamp
|
|
stats["entries"][query_hash[:8]] = {
|
|
"age_seconds": round(age, 1),
|
|
"expired": age >= self._kg_cache_ttl,
|
|
"result_count": len(self._kg_cache.get(query_hash, set()))
|
|
}
|
|
|
|
return stats
|
|
|
|
def clear_all_cache(self) -> None:
|
|
"""Clear all cached data and reset to initial state.
|
|
|
|
This resets the OntologyLoader to its initial state, requiring
|
|
a full reload on next access. Use with caution.
|
|
"""
|
|
self._predicates.clear()
|
|
self._external_predicates.clear()
|
|
self._classes.clear()
|
|
self._slot_values.clear()
|
|
self._synonyms.clear()
|
|
self._enums.clear()
|
|
self._institution_type_codes.clear()
|
|
self._institution_type_mappings.clear()
|
|
self._subregion_mappings.clear()
|
|
self._country_mappings.clear()
|
|
self._kg_cache.clear()
|
|
self._kg_cache_timestamps.clear()
|
|
self._loaded = False
|
|
self._kg_values_last_refresh = 0.0
|
|
logger.info("All OntologyLoader cache cleared")
|
|
|
|
|
|
# Global ontology loader instance
|
|
_ontology_loader: Optional[OntologyLoader] = None
|
|
|
|
def get_ontology_loader() -> OntologyLoader:
|
|
"""Get or create the global ontology loader."""
|
|
global _ontology_loader
|
|
if _ontology_loader is None:
|
|
_ontology_loader = OntologyLoader()
|
|
return _ontology_loader
|
|
|
|
# Standard SPARQL prefixes
|
|
SPARQL_PREFIXES = """PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
|
|
PREFIX crm: <http://www.cidoc-crm.org/cidoc-crm/>
|
|
PREFIX schema: <http://schema.org/>
|
|
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
|
|
PREFIX org: <http://www.w3.org/ns/org#>
|
|
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
|
PREFIX dcterms: <http://purl.org/dc/terms/>
|
|
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
|
|
PREFIX wd: <http://www.wikidata.org/entity/>
|
|
PREFIX geo: <http://www.w3.org/2003/01/geo/wgs84_pos#>"""
|
|
|
|
|
|
# =============================================================================
|
|
# PYDANTIC MODELS
|
|
# =============================================================================
|
|
|
|
class SlotType(str, Enum):
|
|
"""Types of template slots."""
|
|
INSTITUTION_TYPE = "institution_type"
|
|
SUBREGION = "subregion"
|
|
COUNTRY = "country"
|
|
CITY = "city"
|
|
INSTITUTION_NAME = "institution_name"
|
|
BUDGET_CATEGORY = "budget_category"
|
|
STRING = "string"
|
|
INTEGER = "integer"
|
|
DECIMAL = "decimal"
|
|
|
|
|
|
class SlotDefinition(BaseModel):
|
|
"""Definition of a template slot."""
|
|
type: SlotType
|
|
required: bool = True
|
|
default: Optional[str] = None
|
|
examples: list[str] = Field(default_factory=list)
|
|
fallback_types: list[SlotType] = Field(default_factory=list)
|
|
valid_values: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class TemplateDefinition(BaseModel):
|
|
"""Definition of a SPARQL query template."""
|
|
id: str
|
|
description: str
|
|
intent: list[str]
|
|
question_patterns: list[str]
|
|
slots: dict[str, SlotDefinition]
|
|
sparql_template: str
|
|
sparql_template_alt: Optional[str] = None
|
|
sparql_template_region: Optional[str] = None
|
|
sparql_template_country: Optional[str] = None
|
|
sparql_template_isil: Optional[str] = None
|
|
sparql_template_ghcid: Optional[str] = None
|
|
examples: list[dict[str, Any]] = Field(default_factory=list)
|
|
# Response rendering configuration (template-driven, not hardcoded)
|
|
# Available modes: "table", "count", "prose", "chart", "map"
|
|
# Fast path rule: If "prose" is NOT in response_modes, LLM generation is skipped
|
|
response_modes: list[str] = Field(default_factory=lambda: ["prose"])
|
|
# Optional per-language UI templates for formatting simple answers
|
|
ui_template: Optional[dict[str, str]] = None
|
|
# Database routing configuration
|
|
# Available databases: "oxigraph" (SPARQL/KG), "qdrant" (vector search)
|
|
# When specified, only the listed databases are queried (skipping others)
|
|
# Default: query both databases (backward compatible)
|
|
# Use ["oxigraph"] for factual/geographic queries where vector search adds noise
|
|
databases: list[str] = Field(default_factory=lambda: ["oxigraph", "qdrant"])
|
|
|
|
|
|
class FollowUpPattern(BaseModel):
|
|
"""Definition of a follow-up question pattern."""
|
|
description: str
|
|
patterns: list[str]
|
|
slot_inheritance: list[str] = Field(default_factory=list)
|
|
transforms_to: Optional[str] = None
|
|
resolution_strategy: str
|
|
requires_previous_results: bool = False
|
|
|
|
|
|
class FykeFilterConfig(BaseModel):
|
|
"""Configuration for the Fyke filter."""
|
|
out_of_scope_keywords: list[str]
|
|
out_of_scope_categories: list[str]
|
|
heritage_keywords: list[str]
|
|
standard_response: dict[str, str]
|
|
|
|
|
|
class ConversationTurn(BaseModel):
|
|
"""A single turn in conversation history."""
|
|
role: Literal["user", "assistant"]
|
|
content: str
|
|
resolved_question: Optional[str] = None
|
|
template_id: Optional[str] = None
|
|
slots: dict[str, str] = Field(default_factory=dict)
|
|
results: list[dict[str, Any]] = Field(default_factory=list)
|
|
|
|
|
|
class ConversationState(BaseModel):
|
|
"""State tracking across conversation turns."""
|
|
turns: list[ConversationTurn] = Field(default_factory=list)
|
|
current_slots: dict[str, str] = Field(default_factory=dict)
|
|
current_template_id: Optional[str] = None
|
|
language: str = "nl"
|
|
|
|
def add_turn(self, turn: ConversationTurn) -> None:
|
|
"""Add a turn and update current state."""
|
|
self.turns.append(turn)
|
|
if turn.role == "user" and turn.slots:
|
|
# Inherit slots from user turns
|
|
self.current_slots.update(turn.slots)
|
|
if turn.template_id:
|
|
self.current_template_id = turn.template_id
|
|
|
|
def get_previous_user_turn(self) -> Optional[ConversationTurn]:
|
|
"""Get the most recent user turn."""
|
|
for turn in reversed(self.turns):
|
|
if turn.role == "user":
|
|
return turn
|
|
return None
|
|
|
|
def to_dspy_history(self) -> History:
|
|
"""Convert to DSPy History object."""
|
|
messages = []
|
|
for turn in self.turns[-6:]: # Keep last 6 turns for context
|
|
messages.append({
|
|
"role": turn.role,
|
|
"content": turn.resolved_question or turn.content
|
|
})
|
|
return History(messages=messages)
|
|
|
|
|
|
class TemplateMatchResult(BaseModel):
|
|
"""Result of template matching."""
|
|
matched: bool
|
|
template_id: Optional[str] = None
|
|
confidence: float = 0.0
|
|
slots: dict[str, str] = Field(default_factory=dict)
|
|
sparql: Optional[str] = None
|
|
reasoning: str = ""
|
|
# Response rendering configuration (passed through from template definition)
|
|
response_modes: list[str] = Field(default_factory=lambda: ["prose"])
|
|
ui_template: Optional[dict[str, str]] = None
|
|
# Database routing configuration (passed through from template definition)
|
|
# When ["oxigraph"] only, vector search is skipped for faster, deterministic results
|
|
databases: list[str] = Field(default_factory=lambda: ["oxigraph", "qdrant"])
|
|
|
|
def requires_llm(self) -> bool:
|
|
"""Check if this template requires LLM prose generation.
|
|
|
|
Fast path rule: If "prose" is NOT in response_modes, LLM generation is skipped.
|
|
"""
|
|
return "prose" in self.response_modes
|
|
|
|
def use_oxigraph(self) -> bool:
|
|
"""Check if this template should query Oxigraph (SPARQL/KG)."""
|
|
return "oxigraph" in self.databases
|
|
|
|
def use_qdrant(self) -> bool:
|
|
"""Check if this template should query Qdrant (vector search)."""
|
|
return "qdrant" in self.databases
|
|
|
|
|
|
class ResolvedQuestion(BaseModel):
|
|
"""Result of conversation context resolution."""
|
|
original: str
|
|
resolved: str
|
|
is_follow_up: bool = False
|
|
follow_up_type: Optional[str] = None
|
|
inherited_slots: dict[str, str] = Field(default_factory=dict)
|
|
confidence: float = 1.0
|
|
|
|
|
|
class FykeResult(BaseModel):
|
|
"""Result of Fyke filter."""
|
|
is_relevant: bool
|
|
confidence: float
|
|
reasoning: str
|
|
standard_response: Optional[str] = None
|
|
|
|
|
|
# =============================================================================
|
|
# SYNONYM MAPPINGS (loaded from OntologyLoader)
|
|
# =============================================================================
|
|
|
|
class SynonymResolver:
|
|
"""Resolves natural language terms to canonical slot values.
|
|
|
|
Uses OntologyLoader to get mappings from the validation rules JSON,
|
|
eliminating hardcoded heuristics.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._institution_types: dict[str, str] = {}
|
|
self._subregions: dict[str, str] = {}
|
|
self._countries: dict[str, str] = {}
|
|
self._cities: set[str] = set()
|
|
self._budget_categories: dict[str, str] = {}
|
|
self._valid_type_codes: set[str] = set() # From HeritageTypeEnum
|
|
self._loaded = False
|
|
|
|
def load(self) -> None:
|
|
"""Load synonym mappings from OntologyLoader and templates."""
|
|
if self._loaded:
|
|
return
|
|
|
|
# Get mappings from OntologyLoader (loads from validation rules JSON)
|
|
ontology = get_ontology_loader()
|
|
ontology.load()
|
|
|
|
# Get institution type mappings from OntologyLoader
|
|
self._institution_types = dict(ontology.get_institution_type_mappings())
|
|
|
|
# Get valid type codes from HeritageTypeEnum (replaces hardcoded "MLAGORCUBESFIXPHDNT")
|
|
self._valid_type_codes = ontology.get_institution_type_codes()
|
|
|
|
# Get subregion mappings from OntologyLoader
|
|
self._subregions = dict(ontology.get_subregion_mappings())
|
|
|
|
# Get country mappings from OntologyLoader
|
|
self._countries = dict(ontology.get_country_mappings())
|
|
|
|
# Load additional synonyms from templates YAML
|
|
if TEMPLATES_PATH.exists():
|
|
try:
|
|
import yaml
|
|
with open(TEMPLATES_PATH) as f:
|
|
templates = yaml.safe_load(f)
|
|
|
|
slot_types = templates.get("_slot_types", {})
|
|
|
|
# Institution type synonyms (merge with ontology mappings)
|
|
inst_synonyms = slot_types.get("institution_type", {}).get("synonyms", {})
|
|
for k, v in inst_synonyms.items():
|
|
key = k.lower().replace("_", " ")
|
|
if key not in self._institution_types:
|
|
self._institution_types[key] = v
|
|
|
|
# Subregion synonyms (merge with ontology mappings)
|
|
region_synonyms = slot_types.get("subregion", {}).get("synonyms", {})
|
|
for k, v in region_synonyms.items():
|
|
key = k.lower().replace("_", " ")
|
|
if key not in self._subregions:
|
|
self._subregions[key] = v
|
|
|
|
# Country synonyms (OVERRIDE ontology mappings - YAML synonyms use ISO codes)
|
|
# The ontology mappings have Wikidata IDs (wd:Q31) but we need ISO codes (BE)
|
|
country_synonyms = slot_types.get("country", {}).get("synonyms", {})
|
|
for k, v in country_synonyms.items():
|
|
key = k.lower().replace("_", " ")
|
|
# Always use YAML value - it has ISO codes, not Wikidata IDs
|
|
self._countries[key] = v
|
|
|
|
# Budget category synonyms
|
|
budget_synonyms = slot_types.get("budget_category", {}).get("synonyms", {})
|
|
for k, v in budget_synonyms.items():
|
|
self._budget_categories[k.lower().replace("_", " ")] = v
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load template synonyms: {e}")
|
|
|
|
# Add common Dutch institution type synonyms (fallback for NLP variations)
|
|
# These are added as fallback if not already in the ontology mappings
|
|
dutch_types = {
|
|
"museum": "M", "musea": "M", "museums": "M",
|
|
"bibliotheek": "L", "bibliotheken": "L", "library": "L", "libraries": "L",
|
|
"archief": "A", "archieven": "A", "archive": "A", "archives": "A",
|
|
"galerie": "G", "galerij": "G", "galerijen": "G", "gallery": "G", "galleries": "G",
|
|
}
|
|
for k, v in dutch_types.items():
|
|
if k not in self._institution_types:
|
|
self._institution_types[k] = v
|
|
|
|
self._loaded = True
|
|
logger.info(f"Loaded {len(self._institution_types)} institution types, "
|
|
f"{len(self._subregions)} subregions, {len(self._countries)} countries, "
|
|
f"{len(self._valid_type_codes)} valid type codes")
|
|
|
|
def resolve_institution_type(self, term: str) -> Optional[str]:
|
|
"""Resolve institution type term to single-letter code."""
|
|
self.load()
|
|
term_lower = term.lower().strip()
|
|
|
|
# Direct match
|
|
if term_lower in self._institution_types:
|
|
return self._institution_types[term_lower]
|
|
|
|
# Already a valid code - use ontology-derived codes instead of hardcoded string
|
|
if term.upper() in self._valid_type_codes:
|
|
return term.upper()
|
|
|
|
# Fuzzy match
|
|
if self._institution_types:
|
|
match = process.extractOne(
|
|
term_lower,
|
|
list(self._institution_types.keys()),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=80
|
|
)
|
|
if match:
|
|
return self._institution_types[match[0]]
|
|
|
|
return None
|
|
|
|
def resolve_subregion(self, term: str) -> Optional[str]:
|
|
"""Resolve subregion term to ISO 3166-2 code."""
|
|
self.load()
|
|
term_lower = term.lower().strip()
|
|
|
|
# Direct match
|
|
if term_lower in self._subregions:
|
|
return self._subregions[term_lower]
|
|
|
|
# Already a valid code (e.g., NL-NH)
|
|
if re.match(r'^[A-Z]{2}-[A-Z]{2,3}$', term.upper()):
|
|
return term.upper()
|
|
|
|
# Fuzzy match
|
|
if self._subregions:
|
|
match = process.extractOne(
|
|
term_lower,
|
|
list(self._subregions.keys()),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=75
|
|
)
|
|
if match:
|
|
return self._subregions[match[0]]
|
|
|
|
return None
|
|
|
|
def resolve_country(self, term: str) -> Optional[str]:
|
|
"""Resolve country term to ISO 3166-1 alpha-2 code."""
|
|
self.load()
|
|
term_lower = term.lower().strip()
|
|
|
|
# Direct match
|
|
if term_lower in self._countries:
|
|
return self._countries[term_lower]
|
|
|
|
# Already an ISO 3166-1 alpha-2 code (e.g., "NL", "BE", "DE")
|
|
if re.match(r'^[A-Z]{2}$', term.upper()):
|
|
return term.upper()
|
|
|
|
# Fuzzy match
|
|
if self._countries:
|
|
match = process.extractOne(
|
|
term_lower,
|
|
list(self._countries.keys()),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=80
|
|
)
|
|
if match:
|
|
return self._countries[match[0]]
|
|
|
|
return None
|
|
|
|
def resolve_city(self, term: str) -> str:
|
|
"""Normalize city name (title case, common corrections)."""
|
|
# Common Dutch city name corrections
|
|
corrections = {
|
|
"den haag": "Den Haag",
|
|
"the hague": "Den Haag",
|
|
"'s-gravenhage": "Den Haag",
|
|
"s-gravenhage": "Den Haag",
|
|
"'s-hertogenbosch": "'s-Hertogenbosch",
|
|
"den bosch": "'s-Hertogenbosch",
|
|
}
|
|
|
|
term_lower = term.lower().strip()
|
|
if term_lower in corrections:
|
|
return corrections[term_lower]
|
|
|
|
# Title case with Dutch article handling
|
|
if term_lower.startswith("'s-"):
|
|
return "'" + term[1:2] + "-" + term[3:].title()
|
|
|
|
return term.title()
|
|
|
|
def resolve_budget_category(self, term: str) -> Optional[str]:
|
|
"""Resolve budget category term to canonical slot name.
|
|
|
|
Args:
|
|
term: Budget category term (e.g., "innovatie", "digitalisering", "innovation budget")
|
|
|
|
Returns:
|
|
Canonical budget category slot name (e.g., "innovation", "digitization") or None
|
|
"""
|
|
self.load()
|
|
# Normalize: lowercase and replace underscores with spaces (consistent with synonym loading)
|
|
term_normalized = term.lower().strip().replace("_", " ")
|
|
|
|
# Direct match from synonyms
|
|
if term_normalized in self._budget_categories:
|
|
return self._budget_categories[term_normalized]
|
|
|
|
# Already a valid category
|
|
valid_categories = [
|
|
"innovation", "digitization", "preservation", "personnel",
|
|
"acquisition", "operating", "capital", "external_funding",
|
|
"internal_funding", "endowment_draw"
|
|
]
|
|
if term_normalized in valid_categories:
|
|
return term_normalized
|
|
|
|
# Fuzzy match against loaded synonyms
|
|
if self._budget_categories:
|
|
match = process.extractOne(
|
|
term_normalized,
|
|
list(self._budget_categories.keys()),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=75
|
|
)
|
|
if match:
|
|
return self._budget_categories[match[0]]
|
|
|
|
return None
|
|
|
|
def is_region(self, term: str) -> bool:
|
|
"""Check if a term is a known region/province name.
|
|
|
|
This is used to disambiguate between city and region patterns
|
|
when both would match the same question structure.
|
|
|
|
Args:
|
|
term: Location term to check (e.g., "Noord-Holland", "Amsterdam")
|
|
|
|
Returns:
|
|
True if the term resolves to a known region, False otherwise
|
|
"""
|
|
self.load()
|
|
term_lower = term.lower().strip()
|
|
|
|
# Check if it's in our subregions mapping
|
|
if term_lower in self._subregions:
|
|
return True
|
|
|
|
# Check if it's already a valid ISO code
|
|
if re.match(r'^[A-Z]{2}-[A-Z]{2,3}$', term.upper()):
|
|
return True
|
|
|
|
# Fuzzy match with high threshold to avoid false positives
|
|
if self._subregions:
|
|
match = process.extractOne(
|
|
term_lower,
|
|
list(self._subregions.keys()),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=85 # Higher threshold than resolve_subregion
|
|
)
|
|
if match:
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_country(self, term: str) -> bool:
|
|
"""Check if a term is a known country name.
|
|
|
|
This is used to disambiguate between city and country patterns
|
|
when both would match the same question structure.
|
|
|
|
Args:
|
|
term: Location term to check (e.g., "Belgium", "Netherlands")
|
|
|
|
Returns:
|
|
True if the term resolves to a known country, False otherwise
|
|
"""
|
|
self.load()
|
|
term_lower = term.lower().strip()
|
|
|
|
# Check if it's in our countries mapping
|
|
if term_lower in self._countries:
|
|
return True
|
|
|
|
# Check if it's already a valid ISO country code (2-letter)
|
|
if re.match(r'^[A-Z]{2}$', term.upper()):
|
|
return True
|
|
|
|
# Fuzzy match with high threshold to avoid false positives
|
|
if self._countries:
|
|
match = process.extractOne(
|
|
term_lower,
|
|
list(self._countries.keys()),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=85 # High threshold to avoid "Berlin" matching "Belgium"
|
|
)
|
|
if match:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
# Global synonym resolver instance
|
|
_synonym_resolver: Optional[SynonymResolver] = None
|
|
|
|
def get_synonym_resolver() -> SynonymResolver:
|
|
"""Get or create the global synonym resolver."""
|
|
global _synonym_resolver
|
|
if _synonym_resolver is None:
|
|
_synonym_resolver = SynonymResolver()
|
|
return _synonym_resolver
|
|
|
|
|
|
# =============================================================================
|
|
# SCHEMA-AWARE SLOT VALIDATOR (SOTA Pattern)
|
|
# =============================================================================
|
|
|
|
class SlotValidationResult(BaseModel):
|
|
"""Result of slot value validation against ontology."""
|
|
valid: bool
|
|
original_value: str
|
|
corrected_value: Optional[str] = None
|
|
slot_name: str
|
|
errors: list[str] = Field(default_factory=list)
|
|
suggestions: list[str] = Field(default_factory=list)
|
|
confidence: float = 1.0
|
|
|
|
|
|
class SchemaAwareSlotValidator:
|
|
"""Validates and auto-corrects slot values against the ontology schema.
|
|
|
|
Based on KGQuest (arXiv:2511.11258) pattern for schema-aware slot filling.
|
|
Uses the validation rules JSON which contains:
|
|
- Valid enum values for each slot type
|
|
- Synonym mappings for fuzzy matching
|
|
- Property constraints from LinkML schema
|
|
|
|
Architecture:
|
|
1. Load valid values from validation rules JSON
|
|
2. For each extracted slot, check if value is valid
|
|
3. If invalid, attempt fuzzy match to find correct value
|
|
4. Return validation result with corrections and suggestions
|
|
|
|
Caching:
|
|
- KG validation results are cached with TTL (default: 5 minutes)
|
|
- Use clear_kg_validation_cache() to reset cache
|
|
"""
|
|
|
|
_instance = None
|
|
_valid_values: dict[str, set[str]] = {}
|
|
_synonym_maps: dict[str, dict[str, str]] = {}
|
|
_loaded: bool = False
|
|
|
|
# TTL-based caching for KG validation
|
|
_kg_validation_cache: dict[str, bool] = {} # (slot_name, value) hash → is_valid
|
|
_kg_validation_timestamps: dict[str, float] = {} # (slot_name, value) hash → timestamp
|
|
_kg_validation_ttl: float = 300.0 # 5 minutes default TTL
|
|
|
|
def __new__(cls):
|
|
"""Singleton pattern."""
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def _load_validation_rules(self) -> None:
|
|
"""Load valid values and synonyms from validation rules JSON and SynonymResolver."""
|
|
if self._loaded:
|
|
return
|
|
|
|
# First, load from the SynonymResolver (has comprehensive Dutch mappings)
|
|
resolver = get_synonym_resolver()
|
|
resolver.load()
|
|
|
|
# Copy institution type mappings from resolver
|
|
if resolver._institution_types:
|
|
self._synonym_maps["institution_type"] = dict(resolver._institution_types)
|
|
self._valid_values["institution_type"] = set(resolver._institution_types.values())
|
|
|
|
# Copy subregion mappings from resolver
|
|
if resolver._subregions:
|
|
self._synonym_maps["subregion"] = dict(resolver._subregions)
|
|
self._valid_values["subregion"] = set(resolver._subregions.values())
|
|
|
|
# Copy country mappings from resolver
|
|
if resolver._countries:
|
|
self._synonym_maps["country"] = dict(resolver._countries)
|
|
self._valid_values["country"] = set(resolver._countries.values())
|
|
|
|
# Copy budget category mappings from resolver
|
|
if resolver._budget_categories:
|
|
self._synonym_maps["budget_category"] = dict(resolver._budget_categories)
|
|
self._valid_values["budget_category"] = set(resolver._budget_categories.values())
|
|
|
|
# Then, augment with validation rules JSON (has enum definitions and slot constraints)
|
|
if not VALIDATION_RULES_PATH.exists():
|
|
logger.warning(f"Validation rules not found: {VALIDATION_RULES_PATH}")
|
|
self._loaded = True
|
|
return
|
|
|
|
try:
|
|
with open(VALIDATION_RULES_PATH) as f:
|
|
rules = json.load(f)
|
|
|
|
# Augment institution type mappings (don't overwrite, merge)
|
|
if "institution_type_mappings" in rules:
|
|
mappings = rules["institution_type_mappings"]
|
|
if "institution_type" not in self._synonym_maps:
|
|
self._synonym_maps["institution_type"] = {}
|
|
for k, v in mappings.items():
|
|
self._synonym_maps["institution_type"][k.lower()] = v
|
|
if "institution_type" not in self._valid_values:
|
|
self._valid_values["institution_type"] = set()
|
|
self._valid_values["institution_type"].update(mappings.values())
|
|
|
|
# Augment subregion mappings
|
|
if "subregion_mappings" in rules:
|
|
mappings = rules["subregion_mappings"]
|
|
if "subregion" not in self._synonym_maps:
|
|
self._synonym_maps["subregion"] = {}
|
|
for k, v in mappings.items():
|
|
self._synonym_maps["subregion"][k.lower()] = v
|
|
if "subregion" not in self._valid_values:
|
|
self._valid_values["subregion"] = set()
|
|
self._valid_values["subregion"].update(mappings.values())
|
|
|
|
# Augment country mappings
|
|
if "country_mappings" in rules:
|
|
mappings = rules["country_mappings"]
|
|
if "country" not in self._synonym_maps:
|
|
self._synonym_maps["country"] = {}
|
|
for k, v in mappings.items():
|
|
self._synonym_maps["country"][k.lower()] = v
|
|
if "country" not in self._valid_values:
|
|
self._valid_values["country"] = set()
|
|
self._valid_values["country"].update(mappings.values())
|
|
|
|
# Load enum valid values from the 'enums' section
|
|
if "enums" in rules:
|
|
for enum_name, enum_data in rules["enums"].items():
|
|
if isinstance(enum_data, dict) and "permissible_values" in enum_data:
|
|
values = set(enum_data["permissible_values"].keys())
|
|
self._valid_values[enum_name] = values
|
|
|
|
# Load slot constraints
|
|
if "slots" in rules:
|
|
for slot_name, slot_data in rules["slots"].items():
|
|
if isinstance(slot_data, dict):
|
|
# Extract range enum if present
|
|
if "range" in slot_data:
|
|
range_enum = slot_data["range"]
|
|
if range_enum in self._valid_values:
|
|
self._valid_values[slot_name] = self._valid_values[range_enum]
|
|
|
|
self._loaded = True
|
|
logger.info(
|
|
f"Loaded schema validation rules: "
|
|
f"{len(self._valid_values)} value sets, "
|
|
f"{len(self._synonym_maps)} synonym maps"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load validation rules: {e}")
|
|
self._loaded = True
|
|
|
|
def validate_slot(
|
|
self,
|
|
slot_name: str,
|
|
value: str,
|
|
auto_correct: bool = True
|
|
) -> SlotValidationResult:
|
|
"""Validate a slot value against the ontology schema.
|
|
|
|
Args:
|
|
slot_name: Name of the slot (e.g., "institution_type", "subregion")
|
|
value: Extracted value to validate
|
|
auto_correct: Whether to attempt fuzzy matching for correction
|
|
|
|
Returns:
|
|
SlotValidationResult with validation status and corrections
|
|
"""
|
|
self._load_validation_rules()
|
|
|
|
result = SlotValidationResult(
|
|
valid=True,
|
|
original_value=value,
|
|
slot_name=slot_name
|
|
)
|
|
|
|
# Normalize value
|
|
value_normalized = value.strip()
|
|
value_lower = value_normalized.lower()
|
|
|
|
# Check synonym maps first (most common case)
|
|
if slot_name in self._synonym_maps:
|
|
synonym_map = self._synonym_maps[slot_name]
|
|
if value_lower in synonym_map:
|
|
# Direct synonym match - return canonical value
|
|
result.corrected_value = synonym_map[value_lower]
|
|
result.confidence = 1.0
|
|
return result
|
|
|
|
# Check if value is already a valid canonical value
|
|
if slot_name in self._valid_values:
|
|
valid_set = self._valid_values[slot_name]
|
|
if value_normalized in valid_set or value_normalized.upper() in valid_set:
|
|
result.confidence = 1.0
|
|
return result
|
|
|
|
# Value not valid - attempt correction if enabled
|
|
result.valid = False
|
|
result.errors.append(f"Invalid value '{value}' for slot '{slot_name}'")
|
|
|
|
if auto_correct:
|
|
# Try fuzzy matching against synonym keys
|
|
if slot_name in self._synonym_maps:
|
|
synonym_keys = list(self._synonym_maps[slot_name].keys())
|
|
match = process.extractOne(
|
|
value_lower,
|
|
synonym_keys,
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=70
|
|
)
|
|
if match:
|
|
corrected = self._synonym_maps[slot_name][match[0]]
|
|
result.corrected_value = corrected
|
|
result.confidence = match[1] / 100.0
|
|
result.suggestions.append(
|
|
f"Did you mean '{match[0]}' → '{corrected}'?"
|
|
)
|
|
return result
|
|
|
|
# Try fuzzy matching against valid values directly
|
|
match = process.extractOne(
|
|
value_normalized,
|
|
list(valid_set),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=70
|
|
)
|
|
if match:
|
|
result.corrected_value = match[0]
|
|
result.confidence = match[1] / 100.0
|
|
result.suggestions.append(f"Did you mean '{match[0]}'?")
|
|
return result
|
|
|
|
# No correction found - provide suggestions
|
|
if slot_name in self._valid_values:
|
|
sample_values = list(self._valid_values[slot_name])[:5]
|
|
result.suggestions.append(
|
|
f"Valid values include: {', '.join(sample_values)}"
|
|
)
|
|
|
|
return result
|
|
|
|
def validate_slot_against_kg(
|
|
self,
|
|
slot_name: str,
|
|
value: str,
|
|
use_cache: bool = True
|
|
) -> bool:
|
|
"""Validate a slot value against the Knowledge Graph with TTL-based caching.
|
|
|
|
This is a fallback when local validation has no data for the slot.
|
|
Queries the Oxigraph endpoint to verify the value exists in the KG.
|
|
|
|
The KG validation uses actual values stored in the Knowledge Graph,
|
|
which may differ from the static validation rules JSON.
|
|
|
|
Args:
|
|
slot_name: Name of the slot (e.g., "institution_type", "city", "subregion")
|
|
value: Value to validate against the KG
|
|
use_cache: Whether to use cached validation results (default: True)
|
|
|
|
Returns:
|
|
True if value exists in KG or if KG is unavailable, False if value is invalid.
|
|
|
|
Note:
|
|
This method is non-blocking - if KG query fails, it returns True to avoid
|
|
rejecting potentially valid values.
|
|
|
|
Caching:
|
|
Results are cached using (slot_name, value) as key. Cached results are
|
|
returned if within TTL window (default 5 minutes).
|
|
"""
|
|
# Generate cache key
|
|
cache_key = f"{slot_name}:{value}"
|
|
|
|
# Check cache if enabled
|
|
if use_cache and cache_key in self._kg_validation_cache:
|
|
cache_time = self._kg_validation_timestamps.get(cache_key, 0)
|
|
if time.time() - cache_time < self._kg_validation_ttl:
|
|
logger.debug(f"KG validation cache hit for {slot_name}='{value}'")
|
|
return self._kg_validation_cache[cache_key]
|
|
|
|
ontology = get_ontology_loader()
|
|
ontology.load()
|
|
|
|
# Map slot names to OntologyLoader slot value keys
|
|
slot_key_map = {
|
|
"institution_type": "institution_type",
|
|
"type": "institution_type",
|
|
"city": "city",
|
|
"settlement": "city",
|
|
"subregion": "subregion",
|
|
"province": "subregion",
|
|
"country": "country",
|
|
}
|
|
|
|
slot_key = slot_key_map.get(slot_name, slot_name)
|
|
|
|
# Use OntologyLoader's KG-based validation
|
|
is_valid = ontology.is_valid_value(slot_key, value)
|
|
|
|
# Cache the result
|
|
self._kg_validation_cache[cache_key] = is_valid
|
|
self._kg_validation_timestamps[cache_key] = time.time()
|
|
|
|
return is_valid
|
|
|
|
def validate_slot_with_kg_fallback(
|
|
self,
|
|
slot_name: str,
|
|
value: str,
|
|
auto_correct: bool = True
|
|
) -> SlotValidationResult:
|
|
"""Validate slot with KG fallback for values not in local validation rules.
|
|
|
|
This method first tries local validation (fast, uses cached rules).
|
|
If local validation has no data for the slot, it falls back to
|
|
querying the Knowledge Graph for validation.
|
|
|
|
Args:
|
|
slot_name: Name of the slot to validate
|
|
value: Value to validate
|
|
auto_correct: Whether to attempt fuzzy matching for correction
|
|
|
|
Returns:
|
|
SlotValidationResult with validation status, corrections, and source
|
|
"""
|
|
# First try local validation
|
|
result = self.validate_slot(slot_name, value, auto_correct)
|
|
|
|
# If local validation has no data for this slot, try KG validation
|
|
if slot_name not in self._valid_values and slot_name not in self._synonym_maps:
|
|
kg_valid = self.validate_slot_against_kg(slot_name, value)
|
|
if kg_valid:
|
|
result.valid = True
|
|
result.errors = []
|
|
result.confidence = 0.85 # KG validation has slightly lower confidence
|
|
logger.debug(f"Slot '{slot_name}' validated against KG: {value}")
|
|
else:
|
|
result.valid = False
|
|
result.errors.append(f"Value '{value}' not found in Knowledge Graph for slot '{slot_name}'")
|
|
logger.debug(f"Slot '{slot_name}' failed KG validation: {value}")
|
|
|
|
return result
|
|
|
|
def validate_slots(
|
|
self,
|
|
slots: dict[str, str],
|
|
auto_correct: bool = True
|
|
) -> dict[str, SlotValidationResult]:
|
|
"""Validate multiple slots at once.
|
|
|
|
Args:
|
|
slots: Dictionary of slot_name -> value
|
|
auto_correct: Whether to attempt corrections
|
|
|
|
Returns:
|
|
Dictionary of slot_name -> SlotValidationResult
|
|
"""
|
|
results = {}
|
|
for slot_name, value in slots.items():
|
|
results[slot_name] = self.validate_slot(slot_name, value, auto_correct)
|
|
return results
|
|
|
|
def get_corrected_slots(
|
|
self,
|
|
slots: dict[str, str],
|
|
min_confidence: float = 0.7
|
|
) -> dict[str, str]:
|
|
"""Get slots with auto-corrected values applied.
|
|
|
|
Args:
|
|
slots: Original slot values
|
|
min_confidence: Minimum confidence for applying corrections
|
|
|
|
Returns:
|
|
Dictionary with corrected values applied
|
|
"""
|
|
results = self.validate_slots(slots, auto_correct=True)
|
|
corrected = {}
|
|
|
|
for slot_name, result in results.items():
|
|
if result.corrected_value and result.confidence >= min_confidence:
|
|
corrected[slot_name] = result.corrected_value
|
|
if result.corrected_value != result.original_value:
|
|
logger.info(
|
|
f"Auto-corrected slot '{slot_name}': "
|
|
f"'{result.original_value}' → '{result.corrected_value}' "
|
|
f"(confidence={result.confidence:.2f})"
|
|
)
|
|
else:
|
|
corrected[slot_name] = result.original_value
|
|
|
|
return corrected
|
|
|
|
def set_kg_validation_ttl(self, ttl_seconds: float) -> None:
|
|
"""Set the TTL for KG validation cache.
|
|
|
|
Args:
|
|
ttl_seconds: Time-to-live in seconds for cached validation results.
|
|
Default is 300 seconds (5 minutes).
|
|
"""
|
|
self._kg_validation_ttl = ttl_seconds
|
|
logger.info(f"KG validation cache TTL set to {ttl_seconds} seconds")
|
|
|
|
def get_kg_validation_ttl(self) -> float:
|
|
"""Get the current TTL for KG validation cache."""
|
|
return self._kg_validation_ttl
|
|
|
|
def clear_kg_validation_cache(self) -> None:
|
|
"""Clear the KG validation cache, forcing fresh validations on next access."""
|
|
self._kg_validation_cache.clear()
|
|
self._kg_validation_timestamps.clear()
|
|
logger.info("KG validation cache cleared")
|
|
|
|
def get_kg_validation_cache_stats(self) -> dict[str, Any]:
|
|
"""Get statistics about the KG validation cache.
|
|
|
|
Returns:
|
|
Dict with cache statistics including size, hit rate, and age of entries.
|
|
"""
|
|
now = time.time()
|
|
valid_count = sum(1 for v in self._kg_validation_cache.values() if v)
|
|
invalid_count = len(self._kg_validation_cache) - valid_count
|
|
|
|
expired_count = sum(
|
|
1 for cache_key in self._kg_validation_timestamps
|
|
if now - self._kg_validation_timestamps[cache_key] >= self._kg_validation_ttl
|
|
)
|
|
|
|
return {
|
|
"cache_size": len(self._kg_validation_cache),
|
|
"valid_entries": valid_count,
|
|
"invalid_entries": invalid_count,
|
|
"expired_entries": expired_count,
|
|
"ttl_seconds": self._kg_validation_ttl,
|
|
}
|
|
|
|
|
|
# Global schema-aware slot validator instance
|
|
_schema_slot_validator: Optional[SchemaAwareSlotValidator] = None
|
|
|
|
def get_schema_slot_validator() -> SchemaAwareSlotValidator:
|
|
"""Get or create the schema-aware slot validator."""
|
|
global _schema_slot_validator
|
|
if _schema_slot_validator is None:
|
|
_schema_slot_validator = SchemaAwareSlotValidator()
|
|
return _schema_slot_validator
|
|
|
|
|
|
# =============================================================================
|
|
# DSPy SIGNATURES
|
|
# =============================================================================
|
|
|
|
class ConversationContextSignature(dspy.Signature):
|
|
"""Resolve elliptical follow-up questions using conversation history.
|
|
|
|
CRITICAL: This module runs FIRST, before any filtering or classification.
|
|
It expands short follow-up questions into complete, self-contained questions.
|
|
|
|
Examples of resolution:
|
|
- Previous: "Welke archieven zijn er in Den Haag?"
|
|
Current: "En in Enschede?"
|
|
Resolved: "Welke archieven zijn er in Enschede?"
|
|
|
|
- Previous: "Welke musea zijn er in Amsterdam?"
|
|
Current: "Hoeveel?"
|
|
Resolved: "Hoeveel musea zijn er in Amsterdam?"
|
|
|
|
- Previous: Listed 5 museums
|
|
Current: "Vertel me meer over de eerste"
|
|
Resolved: "Vertel me meer over [first museum name]"
|
|
|
|
The resolved question must be a complete, standalone question that would
|
|
make sense without any conversation history.
|
|
"""
|
|
|
|
question: str = dspy.InputField(
|
|
desc="Current user question (may be elliptical follow-up)"
|
|
)
|
|
history: History = dspy.InputField(
|
|
desc="Previous conversation turns",
|
|
default=History(messages=[])
|
|
)
|
|
previous_slots: str = dspy.InputField(
|
|
desc="JSON string of slot values from previous query (e.g., {\"institution_type\": \"A\", \"city\": \"Den Haag\"})",
|
|
default="{}"
|
|
)
|
|
previous_results_summary: str = dspy.InputField(
|
|
desc="Brief summary of previous query results for ordinal/pronoun resolution",
|
|
default=""
|
|
)
|
|
|
|
resolved_question: str = dspy.OutputField(
|
|
desc="Fully expanded, self-contained question. If not a follow-up, return original question unchanged."
|
|
)
|
|
is_follow_up: bool = dspy.OutputField(
|
|
desc="True if this was an elliptical follow-up that needed resolution"
|
|
)
|
|
follow_up_type: str = dspy.OutputField(
|
|
desc="Type of follow-up: 'location_swap', 'type_swap', 'count_from_list', 'details_request', 'ordinal_reference', 'pronoun_reference', or 'none'"
|
|
)
|
|
inherited_slots_json: str = dspy.OutputField(
|
|
desc="JSON string of slots inherited from previous query (e.g., {\"institution_type\": \"A\"})"
|
|
)
|
|
confidence: float = dspy.OutputField(
|
|
desc="Confidence in resolution (0.0-1.0)"
|
|
)
|
|
|
|
|
|
class FykeFilterSignature(dspy.Signature):
|
|
"""Determine if a RESOLVED question is relevant to heritage institutions.
|
|
|
|
CRITICAL: This filter operates on the RESOLVED question from ConversationContextResolver,
|
|
NOT the raw user input. This prevents false positives on short follow-ups like
|
|
"En in Enschede?" which resolves to "Welke archieven zijn er in Enschede?"
|
|
|
|
Heritage institutions include:
|
|
- Museums (musea)
|
|
- Archives (archieven)
|
|
- Libraries (bibliotheken)
|
|
- Galleries (galerijen)
|
|
- Heritage societies
|
|
- Cultural institutions
|
|
- Collections
|
|
|
|
Questions about these topics are RELEVANT.
|
|
Questions about shopping, weather, sports, restaurants, etc. are NOT relevant.
|
|
|
|
When in doubt, err on the side of relevance (return True).
|
|
"""
|
|
|
|
resolved_question: str = dspy.InputField(
|
|
desc="The fully resolved question (after context resolution)"
|
|
)
|
|
conversation_topic: str = dspy.InputField(
|
|
desc="Brief summary of what the conversation has been about so far",
|
|
default="heritage institutions"
|
|
)
|
|
|
|
is_relevant: bool = dspy.OutputField(
|
|
desc="True if question is about heritage institutions, False otherwise"
|
|
)
|
|
confidence: float = dspy.OutputField(
|
|
desc="Confidence in relevance classification (0.0-1.0)"
|
|
)
|
|
reasoning: str = dspy.OutputField(
|
|
desc="Brief explanation of why question is/isn't relevant"
|
|
)
|
|
|
|
|
|
class TemplateClassifierSignature(dspy.Signature):
|
|
"""Classify a heritage question to match it with a SPARQL template.
|
|
|
|
Given a resolved question about heritage institutions, determine which
|
|
SPARQL query template best matches the user's intent.
|
|
|
|
Available template IDs (return the EXACT ID string, not the number):
|
|
- list_institutions_by_type_city: List institutions of type X in city Y
|
|
- list_institutions_by_type_region: List institutions of type X in region Y (province/state)
|
|
- list_institutions_by_type_country: List institutions of type X in country Y
|
|
- count_institutions_by_type_location: Count institutions of type X in location Y
|
|
- count_institutions_by_type: Count all institutions grouped by type
|
|
- find_institution_by_name: Find specific institution by name
|
|
- list_all_institutions_in_city: List all institutions in city Y
|
|
- find_institutions_by_founding_date: Find oldest/newest institutions
|
|
- find_institution_by_identifier: Find by ISIL/GHCID
|
|
- compare_locations: Compare institutions between locations
|
|
- find_custodians_by_budget_threshold: Find custodians with budget category above/below threshold (e.g., "Which custodians spend more than 5000 euros on innovation?")
|
|
- none: No template matches (fall back to LLM generation)
|
|
|
|
CRITICAL DISAMBIGUATION - Province vs City:
|
|
Some Dutch locations are BOTH a province AND a city with the same name.
|
|
When the location name alone is used (without "stad" or "de stad"),
|
|
DEFAULT TO PROVINCE (use list_institutions_by_type_region):
|
|
- Groningen → province (NL-GR), NOT the city
|
|
- Utrecht → province (NL-UT), NOT the city
|
|
- Limburg → province (NL-LI), NOT the city
|
|
- Friesland/Frisia → province (NL-FR)
|
|
- Drenthe → province (NL-DR)
|
|
- Gelderland → province (NL-GE)
|
|
- Overijssel → province (NL-OV)
|
|
- Flevoland → province (NL-FL)
|
|
- Zeeland → province (NL-ZE)
|
|
- Noord-Holland → province (NL-NH)
|
|
- Zuid-Holland → province (NL-ZH)
|
|
- Noord-Brabant → province (NL-NB)
|
|
|
|
Use list_institutions_by_type_city ONLY when:
|
|
- The question explicitly says "de stad" or "in de stad" (the city)
|
|
- The location is clearly just a city (e.g., Amsterdam, Rotterdam, Den Haag)
|
|
|
|
IMPORTANT: Return the template ID string exactly as shown (e.g. "count_institutions_by_type_location"),
|
|
NOT a number. Return "none" if no template matches well.
|
|
"""
|
|
|
|
question: str = dspy.InputField(
|
|
desc="Resolved natural language question about heritage institutions"
|
|
)
|
|
language: str = dspy.InputField(
|
|
desc="Language code: nl, en, de, fr",
|
|
default="nl"
|
|
)
|
|
|
|
template_id: str = dspy.OutputField(
|
|
desc="EXACT template ID string from the list above (e.g. 'count_institutions_by_type_location'), or 'none'"
|
|
)
|
|
confidence: float = dspy.OutputField(
|
|
desc="Confidence in template match (0.0-1.0). Return 'none' as template_id if below 0.6"
|
|
)
|
|
reasoning: str = dspy.OutputField(
|
|
desc="Brief explanation of why this template matches"
|
|
)
|
|
|
|
|
|
class SlotExtractorSignature(dspy.Signature):
|
|
"""Extract slot values from a question for a specific template.
|
|
|
|
Given a question and the template it matched, extract the values needed
|
|
to fill in the template's slots.
|
|
|
|
Slot types and expected formats:
|
|
- institution_type: Return single-letter code (M, L, A, G, O, R, C, U, B, E, S, F, I, X, P, H, D, N, T)
|
|
Examples: "musea" → "M", "archieven" → "A", "bibliotheken" → "L"
|
|
- city: Return city name with proper capitalization
|
|
Examples: "amsterdam" → "Amsterdam", "den haag" → "Den Haag"
|
|
- region/subregion: Return ISO 3166-2 code
|
|
Examples: "Noord-Holland" → "NL-NH", "Gelderland" → "NL-GE"
|
|
- country: Return ISO 3166-1 alpha-2 code (NOT Wikidata Q-numbers!)
|
|
Examples: "Nederland" → "NL", "Belgium" → "BE", "Duitsland" → "DE", "Frankrijk" → "FR"
|
|
- location: For generic location slots, return the location name as-is (will be resolved by type)
|
|
- institution_name: Return the institution name as mentioned
|
|
- limit: Return integer (default 10)
|
|
|
|
Return slots as a JSON object with slot names as keys.
|
|
"""
|
|
|
|
question: str = dspy.InputField(
|
|
desc="The user's question"
|
|
)
|
|
template_id: str = dspy.InputField(
|
|
desc="ID of the matched template"
|
|
)
|
|
required_slots: str = dspy.InputField(
|
|
desc="Comma-separated list of required slot names for this template"
|
|
)
|
|
inherited_slots: str = dspy.InputField(
|
|
desc="JSON string of slots inherited from conversation context",
|
|
default="{}"
|
|
)
|
|
|
|
slots_json: str = dspy.OutputField(
|
|
desc="JSON object with extracted slot values, e.g., {\"institution_type\": \"M\", \"city\": \"Amsterdam\"}"
|
|
)
|
|
extraction_notes: str = dspy.OutputField(
|
|
desc="Notes about any slots that couldn't be extracted or needed inference"
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# DSPy MODULES
|
|
# =============================================================================
|
|
|
|
class ConversationContextResolver(dspy.Module):
|
|
"""Resolves elliptical follow-up questions using conversation history.
|
|
|
|
CRITICAL: This module MUST run FIRST in the pipeline, before FykeFilter.
|
|
|
|
It expands short follow-ups like "En in Enschede?" into complete questions
|
|
like "Welke archieven zijn er in Enschede?" so that subsequent modules
|
|
can properly understand the user's intent.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.resolve = dspy.ChainOfThought(ConversationContextSignature)
|
|
|
|
def forward(
|
|
self,
|
|
question: str,
|
|
conversation_state: Optional[ConversationState] = None,
|
|
) -> ResolvedQuestion:
|
|
"""Resolve a potentially elliptical question.
|
|
|
|
Args:
|
|
question: Current user question (may be elliptical)
|
|
conversation_state: Full conversation state with history
|
|
|
|
Returns:
|
|
ResolvedQuestion with expanded question and metadata
|
|
"""
|
|
# If no conversation history, return as-is
|
|
if conversation_state is None:
|
|
logger.debug(f"[ContextResolver] No conversation_state provided for '{question}'")
|
|
return ResolvedQuestion(
|
|
original=question,
|
|
resolved=question,
|
|
is_follow_up=False,
|
|
confidence=1.0
|
|
)
|
|
|
|
if not conversation_state.turns:
|
|
logger.debug(f"[ContextResolver] Empty turns in conversation_state for '{question}'")
|
|
return ResolvedQuestion(
|
|
original=question,
|
|
resolved=question,
|
|
is_follow_up=False,
|
|
confidence=1.0
|
|
)
|
|
|
|
logger.debug(f"[ContextResolver] Resolving '{question}' with {len(conversation_state.turns)} previous turns")
|
|
logger.debug(f"[ContextResolver] Current slots: {conversation_state.current_slots}")
|
|
|
|
# Prepare inputs for DSPy
|
|
history = conversation_state.to_dspy_history()
|
|
previous_slots = json.dumps(conversation_state.current_slots)
|
|
|
|
# Get previous results summary if available
|
|
prev_turn = conversation_state.get_previous_user_turn()
|
|
results_summary = ""
|
|
if prev_turn and prev_turn.results:
|
|
results_summary = ", ".join(
|
|
r.get("name", str(r))[:50] for r in prev_turn.results[:5]
|
|
)
|
|
|
|
try:
|
|
result = self.resolve(
|
|
question=question,
|
|
history=history,
|
|
previous_slots=previous_slots,
|
|
previous_results_summary=results_summary
|
|
)
|
|
|
|
# Parse inherited slots
|
|
try:
|
|
inherited = json.loads(result.inherited_slots_json)
|
|
except (json.JSONDecodeError, TypeError):
|
|
inherited = {}
|
|
|
|
return ResolvedQuestion(
|
|
original=question,
|
|
resolved=result.resolved_question,
|
|
is_follow_up=result.is_follow_up,
|
|
follow_up_type=result.follow_up_type if result.follow_up_type != "none" else None,
|
|
inherited_slots=inherited,
|
|
confidence=result.confidence
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Context resolution failed: {e}, returning original")
|
|
return ResolvedQuestion(
|
|
original=question,
|
|
resolved=question,
|
|
is_follow_up=False,
|
|
confidence=0.5
|
|
)
|
|
|
|
|
|
class FykeFilter(dspy.Module):
|
|
"""Filters out irrelevant questions with standard response.
|
|
|
|
CRITICAL: Must operate on RESOLVED question from ConversationContextResolver,
|
|
not the raw user input. This prevents false positives on follow-ups.
|
|
|
|
Named after Dutch "fuik" (fish trap) - catches irrelevant questions.
|
|
"""
|
|
|
|
def __init__(self, config: Optional[FykeFilterConfig] = None):
|
|
super().__init__()
|
|
self.classify = dspy.ChainOfThought(FykeFilterSignature)
|
|
self.config = config or self._load_config()
|
|
|
|
def _load_config(self) -> FykeFilterConfig:
|
|
"""Load Fyke configuration from templates YAML."""
|
|
default_config = FykeFilterConfig(
|
|
out_of_scope_keywords=[
|
|
"tandpasta", "toothpaste", "supermarkt", "restaurant",
|
|
"hotel", "weer", "weather", "voetbal", "soccer"
|
|
],
|
|
out_of_scope_categories=["shopping", "sports", "cooking"],
|
|
heritage_keywords=[
|
|
"museum", "musea", "archief", "archieven", "bibliotheek",
|
|
"galerie", "erfgoed", "heritage", "collectie", "collection"
|
|
],
|
|
standard_response={
|
|
"nl": "Ik kan je helpen met vragen over erfgoedinstellingen zoals musea, archieven, bibliotheken en galerijen.",
|
|
"en": "I can help you with questions about heritage institutions such as museums, archives, libraries and galleries.",
|
|
"de": "Ich kann Ihnen bei Fragen zu Kulturerbeinstitutionen wie Museen, Archiven und Bibliotheken helfen.",
|
|
"fr": "Je peux vous aider avec des questions sur les institutions patrimoniales."
|
|
}
|
|
)
|
|
|
|
if TEMPLATES_PATH.exists():
|
|
try:
|
|
import yaml
|
|
with open(TEMPLATES_PATH) as f:
|
|
templates = yaml.safe_load(f)
|
|
fyke_config = templates.get("fyke_filter", {})
|
|
return FykeFilterConfig(
|
|
out_of_scope_keywords=fyke_config.get("out_of_scope_keywords", default_config.out_of_scope_keywords),
|
|
out_of_scope_categories=fyke_config.get("out_of_scope_categories", default_config.out_of_scope_categories),
|
|
heritage_keywords=fyke_config.get("heritage_keywords", default_config.heritage_keywords),
|
|
standard_response=fyke_config.get("standard_response", default_config.standard_response)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load Fyke config: {e}")
|
|
|
|
return default_config
|
|
|
|
def forward(
|
|
self,
|
|
resolved_question: str,
|
|
conversation_topic: str = "heritage institutions",
|
|
language: str = "nl"
|
|
) -> FykeResult:
|
|
"""Check if resolved question is relevant to heritage.
|
|
|
|
Args:
|
|
resolved_question: The RESOLVED question (not raw input!)
|
|
conversation_topic: Summary of conversation so far
|
|
language: Language code for standard response
|
|
|
|
Returns:
|
|
FykeResult with relevance decision
|
|
"""
|
|
question_lower = resolved_question.lower()
|
|
|
|
# Quick check: obvious heritage keywords → definitely relevant
|
|
for keyword in self.config.heritage_keywords:
|
|
if keyword in question_lower:
|
|
return FykeResult(
|
|
is_relevant=True,
|
|
confidence=0.95,
|
|
reasoning=f"Contains heritage keyword: {keyword}"
|
|
)
|
|
|
|
# Quick check: obvious out-of-scope keywords → definitely irrelevant
|
|
for keyword in self.config.out_of_scope_keywords:
|
|
if keyword in question_lower:
|
|
return FykeResult(
|
|
is_relevant=False,
|
|
confidence=0.95,
|
|
reasoning=f"Contains out-of-scope keyword: {keyword}",
|
|
standard_response=self.config.standard_response.get(
|
|
language, self.config.standard_response.get("en")
|
|
)
|
|
)
|
|
|
|
# Use DSPy for ambiguous cases
|
|
try:
|
|
result = self.classify(
|
|
resolved_question=resolved_question,
|
|
conversation_topic=conversation_topic
|
|
)
|
|
|
|
return FykeResult(
|
|
is_relevant=result.is_relevant,
|
|
confidence=result.confidence,
|
|
reasoning=result.reasoning,
|
|
standard_response=None if result.is_relevant else self.config.standard_response.get(
|
|
language, self.config.standard_response.get("en")
|
|
)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Fyke classification failed: {e}, assuming relevant")
|
|
# Err on side of relevance
|
|
return FykeResult(
|
|
is_relevant=True,
|
|
confidence=0.5,
|
|
reasoning=f"Classification failed, assuming relevant: {e}"
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# EMBEDDING-BASED TEMPLATE MATCHING
|
|
# =============================================================================
|
|
|
|
class TemplateEmbeddingMatcher:
|
|
"""Matches questions to templates using semantic embeddings.
|
|
|
|
Uses sentence-transformers to compute embeddings for template patterns
|
|
and find the best match for incoming questions based on cosine similarity.
|
|
|
|
This provides semantic matching that can handle:
|
|
- Paraphrases ("Welke musea..." vs "Zijn er musea die...")
|
|
- Synonyms ("instellingen" vs "organisaties")
|
|
- Different word orders
|
|
- Multilingual queries (Dutch, English, German)
|
|
"""
|
|
|
|
_instance = None
|
|
_pattern_embeddings: Optional[np.ndarray] = None
|
|
_pattern_template_ids: Optional[list[str]] = None
|
|
_pattern_texts: Optional[list[str]] = None
|
|
|
|
def __new__(cls):
|
|
"""Singleton pattern - embeddings are expensive to compute."""
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def _ensure_embeddings_computed(self, templates: dict[str, "TemplateDefinition"]) -> bool:
|
|
"""Compute and cache embeddings for all template patterns.
|
|
|
|
Returns:
|
|
True if embeddings are available, False otherwise
|
|
"""
|
|
if self._pattern_embeddings is not None:
|
|
return True
|
|
|
|
model = _get_embedding_model()
|
|
if model is None:
|
|
return False
|
|
|
|
# Collect all patterns with their template IDs
|
|
pattern_texts = []
|
|
template_ids = []
|
|
|
|
for template_id, template_def in templates.items():
|
|
for pattern in template_def.question_patterns:
|
|
# Normalize pattern: replace {slot} with generic placeholder
|
|
normalized = re.sub(r'\{[^}]+\}', '[VALUE]', pattern)
|
|
pattern_texts.append(normalized)
|
|
template_ids.append(template_id)
|
|
|
|
if not pattern_texts:
|
|
logger.warning("No patterns found for embedding computation")
|
|
return False
|
|
|
|
# Compute embeddings for all patterns
|
|
logger.info(f"Computing embeddings for {len(pattern_texts)} template patterns...")
|
|
try:
|
|
embeddings = model.encode(pattern_texts, convert_to_numpy=True, show_progress_bar=False)
|
|
self._pattern_embeddings = embeddings
|
|
self._pattern_template_ids = template_ids
|
|
self._pattern_texts = pattern_texts
|
|
logger.info(f"Computed {len(embeddings)} pattern embeddings (dim={embeddings.shape[1]})")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to compute pattern embeddings: {e}")
|
|
return False
|
|
|
|
def match(
|
|
self,
|
|
question: str,
|
|
templates: dict[str, "TemplateDefinition"],
|
|
min_similarity: float = 0.70
|
|
) -> Optional["TemplateMatchResult"]:
|
|
"""Find best matching template using embedding similarity.
|
|
|
|
Args:
|
|
question: Natural language question
|
|
templates: Dictionary of template definitions
|
|
min_similarity: Minimum cosine similarity threshold (0-1)
|
|
|
|
Returns:
|
|
TemplateMatchResult if similarity >= threshold, None otherwise
|
|
"""
|
|
if not self._ensure_embeddings_computed(templates):
|
|
return None
|
|
|
|
model = _get_embedding_model()
|
|
if model is None:
|
|
return None
|
|
|
|
# Normalize question: replace numbers with placeholder
|
|
normalized_question = re.sub(r'\d+(?:[.,]\d+)?', '[VALUE]', question)
|
|
|
|
# Compute question embedding
|
|
try:
|
|
question_embedding = model.encode([normalized_question], convert_to_numpy=True)[0]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to compute question embedding: {e}")
|
|
return None
|
|
|
|
# Guard against None embeddings (should not happen after _ensure_embeddings_computed)
|
|
if self._pattern_embeddings is None or self._pattern_template_ids is None or self._pattern_texts is None:
|
|
return None
|
|
|
|
# Compute cosine similarities
|
|
# Normalize vectors for cosine similarity
|
|
question_norm = question_embedding / np.linalg.norm(question_embedding)
|
|
pattern_norms = self._pattern_embeddings / np.linalg.norm(self._pattern_embeddings, axis=1, keepdims=True)
|
|
|
|
similarities = np.dot(pattern_norms, question_norm)
|
|
|
|
# Find best match
|
|
best_idx = int(np.argmax(similarities))
|
|
best_similarity = float(similarities[best_idx])
|
|
|
|
if best_similarity < min_similarity:
|
|
logger.debug(f"Best embedding similarity {best_similarity:.3f} below threshold {min_similarity}")
|
|
return None
|
|
|
|
best_template_id = self._pattern_template_ids[best_idx]
|
|
best_pattern = self._pattern_texts[best_idx]
|
|
|
|
# Scale similarity to confidence (0.70 → 0.70, 0.85 → 0.85, etc.)
|
|
confidence = best_similarity
|
|
|
|
logger.info(f"Embedding match found: template='{best_template_id}', similarity={best_similarity:.3f}, pattern='{best_pattern}'")
|
|
|
|
return TemplateMatchResult(
|
|
matched=True,
|
|
template_id=best_template_id,
|
|
confidence=confidence,
|
|
reasoning=f"Embedding similarity: {best_similarity:.3f} with pattern '{best_pattern}'"
|
|
)
|
|
|
|
|
|
# Singleton instance
|
|
_template_embedding_matcher: Optional[TemplateEmbeddingMatcher] = None
|
|
|
|
def get_template_embedding_matcher() -> TemplateEmbeddingMatcher:
|
|
"""Get or create the singleton embedding matcher."""
|
|
global _template_embedding_matcher
|
|
if _template_embedding_matcher is None:
|
|
_template_embedding_matcher = TemplateEmbeddingMatcher()
|
|
return _template_embedding_matcher
|
|
|
|
|
|
# =============================================================================
|
|
# RAG-ENHANCED TEMPLATE MATCHING (TIER 2.5)
|
|
# =============================================================================
|
|
|
|
class RAGEnhancedMatcher:
|
|
"""Context-enriched matching using similar Q&A examples from templates.
|
|
|
|
This tier sits between embedding matching (Tier 2) and LLM fallback (Tier 3).
|
|
It retrieves similar examples from the template YAML and uses voting to
|
|
determine the best template match.
|
|
|
|
Based on SPARQL-LLM (arXiv:2512.14277) and COT-SPARQL (SEMANTICS 2024)
|
|
patterns for RAG-enhanced query generation.
|
|
|
|
Architecture:
|
|
1. Embed all Q&A examples from templates (cached)
|
|
2. For incoming question, find top-k most similar examples
|
|
3. Vote: If majority of examples agree on template, use it
|
|
4. Return match with confidence based on vote agreement
|
|
"""
|
|
|
|
_instance = None
|
|
_example_embeddings: Optional[np.ndarray] = None
|
|
_example_template_ids: Optional[list[str]] = None
|
|
_example_texts: Optional[list[str]] = None
|
|
_example_slots: Optional[list[dict]] = None
|
|
|
|
def __new__(cls):
|
|
"""Singleton pattern - embeddings are expensive to compute."""
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def _ensure_examples_indexed(self, templates: dict[str, "TemplateDefinition"]) -> bool:
|
|
"""Index all Q&A examples from templates for retrieval.
|
|
|
|
Returns:
|
|
True if examples are indexed, False otherwise
|
|
"""
|
|
if self._example_embeddings is not None:
|
|
return True
|
|
|
|
model = _get_embedding_model()
|
|
if model is None:
|
|
return False
|
|
|
|
# Collect all examples from templates
|
|
example_texts = []
|
|
template_ids = []
|
|
example_slots = []
|
|
|
|
for template_id, template_def in templates.items():
|
|
for example in template_def.examples:
|
|
if "question" in example:
|
|
example_texts.append(example["question"])
|
|
template_ids.append(template_id)
|
|
example_slots.append(example.get("slots", {}))
|
|
|
|
if not example_texts:
|
|
logger.warning("No examples found for RAG-enhanced matching")
|
|
return False
|
|
|
|
# Compute embeddings for all examples
|
|
logger.info(f"Indexing {len(example_texts)} Q&A examples for RAG-enhanced matching...")
|
|
try:
|
|
embeddings = model.encode(example_texts, convert_to_numpy=True, show_progress_bar=False)
|
|
self._example_embeddings = embeddings
|
|
self._example_template_ids = template_ids
|
|
self._example_texts = example_texts
|
|
self._example_slots = example_slots
|
|
logger.info(f"Indexed {len(embeddings)} examples (dim={embeddings.shape[1]})")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to index examples: {e}")
|
|
return False
|
|
|
|
def match(
|
|
self,
|
|
question: str,
|
|
templates: dict[str, "TemplateDefinition"],
|
|
k: int = 5,
|
|
min_agreement: float = 0.6,
|
|
min_similarity: float = 0.65
|
|
) -> Optional["TemplateMatchResult"]:
|
|
"""Find best template using RAG retrieval and voting.
|
|
|
|
Args:
|
|
question: Natural language question
|
|
templates: Dictionary of template definitions
|
|
k: Number of similar examples to retrieve
|
|
min_agreement: Minimum fraction of examples that must agree (e.g., 0.6 = 3/5)
|
|
min_similarity: Minimum similarity for retrieved examples
|
|
|
|
Returns:
|
|
TemplateMatchResult if voting succeeds, None otherwise
|
|
"""
|
|
if not self._ensure_examples_indexed(templates):
|
|
return None
|
|
|
|
model = _get_embedding_model()
|
|
if model is None:
|
|
return None
|
|
|
|
# Guard against None (should not happen after _ensure_examples_indexed)
|
|
if (self._example_embeddings is None or
|
|
self._example_template_ids is None or
|
|
self._example_texts is None):
|
|
return None
|
|
|
|
# Compute question embedding
|
|
try:
|
|
question_embedding = model.encode([question], convert_to_numpy=True)[0]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to compute question embedding: {e}")
|
|
return None
|
|
|
|
# Compute cosine similarities
|
|
question_norm = question_embedding / np.linalg.norm(question_embedding)
|
|
example_norms = self._example_embeddings / np.linalg.norm(
|
|
self._example_embeddings, axis=1, keepdims=True
|
|
)
|
|
similarities = np.dot(example_norms, question_norm)
|
|
|
|
# Get top-k indices
|
|
top_k_indices = np.argsort(similarities)[-k:][::-1]
|
|
|
|
# Filter by minimum similarity
|
|
valid_indices = [
|
|
i for i in top_k_indices
|
|
if similarities[i] >= min_similarity
|
|
]
|
|
|
|
if not valid_indices:
|
|
logger.debug(f"RAG: No examples above similarity threshold {min_similarity}")
|
|
return None
|
|
|
|
# Vote on template
|
|
from collections import Counter
|
|
template_votes = Counter(
|
|
self._example_template_ids[i] for i in valid_indices
|
|
)
|
|
|
|
top_template, vote_count = template_votes.most_common(1)[0]
|
|
agreement = vote_count / len(valid_indices)
|
|
|
|
if agreement < min_agreement:
|
|
logger.debug(
|
|
f"RAG: Low agreement {agreement:.2f} < {min_agreement} "
|
|
f"(votes: {dict(template_votes)})"
|
|
)
|
|
return None
|
|
|
|
# Calculate confidence based on agreement and average similarity
|
|
avg_similarity = np.mean([similarities[i] for i in valid_indices])
|
|
confidence = 0.70 + (agreement * 0.15) + (avg_similarity * 0.10)
|
|
confidence = min(0.90, confidence) # Cap at 0.90
|
|
|
|
# Log retrieved examples for debugging
|
|
retrieved_examples = [
|
|
(self._example_texts[i], self._example_template_ids[i], similarities[i])
|
|
for i in valid_indices[:3]
|
|
]
|
|
logger.info(
|
|
f"RAG match: template='{top_template}', agreement={agreement:.2f}, "
|
|
f"confidence={confidence:.2f}, examples={retrieved_examples}"
|
|
)
|
|
|
|
return TemplateMatchResult(
|
|
matched=True,
|
|
template_id=top_template,
|
|
confidence=float(confidence),
|
|
reasoning=f"RAG: {vote_count}/{len(valid_indices)} examples vote for {top_template}"
|
|
)
|
|
|
|
|
|
# Singleton instance
|
|
_rag_enhanced_matcher: Optional[RAGEnhancedMatcher] = None
|
|
|
|
def get_rag_enhanced_matcher() -> RAGEnhancedMatcher:
|
|
"""Get or create the singleton RAG-enhanced matcher."""
|
|
global _rag_enhanced_matcher
|
|
if _rag_enhanced_matcher is None:
|
|
_rag_enhanced_matcher = RAGEnhancedMatcher()
|
|
return _rag_enhanced_matcher
|
|
|
|
|
|
# =============================================================================
|
|
# SPARQL VALIDATION (SPARQL-LLM Pattern)
|
|
# =============================================================================
|
|
|
|
class SPARQLValidationResult(BaseModel):
|
|
"""Result of SPARQL query validation."""
|
|
valid: bool
|
|
errors: list[str] = Field(default_factory=list)
|
|
warnings: list[str] = Field(default_factory=list)
|
|
suggestions: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class SPARQLValidator:
|
|
"""Validates generated SPARQL against ontology schema.
|
|
|
|
Based on SPARQL-LLM (arXiv:2512.14277) validation-correction pattern.
|
|
|
|
Dynamically loads predicates and classes from OntologyLoader, which
|
|
reads from LinkML schema files and validation rules JSON.
|
|
|
|
Fallback hardcoded sets are kept for robustness when schema files
|
|
are unavailable.
|
|
"""
|
|
|
|
# Fallback predicates (used when OntologyLoader can't load from schema)
|
|
# These are kept for robustness in case schema files are missing
|
|
_FALLBACK_HC_PREDICATES = {
|
|
"hc:institutionType", "hc:settlementName", "hc:subregionCode",
|
|
"hc:countryCode", "hc:ghcid", "hc:isil", "hc:validFrom", "hc:validTo",
|
|
"hc:changeType", "hc:changeReason", "hc:eventType", "hc:eventDate",
|
|
}
|
|
|
|
_FALLBACK_CLASSES = {
|
|
"hcc:Custodian", "hc:class/Budget", "hc:class/FinancialStatement",
|
|
"hc:OrganizationalChangeEvent",
|
|
}
|
|
|
|
# Standard external predicates from base ontologies (rarely change)
|
|
VALID_EXTERNAL_PREDICATES = {
|
|
# Schema.org predicates
|
|
"schema:name", "schema:description", "schema:foundingDate",
|
|
"schema:addressCountry", "schema:addressLocality", "schema:affiliation",
|
|
"schema:about", "schema:archivedAt", "schema:areaServed",
|
|
"schema:authenticationType", "schema:conditionsOfAccess",
|
|
"schema:countryOfOrigin", "schema:dateAcquired", "schema:dateModified",
|
|
|
|
# FOAF predicates
|
|
"foaf:homepage", "foaf:name",
|
|
|
|
# SKOS predicates
|
|
"skos:prefLabel", "skos:altLabel", "skos:broader", "skos:notation",
|
|
"skos:mappingRelation",
|
|
|
|
# Dublin Core predicates
|
|
"dcterms:identifier", "dcterms:accessRights", "dcterms:conformsTo",
|
|
"dcterms:hasPart", "dcterms:language", "dcterms:type",
|
|
|
|
# W3C Org predicates
|
|
"org:memberOf", "org:hasSubOrganization", "org:subOrganizationOf", "org:hasSite",
|
|
|
|
# PROV-O predicates
|
|
"prov:type", "prov:wasInfluencedBy", "prov:influenced", "prov:wasAttributedTo",
|
|
"prov:contributed", "prov:generatedAtTime", "prov:hadReason",
|
|
|
|
# CIDOC-CRM predicates
|
|
"crm:P1_is_identified_by", "crm:P2_has_type", "crm:P11_had_participant",
|
|
"crm:P24i_changed_ownership_through", "crm:P82a_begin_of_the_begin",
|
|
"crm:P81b_begin_of_the_end",
|
|
|
|
# RiC-O predicates
|
|
"rico:hasProvenance", "rico:hasRecordSetType", "rico:hasOrHadAllMembersWithRecordState",
|
|
|
|
# DCAT predicates
|
|
"dcat:endpointURL",
|
|
|
|
# WGS84 predicates
|
|
"wgs84:alt",
|
|
|
|
# PiCo predicates
|
|
"pico:hasAge",
|
|
|
|
# PNV predicates (Person Name Vocabulary)
|
|
"pnv:baseSurname",
|
|
|
|
# Wikidata predicates
|
|
"wikidata:P1196", "wdt:P31",
|
|
|
|
# RDF/RDFS predicates
|
|
"rdf:value", "rdfs:label",
|
|
|
|
# SDO (Schema.org alternate prefix)
|
|
"sdo:birthDate", "sdo:birthPlace",
|
|
}
|
|
|
|
def __init__(self):
|
|
"""Initialize SPARQLValidator with predicates from OntologyLoader.
|
|
|
|
IMPORTANT: The RAG SPARQL queries use custom hc: prefixed predicates
|
|
(hc:institutionType, hc:settlementName, etc.) while the LinkML schema
|
|
uses semantic URIs from base ontologies (org:classification, schema:location).
|
|
|
|
To support both, we ALWAYS include:
|
|
1. Core RAG predicates (_FALLBACK_HC_PREDICATES) - used in actual queries
|
|
2. Schema predicates from OntologyLoader - for validation flexibility
|
|
3. External predicates (VALID_EXTERNAL_PREDICATES) - standard ontology URIs
|
|
"""
|
|
# Load predicates dynamically from OntologyLoader
|
|
ontology = get_ontology_loader()
|
|
ontology.load()
|
|
|
|
# Get predicates from LinkML schema (semantic URIs like org:classification)
|
|
schema_predicates = ontology.get_predicates()
|
|
schema_classes = ontology.get_classes()
|
|
|
|
# Start with core RAG predicates (these are what queries actually use)
|
|
# These are ALWAYS included regardless of schema loading
|
|
hc_predicates = set(self._FALLBACK_HC_PREDICATES)
|
|
|
|
# Add schema predicates if available (for flexibility)
|
|
if schema_predicates:
|
|
hc_predicates = hc_predicates | schema_predicates
|
|
logger.info(f"SPARQLValidator: loaded {len(schema_predicates)} additional predicates from OntologyLoader")
|
|
|
|
logger.info(f"SPARQLValidator: {len(hc_predicates)} total hc predicates (core + schema)")
|
|
|
|
# Use schema classes if available, otherwise use fallback
|
|
if schema_classes:
|
|
self._all_classes = schema_classes | self._FALLBACK_CLASSES
|
|
logger.info(f"SPARQLValidator: loaded {len(schema_classes)} classes from OntologyLoader")
|
|
else:
|
|
self._all_classes = self._FALLBACK_CLASSES
|
|
logger.warning("SPARQLValidator: using fallback hardcoded classes")
|
|
|
|
# Combine hc predicates with external predicates
|
|
self._all_predicates = hc_predicates | self.VALID_EXTERNAL_PREDICATES
|
|
|
|
# Expose as VALID_HC_PREDICATES for backward compatibility with tests
|
|
self.VALID_HC_PREDICATES = hc_predicates
|
|
|
|
def validate(self, sparql: str) -> SPARQLValidationResult:
|
|
"""Validate SPARQL query against schema.
|
|
|
|
Args:
|
|
sparql: SPARQL query string
|
|
|
|
Returns:
|
|
SPARQLValidationResult with errors and suggestions
|
|
"""
|
|
errors: list[str] = []
|
|
warnings: list[str] = []
|
|
suggestions: list[str] = []
|
|
|
|
# Skip validation for queries without our predicates
|
|
if "hc:" not in sparql and "hcc:" not in sparql:
|
|
return SPARQLValidationResult(valid=True)
|
|
|
|
# Extract predicates used - expanded to capture all known namespaces
|
|
predicate_pattern = (
|
|
r'(hc:\w+|hcc:\w+|schema:\w+|foaf:\w+|skos:\w+|dcterms:\w+|'
|
|
r'org:\w+|prov:\w+|crm:\w+|rico:\w+|dcat:\w+|wgs84:\w+|'
|
|
r'pico:\w+|pnv:\w+|wikidata:\w+|wdt:\w+|rdf:\w+|rdfs:\w+|sdo:\w+)'
|
|
)
|
|
predicates = set(re.findall(predicate_pattern, sparql))
|
|
|
|
for pred in predicates:
|
|
if pred.startswith("hc:") or pred.startswith("hcc:"):
|
|
if pred not in self._all_predicates and pred not in self._all_classes:
|
|
# Check for common typos
|
|
similar = self._find_similar(pred, self._all_predicates)
|
|
if similar:
|
|
errors.append(f"Unknown predicate: {pred}")
|
|
suggestions.append(f"Did you mean: {similar}?")
|
|
else:
|
|
warnings.append(f"Unrecognized predicate: {pred}")
|
|
|
|
# Extract classes (a hcc:xxx)
|
|
class_pattern = r'a\s+(hcc:\w+|hc:class/\w+)'
|
|
classes = set(re.findall(class_pattern, sparql))
|
|
|
|
for cls in classes:
|
|
if cls not in self._all_classes:
|
|
errors.append(f"Unknown class: {cls}")
|
|
|
|
# Check for common SPARQL syntax issues
|
|
if sparql.count("{") != sparql.count("}"):
|
|
errors.append("Mismatched braces in query")
|
|
|
|
if "SELECT" in sparql.upper() and "WHERE" not in sparql.upper():
|
|
errors.append("SELECT query missing WHERE clause")
|
|
|
|
return SPARQLValidationResult(
|
|
valid=len(errors) == 0,
|
|
errors=errors,
|
|
warnings=warnings,
|
|
suggestions=suggestions
|
|
)
|
|
|
|
def _find_similar(self, term: str, candidates: set[str], threshold: float = 0.7) -> Optional[str]:
|
|
"""Find similar term using fuzzy matching."""
|
|
if not candidates:
|
|
return None
|
|
match = process.extractOne(
|
|
term,
|
|
list(candidates),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=int(threshold * 100)
|
|
)
|
|
return match[0] if match else None
|
|
|
|
|
|
# Global validator instance
|
|
_sparql_validator: Optional[SPARQLValidator] = None
|
|
|
|
def get_sparql_validator() -> SPARQLValidator:
|
|
"""Get or create the SPARQL validator."""
|
|
global _sparql_validator
|
|
if _sparql_validator is None:
|
|
_sparql_validator = SPARQLValidator()
|
|
return _sparql_validator
|
|
|
|
|
|
class TemplateClassifier(dspy.Module):
|
|
"""Classifies questions to match SPARQL templates."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.classify = dspy.ChainOfThought(TemplateClassifierSignature)
|
|
self._templates: Optional[dict[str, TemplateDefinition]] = None
|
|
|
|
def _load_templates(self) -> dict[str, TemplateDefinition]:
|
|
"""Load template definitions from YAML."""
|
|
if self._templates is not None:
|
|
return self._templates
|
|
|
|
self._templates = {}
|
|
|
|
if TEMPLATES_PATH.exists():
|
|
try:
|
|
import yaml
|
|
with open(TEMPLATES_PATH) as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
templates = data.get("templates", {})
|
|
for template_id, template_data in templates.items():
|
|
try:
|
|
# Convert slots
|
|
slots = {}
|
|
for slot_name, slot_data in template_data.get("slots", {}).items():
|
|
if isinstance(slot_data, dict):
|
|
slot_type = slot_data.get("type", "string")
|
|
slots[slot_name] = SlotDefinition(
|
|
type=SlotType(slot_type) if slot_type in [e.value for e in SlotType] else SlotType.STRING,
|
|
required=slot_data.get("required", True),
|
|
default=slot_data.get("default"),
|
|
examples=slot_data.get("examples", []),
|
|
fallback_types=[SlotType(t) for t in slot_data.get("fallback_types", []) if t in [e.value for e in SlotType]],
|
|
valid_values=slot_data.get("valid_values", [])
|
|
)
|
|
|
|
self._templates[template_id] = TemplateDefinition(
|
|
id=template_id,
|
|
description=template_data.get("description", ""),
|
|
intent=template_data.get("intent", []),
|
|
question_patterns=template_data.get("question_patterns", []),
|
|
slots=slots,
|
|
sparql_template=template_data.get("sparql_template", ""),
|
|
sparql_template_alt=template_data.get("sparql_template_alt"),
|
|
sparql_template_region=template_data.get("sparql_template_region"),
|
|
sparql_template_country=template_data.get("sparql_template_country"),
|
|
sparql_template_isil=template_data.get("sparql_template_isil"),
|
|
sparql_template_ghcid=template_data.get("sparql_template_ghcid"),
|
|
examples=template_data.get("examples", []),
|
|
# Response rendering configuration (template-driven)
|
|
response_modes=template_data.get("response_modes", ["prose"]),
|
|
ui_template=template_data.get("ui_template"),
|
|
# Database routing configuration (template-driven)
|
|
# Default to both databases for backward compatibility
|
|
databases=template_data.get("databases", ["oxigraph", "qdrant"]),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to parse template {template_id}: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load templates: {e}")
|
|
|
|
return self._templates
|
|
|
|
def _pattern_to_regex(self, pattern: str) -> tuple[re.Pattern, list[str]]:
|
|
"""Convert a template pattern to a regex for matching.
|
|
|
|
Converts patterns like:
|
|
"Welke instellingen geven meer dan {amount} uit aan {budget_category}?"
|
|
To regex like:
|
|
"Welke instellingen geven meer dan (.+?) uit aan (.+?)\\?"
|
|
|
|
Args:
|
|
pattern: Template pattern with {slot_name} placeholders
|
|
|
|
Returns:
|
|
Tuple of (compiled regex pattern, list of slot names in order)
|
|
"""
|
|
# Extract slot names in order
|
|
slot_names = re.findall(r'\{([^}]+)\}', pattern)
|
|
|
|
# Escape regex special characters (except { and })
|
|
escaped = re.escape(pattern)
|
|
|
|
# Replace escaped braces with capture groups
|
|
# \{...\} becomes (.+?) for non-greedy capture
|
|
regex_str = re.sub(r'\\{[^}]+\\}', r'(.+?)', escaped)
|
|
|
|
# Compile with case-insensitive matching
|
|
return re.compile(regex_str, re.IGNORECASE), slot_names
|
|
|
|
def _validate_slot_value(self, slot_name: str, value: str, template_id: str) -> bool:
|
|
"""Validate a captured slot value against its expected type.
|
|
|
|
This is used to disambiguate between templates that have identical patterns
|
|
but different slot types (e.g., city vs region vs country).
|
|
|
|
Args:
|
|
slot_name: Name of the slot (e.g., "city", "region", "country")
|
|
value: Captured value to validate
|
|
template_id: Template ID for context
|
|
|
|
Returns:
|
|
True if the value is valid for the slot type, False otherwise
|
|
"""
|
|
resolver = get_synonym_resolver()
|
|
|
|
# Slot-specific validation
|
|
if slot_name in ("region", "subregion"):
|
|
# For region slots, check if the value is a known region
|
|
return resolver.is_region(value)
|
|
elif slot_name == "country":
|
|
# For country slots, check if the value is a known country
|
|
return resolver.is_country(value)
|
|
elif slot_name == "city":
|
|
# For city slots, check if the value is NOT a region AND NOT a country
|
|
# This helps disambiguate "Noord-Holland" (region), "Belgium" (country), and "Amsterdam" (city)
|
|
return not resolver.is_region(value) and not resolver.is_country(value)
|
|
|
|
# Default: accept any value
|
|
return True
|
|
|
|
def _match_by_patterns(
|
|
self,
|
|
question: str,
|
|
templates: dict[str, TemplateDefinition]
|
|
) -> Optional[TemplateMatchResult]:
|
|
"""Try to match question against template patterns using regex.
|
|
|
|
This provides a fast, deterministic fallback before using LLM classification.
|
|
Patterns are defined in the YAML template's `question_patterns` field.
|
|
|
|
When multiple patterns match, prefers:
|
|
1. Patterns with more literal text (more specific)
|
|
2. Patterns with higher fuzzy similarity
|
|
|
|
Args:
|
|
question: The natural language question
|
|
templates: Dictionary of template definitions
|
|
|
|
Returns:
|
|
TemplateMatchResult if high-confidence match found, None otherwise
|
|
"""
|
|
all_matches: list[tuple[str, float, str, int, bool]] = [] # (template_id, confidence, pattern, literal_chars, is_exact)
|
|
|
|
# Normalize question for matching
|
|
question_normalized = question.strip()
|
|
|
|
for template_id, template_def in templates.items():
|
|
patterns = template_def.question_patterns
|
|
if not patterns:
|
|
continue
|
|
|
|
for pattern in patterns:
|
|
try:
|
|
regex, slot_names = self._pattern_to_regex(pattern)
|
|
match = regex.fullmatch(question_normalized)
|
|
|
|
# Calculate literal characters (non-slot text) in pattern
|
|
literal_text = re.sub(r'\{[^}]+\}', '', pattern)
|
|
literal_chars = len(literal_text.strip())
|
|
|
|
if match:
|
|
# Validate captured slot values
|
|
captured_values = match.groups()
|
|
slots_valid = True
|
|
for slot_name, value in zip(slot_names, captured_values):
|
|
if not self._validate_slot_value(slot_name, value, template_id):
|
|
logger.debug(f"Slot validation failed: {slot_name}='{value}' for template {template_id}")
|
|
slots_valid = False
|
|
break
|
|
|
|
if not slots_valid:
|
|
# Skip this match - slot value doesn't match expected type
|
|
continue
|
|
|
|
# Full match = high confidence, but scaled by specificity
|
|
# More literal chars = more specific = higher confidence
|
|
base_confidence = 0.95
|
|
# Boost confidence slightly for more specific patterns
|
|
specificity_bonus = min(0.04, literal_chars / 200.0)
|
|
confidence = base_confidence + specificity_bonus
|
|
|
|
logger.debug(f"Pattern exact match: '{pattern}' -> {template_id} (literal_chars={literal_chars})")
|
|
all_matches.append((template_id, confidence, pattern, literal_chars, True)) # is_exact=True
|
|
continue
|
|
|
|
# Try partial/fuzzy matching with lower confidence
|
|
# Use rapidfuzz to compare pattern structure (with slots replaced)
|
|
pattern_normalized = re.sub(r'\{[^}]+\}', '___', pattern.lower())
|
|
question_lower = question_normalized.lower()
|
|
|
|
# Replace common numeric patterns with placeholder
|
|
question_for_compare = re.sub(r'\d+(?:[.,]\d+)?', '___', question_lower)
|
|
|
|
# Calculate similarity
|
|
similarity = fuzz.ratio(pattern_normalized, question_for_compare) / 100.0
|
|
|
|
if similarity >= 0.75:
|
|
# Good fuzzy match
|
|
confidence = 0.70 + (similarity - 0.75) * 0.8 # Scale 0.75-1.0 to 0.70-0.90
|
|
logger.debug(f"Pattern fuzzy match: '{pattern}' -> {template_id} (sim={similarity:.2f}, conf={confidence:.2f})")
|
|
all_matches.append((template_id, confidence, pattern, literal_chars, False)) # is_exact=False
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Pattern matching error for '{pattern}': {e}")
|
|
continue
|
|
|
|
# Debug: Print match count
|
|
logger.debug(f"Pattern matching found {len(all_matches)} matches for '{question[:50]}...'")
|
|
|
|
if not all_matches:
|
|
return None
|
|
|
|
# Sort by: 1) is_exact descending (exact matches first), 2) literal_chars descending, 3) confidence descending
|
|
all_matches.sort(key=lambda x: (x[4], x[3], x[1]), reverse=True)
|
|
|
|
# Debug: show best match
|
|
best_match = all_matches[0]
|
|
logger.debug(f"Best match after sort: {best_match}")
|
|
|
|
template_id, confidence, matched_pattern, literal_chars, is_exact = best_match
|
|
logger.debug(f"Extracted: template_id={template_id}, confidence={confidence}, literal_chars={literal_chars}, is_exact={is_exact}")
|
|
|
|
if confidence >= 0.75:
|
|
logger.info(f"Pattern-based match found: template='{template_id}', confidence={confidence:.2f}, pattern='{matched_pattern}' (literal_chars={literal_chars})")
|
|
return TemplateMatchResult(
|
|
matched=True,
|
|
template_id=template_id,
|
|
confidence=confidence,
|
|
reasoning=f"Pattern match: '{matched_pattern}'"
|
|
)
|
|
|
|
return None
|
|
|
|
def forward(self, question: str, language: str = "nl") -> TemplateMatchResult:
|
|
"""Classify question to find matching template.
|
|
|
|
Args:
|
|
question: Resolved natural language question
|
|
language: Language code
|
|
|
|
Returns:
|
|
TemplateMatchResult with template ID and confidence
|
|
"""
|
|
templates = self._load_templates()
|
|
|
|
if not templates:
|
|
return TemplateMatchResult(
|
|
matched=False,
|
|
reasoning="No templates loaded"
|
|
)
|
|
|
|
# TIER 1: Pattern-based matching (fast, deterministic, exact regex)
|
|
tier1_start = time.perf_counter()
|
|
pattern_match = self._match_by_patterns(question, templates)
|
|
tier1_duration = time.perf_counter() - tier1_start
|
|
|
|
if pattern_match and pattern_match.confidence >= 0.75:
|
|
logger.info(f"Using pattern-based match: {pattern_match.template_id} (confidence={pattern_match.confidence:.2f})")
|
|
# Record tier 1 success
|
|
if _record_template_tier:
|
|
_record_template_tier(
|
|
tier="pattern",
|
|
matched=True,
|
|
template_id=pattern_match.template_id,
|
|
duration_seconds=tier1_duration,
|
|
)
|
|
return pattern_match
|
|
else:
|
|
# Record tier 1 miss
|
|
if _record_template_tier:
|
|
_record_template_tier(tier="pattern", matched=False, duration_seconds=tier1_duration)
|
|
|
|
# TIER 2: Embedding-based matching (semantic similarity, handles paraphrases)
|
|
tier2_start = time.perf_counter()
|
|
embedding_matcher = get_template_embedding_matcher()
|
|
embedding_match = embedding_matcher.match(question, templates, min_similarity=0.70)
|
|
tier2_duration = time.perf_counter() - tier2_start
|
|
|
|
if embedding_match and embedding_match.confidence >= 0.70:
|
|
logger.info(f"Using embedding-based match: {embedding_match.template_id} (confidence={embedding_match.confidence:.2f})")
|
|
# Record tier 2 success
|
|
if _record_template_tier:
|
|
_record_template_tier(
|
|
tier="embedding",
|
|
matched=True,
|
|
template_id=embedding_match.template_id,
|
|
duration_seconds=tier2_duration,
|
|
)
|
|
return embedding_match
|
|
else:
|
|
# Record tier 2 miss
|
|
if _record_template_tier:
|
|
_record_template_tier(tier="embedding", matched=False, duration_seconds=tier2_duration)
|
|
|
|
# TIER 2.5: RAG-enhanced matching (retrieval + voting from Q&A examples)
|
|
# Based on SPARQL-LLM (arXiv:2512.14277) and COT-SPARQL patterns
|
|
tier2_5_start = time.perf_counter()
|
|
rag_matcher = get_rag_enhanced_matcher()
|
|
rag_match = rag_matcher.match(question, templates, k=5, min_agreement=0.6)
|
|
tier2_5_duration = time.perf_counter() - tier2_5_start
|
|
|
|
if rag_match and rag_match.confidence >= 0.70:
|
|
logger.info(f"Using RAG-enhanced match: {rag_match.template_id} (confidence={rag_match.confidence:.2f})")
|
|
# Record tier 2.5 success
|
|
if _record_template_tier:
|
|
_record_template_tier(
|
|
tier="rag",
|
|
matched=True,
|
|
template_id=rag_match.template_id,
|
|
duration_seconds=tier2_5_duration,
|
|
)
|
|
return rag_match
|
|
else:
|
|
# Record tier 2.5 miss
|
|
if _record_template_tier:
|
|
_record_template_tier(tier="rag", matched=False, duration_seconds=tier2_5_duration)
|
|
|
|
# TIER 3: LLM classification (fallback for complex/novel queries)
|
|
tier3_start = time.perf_counter()
|
|
try:
|
|
result = self.classify(
|
|
question=question,
|
|
language=language
|
|
)
|
|
tier3_duration = time.perf_counter() - tier3_start
|
|
|
|
template_id = result.template_id
|
|
confidence = result.confidence
|
|
|
|
# Debug logging to see what LLM returned
|
|
logger.info(f"Template classifier returned: template_id='{template_id}', confidence={confidence}, reasoning='{result.reasoning[:100]}...'")
|
|
logger.debug(f"Available templates: {list(templates.keys())}")
|
|
|
|
# Handle numeric IDs (LLM sometimes returns "4" instead of "count_institutions_by_type_location")
|
|
numeric_to_template = {
|
|
"1": "list_institutions_by_type_city",
|
|
"2": "list_institutions_by_type_region",
|
|
"3": "list_institutions_by_type_country",
|
|
"4": "count_institutions_by_type_location",
|
|
"5": "count_institutions_by_type",
|
|
"6": "find_institution_by_name",
|
|
"7": "list_all_institutions_in_city",
|
|
"8": "find_institutions_by_founding_date",
|
|
"9": "find_institution_by_identifier",
|
|
"10": "compare_locations",
|
|
"11": "find_custodians_by_budget_threshold",
|
|
}
|
|
if template_id in numeric_to_template:
|
|
logger.info(f"Converting numeric template_id '{template_id}' to '{numeric_to_template[template_id]}'")
|
|
template_id = numeric_to_template[template_id]
|
|
|
|
# Validate template exists
|
|
if template_id != "none" and template_id not in templates:
|
|
# Try fuzzy match on template IDs
|
|
match = process.extractOne(
|
|
template_id,
|
|
list(templates.keys()),
|
|
scorer=fuzz.ratio,
|
|
score_cutoff=70
|
|
)
|
|
if match:
|
|
template_id = match[0]
|
|
else:
|
|
template_id = "none"
|
|
confidence = 0.0
|
|
|
|
if template_id == "none" or confidence < 0.6:
|
|
# Record tier 3 miss
|
|
if _record_template_tier:
|
|
_record_template_tier(tier="llm", matched=False, duration_seconds=tier3_duration)
|
|
return TemplateMatchResult(
|
|
matched=False,
|
|
template_id=None,
|
|
confidence=confidence,
|
|
reasoning=result.reasoning
|
|
)
|
|
|
|
# Record tier 3 success
|
|
if _record_template_tier:
|
|
_record_template_tier(
|
|
tier="llm",
|
|
matched=True,
|
|
template_id=template_id,
|
|
duration_seconds=tier3_duration,
|
|
)
|
|
return TemplateMatchResult(
|
|
matched=True,
|
|
template_id=template_id,
|
|
confidence=confidence,
|
|
reasoning=result.reasoning
|
|
)
|
|
|
|
except Exception as e:
|
|
tier3_duration = time.perf_counter() - tier3_start
|
|
logger.warning(f"Template classification failed: {e}")
|
|
# Record tier 3 error
|
|
if _record_template_tier:
|
|
_record_template_tier(tier="llm", matched=False, duration_seconds=tier3_duration)
|
|
return TemplateMatchResult(
|
|
matched=False,
|
|
reasoning=f"Classification error: {e}"
|
|
)
|
|
|
|
|
|
class SlotExtractor(dspy.Module):
|
|
"""Extracts slot values from questions with synonym resolution."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.extract = dspy.ChainOfThought(SlotExtractorSignature)
|
|
self.resolver = get_synonym_resolver()
|
|
self._templates: Optional[dict[str, TemplateDefinition]] = None
|
|
|
|
def _get_template(self, template_id: str) -> Optional[TemplateDefinition]:
|
|
"""Get template definition by ID."""
|
|
if self._templates is None:
|
|
classifier = TemplateClassifier()
|
|
self._templates = classifier._load_templates()
|
|
return self._templates.get(template_id)
|
|
|
|
def forward(
|
|
self,
|
|
question: str,
|
|
template_id: str,
|
|
inherited_slots: Optional[dict[str, str]] = None
|
|
) -> tuple[dict[str, str], Optional[str]]:
|
|
"""Extract slot values from question.
|
|
|
|
Args:
|
|
question: User's question
|
|
template_id: ID of matched template
|
|
inherited_slots: Slots inherited from conversation context
|
|
|
|
Returns:
|
|
Tuple of (resolved slots dict, variant string or None)
|
|
Variant is 'region' or 'country' if location was resolved as such
|
|
"""
|
|
template = self._get_template(template_id)
|
|
if not template:
|
|
return inherited_slots or {}, None
|
|
|
|
# Get required slots
|
|
required_slots = [
|
|
name for name, slot in template.slots.items()
|
|
if slot.required
|
|
]
|
|
|
|
# Track which variant to use based on resolved slot types
|
|
detected_variant: Optional[str] = None
|
|
|
|
try:
|
|
result = self.extract(
|
|
question=question,
|
|
template_id=template_id,
|
|
required_slots=", ".join(required_slots),
|
|
inherited_slots=json.dumps(inherited_slots or {})
|
|
)
|
|
|
|
# Parse extracted slots
|
|
try:
|
|
raw_slots = json.loads(result.slots_json)
|
|
except (json.JSONDecodeError, TypeError):
|
|
raw_slots = {}
|
|
|
|
# Normalize slot names: LLM sometimes returns "city" when template expects "location"
|
|
# Map common LLM slot names to expected template slot names
|
|
slot_name_aliases = {
|
|
"city": "location",
|
|
"region": "location",
|
|
"province": "location",
|
|
"subregion": "location",
|
|
"country": "location", # Country is also a location variant
|
|
"type": "institution_type",
|
|
"inst_type": "institution_type",
|
|
}
|
|
|
|
normalized_slots = {}
|
|
expected_slot_names = set(template.slots.keys())
|
|
for key, value in raw_slots.items():
|
|
if key in expected_slot_names:
|
|
normalized_slots[key] = value
|
|
elif key in slot_name_aliases and slot_name_aliases[key] in expected_slot_names:
|
|
expected_name = slot_name_aliases[key]
|
|
if expected_name not in normalized_slots:
|
|
normalized_slots[expected_name] = value
|
|
logger.debug(f"Normalized slot name {key!r} to {expected_name!r}")
|
|
else:
|
|
normalized_slots[key] = value
|
|
raw_slots = normalized_slots
|
|
|
|
# Merge with inherited slots (extracted takes precedence)
|
|
slots = {**(inherited_slots or {}), **raw_slots}
|
|
|
|
# Resolve synonyms for each slot
|
|
resolved_slots = {}
|
|
for name, value in slots.items():
|
|
if not value:
|
|
continue
|
|
|
|
slot_def = template.slots.get(name)
|
|
if not slot_def:
|
|
resolved_slots[name] = value
|
|
continue
|
|
|
|
# Resolve based on slot type, with fallback type handling
|
|
if slot_def.type == SlotType.INSTITUTION_TYPE:
|
|
resolved = self.resolver.resolve_institution_type(value)
|
|
resolved_slots[name] = resolved or value
|
|
elif slot_def.type == SlotType.SUBREGION:
|
|
resolved = self.resolver.resolve_subregion(value)
|
|
resolved_slots[name] = resolved or value
|
|
detected_variant = "region"
|
|
elif slot_def.type == SlotType.COUNTRY:
|
|
resolved = self.resolver.resolve_country(value)
|
|
resolved_slots[name] = resolved or value
|
|
detected_variant = "country"
|
|
elif slot_def.type == SlotType.CITY:
|
|
# Check fallback types: if value is a region/country, use that instead
|
|
fallback_types = slot_def.fallback_types
|
|
|
|
# Check if value is actually a region (subregion)
|
|
if SlotType.SUBREGION in fallback_types and self.resolver.is_region(value):
|
|
resolved = self.resolver.resolve_subregion(value)
|
|
resolved_slots[name] = resolved or value
|
|
detected_variant = "region"
|
|
logger.info(f"Location '{value}' resolved as region: {resolved_slots[name]} (variant=region)")
|
|
# Check if value is actually a country
|
|
elif SlotType.COUNTRY in fallback_types:
|
|
country_resolved = self.resolver.resolve_country(value)
|
|
if country_resolved:
|
|
resolved_slots[name] = country_resolved
|
|
detected_variant = "country"
|
|
logger.info(f"Location '{value}' resolved as country: {country_resolved} (variant=country)")
|
|
else:
|
|
# Default to city resolution
|
|
resolved_slots[name] = self.resolver.resolve_city(value)
|
|
else:
|
|
# No fallback types or value doesn't match fallback - use city
|
|
resolved_slots[name] = self.resolver.resolve_city(value)
|
|
elif slot_def.type == SlotType.BUDGET_CATEGORY:
|
|
resolved = self.resolver.resolve_budget_category(value)
|
|
resolved_slots[name] = resolved or value
|
|
else:
|
|
resolved_slots[name] = value
|
|
|
|
# Ensure all slot values are strings (TemplateMatchResult.slots expects dict[str, str])
|
|
# LLM may return integers for limit/offset slots - convert them
|
|
resolved_slots = {k: str(v) if v is not None else "" for k, v in resolved_slots.items()}
|
|
|
|
return resolved_slots, detected_variant
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Slot extraction failed: {e}")
|
|
return inherited_slots or {}, None
|
|
|
|
|
|
class TemplateInstantiator:
|
|
"""Renders SPARQL queries from templates using Jinja2."""
|
|
|
|
def __init__(self):
|
|
self.env = Environment(loader=BaseLoader())
|
|
self._templates: Optional[dict[str, TemplateDefinition]] = None
|
|
|
|
def _get_template(self, template_id: str) -> Optional[TemplateDefinition]:
|
|
"""Get template definition by ID."""
|
|
if self._templates is None:
|
|
classifier = TemplateClassifier()
|
|
self._templates = classifier._load_templates()
|
|
return self._templates.get(template_id)
|
|
|
|
def render(
|
|
self,
|
|
template_id: str,
|
|
slots: dict[str, str],
|
|
variant: Optional[str] = None
|
|
) -> Optional[str]:
|
|
"""Render SPARQL query from template and slots.
|
|
|
|
Args:
|
|
template_id: Template to use
|
|
slots: Resolved slot values
|
|
variant: Optional variant (e.g., 'region', 'country', 'isil')
|
|
|
|
Returns:
|
|
Rendered SPARQL query or None if rendering fails
|
|
"""
|
|
template_def = self._get_template(template_id)
|
|
if not template_def:
|
|
logger.warning(f"Template not found: {template_id}")
|
|
return None
|
|
|
|
# Select template variant
|
|
if variant == "region" and template_def.sparql_template_region:
|
|
sparql_template = template_def.sparql_template_region
|
|
elif variant == "country" and template_def.sparql_template_country:
|
|
sparql_template = template_def.sparql_template_country
|
|
elif variant == "isil" and template_def.sparql_template_isil:
|
|
sparql_template = template_def.sparql_template_isil
|
|
elif variant == "ghcid" and template_def.sparql_template_ghcid:
|
|
sparql_template = template_def.sparql_template_ghcid
|
|
elif variant == "alt" and template_def.sparql_template_alt:
|
|
sparql_template = template_def.sparql_template_alt
|
|
else:
|
|
sparql_template = template_def.sparql_template
|
|
|
|
if not sparql_template:
|
|
logger.warning(f"No SPARQL template for {template_id} variant {variant}")
|
|
return None
|
|
|
|
try:
|
|
# Add prefixes to context
|
|
# Note: limit is NOT defaulted here - each template decides via Jinja2
|
|
# whether to have a default limit or no limit at all.
|
|
# "Show all X" queries should return ALL results, not just 10.
|
|
context = {
|
|
"prefixes": SPARQL_PREFIXES,
|
|
**slots
|
|
}
|
|
# Only add limit to context if explicitly provided in slots
|
|
if "limit" in slots and slots["limit"] is not None:
|
|
context["limit"] = slots["limit"]
|
|
|
|
# Render template
|
|
jinja_template = self.env.from_string(sparql_template)
|
|
sparql = jinja_template.render(**context)
|
|
|
|
# Clean up whitespace
|
|
sparql = re.sub(r'\n\s*\n', '\n', sparql.strip())
|
|
|
|
return sparql
|
|
|
|
except Exception as e:
|
|
logger.error(f"Template rendering failed: {e}")
|
|
return None
|
|
|
|
|
|
# =============================================================================
|
|
# MAIN PIPELINE
|
|
# =============================================================================
|
|
|
|
class TemplateSPARQLPipeline(dspy.Module):
|
|
"""Complete template-based SPARQL generation pipeline.
|
|
|
|
Pipeline order (CRITICAL):
|
|
1. ConversationContextResolver - Expand follow-ups FIRST
|
|
2. FykeFilter - Filter irrelevant on RESOLVED question
|
|
3. TemplateClassifier - Match to template (4 tiers: regex → embedding → RAG → LLM)
|
|
4. SlotExtractor - Extract and resolve slots
|
|
5. TemplateInstantiator - Render SPARQL
|
|
6. SPARQLValidator - Validate against schema (SPARQL-LLM pattern)
|
|
|
|
Falls back to LLM generation if no template matches.
|
|
|
|
Based on SOTA patterns:
|
|
- SPARQL-LLM (arXiv:2512.14277) - RAG + validation loop
|
|
- COT-SPARQL (SEMANTICS 2024) - Context-enriched matching
|
|
- KGQuest (arXiv:2511.11258) - Deterministic templates + LLM refinement
|
|
"""
|
|
|
|
def __init__(self, validate_sparql: bool = True):
|
|
super().__init__()
|
|
self.context_resolver = ConversationContextResolver()
|
|
self.fyke_filter = FykeFilter()
|
|
self.template_classifier = TemplateClassifier()
|
|
self.slot_extractor = SlotExtractor()
|
|
self.instantiator = TemplateInstantiator()
|
|
self.validator = get_sparql_validator() if validate_sparql else None
|
|
self.validate_sparql = validate_sparql
|
|
|
|
def forward(
|
|
self,
|
|
question: str,
|
|
conversation_state: Optional[ConversationState] = None,
|
|
language: str = "nl"
|
|
) -> TemplateMatchResult:
|
|
"""Process question through complete pipeline.
|
|
|
|
Args:
|
|
question: User's question (may be elliptical follow-up)
|
|
conversation_state: Conversation history and state
|
|
language: Language code
|
|
|
|
Returns:
|
|
TemplateMatchResult with SPARQL query if successful
|
|
"""
|
|
# Step 1: Resolve conversation context FIRST
|
|
resolved = self.context_resolver.forward(
|
|
question=question,
|
|
conversation_state=conversation_state
|
|
)
|
|
|
|
logger.info(f"Resolved question: '{question}' → '{resolved.resolved}'")
|
|
|
|
# Step 2: Fyke filter on RESOLVED question
|
|
fyke_result = self.fyke_filter.forward(
|
|
resolved_question=resolved.resolved,
|
|
conversation_topic="heritage institutions",
|
|
language=language
|
|
)
|
|
|
|
if not fyke_result.is_relevant:
|
|
logger.info(f"Question filtered by Fyke: {fyke_result.reasoning}")
|
|
return TemplateMatchResult(
|
|
matched=False,
|
|
reasoning=f"Out of scope: {fyke_result.reasoning}",
|
|
sparql=None # Will trigger standard response
|
|
)
|
|
|
|
# Step 3: Classify to template
|
|
match_result = self.template_classifier.forward(
|
|
question=resolved.resolved,
|
|
language=language
|
|
)
|
|
|
|
if not match_result.matched:
|
|
logger.info(f"No template match: {match_result.reasoning}")
|
|
return match_result # Falls back to LLM generation
|
|
|
|
# Step 4: Extract slots
|
|
template_id = match_result.template_id
|
|
if template_id is None:
|
|
return TemplateMatchResult(
|
|
matched=False,
|
|
reasoning="No template ID from classifier"
|
|
)
|
|
|
|
# Step 4: Extract slots (returns tuple of slots and detected variant)
|
|
slots, detected_variant = self.slot_extractor.forward(
|
|
question=resolved.resolved,
|
|
template_id=template_id,
|
|
inherited_slots=resolved.inherited_slots
|
|
)
|
|
|
|
logger.info(f"Extracted slots: {slots}, detected_variant: {detected_variant}")
|
|
|
|
# Step 5: Render SPARQL with appropriate variant
|
|
sparql = self.instantiator.render(
|
|
template_id=template_id,
|
|
slots=slots,
|
|
variant=detected_variant
|
|
)
|
|
|
|
if not sparql:
|
|
logger.warning(f"Failed to render template {match_result.template_id}")
|
|
return TemplateMatchResult(
|
|
matched=False,
|
|
template_id=match_result.template_id,
|
|
reasoning="Template rendering failed"
|
|
)
|
|
|
|
# Step 6: Validate SPARQL against schema (SPARQL-LLM pattern)
|
|
if self.validate_sparql and self.validator:
|
|
validation = self.validator.validate(sparql)
|
|
if not validation.valid:
|
|
logger.warning(
|
|
f"SPARQL validation errors: {validation.errors}, "
|
|
f"suggestions: {validation.suggestions}"
|
|
)
|
|
# Log but don't fail - errors may be false positives
|
|
# In future: could use LLM to correct errors
|
|
elif validation.warnings:
|
|
logger.info(f"SPARQL validation warnings: {validation.warnings}")
|
|
|
|
# Update conversation state if provided
|
|
if conversation_state:
|
|
conversation_state.add_turn(ConversationTurn(
|
|
role="user",
|
|
content=question,
|
|
resolved_question=resolved.resolved,
|
|
template_id=match_result.template_id,
|
|
slots=slots
|
|
))
|
|
|
|
# Get response modes from template definition (template-driven, not hardcoded)
|
|
template_def = self.instantiator._get_template(template_id)
|
|
response_modes = template_def.response_modes if template_def else ["prose"]
|
|
ui_template = template_def.ui_template if template_def else None
|
|
databases = template_def.databases if template_def else ["oxigraph", "qdrant"]
|
|
|
|
return TemplateMatchResult(
|
|
matched=True,
|
|
template_id=match_result.template_id,
|
|
confidence=match_result.confidence,
|
|
slots=slots,
|
|
sparql=sparql,
|
|
reasoning=match_result.reasoning,
|
|
response_modes=response_modes,
|
|
ui_template=ui_template,
|
|
databases=databases
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# FACTORY FUNCTION
|
|
# =============================================================================
|
|
|
|
def get_template_pipeline() -> TemplateSPARQLPipeline:
|
|
"""Get or create template SPARQL pipeline instance."""
|
|
return TemplateSPARQLPipeline()
|