156 lines
4.8 KiB
Python
156 lines
4.8 KiB
Python
"""
|
|
Composite Metrics
|
|
|
|
Combines individual metrics into weighted scores for DSPy evaluation.
|
|
"""
|
|
|
|
from typing import Any, Callable, Optional
|
|
from .intent_accuracy import intent_accuracy_metric, intent_similarity_score
|
|
from .entity_extraction import entity_f1_metric, fuzzy_entity_f1
|
|
from .sparql_correctness import sparql_validation_score
|
|
from .answer_relevance import answer_relevance_metric, language_match_score
|
|
|
|
|
|
def heritage_rag_metric(example: Any, pred: Any, trace: Any = None) -> float:
|
|
"""Composite metric for Heritage RAG pipeline evaluation.
|
|
|
|
Weights:
|
|
- Intent accuracy: 20%
|
|
- Entity extraction: 20%
|
|
- SPARQL validity: 20%
|
|
- Answer relevance: 40%
|
|
|
|
Args:
|
|
example: DSPy Example with expected values
|
|
pred: Prediction with generated values
|
|
trace: Optional trace for debugging
|
|
|
|
Returns:
|
|
Weighted composite score 0.0-1.0
|
|
"""
|
|
scores = {}
|
|
|
|
# Intent accuracy (20%)
|
|
expected_intent = getattr(example, "expected_intent", None)
|
|
predicted_intent = getattr(pred, "intent", None)
|
|
if expected_intent and predicted_intent:
|
|
scores["intent"] = intent_similarity_score(expected_intent, predicted_intent)
|
|
else:
|
|
scores["intent"] = 0.0
|
|
|
|
# Entity extraction F1 (20%)
|
|
expected_entities = getattr(example, "expected_entities", [])
|
|
predicted_entities = getattr(pred, "entities", [])
|
|
scores["entity_f1"] = fuzzy_entity_f1(expected_entities, predicted_entities)
|
|
|
|
# SPARQL validity (20%)
|
|
sparql = getattr(pred, "sparql", None)
|
|
if sparql:
|
|
scores["sparql"] = sparql_validation_score(sparql)
|
|
else:
|
|
scores["sparql"] = 0.0
|
|
|
|
# Answer relevance (40%)
|
|
scores["answer"] = answer_relevance_metric(example, pred, trace)
|
|
|
|
# Language match bonus (adjust answer score)
|
|
language = getattr(example, "language", "nl")
|
|
answer = getattr(pred, "answer", "")
|
|
if answer:
|
|
lang_score = language_match_score(language, answer)
|
|
# Penalize wrong language
|
|
if lang_score < 1.0:
|
|
scores["answer"] *= 0.8
|
|
|
|
# Weighted combination
|
|
weights = {
|
|
"intent": 0.20,
|
|
"entity_f1": 0.20,
|
|
"sparql": 0.20,
|
|
"answer": 0.40,
|
|
}
|
|
|
|
total = sum(scores.get(k, 0) * w for k, w in weights.items())
|
|
|
|
return total
|
|
|
|
|
|
def create_weighted_metric(
|
|
weights: dict[str, float] = None,
|
|
include_sparql: bool = True,
|
|
include_answer: bool = True,
|
|
) -> Callable[[Any, Any, Any], float]:
|
|
"""Create custom weighted metric function.
|
|
|
|
Args:
|
|
weights: Custom weights for each component
|
|
include_sparql: Whether to include SPARQL validation
|
|
include_answer: Whether to include answer evaluation
|
|
|
|
Returns:
|
|
Metric function compatible with dspy.Evaluate
|
|
"""
|
|
default_weights = {
|
|
"intent": 0.25,
|
|
"entity_f1": 0.25,
|
|
"sparql": 0.25,
|
|
"answer": 0.25,
|
|
}
|
|
|
|
if weights:
|
|
default_weights.update(weights)
|
|
|
|
# Normalize weights
|
|
total_weight = sum(default_weights.values())
|
|
normalized = {k: v / total_weight for k, v in default_weights.items()}
|
|
|
|
def metric(example: Any, pred: Any, trace: Any = None) -> float:
|
|
scores = {}
|
|
|
|
# Intent
|
|
expected_intent = getattr(example, "expected_intent", None)
|
|
predicted_intent = getattr(pred, "intent", None)
|
|
if expected_intent and predicted_intent:
|
|
scores["intent"] = intent_similarity_score(expected_intent, predicted_intent)
|
|
else:
|
|
scores["intent"] = 0.0
|
|
|
|
# Entities
|
|
expected_entities = getattr(example, "expected_entities", [])
|
|
predicted_entities = getattr(pred, "entities", [])
|
|
scores["entity_f1"] = fuzzy_entity_f1(expected_entities, predicted_entities)
|
|
|
|
# SPARQL
|
|
if include_sparql:
|
|
sparql = getattr(pred, "sparql", None)
|
|
scores["sparql"] = sparql_validation_score(sparql) if sparql else 0.0
|
|
|
|
# Answer
|
|
if include_answer:
|
|
scores["answer"] = answer_relevance_metric(example, pred, trace)
|
|
|
|
return sum(scores.get(k, 0) * normalized.get(k, 0) for k in normalized)
|
|
|
|
return metric
|
|
|
|
|
|
# Pre-defined metric configurations
|
|
INTENT_ONLY_METRIC = create_weighted_metric(
|
|
weights={"intent": 1.0, "entity_f1": 0.0, "sparql": 0.0, "answer": 0.0},
|
|
include_sparql=False,
|
|
include_answer=False,
|
|
)
|
|
|
|
CLASSIFICATION_METRIC = create_weighted_metric(
|
|
weights={"intent": 0.5, "entity_f1": 0.5, "sparql": 0.0, "answer": 0.0},
|
|
include_sparql=False,
|
|
include_answer=False,
|
|
)
|
|
|
|
SPARQL_GENERATION_METRIC = create_weighted_metric(
|
|
weights={"intent": 0.2, "entity_f1": 0.2, "sparql": 0.6, "answer": 0.0},
|
|
include_sparql=True,
|
|
include_answer=False,
|
|
)
|
|
|
|
FULL_PIPELINE_METRIC = heritage_rag_metric
|