""" Composite Metrics Combines individual metrics into weighted scores for DSPy evaluation. """ from typing import Any, Callable, Optional from .intent_accuracy import intent_accuracy_metric, intent_similarity_score from .entity_extraction import entity_f1_metric, fuzzy_entity_f1 from .sparql_correctness import sparql_validation_score from .answer_relevance import answer_relevance_metric, language_match_score def heritage_rag_metric(example: Any, pred: Any, trace: Any = None) -> float: """Composite metric for Heritage RAG pipeline evaluation. Weights: - Intent accuracy: 20% - Entity extraction: 20% - SPARQL validity: 20% - Answer relevance: 40% Args: example: DSPy Example with expected values pred: Prediction with generated values trace: Optional trace for debugging Returns: Weighted composite score 0.0-1.0 """ scores = {} # Intent accuracy (20%) expected_intent = getattr(example, "expected_intent", None) predicted_intent = getattr(pred, "intent", None) if expected_intent and predicted_intent: scores["intent"] = intent_similarity_score(expected_intent, predicted_intent) else: scores["intent"] = 0.0 # Entity extraction F1 (20%) expected_entities = getattr(example, "expected_entities", []) predicted_entities = getattr(pred, "entities", []) scores["entity_f1"] = fuzzy_entity_f1(expected_entities, predicted_entities) # SPARQL validity (20%) sparql = getattr(pred, "sparql", None) if sparql: scores["sparql"] = sparql_validation_score(sparql) else: scores["sparql"] = 0.0 # Answer relevance (40%) scores["answer"] = answer_relevance_metric(example, pred, trace) # Language match bonus (adjust answer score) language = getattr(example, "language", "nl") answer = getattr(pred, "answer", "") if answer: lang_score = language_match_score(language, answer) # Penalize wrong language if lang_score < 1.0: scores["answer"] *= 0.8 # Weighted combination weights = { "intent": 0.20, "entity_f1": 0.20, "sparql": 0.20, "answer": 0.40, } total = sum(scores.get(k, 0) * w for k, w in weights.items()) return total def create_weighted_metric( weights: dict[str, float] = None, include_sparql: bool = True, include_answer: bool = True, ) -> Callable[[Any, Any, Any], float]: """Create custom weighted metric function. Args: weights: Custom weights for each component include_sparql: Whether to include SPARQL validation include_answer: Whether to include answer evaluation Returns: Metric function compatible with dspy.Evaluate """ default_weights = { "intent": 0.25, "entity_f1": 0.25, "sparql": 0.25, "answer": 0.25, } if weights: default_weights.update(weights) # Normalize weights total_weight = sum(default_weights.values()) normalized = {k: v / total_weight for k, v in default_weights.items()} def metric(example: Any, pred: Any, trace: Any = None) -> float: scores = {} # Intent expected_intent = getattr(example, "expected_intent", None) predicted_intent = getattr(pred, "intent", None) if expected_intent and predicted_intent: scores["intent"] = intent_similarity_score(expected_intent, predicted_intent) else: scores["intent"] = 0.0 # Entities expected_entities = getattr(example, "expected_entities", []) predicted_entities = getattr(pred, "entities", []) scores["entity_f1"] = fuzzy_entity_f1(expected_entities, predicted_entities) # SPARQL if include_sparql: sparql = getattr(pred, "sparql", None) scores["sparql"] = sparql_validation_score(sparql) if sparql else 0.0 # Answer if include_answer: scores["answer"] = answer_relevance_metric(example, pred, trace) return sum(scores.get(k, 0) * normalized.get(k, 0) for k in normalized) return metric # Pre-defined metric configurations INTENT_ONLY_METRIC = create_weighted_metric( weights={"intent": 1.0, "entity_f1": 0.0, "sparql": 0.0, "answer": 0.0}, include_sparql=False, include_answer=False, ) CLASSIFICATION_METRIC = create_weighted_metric( weights={"intent": 0.5, "entity_f1": 0.5, "sparql": 0.0, "answer": 0.0}, include_sparql=False, include_answer=False, ) SPARQL_GENERATION_METRIC = create_weighted_metric( weights={"intent": 0.2, "entity_f1": 0.2, "sparql": 0.6, "answer": 0.0}, include_sparql=True, include_answer=False, ) FULL_PIPELINE_METRIC = heritage_rag_metric