386 lines
10 KiB
Markdown
386 lines
10 KiB
Markdown
# DSPy Evaluation Metrics
|
|
|
|
## Overview
|
|
|
|
Metrics are Python functions that score DSPy module outputs. They're used for both evaluation (tracking progress) and optimization (training optimizers).
|
|
|
|
## Metric Function Signature
|
|
|
|
```python
|
|
def metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float | bool:
|
|
"""
|
|
Score a prediction against expected output.
|
|
|
|
Args:
|
|
example: The input example with expected outputs
|
|
pred: The module's prediction
|
|
trace: Optional trace for intermediate steps (used in optimization)
|
|
|
|
Returns:
|
|
Score between 0.0-1.0, or bool for pass/fail
|
|
"""
|
|
pass
|
|
```
|
|
|
|
## Core Metrics
|
|
|
|
### 1. Intent Accuracy
|
|
|
|
```python
|
|
# tests/dspy_gitops/metrics/intent_accuracy.py
|
|
|
|
def intent_accuracy(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Check if predicted intent matches expected intent.
|
|
|
|
Returns: 1.0 if match, 0.0 otherwise
|
|
"""
|
|
expected = example.expected_intent.lower().strip()
|
|
predicted = getattr(pred, 'intent', '').lower().strip()
|
|
|
|
return 1.0 if expected == predicted else 0.0
|
|
|
|
|
|
def intent_accuracy_with_fallback(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Intent accuracy with partial credit for related intents.
|
|
|
|
Related intents:
|
|
- statistical <-> exploration (0.5)
|
|
- entity_lookup <-> relational (0.5)
|
|
"""
|
|
expected = example.expected_intent.lower()
|
|
predicted = getattr(pred, 'intent', '').lower()
|
|
|
|
if expected == predicted:
|
|
return 1.0
|
|
|
|
# Partial credit for related intents
|
|
related_pairs = [
|
|
("statistical", "exploration"),
|
|
("entity_lookup", "relational"),
|
|
("temporal", "entity_lookup"),
|
|
]
|
|
|
|
for a, b in related_pairs:
|
|
if {expected, predicted} == {a, b}:
|
|
return 0.5
|
|
|
|
return 0.0
|
|
```
|
|
|
|
### 2. Entity Extraction F1
|
|
|
|
```python
|
|
# tests/dspy_gitops/metrics/entity_extraction.py
|
|
|
|
def entity_f1(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Calculate F1 score for entity extraction.
|
|
|
|
Compares extracted entities against expected entities using
|
|
fuzzy matching for robustness.
|
|
"""
|
|
expected = set(e.lower() for e in example.expected_entities)
|
|
predicted_raw = getattr(pred, 'entities', [])
|
|
|
|
if isinstance(predicted_raw, str):
|
|
# Parse comma-separated string
|
|
predicted_raw = [e.strip() for e in predicted_raw.split(',')]
|
|
|
|
predicted = set(e.lower() for e in predicted_raw if e)
|
|
|
|
if not expected and not predicted:
|
|
return 1.0
|
|
if not expected or not predicted:
|
|
return 0.0
|
|
|
|
# Fuzzy matching for each expected entity
|
|
from rapidfuzz import fuzz
|
|
|
|
true_positives = 0
|
|
for exp_entity in expected:
|
|
for pred_entity in predicted:
|
|
if fuzz.ratio(exp_entity, pred_entity) >= 80:
|
|
true_positives += 1
|
|
break
|
|
|
|
precision = true_positives / len(predicted) if predicted else 0
|
|
recall = true_positives / len(expected) if expected else 0
|
|
|
|
if precision + recall == 0:
|
|
return 0.0
|
|
|
|
f1 = 2 * (precision * recall) / (precision + recall)
|
|
return f1
|
|
|
|
|
|
def entity_type_accuracy(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""Check if entity type (institution vs person) is correct."""
|
|
expected = getattr(example, 'expected_entity_type', 'institution')
|
|
predicted = getattr(pred, 'entity_type', 'institution')
|
|
|
|
return 1.0 if expected == predicted else 0.0
|
|
```
|
|
|
|
### 3. SPARQL Correctness
|
|
|
|
```python
|
|
# tests/dspy_gitops/metrics/sparql_correctness.py
|
|
|
|
import re
|
|
|
|
def sparql_syntax_valid(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Validate SPARQL syntax.
|
|
|
|
Checks:
|
|
- Required clauses (SELECT, WHERE)
|
|
- Balanced braces
|
|
- Valid prefixes
|
|
"""
|
|
sparql = getattr(pred, 'sparql_query', '')
|
|
|
|
if not sparql:
|
|
return 0.0
|
|
|
|
errors = []
|
|
|
|
# Check required clauses
|
|
if 'SELECT' not in sparql.upper():
|
|
errors.append("Missing SELECT")
|
|
if 'WHERE' not in sparql.upper():
|
|
errors.append("Missing WHERE")
|
|
|
|
# Check balanced braces
|
|
if sparql.count('{') != sparql.count('}'):
|
|
errors.append("Unbalanced braces")
|
|
|
|
# Check for common errors
|
|
if '???' in sparql:
|
|
errors.append("Contains placeholder ???")
|
|
|
|
return 1.0 if not errors else 0.0
|
|
|
|
|
|
def sparql_executes(
|
|
example: dspy.Example,
|
|
pred: dspy.Prediction,
|
|
oxigraph_url: str = "http://localhost:7878",
|
|
trace=None
|
|
) -> float:
|
|
"""
|
|
Check if SPARQL query executes without error.
|
|
|
|
Requires live Oxigraph connection.
|
|
"""
|
|
import httpx
|
|
|
|
sparql = getattr(pred, 'sparql_query', '')
|
|
if not sparql:
|
|
return 0.0
|
|
|
|
try:
|
|
response = httpx.post(
|
|
f"{oxigraph_url}/query",
|
|
data={"query": sparql},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=10.0,
|
|
)
|
|
return 1.0 if response.status_code == 200 else 0.0
|
|
except Exception:
|
|
return 0.0
|
|
|
|
|
|
def sparql_returns_results(
|
|
example: dspy.Example,
|
|
pred: dspy.Prediction,
|
|
oxigraph_url: str = "http://localhost:7878",
|
|
trace=None
|
|
) -> float:
|
|
"""
|
|
Check if SPARQL query returns non-empty results.
|
|
"""
|
|
import httpx
|
|
|
|
sparql = getattr(pred, 'sparql_query', '')
|
|
if not sparql:
|
|
return 0.0
|
|
|
|
try:
|
|
response = httpx.post(
|
|
f"{oxigraph_url}/query",
|
|
data={"query": sparql},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=10.0,
|
|
)
|
|
if response.status_code != 200:
|
|
return 0.0
|
|
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
return 1.0 if bindings else 0.0
|
|
|
|
except Exception:
|
|
return 0.0
|
|
```
|
|
|
|
### 4. Answer Relevance (LLM-as-Judge)
|
|
|
|
```python
|
|
# tests/dspy_gitops/metrics/answer_relevance.py
|
|
|
|
import dspy
|
|
|
|
class AnswerRelevanceJudge(dspy.Signature):
|
|
"""Judge if an answer is relevant to the question."""
|
|
|
|
question: str = dspy.InputField(desc="The original question")
|
|
answer: str = dspy.InputField(desc="The generated answer")
|
|
|
|
relevant: bool = dspy.OutputField(desc="Is the answer relevant to the question?")
|
|
reasoning: str = dspy.OutputField(desc="Brief explanation of relevance judgment")
|
|
|
|
|
|
class AnswerGroundednessJudge(dspy.Signature):
|
|
"""Judge if an answer is grounded in retrieved context."""
|
|
|
|
context: str = dspy.InputField(desc="The retrieved context")
|
|
answer: str = dspy.InputField(desc="The generated answer")
|
|
|
|
grounded: bool = dspy.OutputField(desc="Is the answer grounded in the context?")
|
|
reasoning: str = dspy.OutputField(desc="Brief explanation")
|
|
|
|
|
|
def answer_relevance_llm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Use LLM to judge answer relevance.
|
|
|
|
More expensive but handles nuanced evaluation.
|
|
"""
|
|
answer = getattr(pred, 'answer', '')
|
|
if not answer:
|
|
return 0.0
|
|
|
|
judge = dspy.Predict(AnswerRelevanceJudge)
|
|
result = judge(question=example.question, answer=answer)
|
|
|
|
return 1.0 if result.relevant else 0.0
|
|
|
|
|
|
def answer_groundedness_llm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Use LLM to judge if answer is grounded in context.
|
|
"""
|
|
answer = getattr(pred, 'answer', '')
|
|
context = getattr(pred, 'context', '')
|
|
|
|
if not answer or not context:
|
|
return 0.0
|
|
|
|
judge = dspy.Predict(AnswerGroundednessJudge)
|
|
result = judge(context=context, answer=answer)
|
|
|
|
return 1.0 if result.grounded else 0.0
|
|
```
|
|
|
|
### 5. Composite Metric
|
|
|
|
```python
|
|
# tests/dspy_gitops/metrics/composite.py
|
|
|
|
def heritage_rag_metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Comprehensive metric for Heritage RAG evaluation.
|
|
|
|
Combines multiple aspects:
|
|
- Intent accuracy (25%)
|
|
- Entity extraction (25%)
|
|
- SPARQL validity (20%)
|
|
- Answer presence (15%)
|
|
- Sources used (15%)
|
|
"""
|
|
from .intent_accuracy import intent_accuracy
|
|
from .entity_extraction import entity_f1
|
|
from .sparql_correctness import sparql_syntax_valid
|
|
|
|
scores = {
|
|
"intent": intent_accuracy(example, pred, trace),
|
|
"entities": entity_f1(example, pred, trace),
|
|
"sparql": sparql_syntax_valid(example, pred, trace),
|
|
"answer": 1.0 if getattr(pred, 'answer', '') else 0.0,
|
|
"sources": 1.0 if getattr(pred, 'sources_used', []) else 0.0,
|
|
}
|
|
|
|
weights = {
|
|
"intent": 0.25,
|
|
"entities": 0.25,
|
|
"sparql": 0.20,
|
|
"answer": 0.15,
|
|
"sources": 0.15,
|
|
}
|
|
|
|
weighted_score = sum(scores[k] * weights[k] for k in scores)
|
|
|
|
# For trace (optimization), require all components to pass
|
|
if trace is not None:
|
|
return weighted_score >= 0.8
|
|
|
|
return weighted_score
|
|
|
|
|
|
def heritage_rag_strict(example: dspy.Example, pred: dspy.Prediction, trace=None) -> bool:
|
|
"""
|
|
Strict pass/fail metric for bootstrapping optimization.
|
|
|
|
All components must pass for bootstrapping to use this example.
|
|
"""
|
|
from .intent_accuracy import intent_accuracy
|
|
from .entity_extraction import entity_f1
|
|
from .sparql_correctness import sparql_syntax_valid
|
|
|
|
return all([
|
|
intent_accuracy(example, pred, trace) >= 0.9,
|
|
entity_f1(example, pred, trace) >= 0.8,
|
|
sparql_syntax_valid(example, pred, trace) >= 1.0,
|
|
bool(getattr(pred, 'answer', '')),
|
|
])
|
|
```
|
|
|
|
## Using Metrics with dspy.Evaluate
|
|
|
|
```python
|
|
from dspy import Evaluate
|
|
from metrics import heritage_rag_metric, intent_accuracy, entity_f1
|
|
|
|
# Create evaluator
|
|
evaluator = Evaluate(
|
|
devset=dev_set,
|
|
metric=heritage_rag_metric,
|
|
num_threads=4,
|
|
display_progress=True,
|
|
display_table=5, # Show top 5 results
|
|
)
|
|
|
|
# Run evaluation
|
|
result = evaluator(pipeline)
|
|
|
|
print(f"Overall score: {result.score}%")
|
|
|
|
# Detailed per-metric breakdown
|
|
for example, pred, score in result.results:
|
|
print(f"Question: {example.question[:50]}...")
|
|
print(f" Intent: {intent_accuracy(example, pred)}")
|
|
print(f" Entities: {entity_f1(example, pred):.2f}")
|
|
print(f" Score: {score:.2f}")
|
|
```
|
|
|
|
## Metric Selection Guidelines
|
|
|
|
| Use Case | Recommended Metric |
|
|
|----------|-------------------|
|
|
| Quick smoke test | `intent_accuracy` |
|
|
| CI gate | `heritage_rag_metric` |
|
|
| Optimization | `heritage_rag_strict` |
|
|
| Deep analysis | `answer_relevance_llm` |
|
|
| SPARQL debugging | `sparql_returns_results` |
|