glam/docs/plan/dspy_gitops/02-metrics.md
2026-01-11 18:08:40 +01:00

10 KiB

DSPy Evaluation Metrics

Overview

Metrics are Python functions that score DSPy module outputs. They're used for both evaluation (tracking progress) and optimization (training optimizers).

Metric Function Signature

def metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float | bool:
    """
    Score a prediction against expected output.
    
    Args:
        example: The input example with expected outputs
        pred: The module's prediction
        trace: Optional trace for intermediate steps (used in optimization)
        
    Returns:
        Score between 0.0-1.0, or bool for pass/fail
    """
    pass

Core Metrics

1. Intent Accuracy

# tests/dspy_gitops/metrics/intent_accuracy.py

def intent_accuracy(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    """
    Check if predicted intent matches expected intent.
    
    Returns: 1.0 if match, 0.0 otherwise
    """
    expected = example.expected_intent.lower().strip()
    predicted = getattr(pred, 'intent', '').lower().strip()
    
    return 1.0 if expected == predicted else 0.0


def intent_accuracy_with_fallback(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    """
    Intent accuracy with partial credit for related intents.
    
    Related intents:
    - statistical <-> exploration (0.5)
    - entity_lookup <-> relational (0.5)
    """
    expected = example.expected_intent.lower()
    predicted = getattr(pred, 'intent', '').lower()
    
    if expected == predicted:
        return 1.0
    
    # Partial credit for related intents
    related_pairs = [
        ("statistical", "exploration"),
        ("entity_lookup", "relational"),
        ("temporal", "entity_lookup"),
    ]
    
    for a, b in related_pairs:
        if {expected, predicted} == {a, b}:
            return 0.5
    
    return 0.0

2. Entity Extraction F1

# tests/dspy_gitops/metrics/entity_extraction.py

def entity_f1(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    """
    Calculate F1 score for entity extraction.
    
    Compares extracted entities against expected entities using
    fuzzy matching for robustness.
    """
    expected = set(e.lower() for e in example.expected_entities)
    predicted_raw = getattr(pred, 'entities', [])
    
    if isinstance(predicted_raw, str):
        # Parse comma-separated string
        predicted_raw = [e.strip() for e in predicted_raw.split(',')]
    
    predicted = set(e.lower() for e in predicted_raw if e)
    
    if not expected and not predicted:
        return 1.0
    if not expected or not predicted:
        return 0.0
    
    # Fuzzy matching for each expected entity
    from rapidfuzz import fuzz
    
    true_positives = 0
    for exp_entity in expected:
        for pred_entity in predicted:
            if fuzz.ratio(exp_entity, pred_entity) >= 80:
                true_positives += 1
                break
    
    precision = true_positives / len(predicted) if predicted else 0
    recall = true_positives / len(expected) if expected else 0
    
    if precision + recall == 0:
        return 0.0
    
    f1 = 2 * (precision * recall) / (precision + recall)
    return f1


def entity_type_accuracy(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    """Check if entity type (institution vs person) is correct."""
    expected = getattr(example, 'expected_entity_type', 'institution')
    predicted = getattr(pred, 'entity_type', 'institution')
    
    return 1.0 if expected == predicted else 0.0

3. SPARQL Correctness

# tests/dspy_gitops/metrics/sparql_correctness.py

import re

def sparql_syntax_valid(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    """
    Validate SPARQL syntax.
    
    Checks:
    - Required clauses (SELECT, WHERE)
    - Balanced braces
    - Valid prefixes
    """
    sparql = getattr(pred, 'sparql_query', '')
    
    if not sparql:
        return 0.0
    
    errors = []
    
    # Check required clauses
    if 'SELECT' not in sparql.upper():
        errors.append("Missing SELECT")
    if 'WHERE' not in sparql.upper():
        errors.append("Missing WHERE")
    
    # Check balanced braces
    if sparql.count('{') != sparql.count('}'):
        errors.append("Unbalanced braces")
    
    # Check for common errors
    if '???' in sparql:
        errors.append("Contains placeholder ???")
    
    return 1.0 if not errors else 0.0


def sparql_executes(
    example: dspy.Example, 
    pred: dspy.Prediction, 
    oxigraph_url: str = "http://localhost:7878",
    trace=None
) -> float:
    """
    Check if SPARQL query executes without error.
    
    Requires live Oxigraph connection.
    """
    import httpx
    
    sparql = getattr(pred, 'sparql_query', '')
    if not sparql:
        return 0.0
    
    try:
        response = httpx.post(
            f"{oxigraph_url}/query",
            data={"query": sparql},
            headers={"Accept": "application/sparql-results+json"},
            timeout=10.0,
        )
        return 1.0 if response.status_code == 200 else 0.0
    except Exception:
        return 0.0


def sparql_returns_results(
    example: dspy.Example, 
    pred: dspy.Prediction, 
    oxigraph_url: str = "http://localhost:7878",
    trace=None
) -> float:
    """
    Check if SPARQL query returns non-empty results.
    """
    import httpx
    
    sparql = getattr(pred, 'sparql_query', '')
    if not sparql:
        return 0.0
    
    try:
        response = httpx.post(
            f"{oxigraph_url}/query",
            data={"query": sparql},
            headers={"Accept": "application/sparql-results+json"},
            timeout=10.0,
        )
        if response.status_code != 200:
            return 0.0
        
        data = response.json()
        bindings = data.get("results", {}).get("bindings", [])
        return 1.0 if bindings else 0.0
        
    except Exception:
        return 0.0

4. Answer Relevance (LLM-as-Judge)

# tests/dspy_gitops/metrics/answer_relevance.py

import dspy

class AnswerRelevanceJudge(dspy.Signature):
    """Judge if an answer is relevant to the question."""
    
    question: str = dspy.InputField(desc="The original question")
    answer: str = dspy.InputField(desc="The generated answer")
    
    relevant: bool = dspy.OutputField(desc="Is the answer relevant to the question?")
    reasoning: str = dspy.OutputField(desc="Brief explanation of relevance judgment")


class AnswerGroundednessJudge(dspy.Signature):
    """Judge if an answer is grounded in retrieved context."""
    
    context: str = dspy.InputField(desc="The retrieved context")
    answer: str = dspy.InputField(desc="The generated answer")
    
    grounded: bool = dspy.OutputField(desc="Is the answer grounded in the context?")
    reasoning: str = dspy.OutputField(desc="Brief explanation")


def answer_relevance_llm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    """
    Use LLM to judge answer relevance.
    
    More expensive but handles nuanced evaluation.
    """
    answer = getattr(pred, 'answer', '')
    if not answer:
        return 0.0
    
    judge = dspy.Predict(AnswerRelevanceJudge)
    result = judge(question=example.question, answer=answer)
    
    return 1.0 if result.relevant else 0.0


def answer_groundedness_llm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    """
    Use LLM to judge if answer is grounded in context.
    """
    answer = getattr(pred, 'answer', '')
    context = getattr(pred, 'context', '')
    
    if not answer or not context:
        return 0.0
    
    judge = dspy.Predict(AnswerGroundednessJudge)
    result = judge(context=context, answer=answer)
    
    return 1.0 if result.grounded else 0.0

5. Composite Metric

# tests/dspy_gitops/metrics/composite.py

def heritage_rag_metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    """
    Comprehensive metric for Heritage RAG evaluation.
    
    Combines multiple aspects:
    - Intent accuracy (25%)
    - Entity extraction (25%)
    - SPARQL validity (20%)
    - Answer presence (15%)
    - Sources used (15%)
    """
    from .intent_accuracy import intent_accuracy
    from .entity_extraction import entity_f1
    from .sparql_correctness import sparql_syntax_valid
    
    scores = {
        "intent": intent_accuracy(example, pred, trace),
        "entities": entity_f1(example, pred, trace),
        "sparql": sparql_syntax_valid(example, pred, trace),
        "answer": 1.0 if getattr(pred, 'answer', '') else 0.0,
        "sources": 1.0 if getattr(pred, 'sources_used', []) else 0.0,
    }
    
    weights = {
        "intent": 0.25,
        "entities": 0.25,
        "sparql": 0.20,
        "answer": 0.15,
        "sources": 0.15,
    }
    
    weighted_score = sum(scores[k] * weights[k] for k in scores)
    
    # For trace (optimization), require all components to pass
    if trace is not None:
        return weighted_score >= 0.8
    
    return weighted_score


def heritage_rag_strict(example: dspy.Example, pred: dspy.Prediction, trace=None) -> bool:
    """
    Strict pass/fail metric for bootstrapping optimization.
    
    All components must pass for bootstrapping to use this example.
    """
    from .intent_accuracy import intent_accuracy
    from .entity_extraction import entity_f1
    from .sparql_correctness import sparql_syntax_valid
    
    return all([
        intent_accuracy(example, pred, trace) >= 0.9,
        entity_f1(example, pred, trace) >= 0.8,
        sparql_syntax_valid(example, pred, trace) >= 1.0,
        bool(getattr(pred, 'answer', '')),
    ])

Using Metrics with dspy.Evaluate

from dspy import Evaluate
from metrics import heritage_rag_metric, intent_accuracy, entity_f1

# Create evaluator
evaluator = Evaluate(
    devset=dev_set,
    metric=heritage_rag_metric,
    num_threads=4,
    display_progress=True,
    display_table=5,  # Show top 5 results
)

# Run evaluation
result = evaluator(pipeline)

print(f"Overall score: {result.score}%")

# Detailed per-metric breakdown
for example, pred, score in result.results:
    print(f"Question: {example.question[:50]}...")
    print(f"  Intent: {intent_accuracy(example, pred)}")
    print(f"  Entities: {entity_f1(example, pred):.2f}")
    print(f"  Score: {score:.2f}")

Metric Selection Guidelines

Use Case Recommended Metric
Quick smoke test intent_accuracy
CI gate heritage_rag_metric
Optimization heritage_rag_strict
Deep analysis answer_relevance_llm
SPARQL debugging sparql_returns_results