""" Metrics for DSPy RAG Evaluation This module provides metric functions for evaluating the Heritage RAG system: - count_accuracy: Checks if COUNT queries return correct counts - slot_extraction_accuracy: Checks if slots are correctly extracted - heritage_rag_metric: Composite metric combining structural and semantic checks """ import re from typing import Any from .dataset_loader import GoldenExample def extract_count_from_answer(answer: str) -> int | None: """ Extract a count number from a natural language answer. Handles patterns like: - "Er zijn 10 archieven in Utrecht." - "In Utrecht zijn er 10 archieven." - "Het aantal archieven in Utrecht is 10." - "Utrecht heeft 10 archieven." Args: answer: The natural language answer string Returns: The extracted count, or None if no count found """ if not answer: return None # Pattern 1: "Er zijn X [type] in [location]" match = re.search(r"(?:Er\s+)?(?:zijn|is)\s+(?:er\s+)?(\d+)", answer, re.IGNORECASE) if match: return int(match.group(1)) # Pattern 2: "[location] heeft X [type]" match = re.search(r"heeft\s+(\d+)", answer, re.IGNORECASE) if match: return int(match.group(1)) # Pattern 3: "Het aantal ... is X" or "totaal van X" match = re.search(r"(?:aantal|totaal)\s+(?:van\s+)?(?:is\s+)?(\d+)", answer, re.IGNORECASE) if match: return int(match.group(1)) # Pattern 4: Just look for any number (fallback) numbers = re.findall(r"\b(\d+)\b", answer) if numbers: # Return the first number found (usually the count) return int(numbers[0]) return None def count_accuracy( example: GoldenExample, prediction: dict[str, Any], tolerance: int = 0, ) -> float: """ Check if the predicted count matches the expected count. Args: example: The golden example with expected count prediction: The RAG response containing 'answer' and optionally 'count' tolerance: Allowed deviation from expected count (default 0 = exact match) Returns: 1.0 if count matches (within tolerance), 0.0 otherwise """ if example.expected_count is None: # No expected count defined (e.g., edge cases) return 0.5 # Neutral score # Try to get count directly from response predicted_count = prediction.get("count") # If not available, try to extract from answer text if predicted_count is None: answer = prediction.get("answer", "") predicted_count = extract_count_from_answer(answer) if predicted_count is None: return 0.0 # Could not extract count # Check if within tolerance diff = abs(predicted_count - example.expected_count) if diff <= tolerance: return 1.0 # Partial credit for close answers (within 10%) percent_diff = diff / max(example.expected_count, 1) if percent_diff <= 0.1: return 0.8 elif percent_diff <= 0.2: return 0.5 return 0.0 def slot_extraction_accuracy( example: GoldenExample, prediction: dict[str, Any], ) -> float: """ Check if the RAG system correctly extracted the expected slots. Slots include: - institution_type: A, M, L, etc. - location: Province code or city name - location_level: subregion or settlement - response_mode: count, list, etc. Args: example: The golden example with expected slots prediction: The RAG response containing slot extraction info Returns: Score between 0.0 and 1.0 based on slot accuracy """ expected = example.expected_slots if not expected: return 1.0 # No slots expected # Get detected slots from prediction # These might be in different locations depending on API response structure detected = {} # Try various locations where slot info might be if "slots" in prediction: detected = prediction["slots"] elif "visualization" in prediction and prediction["visualization"]: vis = prediction["visualization"] # Extract institution_type and location from SPARQL query sparql_query = vis.get("sparql_query", "") # Parse institution type from SPARQL (e.g., 'institutionType "A"') type_match = re.search(r'institutionType\s+"([A-Z])"', sparql_query) if type_match: detected["institution_type"] = type_match.group(1) # Parse location from SPARQL (e.g., 'subregionCode "NL-UT"') location_match = re.search(r'subregionCode\s+"([^"]+)"', sparql_query) if location_match: detected["location"] = location_match.group(1) detected["location_level"] = "subregion" else: # Try settlement name settlement_match = re.search(r'settlementName\s+"([^"]+)"', sparql_query) if settlement_match: detected["location"] = settlement_match.group(1) detected["location_level"] = "settlement" # Check response mode response_modes = vis.get("response_modes", []) if "count" in response_modes: detected["response_mode"] = "count" elif "list" in response_modes: detected["response_mode"] = "list" # Calculate accuracy correct = 0 total = 0 for slot_name, expected_value in expected.items(): if expected_value is None: continue total += 1 detected_value = detected.get(slot_name) # Handle list values (e.g., multiple institution types) if isinstance(expected_value, list): if isinstance(detected_value, list): if set(expected_value) == set(detected_value): correct += 1 elif detected_value in expected_value: correct += 0.5 # Partial credit elif expected_value == "*": # Wildcard - any value is acceptable if detected_value: correct += 1 elif expected_value == detected_value: correct += 1 elif str(expected_value).lower() == str(detected_value).lower(): correct += 0.9 # Close match (case insensitive) if total == 0: return 1.0 return correct / total def heritage_rag_metric( example: GoldenExample, prediction: dict[str, Any], weights: dict[str, float] | None = None, ) -> float: """ Composite metric for evaluating Heritage RAG responses. Combines: - Count accuracy (for COUNT queries) - Slot extraction accuracy - Answer presence and quality checks Args: example: The golden example prediction: The RAG response weights: Optional weights for each component metric Returns: Weighted score between 0.0 and 1.0 """ if weights is None: weights = { "count": 0.5, "slots": 0.3, "answer_present": 0.2, } scores = {} # Count accuracy (for COUNT queries) if example.is_count_query: scores["count"] = count_accuracy(example, prediction) else: scores["count"] = 1.0 # Not a count query, skip this metric # Slot extraction accuracy scores["slots"] = slot_extraction_accuracy(example, prediction) # Answer presence check answer = prediction.get("answer", "") if answer and len(answer) > 10: scores["answer_present"] = 1.0 elif answer: scores["answer_present"] = 0.5 else: scores["answer_present"] = 0.0 # Weighted average total_weight = sum(weights.values()) weighted_score = sum( scores.get(key, 0.0) * weight for key, weight in weights.items() ) / total_weight return weighted_score def format_evaluation_report( results: list[dict[str, Any]], ) -> str: """ Format evaluation results as a human-readable report. Args: results: List of evaluation result dictionaries Returns: Formatted string report """ lines = [ "=" * 60, "Heritage RAG Evaluation Report", "=" * 60, "", ] # Overall statistics total = len(results) passed = sum(1 for r in results if r.get("score", 0) >= 0.8) avg_score = sum(r.get("score", 0) for r in results) / max(total, 1) lines.extend([ f"Total examples: {total}", f"Passed (score >= 0.8): {passed} ({100*passed/max(total,1):.1f}%)", f"Average score: {avg_score:.3f}", "", "-" * 60, ]) # Category breakdown by_category: dict[str, list[float]] = {} for r in results: cat = r.get("category", "unknown") if cat not in by_category: by_category[cat] = [] by_category[cat].append(r.get("score", 0)) lines.append("Scores by category:") for cat, scores in sorted(by_category.items()): avg = sum(scores) / len(scores) lines.append(f" {cat}: {avg:.3f} (n={len(scores)})") lines.extend(["", "-" * 60, ""]) # Failed examples failed = [r for r in results if r.get("score", 0) < 0.8] if failed: lines.append(f"Failed examples ({len(failed)}):") for r in failed[:10]: # Show first 10 lines.append(f" [{r.get('id')}] {r.get('question', '')[:40]}...") lines.append(f" Expected: {r.get('expected_count')}, Got: {r.get('actual_count')}") lines.append(f" Score: {r.get('score', 0):.3f}") if len(failed) > 10: lines.append(f" ... and {len(failed) - 10} more") lines.append("") lines.append("=" * 60) return "\n".join(lines) if __name__ == "__main__": # Quick test class MockExample: expected_count = 10 expected_slots = {"institution_type": "A", "location": "NL-UT"} is_count_query = True example = MockExample() # type: ignore # Test count extraction test_answers = [ ("Er zijn 10 archieven in Utrecht.", 10), ("In Utrecht zijn er 10 archieven.", 10), ("Het aantal archieven is 10.", 10), ("Utrecht heeft 10 archieven.", 10), ("Ik heb 10 resultaten gevonden.", 10), ] print("Count extraction tests:") for answer, expected in test_answers: extracted = extract_count_from_answer(answer) status = "✓" if extracted == expected else "✗" print(f" {status} '{answer[:40]}...' -> {extracted} (expected {expected})")