343 lines
10 KiB
Python
343 lines
10 KiB
Python
"""
|
|
Metrics for DSPy RAG Evaluation
|
|
|
|
This module provides metric functions for evaluating the Heritage RAG system:
|
|
- count_accuracy: Checks if COUNT queries return correct counts
|
|
- slot_extraction_accuracy: Checks if slots are correctly extracted
|
|
- heritage_rag_metric: Composite metric combining structural and semantic checks
|
|
"""
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
from .dataset_loader import GoldenExample
|
|
|
|
|
|
def extract_count_from_answer(answer: str) -> int | None:
|
|
"""
|
|
Extract a count number from a natural language answer.
|
|
|
|
Handles patterns like:
|
|
- "Er zijn 10 archieven in Utrecht."
|
|
- "In Utrecht zijn er 10 archieven."
|
|
- "Het aantal archieven in Utrecht is 10."
|
|
- "Utrecht heeft 10 archieven."
|
|
|
|
Args:
|
|
answer: The natural language answer string
|
|
|
|
Returns:
|
|
The extracted count, or None if no count found
|
|
"""
|
|
if not answer:
|
|
return None
|
|
|
|
# Pattern 1: "Er zijn X [type] in [location]"
|
|
match = re.search(r"(?:Er\s+)?(?:zijn|is)\s+(?:er\s+)?(\d+)", answer, re.IGNORECASE)
|
|
if match:
|
|
return int(match.group(1))
|
|
|
|
# Pattern 2: "[location] heeft X [type]"
|
|
match = re.search(r"heeft\s+(\d+)", answer, re.IGNORECASE)
|
|
if match:
|
|
return int(match.group(1))
|
|
|
|
# Pattern 3: "Het aantal ... is X" or "totaal van X"
|
|
match = re.search(r"(?:aantal|totaal)\s+(?:van\s+)?(?:is\s+)?(\d+)", answer, re.IGNORECASE)
|
|
if match:
|
|
return int(match.group(1))
|
|
|
|
# Pattern 4: Just look for any number (fallback)
|
|
numbers = re.findall(r"\b(\d+)\b", answer)
|
|
if numbers:
|
|
# Return the first number found (usually the count)
|
|
return int(numbers[0])
|
|
|
|
return None
|
|
|
|
|
|
def count_accuracy(
|
|
example: GoldenExample,
|
|
prediction: dict[str, Any],
|
|
tolerance: int = 0,
|
|
) -> float:
|
|
"""
|
|
Check if the predicted count matches the expected count.
|
|
|
|
Args:
|
|
example: The golden example with expected count
|
|
prediction: The RAG response containing 'answer' and optionally 'count'
|
|
tolerance: Allowed deviation from expected count (default 0 = exact match)
|
|
|
|
Returns:
|
|
1.0 if count matches (within tolerance), 0.0 otherwise
|
|
"""
|
|
if example.expected_count is None:
|
|
# No expected count defined (e.g., edge cases)
|
|
return 0.5 # Neutral score
|
|
|
|
# Try to get count directly from response
|
|
predicted_count = prediction.get("count")
|
|
|
|
# If not available, try to extract from answer text
|
|
if predicted_count is None:
|
|
answer = prediction.get("answer", "")
|
|
predicted_count = extract_count_from_answer(answer)
|
|
|
|
if predicted_count is None:
|
|
return 0.0 # Could not extract count
|
|
|
|
# Check if within tolerance
|
|
diff = abs(predicted_count - example.expected_count)
|
|
if diff <= tolerance:
|
|
return 1.0
|
|
|
|
# Partial credit for close answers (within 10%)
|
|
percent_diff = diff / max(example.expected_count, 1)
|
|
if percent_diff <= 0.1:
|
|
return 0.8
|
|
elif percent_diff <= 0.2:
|
|
return 0.5
|
|
|
|
return 0.0
|
|
|
|
|
|
def slot_extraction_accuracy(
|
|
example: GoldenExample,
|
|
prediction: dict[str, Any],
|
|
) -> float:
|
|
"""
|
|
Check if the RAG system correctly extracted the expected slots.
|
|
|
|
Slots include:
|
|
- institution_type: A, M, L, etc.
|
|
- location: Province code or city name
|
|
- location_level: subregion or settlement
|
|
- response_mode: count, list, etc.
|
|
|
|
Args:
|
|
example: The golden example with expected slots
|
|
prediction: The RAG response containing slot extraction info
|
|
|
|
Returns:
|
|
Score between 0.0 and 1.0 based on slot accuracy
|
|
"""
|
|
expected = example.expected_slots
|
|
if not expected:
|
|
return 1.0 # No slots expected
|
|
|
|
# Get detected slots from prediction
|
|
# These might be in different locations depending on API response structure
|
|
detected = {}
|
|
|
|
# Try various locations where slot info might be
|
|
if "slots" in prediction:
|
|
detected = prediction["slots"]
|
|
elif "visualization" in prediction and prediction["visualization"]:
|
|
vis = prediction["visualization"]
|
|
|
|
# Extract institution_type and location from SPARQL query
|
|
sparql_query = vis.get("sparql_query", "")
|
|
|
|
# Parse institution type from SPARQL (e.g., 'institutionType "A"')
|
|
type_match = re.search(r'institutionType\s+"([A-Z])"', sparql_query)
|
|
if type_match:
|
|
detected["institution_type"] = type_match.group(1)
|
|
|
|
# Parse location from SPARQL (e.g., 'subregionCode "NL-UT"')
|
|
location_match = re.search(r'subregionCode\s+"([^"]+)"', sparql_query)
|
|
if location_match:
|
|
detected["location"] = location_match.group(1)
|
|
detected["location_level"] = "subregion"
|
|
else:
|
|
# Try settlement name
|
|
settlement_match = re.search(r'settlementName\s+"([^"]+)"', sparql_query)
|
|
if settlement_match:
|
|
detected["location"] = settlement_match.group(1)
|
|
detected["location_level"] = "settlement"
|
|
|
|
# Check response mode
|
|
response_modes = vis.get("response_modes", [])
|
|
if "count" in response_modes:
|
|
detected["response_mode"] = "count"
|
|
elif "list" in response_modes:
|
|
detected["response_mode"] = "list"
|
|
|
|
# Calculate accuracy
|
|
correct = 0
|
|
total = 0
|
|
|
|
for slot_name, expected_value in expected.items():
|
|
if expected_value is None:
|
|
continue
|
|
|
|
total += 1
|
|
detected_value = detected.get(slot_name)
|
|
|
|
# Handle list values (e.g., multiple institution types)
|
|
if isinstance(expected_value, list):
|
|
if isinstance(detected_value, list):
|
|
if set(expected_value) == set(detected_value):
|
|
correct += 1
|
|
elif detected_value in expected_value:
|
|
correct += 0.5 # Partial credit
|
|
elif expected_value == "*":
|
|
# Wildcard - any value is acceptable
|
|
if detected_value:
|
|
correct += 1
|
|
elif expected_value == detected_value:
|
|
correct += 1
|
|
elif str(expected_value).lower() == str(detected_value).lower():
|
|
correct += 0.9 # Close match (case insensitive)
|
|
|
|
if total == 0:
|
|
return 1.0
|
|
|
|
return correct / total
|
|
|
|
|
|
def heritage_rag_metric(
|
|
example: GoldenExample,
|
|
prediction: dict[str, Any],
|
|
weights: dict[str, float] | None = None,
|
|
) -> float:
|
|
"""
|
|
Composite metric for evaluating Heritage RAG responses.
|
|
|
|
Combines:
|
|
- Count accuracy (for COUNT queries)
|
|
- Slot extraction accuracy
|
|
- Answer presence and quality checks
|
|
|
|
Args:
|
|
example: The golden example
|
|
prediction: The RAG response
|
|
weights: Optional weights for each component metric
|
|
|
|
Returns:
|
|
Weighted score between 0.0 and 1.0
|
|
"""
|
|
if weights is None:
|
|
weights = {
|
|
"count": 0.5,
|
|
"slots": 0.3,
|
|
"answer_present": 0.2,
|
|
}
|
|
|
|
scores = {}
|
|
|
|
# Count accuracy (for COUNT queries)
|
|
if example.is_count_query:
|
|
scores["count"] = count_accuracy(example, prediction)
|
|
else:
|
|
scores["count"] = 1.0 # Not a count query, skip this metric
|
|
|
|
# Slot extraction accuracy
|
|
scores["slots"] = slot_extraction_accuracy(example, prediction)
|
|
|
|
# Answer presence check
|
|
answer = prediction.get("answer", "")
|
|
if answer and len(answer) > 10:
|
|
scores["answer_present"] = 1.0
|
|
elif answer:
|
|
scores["answer_present"] = 0.5
|
|
else:
|
|
scores["answer_present"] = 0.0
|
|
|
|
# Weighted average
|
|
total_weight = sum(weights.values())
|
|
weighted_score = sum(
|
|
scores.get(key, 0.0) * weight
|
|
for key, weight in weights.items()
|
|
) / total_weight
|
|
|
|
return weighted_score
|
|
|
|
|
|
def format_evaluation_report(
|
|
results: list[dict[str, Any]],
|
|
) -> str:
|
|
"""
|
|
Format evaluation results as a human-readable report.
|
|
|
|
Args:
|
|
results: List of evaluation result dictionaries
|
|
|
|
Returns:
|
|
Formatted string report
|
|
"""
|
|
lines = [
|
|
"=" * 60,
|
|
"Heritage RAG Evaluation Report",
|
|
"=" * 60,
|
|
"",
|
|
]
|
|
|
|
# Overall statistics
|
|
total = len(results)
|
|
passed = sum(1 for r in results if r.get("score", 0) >= 0.8)
|
|
avg_score = sum(r.get("score", 0) for r in results) / max(total, 1)
|
|
|
|
lines.extend([
|
|
f"Total examples: {total}",
|
|
f"Passed (score >= 0.8): {passed} ({100*passed/max(total,1):.1f}%)",
|
|
f"Average score: {avg_score:.3f}",
|
|
"",
|
|
"-" * 60,
|
|
])
|
|
|
|
# Category breakdown
|
|
by_category: dict[str, list[float]] = {}
|
|
for r in results:
|
|
cat = r.get("category", "unknown")
|
|
if cat not in by_category:
|
|
by_category[cat] = []
|
|
by_category[cat].append(r.get("score", 0))
|
|
|
|
lines.append("Scores by category:")
|
|
for cat, scores in sorted(by_category.items()):
|
|
avg = sum(scores) / len(scores)
|
|
lines.append(f" {cat}: {avg:.3f} (n={len(scores)})")
|
|
|
|
lines.extend(["", "-" * 60, ""])
|
|
|
|
# Failed examples
|
|
failed = [r for r in results if r.get("score", 0) < 0.8]
|
|
if failed:
|
|
lines.append(f"Failed examples ({len(failed)}):")
|
|
for r in failed[:10]: # Show first 10
|
|
lines.append(f" [{r.get('id')}] {r.get('question', '')[:40]}...")
|
|
lines.append(f" Expected: {r.get('expected_count')}, Got: {r.get('actual_count')}")
|
|
lines.append(f" Score: {r.get('score', 0):.3f}")
|
|
if len(failed) > 10:
|
|
lines.append(f" ... and {len(failed) - 10} more")
|
|
|
|
lines.append("")
|
|
lines.append("=" * 60)
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Quick test
|
|
class MockExample:
|
|
expected_count = 10
|
|
expected_slots = {"institution_type": "A", "location": "NL-UT"}
|
|
is_count_query = True
|
|
|
|
example = MockExample() # type: ignore
|
|
|
|
# Test count extraction
|
|
test_answers = [
|
|
("Er zijn 10 archieven in Utrecht.", 10),
|
|
("In Utrecht zijn er 10 archieven.", 10),
|
|
("Het aantal archieven is 10.", 10),
|
|
("Utrecht heeft 10 archieven.", 10),
|
|
("Ik heb 10 resultaten gevonden.", 10),
|
|
]
|
|
|
|
print("Count extraction tests:")
|
|
for answer, expected in test_answers:
|
|
extracted = extract_count_from_answer(answer)
|
|
status = "✓" if extracted == expected else "✗"
|
|
print(f" {status} '{answer[:40]}...' -> {extracted} (expected {expected})")
|