glam/src/glam_extractor/api/dspy_sparql.py
2025-12-09 07:56:35 +01:00

346 lines
11 KiB
Python

"""
DSPy SPARQL Generation Module
Uses DSPy to generate SPARQL queries from natural language questions
about heritage custodian institutions.
Optionally uses Qdrant vector database for RAG-enhanced query generation.
"""
import logging
from typing import Any
import dspy
from .config import get_settings
logger = logging.getLogger(__name__)
# Lazy-load retriever to avoid import errors when Qdrant is not configured
_retriever: Any = None
def get_retriever() -> Any:
"""Get or create the Qdrant retriever instance."""
global _retriever
if _retriever is None:
settings = get_settings()
if settings.qdrant_enabled:
try:
from .qdrant_retriever import HeritageCustodianRetriever
_retriever = HeritageCustodianRetriever(
host=settings.qdrant_host,
port=settings.qdrant_port,
embedding_model=settings.embedding_model,
embedding_dim=settings.embedding_dim,
api_key=settings.openai_api_key,
)
logger.info("Qdrant retriever initialized")
except Exception as e:
logger.warning(f"Failed to initialize Qdrant retriever: {e}")
_retriever = False # Mark as failed, don't retry
else:
_retriever = False
return _retriever if _retriever else None
# DSPy Signature for SPARQL generation
class QuestionToSPARQL(dspy.Signature):
"""Generate a SPARQL query from a natural language question about heritage institutions.
The query should use the Heritage Custodian Ontology with the following key classes:
- glam:HeritageCustodian - Heritage institutions (museums, libraries, archives, etc.)
- glam:Location - Geographic locations with city, country, coordinates
- glam:Identifier - External identifiers (ISIL, Wikidata, VIAF)
- glam:Collection - Collections held by institutions
- glam:DigitalPlatform - Digital systems used by institutions
Key properties:
- glam:name - Institution name
- glam:institution_type - Type (MUSEUM, LIBRARY, ARCHIVE, etc.)
- glam:city - City where located
- glam:country - Country code (ISO 3166-1)
- glam:latitude, glam:longitude - Coordinates
- glam:identifier_scheme, glam:identifier_value - External IDs
"""
question: str = dspy.InputField(desc="Natural language question about heritage institutions")
language: str = dspy.InputField(desc="Language of the question (nl or en)")
context: str = dspy.InputField(
desc="Previous conversation context (if any)", default=""
)
sparql: str = dspy.OutputField(desc="Valid SPARQL query to answer the question")
explanation: str = dspy.OutputField(
desc="Brief explanation of what the query does in the user's language"
)
class SPARQLGenerator(dspy.Module):
"""DSPy module for generating SPARQL queries from natural language."""
def __init__(self) -> None:
super().__init__()
self.generate = dspy.ChainOfThought(QuestionToSPARQL)
def forward(
self, question: str, language: str = "nl", context: str = ""
) -> dspy.Prediction:
"""Generate a SPARQL query from a natural language question.
Args:
question: The user's question in natural language
language: Language code ('nl' for Dutch, 'en' for English)
context: Previous conversation context for follow-up questions
Returns:
DSPy Prediction with sparql and explanation fields
"""
return self.generate(question=question, language=language, context=context)
class RAGSPARQLGenerator(dspy.Module):
"""RAG-enhanced DSPy module for SPARQL generation.
Uses Qdrant vector search to retrieve relevant heritage institution
context before generating SPARQL queries.
"""
def __init__(self, retriever=None, k: int = 5) -> None:
super().__init__()
self.retriever = retriever or get_retriever()
self.k = k
self.generate = dspy.ChainOfThought(QuestionToSPARQL)
def forward(
self, question: str, language: str = "nl", context: str = ""
) -> dspy.Prediction:
"""Generate a SPARQL query using RAG.
Args:
question: The user's question in natural language
language: Language code ('nl' for Dutch, 'en' for English)
context: Previous conversation context for follow-up questions
Returns:
DSPy Prediction with sparql and explanation fields
"""
# Retrieve relevant institution context
rag_context = ""
if self.retriever:
try:
passages = self.retriever(question, k=self.k)
if passages:
rag_context = "\n\n## Relevant Institutions:\n"
for i, passage in enumerate(passages, 1):
rag_context += f"{i}. {passage}\n"
except Exception as e:
logger.warning(f"RAG retrieval failed: {e}")
# Combine RAG context with existing context
full_context = context
if rag_context:
full_context = rag_context + "\n\n" + context if context else rag_context
return self.generate(question=question, language=language, context=full_context)
# Ontology context to inject into prompts
ONTOLOGY_CONTEXT = """
# Heritage Custodian Ontology - SPARQL Query Guidelines
## Prefixes
PREFIX glam: <https://w3id.org/heritage/custodian/>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
## Main Classes
- glam:HeritageCustodian - Heritage institution (museum, library, archive, etc.)
- glam:Location - Geographic location
- glam:Identifier - External identifier (ISIL, Wikidata, etc.)
- glam:Collection - Collection held by an institution
- glam:DigitalPlatform - Digital system/platform
## Institution Types (glam:institution_type)
- MUSEUM, LIBRARY, ARCHIVE, GALLERY
- RESEARCH_CENTER, BOTANICAL_ZOO, EDUCATION_PROVIDER
- COLLECTING_SOCIETY, HOLY_SITES, DIGITAL_PLATFORM, NGO
## Key Properties
- glam:name - Institution name (string)
- glam:alternative_names - Alternative names (list)
- glam:institution_type - Type of institution
- glam:description - Description text
- glam:homepage - Website URL
- glam:founded_date - Founding date
- glam:closed_date - Closure date (if applicable)
## Location Properties
- glam:city - City name
- glam:country - ISO 3166-1 alpha-2 country code
- glam:region - Province/state/region
- glam:street_address - Street address
- glam:postal_code - Postal/ZIP code
- glam:latitude - Latitude coordinate
- glam:longitude - Longitude coordinate
- glam:geonames_id - GeoNames identifier
## Identifier Properties
- glam:identifier_scheme - Type of ID (ISIL, Wikidata, VIAF, etc.)
- glam:identifier_value - The identifier value
- glam:identifier_url - URL for the identifier
## Example Queries
# Find all museums in Amsterdam
SELECT ?institution ?name WHERE {
?institution a glam:HeritageCustodian ;
glam:name ?name ;
glam:institution_type "MUSEUM" ;
glam:locations/glam:city "Amsterdam" .
}
# Count institutions by country
SELECT ?country (COUNT(?institution) as ?count) WHERE {
?institution a glam:HeritageCustodian ;
glam:locations/glam:country ?country .
}
GROUP BY ?country
ORDER BY DESC(?count)
# Find institutions with Wikidata identifiers
SELECT ?institution ?name ?wikidataId WHERE {
?institution a glam:HeritageCustodian ;
glam:name ?name ;
glam:identifiers ?id .
?id glam:identifier_scheme "Wikidata" ;
glam:identifier_value ?wikidataId .
}
"""
def configure_dspy(
provider: str = "anthropic",
model: str = "claude-sonnet-4-20250514",
api_key: str | None = None,
) -> None:
"""Configure DSPy with the specified LLM provider.
Args:
provider: LLM provider ('anthropic' or 'openai')
model: Model name to use
api_key: API key for the provider
"""
if provider == "anthropic":
lm = dspy.LM(
model=f"anthropic/{model}",
api_key=api_key,
max_tokens=4096,
)
elif provider == "openai":
lm = dspy.LM(
model=f"openai/{model}",
api_key=api_key,
max_tokens=4096,
)
else:
raise ValueError(f"Unknown provider: {provider}")
dspy.configure(lm=lm)
logger.info(f"Configured DSPy with {provider}/{model}")
def generate_sparql(
question: str,
language: str = "nl",
context: list[dict[str, Any]] | None = None,
use_rag: bool = True,
) -> dict[str, Any]:
"""Generate a SPARQL query from a natural language question.
Args:
question: The user's question
language: Language code ('nl' or 'en')
context: Previous conversation messages
use_rag: Whether to use RAG-enhanced generation (default: True)
Returns:
Dict with 'sparql', 'explanation', and 'rag_used' keys
"""
# Choose generator based on RAG availability
retriever = get_retriever() if use_rag else None
if retriever:
generator = RAGSPARQLGenerator(retriever=retriever)
rag_used = True
else:
generator = SPARQLGenerator()
rag_used = False
# Build context string from conversation history
context_str = ""
if context:
context_parts = []
for msg in context[-5:]: # Last 5 messages for context
role = msg.get("role", "user")
content = msg.get("content", "")
if msg.get("sparql"):
content += f"\n[Generated SPARQL: {msg['sparql']}]"
context_parts.append(f"{role}: {content}")
context_str = "\n".join(context_parts)
# Add ontology context to the question
enhanced_question = f"{ONTOLOGY_CONTEXT}\n\nQuestion: {question}"
try:
result = generator(
question=enhanced_question,
language=language,
context=context_str,
)
return {
"sparql": result.sparql.strip(),
"explanation": result.explanation.strip(),
"rag_used": rag_used,
}
except Exception as e:
logger.exception("Error generating SPARQL")
raise RuntimeError(f"Failed to generate SPARQL: {e}") from e
def generate_sparql_with_rag(
question: str,
language: str = "nl",
context: list[dict[str, Any]] | None = None,
k: int = 5,
) -> dict[str, Any]:
"""Generate a SPARQL query using RAG-enhanced generation.
This function always attempts to use RAG. Falls back to standard
generation if Qdrant is unavailable.
Args:
question: The user's question
language: Language code ('nl' or 'en')
context: Previous conversation messages
k: Number of RAG results to retrieve
Returns:
Dict with 'sparql', 'explanation', 'rag_used', and 'retrieved_passages' keys
"""
retriever = get_retriever()
retrieved_passages = []
if retriever:
try:
# Get passages for transparency
passages = retriever(question, k=k)
retrieved_passages = passages
except Exception as e:
logger.warning(f"Failed to retrieve passages: {e}")
result = generate_sparql(question, language, context, use_rag=bool(retriever))
result["retrieved_passages"] = retrieved_passages
return result