614 lines
19 KiB
Markdown
614 lines
19 KiB
Markdown
# DSPy Compatibility
|
|
|
|
## Overview
|
|
|
|
This document describes how the template-based SPARQL system integrates with DSPy 2.6+. The key requirement is that template classification should be a **DSPy module** that can be optimized alongside existing modules using GEPA or other optimizers.
|
|
|
|
## DSPy Integration Points
|
|
|
|
### Current DSPy Architecture (dspy_heritage_rag.py)
|
|
|
|
The existing RAG system uses DSPy for:
|
|
1. **Query Intent Classification** - `ClassifyQueryIntent` Signature
|
|
2. **Entity Extraction** - `ExtractHeritageEntities` Signature
|
|
3. **SPARQL Generation** - `GenerateSPARQL` Signature (LLM-based)
|
|
4. **Answer Generation** - `GenerateHeritageAnswer` Signature
|
|
|
|
### Template Integration Strategy
|
|
|
|
We add template-based query generation as a **pre-filter** before LLM-based generation:
|
|
|
|
```
|
|
User Question
|
|
|
|
|
v
|
|
+------------------------+
|
|
| TemplateClassifier | <-- NEW: DSPy Module
|
|
| (DSPy Signature) |
|
|
+------------------------+
|
|
|
|
|
v
|
|
[Template Match?]
|
|
|
|
|
Yes No
|
|
| |
|
|
v v
|
|
+----------------+ +----------------+
|
|
| Template | | GenerateSPARQL | <-- Existing
|
|
| Instantiation | | (LLM-based) |
|
|
+----------------+ +----------------+
|
|
| |
|
|
+------------+------------+
|
|
|
|
|
v
|
|
Valid SPARQL
|
|
```
|
|
|
|
## DSPy Signatures
|
|
|
|
### 1. TemplateClassifier Signature
|
|
|
|
```python
|
|
import dspy
|
|
from typing import Optional, Literal
|
|
from pydantic import BaseModel, Field
|
|
|
|
class TemplateMatch(BaseModel):
|
|
"""Output model for template classification."""
|
|
|
|
template_id: str = Field(
|
|
description="ID of the matched template, or 'none' if no match"
|
|
)
|
|
confidence: float = Field(
|
|
description="Confidence score between 0.0 and 1.0",
|
|
ge=0.0,
|
|
le=1.0
|
|
)
|
|
extracted_slots: dict[str, str] = Field(
|
|
default_factory=dict,
|
|
description="Extracted slot values from the question"
|
|
)
|
|
reasoning: str = Field(
|
|
description="Brief explanation of why this template was selected"
|
|
)
|
|
|
|
|
|
class ClassifyTemplateSignature(dspy.Signature):
|
|
"""Classify a heritage question and match it to a SPARQL template.
|
|
|
|
Given a user question about Dutch heritage institutions, determine which
|
|
predefined SPARQL template best matches the query intent. If no template
|
|
matches, return template_id='none' to fall back to LLM generation.
|
|
|
|
Available Templates:
|
|
- region_institution_search: Find institutions of a type in a province
|
|
Example: "Welke archieven zijn er in Drenthe?"
|
|
Slots: institution_type (archieven/musea/bibliotheken), province (Dutch province name)
|
|
|
|
- count_by_type: Count institutions by type
|
|
Example: "Hoeveel musea zijn er in Nederland?"
|
|
Slots: institution_type
|
|
|
|
- count_by_type_region: Count institutions by type in a region
|
|
Example: "Hoeveel archieven zijn er in Noord-Holland?"
|
|
Slots: institution_type, province
|
|
|
|
- entity_lookup: Look up a specific institution by name
|
|
Example: "Wat is het Nationaal Archief?"
|
|
Slots: institution_name
|
|
|
|
- entity_lookup_by_ghcid: Look up by GHCID
|
|
Example: "Details van NL-HaNA"
|
|
Slots: ghcid
|
|
|
|
- list_by_type: List all institutions of a type
|
|
Example: "Toon alle bibliotheken"
|
|
Slots: institution_type
|
|
|
|
Dutch Institution Types:
|
|
- archief/archieven -> type code "A"
|
|
- museum/musea -> type code "M"
|
|
- bibliotheek/bibliotheken -> type code "L"
|
|
- galerie/galerijen -> type code "G"
|
|
|
|
Dutch Provinces:
|
|
- Drenthe, Flevoland, Friesland, Gelderland, Groningen,
|
|
Limburg, Noord-Brabant, Noord-Holland, Overijssel,
|
|
Utrecht, Zeeland, Zuid-Holland
|
|
"""
|
|
|
|
question: str = dspy.InputField(
|
|
desc="The user's question about heritage institutions"
|
|
)
|
|
language: str = dspy.InputField(
|
|
desc="Language code: 'nl' for Dutch, 'en' for English",
|
|
default="nl"
|
|
)
|
|
|
|
template_match: TemplateMatch = dspy.OutputField(
|
|
desc="The matched template and extracted slots"
|
|
)
|
|
|
|
|
|
class TemplateClassifier(dspy.Module):
|
|
"""DSPy module for template classification."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.classify = dspy.ChainOfThought(ClassifyTemplateSignature)
|
|
|
|
def forward(self, question: str, language: str = "nl") -> TemplateMatch:
|
|
"""Classify question and return template match."""
|
|
result = self.classify(question=question, language=language)
|
|
return result.template_match
|
|
```
|
|
|
|
### 2. SlotExtractor Signature
|
|
|
|
```python
|
|
class ExtractedSlots(BaseModel):
|
|
"""Output model for slot extraction."""
|
|
|
|
slots: dict[str, str] = Field(
|
|
description="Mapping of slot names to extracted values"
|
|
)
|
|
normalized_slots: dict[str, str] = Field(
|
|
description="Mapping of slot names to normalized values (codes)"
|
|
)
|
|
missing_slots: list[str] = Field(
|
|
default_factory=list,
|
|
description="List of required slots that could not be extracted"
|
|
)
|
|
|
|
|
|
class ExtractSlotsSignature(dspy.Signature):
|
|
"""Extract slot values from a question for a specific template.
|
|
|
|
Given a question and a template definition, extract the values for each
|
|
slot defined in the template. Normalize values to their standard codes.
|
|
|
|
Normalization Rules:
|
|
- Province names -> ISO 3166-2 codes (e.g., "Drenthe" -> "NL-DR")
|
|
- Institution types -> Single-letter codes (e.g., "archieven" -> "A")
|
|
- GHCID -> Keep as-is if valid format (e.g., "NL-HaNA")
|
|
|
|
Province Code Mappings:
|
|
- Drenthe: NL-DR
|
|
- Flevoland: NL-FL
|
|
- Friesland: NL-FR
|
|
- Gelderland: NL-GE
|
|
- Groningen: NL-GR
|
|
- Limburg: NL-LI
|
|
- Noord-Brabant: NL-NB
|
|
- Noord-Holland: NL-NH
|
|
- Overijssel: NL-OV
|
|
- Utrecht: NL-UT
|
|
- Zeeland: NL-ZE
|
|
- Zuid-Holland: NL-ZH
|
|
|
|
Institution Type Codes:
|
|
- A: Archive (archief, archieven)
|
|
- M: Museum (museum, musea)
|
|
- L: Library (bibliotheek, bibliotheken)
|
|
- G: Gallery (galerie, galerijen)
|
|
"""
|
|
|
|
question: str = dspy.InputField(
|
|
desc="The user's question"
|
|
)
|
|
template_id: str = dspy.InputField(
|
|
desc="ID of the matched template"
|
|
)
|
|
required_slots: list[str] = dspy.InputField(
|
|
desc="List of required slot names for this template"
|
|
)
|
|
|
|
extracted_slots: ExtractedSlots = dspy.OutputField(
|
|
desc="Extracted and normalized slot values"
|
|
)
|
|
|
|
|
|
class SlotExtractor(dspy.Module):
|
|
"""DSPy module for slot extraction."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.extract = dspy.ChainOfThought(ExtractSlotsSignature)
|
|
|
|
def forward(
|
|
self,
|
|
question: str,
|
|
template_id: str,
|
|
required_slots: list[str]
|
|
) -> ExtractedSlots:
|
|
"""Extract slots from question."""
|
|
result = self.extract(
|
|
question=question,
|
|
template_id=template_id,
|
|
required_slots=required_slots
|
|
)
|
|
return result.extracted_slots
|
|
```
|
|
|
|
### 3. Combined TemplateSPARQL Module
|
|
|
|
```python
|
|
class TemplateSPARQL(dspy.Module):
|
|
"""Combined DSPy module for template-based SPARQL generation.
|
|
|
|
This module orchestrates:
|
|
1. Template classification
|
|
2. Slot extraction (if template matches)
|
|
3. Template instantiation (if slots valid)
|
|
4. Fallback to LLM (if no match or invalid slots)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
template_registry: "TemplateRegistry",
|
|
fallback_module: Optional[dspy.Module] = None
|
|
):
|
|
super().__init__()
|
|
self.classifier = TemplateClassifier()
|
|
self.slot_extractor = SlotExtractor()
|
|
self.template_registry = template_registry
|
|
self.fallback_module = fallback_module
|
|
|
|
def forward(self, question: str, language: str = "nl") -> str:
|
|
"""Generate SPARQL query from question.
|
|
|
|
Args:
|
|
question: User's natural language question
|
|
language: Language code ('nl' or 'en')
|
|
|
|
Returns:
|
|
Valid SPARQL query string
|
|
"""
|
|
# Step 1: Classify template
|
|
match = self.classifier(question=question, language=language)
|
|
|
|
# Step 2: Check if template matched
|
|
if match.template_id == "none" or match.confidence < 0.7:
|
|
# Fall back to LLM generation
|
|
if self.fallback_module:
|
|
return self.fallback_module(question=question)
|
|
raise ValueError("No template match and no fallback configured")
|
|
|
|
# Step 3: Get template and extract slots
|
|
template = self.template_registry.get(match.template_id)
|
|
required_slots = list(template.slots.keys())
|
|
|
|
slots = self.slot_extractor(
|
|
question=question,
|
|
template_id=match.template_id,
|
|
required_slots=required_slots
|
|
)
|
|
|
|
# Step 4: Check for missing required slots
|
|
if slots.missing_slots:
|
|
# Fall back if required slots are missing
|
|
if self.fallback_module:
|
|
return self.fallback_module(question=question)
|
|
raise ValueError(f"Missing required slots: {slots.missing_slots}")
|
|
|
|
# Step 5: Instantiate template
|
|
return template.instantiate(slots.normalized_slots)
|
|
```
|
|
|
|
## GEPA Optimization
|
|
|
|
The template classification can be optimized using GEPA:
|
|
|
|
```python
|
|
from dspy.teleprompt import GEPA
|
|
|
|
def template_accuracy_metric(example, prediction):
|
|
"""Metric for template classification accuracy."""
|
|
# Check if correct template was selected
|
|
if prediction.template_match.template_id != example.expected_template_id:
|
|
return 0.0
|
|
|
|
# Check slot extraction accuracy
|
|
expected_slots = example.expected_slots
|
|
predicted_slots = prediction.template_match.extracted_slots
|
|
|
|
if not expected_slots:
|
|
return 1.0 # No slots to check
|
|
|
|
correct_slots = sum(
|
|
1 for k, v in expected_slots.items()
|
|
if predicted_slots.get(k) == v
|
|
)
|
|
return correct_slots / len(expected_slots)
|
|
|
|
|
|
def create_training_examples():
|
|
"""Create training examples for GEPA optimization."""
|
|
return [
|
|
dspy.Example(
|
|
question="Welke archieven zijn er in Drenthe?",
|
|
language="nl",
|
|
expected_template_id="region_institution_search",
|
|
expected_slots={
|
|
"institution_type": "archieven",
|
|
"province": "Drenthe"
|
|
}
|
|
).with_inputs("question", "language"),
|
|
|
|
dspy.Example(
|
|
question="Welke musea zijn er in Noord-Holland?",
|
|
language="nl",
|
|
expected_template_id="region_institution_search",
|
|
expected_slots={
|
|
"institution_type": "musea",
|
|
"province": "Noord-Holland"
|
|
}
|
|
).with_inputs("question", "language"),
|
|
|
|
dspy.Example(
|
|
question="Hoeveel bibliotheken zijn er in Nederland?",
|
|
language="nl",
|
|
expected_template_id="count_by_type",
|
|
expected_slots={
|
|
"institution_type": "bibliotheken"
|
|
}
|
|
).with_inputs("question", "language"),
|
|
|
|
# Add more examples...
|
|
]
|
|
|
|
|
|
async def optimize_template_classifier():
|
|
"""Run GEPA optimization on template classifier."""
|
|
|
|
# Create training data
|
|
trainset = create_training_examples()
|
|
|
|
# Initialize classifier
|
|
classifier = TemplateClassifier()
|
|
|
|
# Configure GEPA optimizer
|
|
optimizer = GEPA(
|
|
metric=template_accuracy_metric,
|
|
auto="light", # Use light optimization
|
|
max_metric_calls=100,
|
|
)
|
|
|
|
# Run optimization
|
|
optimized_classifier = optimizer.compile(
|
|
classifier,
|
|
trainset=trainset,
|
|
)
|
|
|
|
# Save optimized module
|
|
optimized_classifier.save("optimized_template_classifier.json")
|
|
|
|
return optimized_classifier
|
|
```
|
|
|
|
## Integration with Existing HeritageRAG
|
|
|
|
### Modified HeritageRAG Class
|
|
|
|
```python
|
|
# In dspy_heritage_rag.py
|
|
|
|
class HeritageRAG(dspy.Module):
|
|
"""Heritage RAG with template-based SPARQL generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
template_registry: Optional["TemplateRegistry"] = None,
|
|
use_templates: bool = True,
|
|
template_confidence_threshold: float = 0.7,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
# Existing components
|
|
self.query_intent = dspy.ChainOfThought(ClassifyQueryIntent)
|
|
self.entity_extractor = dspy.ChainOfThought(ExtractHeritageEntities)
|
|
self.sparql_generator = dspy.ChainOfThought(GenerateSPARQL)
|
|
self.answer_generator = dspy.ChainOfThought(GenerateHeritageAnswer)
|
|
|
|
# NEW: Template-based components
|
|
self.use_templates = use_templates
|
|
self.template_confidence_threshold = template_confidence_threshold
|
|
|
|
if use_templates:
|
|
self.template_classifier = TemplateClassifier()
|
|
self.slot_extractor = SlotExtractor()
|
|
self.template_registry = template_registry or TemplateRegistry.load_default()
|
|
|
|
def generate_sparql(self, question: str, language: str = "nl") -> str:
|
|
"""Generate SPARQL query, trying templates first."""
|
|
|
|
if self.use_templates:
|
|
# Try template-based generation
|
|
try:
|
|
match = self.template_classifier(question=question, language=language)
|
|
|
|
if (match.template_id != "none" and
|
|
match.confidence >= self.template_confidence_threshold):
|
|
|
|
template = self.template_registry.get(match.template_id)
|
|
required_slots = list(template.slots.keys())
|
|
|
|
slots = self.slot_extractor(
|
|
question=question,
|
|
template_id=match.template_id,
|
|
required_slots=required_slots
|
|
)
|
|
|
|
if not slots.missing_slots:
|
|
# Successfully matched template
|
|
logger.info(
|
|
f"Using template '{match.template_id}' "
|
|
f"(confidence: {match.confidence:.2f})"
|
|
)
|
|
return template.instantiate(slots.normalized_slots)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Template generation failed: {e}, falling back to LLM")
|
|
|
|
# Fall back to LLM-based generation
|
|
logger.info("Using LLM-based SPARQL generation")
|
|
return self.sparql_generator(question=question).sparql_query
|
|
```
|
|
|
|
## Signature Caching for OpenAI
|
|
|
|
To leverage OpenAI's prompt caching (1,024+ token threshold), we create cacheable docstrings:
|
|
|
|
```python
|
|
def get_cacheable_template_classifier_docstring() -> str:
|
|
"""Generate >1024 token docstring for template classifier."""
|
|
|
|
# Load template definitions
|
|
templates = TemplateRegistry.load_default()
|
|
|
|
# Build comprehensive docstring
|
|
parts = [
|
|
"Classify a heritage question and match it to a SPARQL template.",
|
|
"",
|
|
"## Available Templates",
|
|
""
|
|
]
|
|
|
|
for template_id, template in templates.items():
|
|
parts.extend([
|
|
f"### {template_id}",
|
|
f"Description: {template.description}",
|
|
f"Example questions:",
|
|
])
|
|
for pattern in template.question_patterns[:3]:
|
|
parts.append(f" - {pattern}")
|
|
parts.extend([
|
|
f"Required slots: {list(template.slots.keys())}",
|
|
""
|
|
])
|
|
|
|
parts.extend([
|
|
"## Slot Value Mappings",
|
|
"",
|
|
"### Province Codes (ISO 3166-2)",
|
|
])
|
|
|
|
# Add all province mappings
|
|
for province, code in get_subregion_mappings().items():
|
|
parts.append(f"- {province}: {code}")
|
|
|
|
parts.extend([
|
|
"",
|
|
"### Institution Type Codes",
|
|
])
|
|
|
|
for type_name, code in get_institution_type_mappings().items():
|
|
parts.append(f"- {type_name}: {code}")
|
|
|
|
return "\n".join(parts)
|
|
```
|
|
|
|
## Testing DSPy Integration
|
|
|
|
```python
|
|
# tests/template_sparql/test_dspy_integration.py
|
|
|
|
import pytest
|
|
import dspy
|
|
|
|
class TestDSPyIntegration:
|
|
"""Test DSPy module integration."""
|
|
|
|
@pytest.fixture
|
|
def classifier(self):
|
|
"""Create template classifier with mock LM."""
|
|
# Use DSPy's testing utilities
|
|
dspy.configure(lm=dspy.LM("gpt-4o-mini"))
|
|
return TemplateClassifier()
|
|
|
|
def test_classifier_forward(self, classifier):
|
|
"""Test classifier forward pass."""
|
|
result = classifier(
|
|
question="Welke archieven zijn er in Drenthe?",
|
|
language="nl"
|
|
)
|
|
|
|
assert result.template_id == "region_institution_search"
|
|
assert result.confidence > 0.5
|
|
assert "institution_type" in result.extracted_slots
|
|
assert "province" in result.extracted_slots
|
|
|
|
def test_classifier_is_optimizable(self, classifier):
|
|
"""Test that classifier can be compiled with GEPA."""
|
|
from dspy.teleprompt import GEPA
|
|
|
|
trainset = [
|
|
dspy.Example(
|
|
question="Welke archieven zijn er in Drenthe?",
|
|
language="nl",
|
|
expected_template_id="region_institution_search"
|
|
).with_inputs("question", "language")
|
|
]
|
|
|
|
optimizer = GEPA(
|
|
metric=lambda e, p: 1.0 if p.template_match.template_id == e.expected_template_id else 0.0,
|
|
max_metric_calls=5
|
|
)
|
|
|
|
# Should not raise
|
|
compiled = optimizer.compile(classifier, trainset=trainset)
|
|
assert compiled is not None
|
|
```
|
|
|
|
## Performance Considerations
|
|
|
|
### Caching Strategy
|
|
|
|
```python
|
|
from functools import lru_cache
|
|
|
|
class CachedTemplateClassifier(TemplateClassifier):
|
|
"""Template classifier with caching."""
|
|
|
|
@lru_cache(maxsize=1000)
|
|
def forward_cached(self, question: str, language: str = "nl") -> TemplateMatch:
|
|
"""Cached forward pass for repeated questions."""
|
|
return super().forward(question, language)
|
|
|
|
def forward(self, question: str, language: str = "nl") -> TemplateMatch:
|
|
# Normalize question for better cache hits
|
|
normalized = question.lower().strip()
|
|
return self.forward_cached(normalized, language)
|
|
```
|
|
|
|
### Batch Processing
|
|
|
|
```python
|
|
class BatchTemplateClassifier(dspy.Module):
|
|
"""Batch classification for efficiency."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.classifier = TemplateClassifier()
|
|
|
|
async def forward_batch(
|
|
self,
|
|
questions: list[str],
|
|
language: str = "nl"
|
|
) -> list[TemplateMatch]:
|
|
"""Classify multiple questions in parallel."""
|
|
import asyncio
|
|
|
|
async def classify_one(q: str) -> TemplateMatch:
|
|
return self.classifier(question=q, language=language)
|
|
|
|
return await asyncio.gather(*[classify_one(q) for q in questions])
|
|
```
|
|
|
|
## References
|
|
|
|
- DSPy Documentation: https://dspy-docs.vercel.app/
|
|
- GEPA Paper: https://arxiv.org/abs/2507.19457
|
|
- DSPy Signatures: https://dspy-docs.vercel.app/docs/building-blocks/signatures
|
|
- DSPy Modules: https://dspy-docs.vercel.app/docs/building-blocks/modules
|