150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
"""
|
|
Golden Dataset Loader for DSPy RAG Evaluation
|
|
|
|
Loads and validates the golden dataset JSON file containing
|
|
test examples for evaluating the Heritage RAG system.
|
|
"""
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
@dataclass
|
|
class GoldenExample:
|
|
"""A single example from the golden dataset."""
|
|
|
|
id: str
|
|
category: str
|
|
subcategory: str
|
|
language: str
|
|
question: str
|
|
expected_count: int | None
|
|
expected_slots: dict[str, Any]
|
|
notes: str | None = None
|
|
|
|
@property
|
|
def is_count_query(self) -> bool:
|
|
"""Check if this is a COUNT query example."""
|
|
return self.category == "count"
|
|
|
|
@property
|
|
def institution_type(self) -> str | list[str] | None:
|
|
"""Get the expected institution type(s)."""
|
|
return self.expected_slots.get("institution_type")
|
|
|
|
@property
|
|
def location(self) -> str | None:
|
|
"""Get the expected location."""
|
|
return self.expected_slots.get("location")
|
|
|
|
@property
|
|
def location_level(self) -> str | None:
|
|
"""Get the expected location level (subregion/settlement)."""
|
|
return self.expected_slots.get("location_level")
|
|
|
|
|
|
def load_golden_dataset(
|
|
dataset_path: str | Path | None = None,
|
|
category_filter: str | None = None,
|
|
subcategory_filter: str | None = None,
|
|
max_examples: int | None = None,
|
|
) -> list[GoldenExample]:
|
|
"""
|
|
Load the golden dataset from JSON file.
|
|
|
|
Args:
|
|
dataset_path: Path to the JSON file. Defaults to data/rag_eval/golden_dataset.json
|
|
category_filter: Only include examples with this category (e.g., "count")
|
|
subcategory_filter: Only include examples with this subcategory
|
|
max_examples: Maximum number of examples to return
|
|
|
|
Returns:
|
|
List of GoldenExample objects
|
|
"""
|
|
if dataset_path is None:
|
|
# Default path relative to project root
|
|
dataset_path = Path(__file__).parent.parent.parent.parent / "data" / "rag_eval" / "golden_dataset.json"
|
|
else:
|
|
dataset_path = Path(dataset_path)
|
|
|
|
if not dataset_path.exists():
|
|
raise FileNotFoundError(f"Golden dataset not found at {dataset_path}")
|
|
|
|
with open(dataset_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
examples: list[GoldenExample] = []
|
|
|
|
for raw_example in data.get("examples", []):
|
|
# Apply filters
|
|
if category_filter and raw_example.get("category") != category_filter:
|
|
continue
|
|
if subcategory_filter and raw_example.get("subcategory") != subcategory_filter:
|
|
continue
|
|
|
|
example = GoldenExample(
|
|
id=raw_example["id"],
|
|
category=raw_example["category"],
|
|
subcategory=raw_example.get("subcategory", ""),
|
|
language=raw_example.get("language", "nl"),
|
|
question=raw_example["question"],
|
|
expected_count=raw_example.get("expected_count"),
|
|
expected_slots=raw_example.get("expected_slots", {}),
|
|
notes=raw_example.get("notes"),
|
|
)
|
|
examples.append(example)
|
|
|
|
if max_examples and len(examples) >= max_examples:
|
|
break
|
|
|
|
return examples
|
|
|
|
|
|
def get_dataset_stats(dataset_path: str | Path | None = None) -> dict[str, Any]:
|
|
"""
|
|
Get statistics about the golden dataset.
|
|
|
|
Returns:
|
|
Dictionary with counts by category, subcategory, etc.
|
|
"""
|
|
examples = load_golden_dataset(dataset_path)
|
|
|
|
categories: dict[str, int] = {}
|
|
subcategories: dict[str, int] = {}
|
|
languages: dict[str, int] = {}
|
|
|
|
for ex in examples:
|
|
categories[ex.category] = categories.get(ex.category, 0) + 1
|
|
subcategories[ex.subcategory] = subcategories.get(ex.subcategory, 0) + 1
|
|
languages[ex.language] = languages.get(ex.language, 0) + 1
|
|
|
|
return {
|
|
"total_examples": len(examples),
|
|
"categories": categories,
|
|
"subcategories": subcategories,
|
|
"languages": languages,
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Quick test
|
|
import sys
|
|
|
|
try:
|
|
stats = get_dataset_stats()
|
|
print("Golden Dataset Statistics:")
|
|
print(f" Total examples: {stats['total_examples']}")
|
|
print(f" Categories: {stats['categories']}")
|
|
print(f" Subcategories: {stats['subcategories']}")
|
|
print(f" Languages: {stats['languages']}")
|
|
|
|
# Load a few examples
|
|
examples = load_golden_dataset(max_examples=3)
|
|
print("\nFirst 3 examples:")
|
|
for ex in examples:
|
|
print(f" [{ex.id}] {ex.question[:50]}... -> {ex.expected_count}")
|
|
except FileNotFoundError as e:
|
|
print(f"Error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|