glam/tests/dspy_gitops/metrics/composite.py
2026-01-11 18:08:40 +01:00

156 lines
4.8 KiB
Python

"""
Composite Metrics
Combines individual metrics into weighted scores for DSPy evaluation.
"""
from typing import Any, Callable, Optional
from .intent_accuracy import intent_accuracy_metric, intent_similarity_score
from .entity_extraction import entity_f1_metric, fuzzy_entity_f1
from .sparql_correctness import sparql_validation_score
from .answer_relevance import answer_relevance_metric, language_match_score
def heritage_rag_metric(example: Any, pred: Any, trace: Any = None) -> float:
"""Composite metric for Heritage RAG pipeline evaluation.
Weights:
- Intent accuracy: 20%
- Entity extraction: 20%
- SPARQL validity: 20%
- Answer relevance: 40%
Args:
example: DSPy Example with expected values
pred: Prediction with generated values
trace: Optional trace for debugging
Returns:
Weighted composite score 0.0-1.0
"""
scores = {}
# Intent accuracy (20%)
expected_intent = getattr(example, "expected_intent", None)
predicted_intent = getattr(pred, "intent", None)
if expected_intent and predicted_intent:
scores["intent"] = intent_similarity_score(expected_intent, predicted_intent)
else:
scores["intent"] = 0.0
# Entity extraction F1 (20%)
expected_entities = getattr(example, "expected_entities", [])
predicted_entities = getattr(pred, "entities", [])
scores["entity_f1"] = fuzzy_entity_f1(expected_entities, predicted_entities)
# SPARQL validity (20%)
sparql = getattr(pred, "sparql", None)
if sparql:
scores["sparql"] = sparql_validation_score(sparql)
else:
scores["sparql"] = 0.0
# Answer relevance (40%)
scores["answer"] = answer_relevance_metric(example, pred, trace)
# Language match bonus (adjust answer score)
language = getattr(example, "language", "nl")
answer = getattr(pred, "answer", "")
if answer:
lang_score = language_match_score(language, answer)
# Penalize wrong language
if lang_score < 1.0:
scores["answer"] *= 0.8
# Weighted combination
weights = {
"intent": 0.20,
"entity_f1": 0.20,
"sparql": 0.20,
"answer": 0.40,
}
total = sum(scores.get(k, 0) * w for k, w in weights.items())
return total
def create_weighted_metric(
weights: dict[str, float] = None,
include_sparql: bool = True,
include_answer: bool = True,
) -> Callable[[Any, Any, Any], float]:
"""Create custom weighted metric function.
Args:
weights: Custom weights for each component
include_sparql: Whether to include SPARQL validation
include_answer: Whether to include answer evaluation
Returns:
Metric function compatible with dspy.Evaluate
"""
default_weights = {
"intent": 0.25,
"entity_f1": 0.25,
"sparql": 0.25,
"answer": 0.25,
}
if weights:
default_weights.update(weights)
# Normalize weights
total_weight = sum(default_weights.values())
normalized = {k: v / total_weight for k, v in default_weights.items()}
def metric(example: Any, pred: Any, trace: Any = None) -> float:
scores = {}
# Intent
expected_intent = getattr(example, "expected_intent", None)
predicted_intent = getattr(pred, "intent", None)
if expected_intent and predicted_intent:
scores["intent"] = intent_similarity_score(expected_intent, predicted_intent)
else:
scores["intent"] = 0.0
# Entities
expected_entities = getattr(example, "expected_entities", [])
predicted_entities = getattr(pred, "entities", [])
scores["entity_f1"] = fuzzy_entity_f1(expected_entities, predicted_entities)
# SPARQL
if include_sparql:
sparql = getattr(pred, "sparql", None)
scores["sparql"] = sparql_validation_score(sparql) if sparql else 0.0
# Answer
if include_answer:
scores["answer"] = answer_relevance_metric(example, pred, trace)
return sum(scores.get(k, 0) * normalized.get(k, 0) for k in normalized)
return metric
# Pre-defined metric configurations
INTENT_ONLY_METRIC = create_weighted_metric(
weights={"intent": 1.0, "entity_f1": 0.0, "sparql": 0.0, "answer": 0.0},
include_sparql=False,
include_answer=False,
)
CLASSIFICATION_METRIC = create_weighted_metric(
weights={"intent": 0.5, "entity_f1": 0.5, "sparql": 0.0, "answer": 0.0},
include_sparql=False,
include_answer=False,
)
SPARQL_GENERATION_METRIC = create_weighted_metric(
weights={"intent": 0.2, "entity_f1": 0.2, "sparql": 0.6, "answer": 0.0},
include_sparql=True,
include_answer=False,
)
FULL_PIPELINE_METRIC = heritage_rag_metric