glam/tests/dspy_gitops/metrics/entity_extraction.py
2026-01-11 18:08:40 +01:00

166 lines
4.5 KiB
Python

"""
Entity Extraction Metrics
Measures precision, recall, and F1 for entity extraction.
"""
from typing import Any
def normalize_entity(entity: str) -> str:
"""Normalize entity for comparison.
Args:
entity: Entity string to normalize
Returns:
Normalized lowercase entity
"""
return entity.lower().strip()
def entity_precision(expected: list[str], predicted: list[str]) -> float:
"""Calculate entity extraction precision.
Precision = correctly predicted / total predicted
Args:
expected: List of expected entities
predicted: List of predicted entities
Returns:
Precision score 0.0-1.0
"""
if not predicted:
return 0.0 if expected else 1.0
expected_normalized = {normalize_entity(e) for e in expected}
predicted_normalized = [normalize_entity(e) for e in predicted]
correct = sum(1 for p in predicted_normalized if p in expected_normalized)
return correct / len(predicted_normalized)
def entity_recall(expected: list[str], predicted: list[str]) -> float:
"""Calculate entity extraction recall.
Recall = correctly predicted / total expected
Args:
expected: List of expected entities
predicted: List of predicted entities
Returns:
Recall score 0.0-1.0
"""
if not expected:
return 1.0 # Nothing to recall
expected_normalized = {normalize_entity(e) for e in expected}
predicted_normalized = {normalize_entity(p) for p in predicted}
correct = len(expected_normalized & predicted_normalized)
return correct / len(expected_normalized)
def entity_f1(expected: list[str], predicted: list[str]) -> float:
"""Calculate entity extraction F1 score.
F1 = 2 * (precision * recall) / (precision + recall)
Args:
expected: List of expected entities
predicted: List of predicted entities
Returns:
F1 score 0.0-1.0
"""
precision = entity_precision(expected, predicted)
recall = entity_recall(expected, predicted)
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)
def entity_f1_metric(example: Any, pred: Any, trace: Any = None) -> float:
"""DSPy-compatible entity F1 metric.
Args:
example: DSPy Example with expected_entities
pred: Prediction with entities field
trace: Optional trace for debugging
Returns:
F1 score 0.0-1.0
"""
expected = getattr(example, "expected_entities", [])
predicted = getattr(pred, "entities", [])
return entity_f1(expected, predicted)
def fuzzy_entity_match(entity1: str, entity2: str, threshold: float = 0.85) -> bool:
"""Check if two entities match using fuzzy matching.
Args:
entity1: First entity
entity2: Second entity
threshold: Similarity threshold (0.0-1.0)
Returns:
True if entities match above threshold
"""
try:
from rapidfuzz import fuzz
similarity = fuzz.ratio(
normalize_entity(entity1),
normalize_entity(entity2)
) / 100
return similarity >= threshold
except ImportError:
# Fallback to exact match
return normalize_entity(entity1) == normalize_entity(entity2)
def fuzzy_entity_f1(
expected: list[str],
predicted: list[str],
threshold: float = 0.85
) -> float:
"""Calculate F1 with fuzzy matching for entities.
Args:
expected: List of expected entities
predicted: List of predicted entities
threshold: Similarity threshold for matching
Returns:
F1 score with fuzzy matching
"""
if not expected and not predicted:
return 1.0
if not expected or not predicted:
return 0.0
# Calculate matches with fuzzy matching
matched_expected = set()
matched_predicted = set()
for i, exp in enumerate(expected):
for j, pred in enumerate(predicted):
if j in matched_predicted:
continue
if fuzzy_entity_match(exp, pred, threshold):
matched_expected.add(i)
matched_predicted.add(j)
break
precision = len(matched_predicted) / len(predicted) if predicted else 0.0
recall = len(matched_expected) / len(expected) if expected else 0.0
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)