glam/docs/plan/dspy_gitops/02-metrics.md
2026-01-11 18:08:40 +01:00

386 lines
10 KiB
Markdown

# DSPy Evaluation Metrics
## Overview
Metrics are Python functions that score DSPy module outputs. They're used for both evaluation (tracking progress) and optimization (training optimizers).
## Metric Function Signature
```python
def metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float | bool:
"""
Score a prediction against expected output.
Args:
example: The input example with expected outputs
pred: The module's prediction
trace: Optional trace for intermediate steps (used in optimization)
Returns:
Score between 0.0-1.0, or bool for pass/fail
"""
pass
```
## Core Metrics
### 1. Intent Accuracy
```python
# tests/dspy_gitops/metrics/intent_accuracy.py
def intent_accuracy(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Check if predicted intent matches expected intent.
Returns: 1.0 if match, 0.0 otherwise
"""
expected = example.expected_intent.lower().strip()
predicted = getattr(pred, 'intent', '').lower().strip()
return 1.0 if expected == predicted else 0.0
def intent_accuracy_with_fallback(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Intent accuracy with partial credit for related intents.
Related intents:
- statistical <-> exploration (0.5)
- entity_lookup <-> relational (0.5)
"""
expected = example.expected_intent.lower()
predicted = getattr(pred, 'intent', '').lower()
if expected == predicted:
return 1.0
# Partial credit for related intents
related_pairs = [
("statistical", "exploration"),
("entity_lookup", "relational"),
("temporal", "entity_lookup"),
]
for a, b in related_pairs:
if {expected, predicted} == {a, b}:
return 0.5
return 0.0
```
### 2. Entity Extraction F1
```python
# tests/dspy_gitops/metrics/entity_extraction.py
def entity_f1(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Calculate F1 score for entity extraction.
Compares extracted entities against expected entities using
fuzzy matching for robustness.
"""
expected = set(e.lower() for e in example.expected_entities)
predicted_raw = getattr(pred, 'entities', [])
if isinstance(predicted_raw, str):
# Parse comma-separated string
predicted_raw = [e.strip() for e in predicted_raw.split(',')]
predicted = set(e.lower() for e in predicted_raw if e)
if not expected and not predicted:
return 1.0
if not expected or not predicted:
return 0.0
# Fuzzy matching for each expected entity
from rapidfuzz import fuzz
true_positives = 0
for exp_entity in expected:
for pred_entity in predicted:
if fuzz.ratio(exp_entity, pred_entity) >= 80:
true_positives += 1
break
precision = true_positives / len(predicted) if predicted else 0
recall = true_positives / len(expected) if expected else 0
if precision + recall == 0:
return 0.0
f1 = 2 * (precision * recall) / (precision + recall)
return f1
def entity_type_accuracy(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""Check if entity type (institution vs person) is correct."""
expected = getattr(example, 'expected_entity_type', 'institution')
predicted = getattr(pred, 'entity_type', 'institution')
return 1.0 if expected == predicted else 0.0
```
### 3. SPARQL Correctness
```python
# tests/dspy_gitops/metrics/sparql_correctness.py
import re
def sparql_syntax_valid(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Validate SPARQL syntax.
Checks:
- Required clauses (SELECT, WHERE)
- Balanced braces
- Valid prefixes
"""
sparql = getattr(pred, 'sparql_query', '')
if not sparql:
return 0.0
errors = []
# Check required clauses
if 'SELECT' not in sparql.upper():
errors.append("Missing SELECT")
if 'WHERE' not in sparql.upper():
errors.append("Missing WHERE")
# Check balanced braces
if sparql.count('{') != sparql.count('}'):
errors.append("Unbalanced braces")
# Check for common errors
if '???' in sparql:
errors.append("Contains placeholder ???")
return 1.0 if not errors else 0.0
def sparql_executes(
example: dspy.Example,
pred: dspy.Prediction,
oxigraph_url: str = "http://localhost:7878",
trace=None
) -> float:
"""
Check if SPARQL query executes without error.
Requires live Oxigraph connection.
"""
import httpx
sparql = getattr(pred, 'sparql_query', '')
if not sparql:
return 0.0
try:
response = httpx.post(
f"{oxigraph_url}/query",
data={"query": sparql},
headers={"Accept": "application/sparql-results+json"},
timeout=10.0,
)
return 1.0 if response.status_code == 200 else 0.0
except Exception:
return 0.0
def sparql_returns_results(
example: dspy.Example,
pred: dspy.Prediction,
oxigraph_url: str = "http://localhost:7878",
trace=None
) -> float:
"""
Check if SPARQL query returns non-empty results.
"""
import httpx
sparql = getattr(pred, 'sparql_query', '')
if not sparql:
return 0.0
try:
response = httpx.post(
f"{oxigraph_url}/query",
data={"query": sparql},
headers={"Accept": "application/sparql-results+json"},
timeout=10.0,
)
if response.status_code != 200:
return 0.0
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
return 1.0 if bindings else 0.0
except Exception:
return 0.0
```
### 4. Answer Relevance (LLM-as-Judge)
```python
# tests/dspy_gitops/metrics/answer_relevance.py
import dspy
class AnswerRelevanceJudge(dspy.Signature):
"""Judge if an answer is relevant to the question."""
question: str = dspy.InputField(desc="The original question")
answer: str = dspy.InputField(desc="The generated answer")
relevant: bool = dspy.OutputField(desc="Is the answer relevant to the question?")
reasoning: str = dspy.OutputField(desc="Brief explanation of relevance judgment")
class AnswerGroundednessJudge(dspy.Signature):
"""Judge if an answer is grounded in retrieved context."""
context: str = dspy.InputField(desc="The retrieved context")
answer: str = dspy.InputField(desc="The generated answer")
grounded: bool = dspy.OutputField(desc="Is the answer grounded in the context?")
reasoning: str = dspy.OutputField(desc="Brief explanation")
def answer_relevance_llm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Use LLM to judge answer relevance.
More expensive but handles nuanced evaluation.
"""
answer = getattr(pred, 'answer', '')
if not answer:
return 0.0
judge = dspy.Predict(AnswerRelevanceJudge)
result = judge(question=example.question, answer=answer)
return 1.0 if result.relevant else 0.0
def answer_groundedness_llm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Use LLM to judge if answer is grounded in context.
"""
answer = getattr(pred, 'answer', '')
context = getattr(pred, 'context', '')
if not answer or not context:
return 0.0
judge = dspy.Predict(AnswerGroundednessJudge)
result = judge(context=context, answer=answer)
return 1.0 if result.grounded else 0.0
```
### 5. Composite Metric
```python
# tests/dspy_gitops/metrics/composite.py
def heritage_rag_metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Comprehensive metric for Heritage RAG evaluation.
Combines multiple aspects:
- Intent accuracy (25%)
- Entity extraction (25%)
- SPARQL validity (20%)
- Answer presence (15%)
- Sources used (15%)
"""
from .intent_accuracy import intent_accuracy
from .entity_extraction import entity_f1
from .sparql_correctness import sparql_syntax_valid
scores = {
"intent": intent_accuracy(example, pred, trace),
"entities": entity_f1(example, pred, trace),
"sparql": sparql_syntax_valid(example, pred, trace),
"answer": 1.0 if getattr(pred, 'answer', '') else 0.0,
"sources": 1.0 if getattr(pred, 'sources_used', []) else 0.0,
}
weights = {
"intent": 0.25,
"entities": 0.25,
"sparql": 0.20,
"answer": 0.15,
"sources": 0.15,
}
weighted_score = sum(scores[k] * weights[k] for k in scores)
# For trace (optimization), require all components to pass
if trace is not None:
return weighted_score >= 0.8
return weighted_score
def heritage_rag_strict(example: dspy.Example, pred: dspy.Prediction, trace=None) -> bool:
"""
Strict pass/fail metric for bootstrapping optimization.
All components must pass for bootstrapping to use this example.
"""
from .intent_accuracy import intent_accuracy
from .entity_extraction import entity_f1
from .sparql_correctness import sparql_syntax_valid
return all([
intent_accuracy(example, pred, trace) >= 0.9,
entity_f1(example, pred, trace) >= 0.8,
sparql_syntax_valid(example, pred, trace) >= 1.0,
bool(getattr(pred, 'answer', '')),
])
```
## Using Metrics with dspy.Evaluate
```python
from dspy import Evaluate
from metrics import heritage_rag_metric, intent_accuracy, entity_f1
# Create evaluator
evaluator = Evaluate(
devset=dev_set,
metric=heritage_rag_metric,
num_threads=4,
display_progress=True,
display_table=5, # Show top 5 results
)
# Run evaluation
result = evaluator(pipeline)
print(f"Overall score: {result.score}%")
# Detailed per-metric breakdown
for example, pred, score in result.results:
print(f"Question: {example.question[:50]}...")
print(f" Intent: {intent_accuracy(example, pred)}")
print(f" Entities: {entity_f1(example, pred):.2f}")
print(f" Score: {score:.2f}")
```
## Metric Selection Guidelines
| Use Case | Recommended Metric |
|----------|-------------------|
| Quick smoke test | `intent_accuracy` |
| CI gate | `heritage_rag_metric` |
| Optimization | `heritage_rag_strict` |
| Deep analysis | `answer_relevance_llm` |
| SPARQL debugging | `sparql_returns_results` |