166 lines
4.5 KiB
Python
166 lines
4.5 KiB
Python
"""
|
|
Entity Extraction Metrics
|
|
|
|
Measures precision, recall, and F1 for entity extraction.
|
|
"""
|
|
|
|
from typing import Any
|
|
|
|
|
|
def normalize_entity(entity: str) -> str:
|
|
"""Normalize entity for comparison.
|
|
|
|
Args:
|
|
entity: Entity string to normalize
|
|
|
|
Returns:
|
|
Normalized lowercase entity
|
|
"""
|
|
return entity.lower().strip()
|
|
|
|
|
|
def entity_precision(expected: list[str], predicted: list[str]) -> float:
|
|
"""Calculate entity extraction precision.
|
|
|
|
Precision = correctly predicted / total predicted
|
|
|
|
Args:
|
|
expected: List of expected entities
|
|
predicted: List of predicted entities
|
|
|
|
Returns:
|
|
Precision score 0.0-1.0
|
|
"""
|
|
if not predicted:
|
|
return 0.0 if expected else 1.0
|
|
|
|
expected_normalized = {normalize_entity(e) for e in expected}
|
|
predicted_normalized = [normalize_entity(e) for e in predicted]
|
|
|
|
correct = sum(1 for p in predicted_normalized if p in expected_normalized)
|
|
return correct / len(predicted_normalized)
|
|
|
|
|
|
def entity_recall(expected: list[str], predicted: list[str]) -> float:
|
|
"""Calculate entity extraction recall.
|
|
|
|
Recall = correctly predicted / total expected
|
|
|
|
Args:
|
|
expected: List of expected entities
|
|
predicted: List of predicted entities
|
|
|
|
Returns:
|
|
Recall score 0.0-1.0
|
|
"""
|
|
if not expected:
|
|
return 1.0 # Nothing to recall
|
|
|
|
expected_normalized = {normalize_entity(e) for e in expected}
|
|
predicted_normalized = {normalize_entity(p) for p in predicted}
|
|
|
|
correct = len(expected_normalized & predicted_normalized)
|
|
return correct / len(expected_normalized)
|
|
|
|
|
|
def entity_f1(expected: list[str], predicted: list[str]) -> float:
|
|
"""Calculate entity extraction F1 score.
|
|
|
|
F1 = 2 * (precision * recall) / (precision + recall)
|
|
|
|
Args:
|
|
expected: List of expected entities
|
|
predicted: List of predicted entities
|
|
|
|
Returns:
|
|
F1 score 0.0-1.0
|
|
"""
|
|
precision = entity_precision(expected, predicted)
|
|
recall = entity_recall(expected, predicted)
|
|
|
|
if precision + recall == 0:
|
|
return 0.0
|
|
|
|
return 2 * (precision * recall) / (precision + recall)
|
|
|
|
|
|
def entity_f1_metric(example: Any, pred: Any, trace: Any = None) -> float:
|
|
"""DSPy-compatible entity F1 metric.
|
|
|
|
Args:
|
|
example: DSPy Example with expected_entities
|
|
pred: Prediction with entities field
|
|
trace: Optional trace for debugging
|
|
|
|
Returns:
|
|
F1 score 0.0-1.0
|
|
"""
|
|
expected = getattr(example, "expected_entities", [])
|
|
predicted = getattr(pred, "entities", [])
|
|
|
|
return entity_f1(expected, predicted)
|
|
|
|
|
|
def fuzzy_entity_match(entity1: str, entity2: str, threshold: float = 0.85) -> bool:
|
|
"""Check if two entities match using fuzzy matching.
|
|
|
|
Args:
|
|
entity1: First entity
|
|
entity2: Second entity
|
|
threshold: Similarity threshold (0.0-1.0)
|
|
|
|
Returns:
|
|
True if entities match above threshold
|
|
"""
|
|
try:
|
|
from rapidfuzz import fuzz
|
|
similarity = fuzz.ratio(
|
|
normalize_entity(entity1),
|
|
normalize_entity(entity2)
|
|
) / 100
|
|
return similarity >= threshold
|
|
except ImportError:
|
|
# Fallback to exact match
|
|
return normalize_entity(entity1) == normalize_entity(entity2)
|
|
|
|
|
|
def fuzzy_entity_f1(
|
|
expected: list[str],
|
|
predicted: list[str],
|
|
threshold: float = 0.85
|
|
) -> float:
|
|
"""Calculate F1 with fuzzy matching for entities.
|
|
|
|
Args:
|
|
expected: List of expected entities
|
|
predicted: List of predicted entities
|
|
threshold: Similarity threshold for matching
|
|
|
|
Returns:
|
|
F1 score with fuzzy matching
|
|
"""
|
|
if not expected and not predicted:
|
|
return 1.0
|
|
if not expected or not predicted:
|
|
return 0.0
|
|
|
|
# Calculate matches with fuzzy matching
|
|
matched_expected = set()
|
|
matched_predicted = set()
|
|
|
|
for i, exp in enumerate(expected):
|
|
for j, pred in enumerate(predicted):
|
|
if j in matched_predicted:
|
|
continue
|
|
if fuzzy_entity_match(exp, pred, threshold):
|
|
matched_expected.add(i)
|
|
matched_predicted.add(j)
|
|
break
|
|
|
|
precision = len(matched_predicted) / len(predicted) if predicted else 0.0
|
|
recall = len(matched_expected) / len(expected) if expected else 0.0
|
|
|
|
if precision + recall == 0:
|
|
return 0.0
|
|
|
|
return 2 * (precision * recall) / (precision + recall)
|