10 KiB
10 KiB
DSPy Evaluation Metrics
Overview
Metrics are Python functions that score DSPy module outputs. They're used for both evaluation (tracking progress) and optimization (training optimizers).
Metric Function Signature
def metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float | bool:
"""
Score a prediction against expected output.
Args:
example: The input example with expected outputs
pred: The module's prediction
trace: Optional trace for intermediate steps (used in optimization)
Returns:
Score between 0.0-1.0, or bool for pass/fail
"""
pass
Core Metrics
1. Intent Accuracy
# tests/dspy_gitops/metrics/intent_accuracy.py
def intent_accuracy(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Check if predicted intent matches expected intent.
Returns: 1.0 if match, 0.0 otherwise
"""
expected = example.expected_intent.lower().strip()
predicted = getattr(pred, 'intent', '').lower().strip()
return 1.0 if expected == predicted else 0.0
def intent_accuracy_with_fallback(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Intent accuracy with partial credit for related intents.
Related intents:
- statistical <-> exploration (0.5)
- entity_lookup <-> relational (0.5)
"""
expected = example.expected_intent.lower()
predicted = getattr(pred, 'intent', '').lower()
if expected == predicted:
return 1.0
# Partial credit for related intents
related_pairs = [
("statistical", "exploration"),
("entity_lookup", "relational"),
("temporal", "entity_lookup"),
]
for a, b in related_pairs:
if {expected, predicted} == {a, b}:
return 0.5
return 0.0
2. Entity Extraction F1
# tests/dspy_gitops/metrics/entity_extraction.py
def entity_f1(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Calculate F1 score for entity extraction.
Compares extracted entities against expected entities using
fuzzy matching for robustness.
"""
expected = set(e.lower() for e in example.expected_entities)
predicted_raw = getattr(pred, 'entities', [])
if isinstance(predicted_raw, str):
# Parse comma-separated string
predicted_raw = [e.strip() for e in predicted_raw.split(',')]
predicted = set(e.lower() for e in predicted_raw if e)
if not expected and not predicted:
return 1.0
if not expected or not predicted:
return 0.0
# Fuzzy matching for each expected entity
from rapidfuzz import fuzz
true_positives = 0
for exp_entity in expected:
for pred_entity in predicted:
if fuzz.ratio(exp_entity, pred_entity) >= 80:
true_positives += 1
break
precision = true_positives / len(predicted) if predicted else 0
recall = true_positives / len(expected) if expected else 0
if precision + recall == 0:
return 0.0
f1 = 2 * (precision * recall) / (precision + recall)
return f1
def entity_type_accuracy(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""Check if entity type (institution vs person) is correct."""
expected = getattr(example, 'expected_entity_type', 'institution')
predicted = getattr(pred, 'entity_type', 'institution')
return 1.0 if expected == predicted else 0.0
3. SPARQL Correctness
# tests/dspy_gitops/metrics/sparql_correctness.py
import re
def sparql_syntax_valid(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Validate SPARQL syntax.
Checks:
- Required clauses (SELECT, WHERE)
- Balanced braces
- Valid prefixes
"""
sparql = getattr(pred, 'sparql_query', '')
if not sparql:
return 0.0
errors = []
# Check required clauses
if 'SELECT' not in sparql.upper():
errors.append("Missing SELECT")
if 'WHERE' not in sparql.upper():
errors.append("Missing WHERE")
# Check balanced braces
if sparql.count('{') != sparql.count('}'):
errors.append("Unbalanced braces")
# Check for common errors
if '???' in sparql:
errors.append("Contains placeholder ???")
return 1.0 if not errors else 0.0
def sparql_executes(
example: dspy.Example,
pred: dspy.Prediction,
oxigraph_url: str = "http://localhost:7878",
trace=None
) -> float:
"""
Check if SPARQL query executes without error.
Requires live Oxigraph connection.
"""
import httpx
sparql = getattr(pred, 'sparql_query', '')
if not sparql:
return 0.0
try:
response = httpx.post(
f"{oxigraph_url}/query",
data={"query": sparql},
headers={"Accept": "application/sparql-results+json"},
timeout=10.0,
)
return 1.0 if response.status_code == 200 else 0.0
except Exception:
return 0.0
def sparql_returns_results(
example: dspy.Example,
pred: dspy.Prediction,
oxigraph_url: str = "http://localhost:7878",
trace=None
) -> float:
"""
Check if SPARQL query returns non-empty results.
"""
import httpx
sparql = getattr(pred, 'sparql_query', '')
if not sparql:
return 0.0
try:
response = httpx.post(
f"{oxigraph_url}/query",
data={"query": sparql},
headers={"Accept": "application/sparql-results+json"},
timeout=10.0,
)
if response.status_code != 200:
return 0.0
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
return 1.0 if bindings else 0.0
except Exception:
return 0.0
4. Answer Relevance (LLM-as-Judge)
# tests/dspy_gitops/metrics/answer_relevance.py
import dspy
class AnswerRelevanceJudge(dspy.Signature):
"""Judge if an answer is relevant to the question."""
question: str = dspy.InputField(desc="The original question")
answer: str = dspy.InputField(desc="The generated answer")
relevant: bool = dspy.OutputField(desc="Is the answer relevant to the question?")
reasoning: str = dspy.OutputField(desc="Brief explanation of relevance judgment")
class AnswerGroundednessJudge(dspy.Signature):
"""Judge if an answer is grounded in retrieved context."""
context: str = dspy.InputField(desc="The retrieved context")
answer: str = dspy.InputField(desc="The generated answer")
grounded: bool = dspy.OutputField(desc="Is the answer grounded in the context?")
reasoning: str = dspy.OutputField(desc="Brief explanation")
def answer_relevance_llm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Use LLM to judge answer relevance.
More expensive but handles nuanced evaluation.
"""
answer = getattr(pred, 'answer', '')
if not answer:
return 0.0
judge = dspy.Predict(AnswerRelevanceJudge)
result = judge(question=example.question, answer=answer)
return 1.0 if result.relevant else 0.0
def answer_groundedness_llm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Use LLM to judge if answer is grounded in context.
"""
answer = getattr(pred, 'answer', '')
context = getattr(pred, 'context', '')
if not answer or not context:
return 0.0
judge = dspy.Predict(AnswerGroundednessJudge)
result = judge(context=context, answer=answer)
return 1.0 if result.grounded else 0.0
5. Composite Metric
# tests/dspy_gitops/metrics/composite.py
def heritage_rag_metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Comprehensive metric for Heritage RAG evaluation.
Combines multiple aspects:
- Intent accuracy (25%)
- Entity extraction (25%)
- SPARQL validity (20%)
- Answer presence (15%)
- Sources used (15%)
"""
from .intent_accuracy import intent_accuracy
from .entity_extraction import entity_f1
from .sparql_correctness import sparql_syntax_valid
scores = {
"intent": intent_accuracy(example, pred, trace),
"entities": entity_f1(example, pred, trace),
"sparql": sparql_syntax_valid(example, pred, trace),
"answer": 1.0 if getattr(pred, 'answer', '') else 0.0,
"sources": 1.0 if getattr(pred, 'sources_used', []) else 0.0,
}
weights = {
"intent": 0.25,
"entities": 0.25,
"sparql": 0.20,
"answer": 0.15,
"sources": 0.15,
}
weighted_score = sum(scores[k] * weights[k] for k in scores)
# For trace (optimization), require all components to pass
if trace is not None:
return weighted_score >= 0.8
return weighted_score
def heritage_rag_strict(example: dspy.Example, pred: dspy.Prediction, trace=None) -> bool:
"""
Strict pass/fail metric for bootstrapping optimization.
All components must pass for bootstrapping to use this example.
"""
from .intent_accuracy import intent_accuracy
from .entity_extraction import entity_f1
from .sparql_correctness import sparql_syntax_valid
return all([
intent_accuracy(example, pred, trace) >= 0.9,
entity_f1(example, pred, trace) >= 0.8,
sparql_syntax_valid(example, pred, trace) >= 1.0,
bool(getattr(pred, 'answer', '')),
])
Using Metrics with dspy.Evaluate
from dspy import Evaluate
from metrics import heritage_rag_metric, intent_accuracy, entity_f1
# Create evaluator
evaluator = Evaluate(
devset=dev_set,
metric=heritage_rag_metric,
num_threads=4,
display_progress=True,
display_table=5, # Show top 5 results
)
# Run evaluation
result = evaluator(pipeline)
print(f"Overall score: {result.score}%")
# Detailed per-metric breakdown
for example, pred, score in result.results:
print(f"Question: {example.question[:50]}...")
print(f" Intent: {intent_accuracy(example, pred)}")
print(f" Entities: {entity_f1(example, pred):.2f}")
print(f" Score: {score:.2f}")
Metric Selection Guidelines
| Use Case | Recommended Metric |
|---|---|
| Quick smoke test | intent_accuracy |
| CI gate | heritage_rag_metric |
| Optimization | heritage_rag_strict |
| Deep analysis | answer_relevance_llm |
| SPARQL debugging | sparql_returns_results |