346 lines
11 KiB
Python
346 lines
11 KiB
Python
"""
|
|
DSPy SPARQL Generation Module
|
|
|
|
Uses DSPy to generate SPARQL queries from natural language questions
|
|
about heritage custodian institutions.
|
|
|
|
Optionally uses Qdrant vector database for RAG-enhanced query generation.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
import dspy
|
|
|
|
from .config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Lazy-load retriever to avoid import errors when Qdrant is not configured
|
|
_retriever: Any = None
|
|
|
|
|
|
def get_retriever() -> Any:
|
|
"""Get or create the Qdrant retriever instance."""
|
|
global _retriever
|
|
if _retriever is None:
|
|
settings = get_settings()
|
|
if settings.qdrant_enabled:
|
|
try:
|
|
from .qdrant_retriever import HeritageCustodianRetriever
|
|
_retriever = HeritageCustodianRetriever(
|
|
host=settings.qdrant_host,
|
|
port=settings.qdrant_port,
|
|
embedding_model=settings.embedding_model,
|
|
embedding_dim=settings.embedding_dim,
|
|
api_key=settings.openai_api_key,
|
|
)
|
|
logger.info("Qdrant retriever initialized")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Qdrant retriever: {e}")
|
|
_retriever = False # Mark as failed, don't retry
|
|
else:
|
|
_retriever = False
|
|
return _retriever if _retriever else None
|
|
|
|
|
|
# DSPy Signature for SPARQL generation
|
|
class QuestionToSPARQL(dspy.Signature):
|
|
"""Generate a SPARQL query from a natural language question about heritage institutions.
|
|
|
|
The query should use the Heritage Custodian Ontology with the following key classes:
|
|
- glam:HeritageCustodian - Heritage institutions (museums, libraries, archives, etc.)
|
|
- glam:Location - Geographic locations with city, country, coordinates
|
|
- glam:Identifier - External identifiers (ISIL, Wikidata, VIAF)
|
|
- glam:Collection - Collections held by institutions
|
|
- glam:DigitalPlatform - Digital systems used by institutions
|
|
|
|
Key properties:
|
|
- glam:name - Institution name
|
|
- glam:institution_type - Type (MUSEUM, LIBRARY, ARCHIVE, etc.)
|
|
- glam:city - City where located
|
|
- glam:country - Country code (ISO 3166-1)
|
|
- glam:latitude, glam:longitude - Coordinates
|
|
- glam:identifier_scheme, glam:identifier_value - External IDs
|
|
"""
|
|
|
|
question: str = dspy.InputField(desc="Natural language question about heritage institutions")
|
|
language: str = dspy.InputField(desc="Language of the question (nl or en)")
|
|
context: str = dspy.InputField(
|
|
desc="Previous conversation context (if any)", default=""
|
|
)
|
|
|
|
sparql: str = dspy.OutputField(desc="Valid SPARQL query to answer the question")
|
|
explanation: str = dspy.OutputField(
|
|
desc="Brief explanation of what the query does in the user's language"
|
|
)
|
|
|
|
|
|
class SPARQLGenerator(dspy.Module):
|
|
"""DSPy module for generating SPARQL queries from natural language."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.generate = dspy.ChainOfThought(QuestionToSPARQL)
|
|
|
|
def forward(
|
|
self, question: str, language: str = "nl", context: str = ""
|
|
) -> dspy.Prediction:
|
|
"""Generate a SPARQL query from a natural language question.
|
|
|
|
Args:
|
|
question: The user's question in natural language
|
|
language: Language code ('nl' for Dutch, 'en' for English)
|
|
context: Previous conversation context for follow-up questions
|
|
|
|
Returns:
|
|
DSPy Prediction with sparql and explanation fields
|
|
"""
|
|
return self.generate(question=question, language=language, context=context)
|
|
|
|
|
|
class RAGSPARQLGenerator(dspy.Module):
|
|
"""RAG-enhanced DSPy module for SPARQL generation.
|
|
|
|
Uses Qdrant vector search to retrieve relevant heritage institution
|
|
context before generating SPARQL queries.
|
|
"""
|
|
|
|
def __init__(self, retriever=None, k: int = 5) -> None:
|
|
super().__init__()
|
|
self.retriever = retriever or get_retriever()
|
|
self.k = k
|
|
self.generate = dspy.ChainOfThought(QuestionToSPARQL)
|
|
|
|
def forward(
|
|
self, question: str, language: str = "nl", context: str = ""
|
|
) -> dspy.Prediction:
|
|
"""Generate a SPARQL query using RAG.
|
|
|
|
Args:
|
|
question: The user's question in natural language
|
|
language: Language code ('nl' for Dutch, 'en' for English)
|
|
context: Previous conversation context for follow-up questions
|
|
|
|
Returns:
|
|
DSPy Prediction with sparql and explanation fields
|
|
"""
|
|
# Retrieve relevant institution context
|
|
rag_context = ""
|
|
if self.retriever:
|
|
try:
|
|
passages = self.retriever(question, k=self.k)
|
|
if passages:
|
|
rag_context = "\n\n## Relevant Institutions:\n"
|
|
for i, passage in enumerate(passages, 1):
|
|
rag_context += f"{i}. {passage}\n"
|
|
except Exception as e:
|
|
logger.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Combine RAG context with existing context
|
|
full_context = context
|
|
if rag_context:
|
|
full_context = rag_context + "\n\n" + context if context else rag_context
|
|
|
|
return self.generate(question=question, language=language, context=full_context)
|
|
|
|
|
|
# Ontology context to inject into prompts
|
|
ONTOLOGY_CONTEXT = """
|
|
# Heritage Custodian Ontology - SPARQL Query Guidelines
|
|
|
|
## Prefixes
|
|
PREFIX glam: <https://w3id.org/heritage/custodian/>
|
|
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
|
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
|
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
|
|
|
|
## Main Classes
|
|
- glam:HeritageCustodian - Heritage institution (museum, library, archive, etc.)
|
|
- glam:Location - Geographic location
|
|
- glam:Identifier - External identifier (ISIL, Wikidata, etc.)
|
|
- glam:Collection - Collection held by an institution
|
|
- glam:DigitalPlatform - Digital system/platform
|
|
|
|
## Institution Types (glam:institution_type)
|
|
- MUSEUM, LIBRARY, ARCHIVE, GALLERY
|
|
- RESEARCH_CENTER, BOTANICAL_ZOO, EDUCATION_PROVIDER
|
|
- COLLECTING_SOCIETY, HOLY_SITES, DIGITAL_PLATFORM, NGO
|
|
|
|
## Key Properties
|
|
- glam:name - Institution name (string)
|
|
- glam:alternative_names - Alternative names (list)
|
|
- glam:institution_type - Type of institution
|
|
- glam:description - Description text
|
|
- glam:homepage - Website URL
|
|
- glam:founded_date - Founding date
|
|
- glam:closed_date - Closure date (if applicable)
|
|
|
|
## Location Properties
|
|
- glam:city - City name
|
|
- glam:country - ISO 3166-1 alpha-2 country code
|
|
- glam:region - Province/state/region
|
|
- glam:street_address - Street address
|
|
- glam:postal_code - Postal/ZIP code
|
|
- glam:latitude - Latitude coordinate
|
|
- glam:longitude - Longitude coordinate
|
|
- glam:geonames_id - GeoNames identifier
|
|
|
|
## Identifier Properties
|
|
- glam:identifier_scheme - Type of ID (ISIL, Wikidata, VIAF, etc.)
|
|
- glam:identifier_value - The identifier value
|
|
- glam:identifier_url - URL for the identifier
|
|
|
|
## Example Queries
|
|
|
|
# Find all museums in Amsterdam
|
|
SELECT ?institution ?name WHERE {
|
|
?institution a glam:HeritageCustodian ;
|
|
glam:name ?name ;
|
|
glam:institution_type "MUSEUM" ;
|
|
glam:locations/glam:city "Amsterdam" .
|
|
}
|
|
|
|
# Count institutions by country
|
|
SELECT ?country (COUNT(?institution) as ?count) WHERE {
|
|
?institution a glam:HeritageCustodian ;
|
|
glam:locations/glam:country ?country .
|
|
}
|
|
GROUP BY ?country
|
|
ORDER BY DESC(?count)
|
|
|
|
# Find institutions with Wikidata identifiers
|
|
SELECT ?institution ?name ?wikidataId WHERE {
|
|
?institution a glam:HeritageCustodian ;
|
|
glam:name ?name ;
|
|
glam:identifiers ?id .
|
|
?id glam:identifier_scheme "Wikidata" ;
|
|
glam:identifier_value ?wikidataId .
|
|
}
|
|
"""
|
|
|
|
|
|
def configure_dspy(
|
|
provider: str = "anthropic",
|
|
model: str = "claude-sonnet-4-20250514",
|
|
api_key: str | None = None,
|
|
) -> None:
|
|
"""Configure DSPy with the specified LLM provider.
|
|
|
|
Args:
|
|
provider: LLM provider ('anthropic' or 'openai')
|
|
model: Model name to use
|
|
api_key: API key for the provider
|
|
"""
|
|
if provider == "anthropic":
|
|
lm = dspy.LM(
|
|
model=f"anthropic/{model}",
|
|
api_key=api_key,
|
|
max_tokens=4096,
|
|
)
|
|
elif provider == "openai":
|
|
lm = dspy.LM(
|
|
model=f"openai/{model}",
|
|
api_key=api_key,
|
|
max_tokens=4096,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown provider: {provider}")
|
|
|
|
dspy.configure(lm=lm)
|
|
logger.info(f"Configured DSPy with {provider}/{model}")
|
|
|
|
|
|
def generate_sparql(
|
|
question: str,
|
|
language: str = "nl",
|
|
context: list[dict[str, Any]] | None = None,
|
|
use_rag: bool = True,
|
|
) -> dict[str, Any]:
|
|
"""Generate a SPARQL query from a natural language question.
|
|
|
|
Args:
|
|
question: The user's question
|
|
language: Language code ('nl' or 'en')
|
|
context: Previous conversation messages
|
|
use_rag: Whether to use RAG-enhanced generation (default: True)
|
|
|
|
Returns:
|
|
Dict with 'sparql', 'explanation', and 'rag_used' keys
|
|
"""
|
|
# Choose generator based on RAG availability
|
|
retriever = get_retriever() if use_rag else None
|
|
if retriever:
|
|
generator = RAGSPARQLGenerator(retriever=retriever)
|
|
rag_used = True
|
|
else:
|
|
generator = SPARQLGenerator()
|
|
rag_used = False
|
|
|
|
# Build context string from conversation history
|
|
context_str = ""
|
|
if context:
|
|
context_parts = []
|
|
for msg in context[-5:]: # Last 5 messages for context
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
if msg.get("sparql"):
|
|
content += f"\n[Generated SPARQL: {msg['sparql']}]"
|
|
context_parts.append(f"{role}: {content}")
|
|
context_str = "\n".join(context_parts)
|
|
|
|
# Add ontology context to the question
|
|
enhanced_question = f"{ONTOLOGY_CONTEXT}\n\nQuestion: {question}"
|
|
|
|
try:
|
|
result = generator(
|
|
question=enhanced_question,
|
|
language=language,
|
|
context=context_str,
|
|
)
|
|
|
|
return {
|
|
"sparql": result.sparql.strip(),
|
|
"explanation": result.explanation.strip(),
|
|
"rag_used": rag_used,
|
|
}
|
|
except Exception as e:
|
|
logger.exception("Error generating SPARQL")
|
|
raise RuntimeError(f"Failed to generate SPARQL: {e}") from e
|
|
|
|
|
|
def generate_sparql_with_rag(
|
|
question: str,
|
|
language: str = "nl",
|
|
context: list[dict[str, Any]] | None = None,
|
|
k: int = 5,
|
|
) -> dict[str, Any]:
|
|
"""Generate a SPARQL query using RAG-enhanced generation.
|
|
|
|
This function always attempts to use RAG. Falls back to standard
|
|
generation if Qdrant is unavailable.
|
|
|
|
Args:
|
|
question: The user's question
|
|
language: Language code ('nl' or 'en')
|
|
context: Previous conversation messages
|
|
k: Number of RAG results to retrieve
|
|
|
|
Returns:
|
|
Dict with 'sparql', 'explanation', 'rag_used', and 'retrieved_passages' keys
|
|
"""
|
|
retriever = get_retriever()
|
|
retrieved_passages = []
|
|
|
|
if retriever:
|
|
try:
|
|
# Get passages for transparency
|
|
passages = retriever(question, k=k)
|
|
retrieved_passages = passages
|
|
except Exception as e:
|
|
logger.warning(f"Failed to retrieve passages: {e}")
|
|
|
|
result = generate_sparql(question, language, context, use_rag=bool(retriever))
|
|
result["retrieved_passages"] = retrieved_passages
|
|
|
|
return result
|