glam/backend/rag/evaluation/dataset_loader.py
2026-01-09 20:35:19 +01:00

150 lines
4.7 KiB
Python

"""
Golden Dataset Loader for DSPy RAG Evaluation
Loads and validates the golden dataset JSON file containing
test examples for evaluating the Heritage RAG system.
"""
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@dataclass
class GoldenExample:
"""A single example from the golden dataset."""
id: str
category: str
subcategory: str
language: str
question: str
expected_count: int | None
expected_slots: dict[str, Any]
notes: str | None = None
@property
def is_count_query(self) -> bool:
"""Check if this is a COUNT query example."""
return self.category == "count"
@property
def institution_type(self) -> str | list[str] | None:
"""Get the expected institution type(s)."""
return self.expected_slots.get("institution_type")
@property
def location(self) -> str | None:
"""Get the expected location."""
return self.expected_slots.get("location")
@property
def location_level(self) -> str | None:
"""Get the expected location level (subregion/settlement)."""
return self.expected_slots.get("location_level")
def load_golden_dataset(
dataset_path: str | Path | None = None,
category_filter: str | None = None,
subcategory_filter: str | None = None,
max_examples: int | None = None,
) -> list[GoldenExample]:
"""
Load the golden dataset from JSON file.
Args:
dataset_path: Path to the JSON file. Defaults to data/rag_eval/golden_dataset.json
category_filter: Only include examples with this category (e.g., "count")
subcategory_filter: Only include examples with this subcategory
max_examples: Maximum number of examples to return
Returns:
List of GoldenExample objects
"""
if dataset_path is None:
# Default path relative to project root
dataset_path = Path(__file__).parent.parent.parent.parent / "data" / "rag_eval" / "golden_dataset.json"
else:
dataset_path = Path(dataset_path)
if not dataset_path.exists():
raise FileNotFoundError(f"Golden dataset not found at {dataset_path}")
with open(dataset_path, "r", encoding="utf-8") as f:
data = json.load(f)
examples: list[GoldenExample] = []
for raw_example in data.get("examples", []):
# Apply filters
if category_filter and raw_example.get("category") != category_filter:
continue
if subcategory_filter and raw_example.get("subcategory") != subcategory_filter:
continue
example = GoldenExample(
id=raw_example["id"],
category=raw_example["category"],
subcategory=raw_example.get("subcategory", ""),
language=raw_example.get("language", "nl"),
question=raw_example["question"],
expected_count=raw_example.get("expected_count"),
expected_slots=raw_example.get("expected_slots", {}),
notes=raw_example.get("notes"),
)
examples.append(example)
if max_examples and len(examples) >= max_examples:
break
return examples
def get_dataset_stats(dataset_path: str | Path | None = None) -> dict[str, Any]:
"""
Get statistics about the golden dataset.
Returns:
Dictionary with counts by category, subcategory, etc.
"""
examples = load_golden_dataset(dataset_path)
categories: dict[str, int] = {}
subcategories: dict[str, int] = {}
languages: dict[str, int] = {}
for ex in examples:
categories[ex.category] = categories.get(ex.category, 0) + 1
subcategories[ex.subcategory] = subcategories.get(ex.subcategory, 0) + 1
languages[ex.language] = languages.get(ex.language, 0) + 1
return {
"total_examples": len(examples),
"categories": categories,
"subcategories": subcategories,
"languages": languages,
}
if __name__ == "__main__":
# Quick test
import sys
try:
stats = get_dataset_stats()
print("Golden Dataset Statistics:")
print(f" Total examples: {stats['total_examples']}")
print(f" Categories: {stats['categories']}")
print(f" Subcategories: {stats['subcategories']}")
print(f" Languages: {stats['languages']}")
# Load a few examples
examples = load_golden_dataset(max_examples=3)
print("\nFirst 3 examples:")
for ex in examples:
print(f" [{ex.id}] {ex.question[:50]}... -> {ex.expected_count}")
except FileNotFoundError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)