glam/backend/rag/evaluation/metrics.py
2026-01-09 20:35:19 +01:00

343 lines
10 KiB
Python

"""
Metrics for DSPy RAG Evaluation
This module provides metric functions for evaluating the Heritage RAG system:
- count_accuracy: Checks if COUNT queries return correct counts
- slot_extraction_accuracy: Checks if slots are correctly extracted
- heritage_rag_metric: Composite metric combining structural and semantic checks
"""
import re
from typing import Any
from .dataset_loader import GoldenExample
def extract_count_from_answer(answer: str) -> int | None:
"""
Extract a count number from a natural language answer.
Handles patterns like:
- "Er zijn 10 archieven in Utrecht."
- "In Utrecht zijn er 10 archieven."
- "Het aantal archieven in Utrecht is 10."
- "Utrecht heeft 10 archieven."
Args:
answer: The natural language answer string
Returns:
The extracted count, or None if no count found
"""
if not answer:
return None
# Pattern 1: "Er zijn X [type] in [location]"
match = re.search(r"(?:Er\s+)?(?:zijn|is)\s+(?:er\s+)?(\d+)", answer, re.IGNORECASE)
if match:
return int(match.group(1))
# Pattern 2: "[location] heeft X [type]"
match = re.search(r"heeft\s+(\d+)", answer, re.IGNORECASE)
if match:
return int(match.group(1))
# Pattern 3: "Het aantal ... is X" or "totaal van X"
match = re.search(r"(?:aantal|totaal)\s+(?:van\s+)?(?:is\s+)?(\d+)", answer, re.IGNORECASE)
if match:
return int(match.group(1))
# Pattern 4: Just look for any number (fallback)
numbers = re.findall(r"\b(\d+)\b", answer)
if numbers:
# Return the first number found (usually the count)
return int(numbers[0])
return None
def count_accuracy(
example: GoldenExample,
prediction: dict[str, Any],
tolerance: int = 0,
) -> float:
"""
Check if the predicted count matches the expected count.
Args:
example: The golden example with expected count
prediction: The RAG response containing 'answer' and optionally 'count'
tolerance: Allowed deviation from expected count (default 0 = exact match)
Returns:
1.0 if count matches (within tolerance), 0.0 otherwise
"""
if example.expected_count is None:
# No expected count defined (e.g., edge cases)
return 0.5 # Neutral score
# Try to get count directly from response
predicted_count = prediction.get("count")
# If not available, try to extract from answer text
if predicted_count is None:
answer = prediction.get("answer", "")
predicted_count = extract_count_from_answer(answer)
if predicted_count is None:
return 0.0 # Could not extract count
# Check if within tolerance
diff = abs(predicted_count - example.expected_count)
if diff <= tolerance:
return 1.0
# Partial credit for close answers (within 10%)
percent_diff = diff / max(example.expected_count, 1)
if percent_diff <= 0.1:
return 0.8
elif percent_diff <= 0.2:
return 0.5
return 0.0
def slot_extraction_accuracy(
example: GoldenExample,
prediction: dict[str, Any],
) -> float:
"""
Check if the RAG system correctly extracted the expected slots.
Slots include:
- institution_type: A, M, L, etc.
- location: Province code or city name
- location_level: subregion or settlement
- response_mode: count, list, etc.
Args:
example: The golden example with expected slots
prediction: The RAG response containing slot extraction info
Returns:
Score between 0.0 and 1.0 based on slot accuracy
"""
expected = example.expected_slots
if not expected:
return 1.0 # No slots expected
# Get detected slots from prediction
# These might be in different locations depending on API response structure
detected = {}
# Try various locations where slot info might be
if "slots" in prediction:
detected = prediction["slots"]
elif "visualization" in prediction and prediction["visualization"]:
vis = prediction["visualization"]
# Extract institution_type and location from SPARQL query
sparql_query = vis.get("sparql_query", "")
# Parse institution type from SPARQL (e.g., 'institutionType "A"')
type_match = re.search(r'institutionType\s+"([A-Z])"', sparql_query)
if type_match:
detected["institution_type"] = type_match.group(1)
# Parse location from SPARQL (e.g., 'subregionCode "NL-UT"')
location_match = re.search(r'subregionCode\s+"([^"]+)"', sparql_query)
if location_match:
detected["location"] = location_match.group(1)
detected["location_level"] = "subregion"
else:
# Try settlement name
settlement_match = re.search(r'settlementName\s+"([^"]+)"', sparql_query)
if settlement_match:
detected["location"] = settlement_match.group(1)
detected["location_level"] = "settlement"
# Check response mode
response_modes = vis.get("response_modes", [])
if "count" in response_modes:
detected["response_mode"] = "count"
elif "list" in response_modes:
detected["response_mode"] = "list"
# Calculate accuracy
correct = 0
total = 0
for slot_name, expected_value in expected.items():
if expected_value is None:
continue
total += 1
detected_value = detected.get(slot_name)
# Handle list values (e.g., multiple institution types)
if isinstance(expected_value, list):
if isinstance(detected_value, list):
if set(expected_value) == set(detected_value):
correct += 1
elif detected_value in expected_value:
correct += 0.5 # Partial credit
elif expected_value == "*":
# Wildcard - any value is acceptable
if detected_value:
correct += 1
elif expected_value == detected_value:
correct += 1
elif str(expected_value).lower() == str(detected_value).lower():
correct += 0.9 # Close match (case insensitive)
if total == 0:
return 1.0
return correct / total
def heritage_rag_metric(
example: GoldenExample,
prediction: dict[str, Any],
weights: dict[str, float] | None = None,
) -> float:
"""
Composite metric for evaluating Heritage RAG responses.
Combines:
- Count accuracy (for COUNT queries)
- Slot extraction accuracy
- Answer presence and quality checks
Args:
example: The golden example
prediction: The RAG response
weights: Optional weights for each component metric
Returns:
Weighted score between 0.0 and 1.0
"""
if weights is None:
weights = {
"count": 0.5,
"slots": 0.3,
"answer_present": 0.2,
}
scores = {}
# Count accuracy (for COUNT queries)
if example.is_count_query:
scores["count"] = count_accuracy(example, prediction)
else:
scores["count"] = 1.0 # Not a count query, skip this metric
# Slot extraction accuracy
scores["slots"] = slot_extraction_accuracy(example, prediction)
# Answer presence check
answer = prediction.get("answer", "")
if answer and len(answer) > 10:
scores["answer_present"] = 1.0
elif answer:
scores["answer_present"] = 0.5
else:
scores["answer_present"] = 0.0
# Weighted average
total_weight = sum(weights.values())
weighted_score = sum(
scores.get(key, 0.0) * weight
for key, weight in weights.items()
) / total_weight
return weighted_score
def format_evaluation_report(
results: list[dict[str, Any]],
) -> str:
"""
Format evaluation results as a human-readable report.
Args:
results: List of evaluation result dictionaries
Returns:
Formatted string report
"""
lines = [
"=" * 60,
"Heritage RAG Evaluation Report",
"=" * 60,
"",
]
# Overall statistics
total = len(results)
passed = sum(1 for r in results if r.get("score", 0) >= 0.8)
avg_score = sum(r.get("score", 0) for r in results) / max(total, 1)
lines.extend([
f"Total examples: {total}",
f"Passed (score >= 0.8): {passed} ({100*passed/max(total,1):.1f}%)",
f"Average score: {avg_score:.3f}",
"",
"-" * 60,
])
# Category breakdown
by_category: dict[str, list[float]] = {}
for r in results:
cat = r.get("category", "unknown")
if cat not in by_category:
by_category[cat] = []
by_category[cat].append(r.get("score", 0))
lines.append("Scores by category:")
for cat, scores in sorted(by_category.items()):
avg = sum(scores) / len(scores)
lines.append(f" {cat}: {avg:.3f} (n={len(scores)})")
lines.extend(["", "-" * 60, ""])
# Failed examples
failed = [r for r in results if r.get("score", 0) < 0.8]
if failed:
lines.append(f"Failed examples ({len(failed)}):")
for r in failed[:10]: # Show first 10
lines.append(f" [{r.get('id')}] {r.get('question', '')[:40]}...")
lines.append(f" Expected: {r.get('expected_count')}, Got: {r.get('actual_count')}")
lines.append(f" Score: {r.get('score', 0):.3f}")
if len(failed) > 10:
lines.append(f" ... and {len(failed) - 10} more")
lines.append("")
lines.append("=" * 60)
return "\n".join(lines)
if __name__ == "__main__":
# Quick test
class MockExample:
expected_count = 10
expected_slots = {"institution_type": "A", "location": "NL-UT"}
is_count_query = True
example = MockExample() # type: ignore
# Test count extraction
test_answers = [
("Er zijn 10 archieven in Utrecht.", 10),
("In Utrecht zijn er 10 archieven.", 10),
("Het aantal archieven is 10.", 10),
("Utrecht heeft 10 archieven.", 10),
("Ik heb 10 resultaten gevonden.", 10),
]
print("Count extraction tests:")
for answer, expected in test_answers:
extracted = extract_count_from_answer(answer)
status = "" if extracted == expected else ""
print(f" {status} '{answer[:40]}...' -> {extracted} (expected {expected})")