glam/src/glam_extractor/api/dspy_sparql.py

516 lines
18 KiB
Python

"""
DSPy SPARQL Generation Module
Uses DSPy to generate SPARQL queries from natural language questions
about heritage custodian institutions.
Optionally uses Qdrant vector database for RAG-enhanced query generation.
"""
import logging
from typing import Any
import dspy
from .config import get_settings
logger = logging.getLogger(__name__)
# Lazy-load retriever to avoid import errors when Qdrant is not configured
_retriever: Any = None
def get_retriever() -> Any:
"""Get or create the Qdrant retriever instance."""
global _retriever
if _retriever is None:
settings = get_settings()
if settings.qdrant_enabled:
try:
from .qdrant_retriever import HeritageCustodianRetriever
_retriever = HeritageCustodianRetriever(
host=settings.qdrant_host,
port=settings.qdrant_port,
embedding_model=settings.embedding_model,
embedding_dim=settings.embedding_dim,
api_key=settings.openai_api_key,
)
logger.info("Qdrant retriever initialized")
except Exception as e:
logger.warning(f"Failed to initialize Qdrant retriever: {e}")
_retriever = False # Mark as failed, don't retry
else:
_retriever = False
return _retriever if _retriever else None
# DSPy Signature for SPARQL generation
class QuestionToSPARQL(dspy.Signature):
"""Generate a SPARQL query from a natural language question about heritage institutions.
Use the Heritage Custodian Ontology with SPARQL endpoint at bronhouder.nl/sparql.
Key class: hc:Custodian (https://nde.nl/ontology/hc/class/Custodian)
Key properties (use hcp: prefix = https://nde.nl/ontology/hc/):
- hcp:institutionType - Single letter: "M"=Museum, "L"=Library, "A"=Archive, "G"=Gallery, "S"=Society
- hcp:ghcid, hcp:isil, hcp:wikidataId - Identifiers
- skos:prefLabel - Institution name
- schema:addressCountry - Country as Wikidata URI (e.g., wd:Q55 = Netherlands)
For Dutch provinces, filter on URI pattern (e.g., FILTER(CONTAINS(STR(?s), "NL-NH")) for Noord-Holland).
"""
question: str = dspy.InputField(desc="Natural language question about heritage institutions")
language: str = dspy.InputField(desc="Language of the question (nl or en)")
context: str = dspy.InputField(
desc="Previous conversation context (if any)", default=""
)
sparql: str = dspy.OutputField(desc="Valid SPARQL query to answer the question")
explanation: str = dspy.OutputField(
desc="Brief explanation of what the query does in the user's language"
)
class SPARQLGenerator(dspy.Module):
"""DSPy module for generating SPARQL queries from natural language."""
def __init__(self) -> None:
super().__init__()
self.generate = dspy.ChainOfThought(QuestionToSPARQL)
def forward(
self, question: str, language: str = "nl", context: str = ""
) -> dspy.Prediction:
"""Generate a SPARQL query from a natural language question.
Args:
question: The user's question in natural language
language: Language code ('nl' for Dutch, 'en' for English)
context: Previous conversation context for follow-up questions
Returns:
DSPy Prediction with sparql and explanation fields
"""
return self.generate(question=question, language=language, context=context)
class RAGSPARQLGenerator(dspy.Module):
"""RAG-enhanced DSPy module for SPARQL generation.
Uses Qdrant vector search to retrieve relevant heritage institution
context before generating SPARQL queries.
"""
def __init__(self, retriever=None, k: int = 5) -> None:
super().__init__()
self.retriever = retriever or get_retriever()
self.k = k
self.generate = dspy.ChainOfThought(QuestionToSPARQL)
def forward(
self, question: str, language: str = "nl", context: str = ""
) -> dspy.Prediction:
"""Generate a SPARQL query using RAG.
Args:
question: The user's question in natural language
language: Language code ('nl' for Dutch, 'en' for English)
context: Previous conversation context for follow-up questions
Returns:
DSPy Prediction with sparql and explanation fields
"""
# Retrieve relevant institution context
rag_context = ""
if self.retriever:
try:
passages = self.retriever(question, k=self.k)
if passages:
rag_context = "\n\n## Relevant Institutions:\n"
for i, passage in enumerate(passages, 1):
rag_context += f"{i}. {passage}\n"
except Exception as e:
logger.warning(f"RAG retrieval failed: {e}")
# Combine RAG context with existing context
full_context = context
if rag_context:
full_context = rag_context + "\n\n" + context if context else rag_context
return self.generate(question=question, language=language, context=full_context)
# Ontology context to inject into prompts
# IMPORTANT: This must match the ACTUAL RDF data in Oxigraph at bronhouder.nl/sparql
ONTOLOGY_CONTEXT = """
# Heritage Custodian Ontology - SPARQL Query Guidelines
## CRITICAL: Prefixes (USE THESE EXACTLY!)
PREFIX hc: <https://nde.nl/ontology/hc/class/> # For classes (Custodian)
PREFIX hcp: <https://nde.nl/ontology/hc/> # For properties (institutionType, ghcid, isil, etc.)
PREFIX schema: <http://schema.org/>
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX wd: <http://www.wikidata.org/entity/>
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
PREFIX dct: <http://purl.org/dc/terms/>
## Main Class (USE THIS!)
- hc:Custodian - The ONLY main class for heritage institutions
Full URI: <https://nde.nl/ontology/hc/class/Custodian>
## Institution Type Property (CRITICAL!)
- hcp:institutionType - Single-letter type codes (NOT full words!)
- "M" = Museum
- "L" = Library
- "A" = Archive
- "G" = Gallery
- "S" = Collecting Society
- "B" = Botanical/Zoo
- "R" = Research Center
- "E" = Education Provider
- "O" = Official Institution
- "D" = Digital Platform
- "N" = NGO
- "H" = Holy Site
- "F" = Feature
- "I" = Intangible Heritage
- "C" = Corporation
- "U" = Unknown
## CRITICAL: Province/Region Filtering
The institution URI contains encoded location information!
URI pattern: https://nde.nl/ontology/hc/{COUNTRY}-{PROVINCE}-{CITY}-{TYPE}-{NAME}
Example URIs:
- https://nde.nl/ontology/hc/NL-NH-AMS-M-RIJKS (Rijksmuseum in Amsterdam, Noord-Holland)
- https://nde.nl/ontology/hc/NL-ZH-RTD-M-BOIJM (Boijmans in Rotterdam, Zuid-Holland)
**To filter by Dutch province, use FILTER on the URI string:**
- FILTER(CONTAINS(STR(?s), "NL-NH")) = Noord-Holland
- FILTER(CONTAINS(STR(?s), "NL-ZH")) = Zuid-Holland
- FILTER(CONTAINS(STR(?s), "NL-NB")) = Noord-Brabant
- FILTER(CONTAINS(STR(?s), "NL-GE")) = Gelderland
- FILTER(CONTAINS(STR(?s), "NL-UT")) = Utrecht
- FILTER(CONTAINS(STR(?s), "NL-OV")) = Overijssel
- FILTER(CONTAINS(STR(?s), "NL-LI")) = Limburg
- FILTER(CONTAINS(STR(?s), "NL-FR")) = Friesland (Fryslân)
- FILTER(CONTAINS(STR(?s), "NL-GR")) = Groningen
- FILTER(CONTAINS(STR(?s), "NL-DR")) = Drenthe
- FILTER(CONTAINS(STR(?s), "NL-FL")) = Flevoland
- FILTER(CONTAINS(STR(?s), "NL-ZE")) = Zeeland
## Key Properties (use hcp: prefix)
- hcp:institutionType - Single-letter type code
- hcp:ghcid - Global Heritage Custodian ID
- hcp:ghcidUUID - GHCID as UUID
- hcp:isil - ISIL code (e.g., "NL-AmRMA")
- hcp:wikidataId - Wikidata Q-number (e.g., "Q190804")
- hcp:viaf - VIAF ID
- hcp:gnd - GND ID
- hcp:foundingYear - Founding year
## Name Properties
- skos:prefLabel - Primary name (PREFERRED - use this!)
- schema:name - Institution name
- rdfs:label - Alternative label
- skos:altLabel - Alternative names
- foaf:name - FOAF name
## Description & URL Properties
- schema:description - Description text
- dct:description - DC Terms description
- schema:url - Website URL
- foaf:homepage - Homepage URL
## Location Properties
- schema:addressCountry - Country as Wikidata URI (e.g., wd:Q55 = Netherlands)
- schema:location - Links to Place
- schema:containedInPlace - Parent region
- wdt:P17 - Country (Wikidata property)
- wdt:P131 - Located in administrative entity
## Country Codes (as Wikidata URIs) - Top countries in dataset:
- wd:Q213 = Czech Republic (6,481 institutions)
- wd:Q17 = Japan (4,346 institutions)
- wd:Q55 = Netherlands (1,123 institutions)
- wd:Q31 = Belgium (97 institutions)
- wd:Q40 = Austria (86 institutions)
- wd:Q298 = Chile (73 institutions)
- wd:Q96 = Mexico (65 institutions)
- wd:Q155 = Brazil (47 institutions)
- wd:Q183 = Germany (40 institutions)
- wd:Q145 = United Kingdom (31 institutions)
- wd:Q142 = France (29 institutions)
- wd:Q30 = United States (22 institutions)
## EXAMPLE QUERIES (COPY THESE PATTERNS!)
# Count all museums in the Netherlands
SELECT (COUNT(DISTINCT ?s) as ?count) WHERE {
?s a hc:Custodian ;
hcp:institutionType "M" ;
schema:addressCountry wd:Q55 .
}
# Count museums in Noord-Holland (use URI filter!)
SELECT (COUNT(?s) as ?count) WHERE {
?s a hc:Custodian ;
hcp:institutionType "M" .
FILTER(CONTAINS(STR(?s), "NL-NH"))
}
# List museums in Amsterdam with names
SELECT ?museum ?name WHERE {
?museum a hc:Custodian ;
hcp:institutionType "M" ;
skos:prefLabel ?name .
FILTER(CONTAINS(STR(?museum), "NL-NH-AMS"))
}
# Count institutions by type
SELECT ?type (COUNT(?s) as ?count) WHERE {
?s a hc:Custodian ;
hcp:institutionType ?type .
} GROUP BY ?type ORDER BY DESC(?count)
# Find all archives in the Netherlands
SELECT ?archive ?name WHERE {
?archive a hc:Custodian ;
hcp:institutionType "A" ;
skos:prefLabel ?name ;
schema:addressCountry wd:Q55 .
} ORDER BY ?name
# Find institution by ISIL code
SELECT ?institution ?name WHERE {
?institution a hc:Custodian ;
hcp:isil "NL-AmRMA" ;
skos:prefLabel ?name .
}
# Find institution by Wikidata ID
SELECT ?institution ?name WHERE {
?institution a hc:Custodian ;
hcp:wikidataId "Q190804" ;
skos:prefLabel ?name .
}
# List all Dutch libraries
SELECT ?library ?name WHERE {
?library a hc:Custodian ;
hcp:institutionType "L" ;
skos:prefLabel ?name ;
schema:addressCountry wd:Q55 .
} ORDER BY ?name
# Count institutions per country
SELECT ?country (COUNT(?s) as ?count) WHERE {
?s a hc:Custodian ;
schema:addressCountry ?country .
} GROUP BY ?country ORDER BY DESC(?count)
## COMMON MISTAKES TO AVOID:
1. DO NOT use crm:E39_Actor - use hc:Custodian
2. DO NOT use hc:institutionType - use hcp:institutionType
3. DO NOT use full type names like "Museum" - use "M"
4. DO NOT use schema:addressLocality for provinces - use FILTER on URI
5. DO NOT forget wd: prefix for Wikidata country codes
"""
def configure_dspy(
provider: str = "anthropic",
model: str = "claude-sonnet-4-20250514",
api_key: str | None = None,
) -> None:
"""Configure DSPy with the specified LLM provider.
Args:
provider: LLM provider ('anthropic', 'openai', or 'zai')
model: Model name to use
api_key: API key for the provider
"""
if provider == "anthropic":
lm = dspy.LM(
model=f"anthropic/{model}",
api_key=api_key,
max_tokens=4096,
)
elif provider == "openai":
lm = dspy.LM(
model=f"openai/{model}",
api_key=api_key,
max_tokens=4096,
)
elif provider == "zai":
# Z.AI Coding Plan uses OpenAI-compatible API with GLM models
# Endpoint: https://api.z.ai/api/coding/paas/v4/chat/completions
lm = dspy.LM(
model=f"openai/{model}", # GLM models use OpenAI-compatible format
api_key=api_key,
api_base="https://api.z.ai/api/coding/paas/v4",
max_tokens=4096,
)
else:
raise ValueError(f"Unknown provider: {provider}")
dspy.configure(lm=lm)
logger.info(f"Configured DSPy with {provider}/{model}")
def generate_sparql(
question: str,
language: str = "nl",
context: list[dict[str, Any]] | None = None,
use_rag: bool = True,
validate: bool = True,
max_retries: int = 2,
) -> dict[str, Any]:
"""Generate a SPARQL query from a natural language question.
Args:
question: The user's question
language: Language code ('nl' or 'en')
context: Previous conversation messages
use_rag: Whether to use RAG-enhanced generation (default: True)
validate: Whether to validate with SHACL-based linter (default: True)
max_retries: Maximum retries if validation fails (default: 2)
Returns:
Dict with 'sparql', 'explanation', 'rag_used', and 'lint_result' keys
"""
from .sparql_linter import lint_sparql, get_lint_context_for_llm, auto_correct_sparql
# Choose generator based on RAG availability
retriever = get_retriever() if use_rag else None
if retriever:
generator = RAGSPARQLGenerator(retriever=retriever)
rag_used = True
else:
generator = SPARQLGenerator()
rag_used = False
# Build context string from conversation history
context_str = ""
if context:
context_parts = []
for msg in context[-5:]: # Last 5 messages for context
role = msg.get("role", "user")
content = msg.get("content", "")
if msg.get("sparql"):
content += f"\n[Generated SPARQL: {msg['sparql']}]"
context_parts.append(f"{role}: {content}")
context_str = "\n".join(context_parts)
# Add ontology context to the question
enhanced_question = f"{ONTOLOGY_CONTEXT}\n\nQuestion: {question}"
retries = 0
lint_feedback = ""
while retries <= max_retries:
try:
# Include lint feedback if this is a retry
retry_question = enhanced_question
if lint_feedback:
retry_question = f"{enhanced_question}\n\n{lint_feedback}"
result = generator(
question=retry_question,
language=language,
context=context_str,
)
sparql = result.sparql.strip()
explanation = result.explanation.strip()
# Auto-correct common errors FIRST (fast, <1ms)
corrected_sparql, was_corrected = auto_correct_sparql(sparql)
if was_corrected:
logger.info("SPARQL auto-corrected by linter")
sparql = corrected_sparql
# Validate with SHACL-based linter if enabled
lint_result = None
if validate:
lint_result = lint_sparql(sparql)
# If there are errors and we haven't exhausted retries, try again
if not lint_result.valid and retries < max_retries:
lint_feedback = get_lint_context_for_llm(lint_result)
logger.warning(f"SPARQL validation failed (attempt {retries + 1}), retrying...")
retries += 1
continue
# Log lint issues even if we proceed
if lint_result.issues:
logger.info(f"SPARQL lint: {lint_result.error_count} errors, {lint_result.warning_count} warnings")
return {
"sparql": sparql,
"explanation": explanation,
"rag_used": rag_used,
"lint_result": {
"valid": lint_result.valid if lint_result else True,
"error_count": lint_result.error_count if lint_result else 0,
"warning_count": lint_result.warning_count if lint_result else 0,
"issues": [
{
"severity": issue.severity.value,
"code": issue.code,
"message": issue.message,
"suggestion": issue.suggestion,
}
for issue in (lint_result.issues if lint_result else [])
],
} if validate else None,
"retries": retries,
}
except Exception as e:
logger.exception("Error generating SPARQL")
raise RuntimeError(f"Failed to generate SPARQL: {e}") from e
# Should not reach here, but just in case
raise RuntimeError("Failed to generate valid SPARQL after retries")
def generate_sparql_with_rag(
question: str,
language: str = "nl",
context: list[dict[str, Any]] | None = None,
k: int = 5,
) -> dict[str, Any]:
"""Generate a SPARQL query using RAG-enhanced generation.
This function always attempts to use RAG. Falls back to standard
generation if Qdrant is unavailable.
Args:
question: The user's question
language: Language code ('nl' or 'en')
context: Previous conversation messages
k: Number of RAG results to retrieve
Returns:
Dict with 'sparql', 'explanation', 'rag_used', and 'retrieved_passages' keys
"""
retriever = get_retriever()
retrieved_passages = []
if retriever:
try:
# Get passages for transparency
passages = retriever(question, k=k)
retrieved_passages = passages
except Exception as e:
logger.warning(f"Failed to retrieve passages: {e}")
result = generate_sparql(question, language, context, use_rag=bool(retriever))
result["retrieved_passages"] = retrieved_passages
return result