478 lines
16 KiB
Python
Executable file
478 lines
16 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
Standalone GEPA optimization script for Heritage RAG pipeline.
|
|
|
|
This script runs GEPA optimization with aggressive timeout handling and
|
|
saves results incrementally. Designed to complete within reasonable time.
|
|
|
|
Usage:
|
|
# Activate environment first
|
|
source .venv/bin/activate && source .env
|
|
|
|
# Run with default settings (light budget, 5 train examples)
|
|
python backend/rag/run_gepa_optimization.py
|
|
|
|
# Run with custom settings
|
|
python backend/rag/run_gepa_optimization.py --budget light --train-size 5 --val-size 3
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
import dspy
|
|
from dspy import Example
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Output directory for optimized models
|
|
OUTPUT_DIR = Path(__file__).parent / "optimized_models"
|
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
def create_minimal_training_data(train_size: int = 5, val_size: int = 3):
|
|
"""Create minimal training data for faster optimization.
|
|
|
|
Uses a representative subset covering different query intents.
|
|
"""
|
|
# Core training examples covering different intents
|
|
all_train = [
|
|
# Statistical (Dutch)
|
|
Example(
|
|
question="Hoeveel musea zijn er in Amsterdam?",
|
|
language="nl",
|
|
expected_intent="statistical",
|
|
expected_entities=["amsterdam", "musea"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["musea", "Amsterdam"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Entity lookup (English)
|
|
Example(
|
|
question="Where is the Rijksmuseum located?",
|
|
language="en",
|
|
expected_intent="entity_lookup",
|
|
expected_entities=["rijksmuseum"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["Rijksmuseum", "Amsterdam"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Exploration
|
|
Example(
|
|
question="Show me archives related to World War II",
|
|
language="en",
|
|
expected_intent="exploration",
|
|
expected_entities=["world war ii", "archives"],
|
|
expected_sources=["qdrant", "sparql"],
|
|
answer_contains=["archive", "war"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Temporal (Dutch)
|
|
Example(
|
|
question="Welke bibliotheken zijn gefuseerd sinds 2000?",
|
|
language="nl",
|
|
expected_intent="temporal",
|
|
expected_entities=["bibliotheken", "2000"],
|
|
expected_sources=["typedb", "sparql"],
|
|
answer_contains=["bibliotheek"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Geographic
|
|
Example(
|
|
question="Which museums are in Noord-Holland province?",
|
|
language="en",
|
|
expected_intent="geographic",
|
|
expected_entities=["noord-holland", "museums"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["museum", "Noord-Holland"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Relational
|
|
Example(
|
|
question="What collections does the Nationaal Archief manage?",
|
|
language="en",
|
|
expected_intent="relational",
|
|
expected_entities=["nationaal archief", "collections"],
|
|
expected_sources=["typedb", "sparql"],
|
|
answer_contains=["Nationaal Archief", "collection"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Comparative
|
|
Example(
|
|
question="Compare visitor numbers of Rijksmuseum and Van Gogh Museum",
|
|
language="en",
|
|
expected_intent="comparative",
|
|
expected_entities=["rijksmuseum", "van gogh museum"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["visitor", "museum"],
|
|
).with_inputs("question", "language"),
|
|
]
|
|
|
|
# Validation examples
|
|
all_val = [
|
|
Example(
|
|
question="List all libraries in Utrecht",
|
|
language="en",
|
|
expected_intent="geographic",
|
|
expected_entities=["libraries", "utrecht"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["library", "Utrecht"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Wat is de geschiedenis van het Anne Frank Huis?",
|
|
language="nl",
|
|
expected_intent="entity_lookup",
|
|
expected_entities=["anne frank huis"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["Anne Frank"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="How many heritage institutions are in the Netherlands?",
|
|
language="en",
|
|
expected_intent="statistical",
|
|
expected_entities=["heritage institutions", "netherlands"],
|
|
expected_sources=["sparql"],
|
|
answer_contains=["institution", "Netherlands"],
|
|
).with_inputs("question", "language"),
|
|
]
|
|
|
|
# Return requested sizes
|
|
return all_train[:train_size], all_val[:val_size]
|
|
|
|
|
|
def create_gepa_metric():
|
|
"""Create simplified GEPA metric for heritage RAG."""
|
|
|
|
def heritage_metric(gold: Example, pred, trace=None, pred_name=None, pred_trace=None) -> dspy.Prediction:
|
|
"""Simplified metric that scores routing and answer quality.
|
|
|
|
DSPy 3.0.4 GEPA requires 5 arguments:
|
|
- gold: The gold example (Example object)
|
|
- pred: The prediction (Prediction object)
|
|
- trace: The trace of the prediction
|
|
- pred_name: Name of the predictor being evaluated
|
|
- pred_trace: Trace specific to this predictor
|
|
|
|
Returns:
|
|
dspy.Prediction with score (float) and feedback (str)
|
|
"""
|
|
# Use gold as the example
|
|
example = gold
|
|
score = 0.0
|
|
feedback_parts = []
|
|
|
|
# 1. Intent matching (30 points)
|
|
expected_intent = getattr(example, "expected_intent", None)
|
|
pred_intent = getattr(pred, "intent", None)
|
|
|
|
if expected_intent and pred_intent:
|
|
if pred_intent.lower() == expected_intent.lower():
|
|
score += 0.30
|
|
feedback_parts.append("Intent correctly identified.")
|
|
else:
|
|
feedback_parts.append(
|
|
f"Intent mismatch: expected '{expected_intent}', got '{pred_intent}'. "
|
|
"Improve intent classification."
|
|
)
|
|
|
|
# 2. Entity extraction (25 points)
|
|
expected_entities = getattr(example, "expected_entities", [])
|
|
pred_entities = getattr(pred, "entities", [])
|
|
|
|
if expected_entities and pred_entities:
|
|
# Normalize for comparison
|
|
expected_lower = {e.lower() for e in expected_entities}
|
|
pred_lower = {str(e).lower() for e in pred_entities}
|
|
|
|
overlap = expected_lower & pred_lower
|
|
if expected_lower:
|
|
entity_score = len(overlap) / len(expected_lower)
|
|
score += 0.25 * entity_score
|
|
|
|
if entity_score == 1.0:
|
|
feedback_parts.append("All expected entities extracted.")
|
|
else:
|
|
missing = expected_lower - pred_lower
|
|
feedback_parts.append(
|
|
f"Missing entities: {missing}. Improve entity extraction."
|
|
)
|
|
|
|
# 3. Source selection (20 points)
|
|
expected_sources = getattr(example, "expected_sources", [])
|
|
pred_sources = getattr(pred, "sources_used", [])
|
|
|
|
if expected_sources and pred_sources:
|
|
expected_set = set(expected_sources)
|
|
pred_set = set(pred_sources)
|
|
|
|
if expected_set == pred_set:
|
|
score += 0.20
|
|
feedback_parts.append("Correct sources selected.")
|
|
elif expected_set & pred_set:
|
|
overlap_ratio = len(expected_set & pred_set) / len(expected_set)
|
|
score += 0.20 * overlap_ratio
|
|
feedback_parts.append(
|
|
f"Partially correct sources. Expected: {expected_sources}, got: {pred_sources}"
|
|
)
|
|
|
|
# 4. Answer quality (25 points)
|
|
answer_contains = getattr(example, "answer_contains", [])
|
|
answer = getattr(pred, "answer", "")
|
|
|
|
if answer_contains and answer:
|
|
answer_lower = answer.lower()
|
|
matches = sum(1 for term in answer_contains if term.lower() in answer_lower)
|
|
|
|
if answer_contains:
|
|
answer_score = matches / len(answer_contains)
|
|
score += 0.25 * answer_score
|
|
|
|
if answer_score == 1.0:
|
|
feedback_parts.append("Answer contains all expected terms.")
|
|
else:
|
|
missing = [t for t in answer_contains if t.lower() not in answer_lower]
|
|
feedback_parts.append(
|
|
f"Answer missing terms: {missing}. Improve answer generation."
|
|
)
|
|
|
|
feedback = "\n".join([f"- {p}" for p in feedback_parts])
|
|
return dspy.Prediction(score=score, feedback=feedback)
|
|
|
|
return heritage_metric
|
|
|
|
|
|
def run_optimization(
|
|
budget: str = "light",
|
|
train_size: int = 5,
|
|
val_size: int = 3,
|
|
model: str = "openai/gpt-4o-mini",
|
|
reflection_model: str = "openai/gpt-4o",
|
|
):
|
|
"""Run GEPA optimization and save results.
|
|
|
|
Args:
|
|
budget: GEPA auto budget (light, medium, heavy)
|
|
train_size: Number of training examples
|
|
val_size: Number of validation examples
|
|
model: Student model for pipeline
|
|
reflection_model: Teacher model for GEPA reflection
|
|
"""
|
|
start_time = time.time()
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("GEPA Optimization for Heritage RAG Pipeline")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Budget: {budget}")
|
|
logger.info(f"Train size: {train_size}, Val size: {val_size}")
|
|
logger.info(f"Student model: {model}")
|
|
logger.info(f"Reflection model: {reflection_model}")
|
|
|
|
# Configure DSPy
|
|
logger.info("Configuring DSPy...")
|
|
student_lm = dspy.LM(model=model, temperature=0.7, max_tokens=2000)
|
|
dspy.configure(lm=student_lm)
|
|
|
|
# Create training data
|
|
logger.info("Creating training data...")
|
|
trainset, valset = create_minimal_training_data(train_size, val_size)
|
|
logger.info(f"Created {len(trainset)} train, {len(valset)} val examples")
|
|
|
|
# Create pipeline
|
|
logger.info("Creating pipeline...")
|
|
from backend.rag.dspy_heritage_rag import HeritageRAGPipeline
|
|
pipeline = HeritageRAGPipeline()
|
|
|
|
# Create optimizer
|
|
logger.info("Creating GEPA optimizer...")
|
|
reflection_lm = dspy.LM(
|
|
model=reflection_model,
|
|
temperature=1.0,
|
|
max_tokens=16000,
|
|
)
|
|
|
|
metric = create_gepa_metric()
|
|
|
|
optimizer = dspy.GEPA(
|
|
metric=metric,
|
|
auto=budget,
|
|
reflection_lm=reflection_lm,
|
|
candidate_selection_strategy="pareto",
|
|
track_stats=True,
|
|
track_best_outputs=True,
|
|
use_merge=True,
|
|
max_merge_invocations=3, # Reduced for speed
|
|
skip_perfect_score=True,
|
|
seed=42,
|
|
)
|
|
|
|
# Run optimization
|
|
logger.info("Starting optimization (this may take 5-15 minutes)...")
|
|
try:
|
|
optimized = optimizer.compile(
|
|
student=pipeline,
|
|
trainset=trainset,
|
|
valset=valset,
|
|
)
|
|
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"Optimization completed in {elapsed:.1f} seconds")
|
|
|
|
# Log results
|
|
if hasattr(optimized, "detailed_results"):
|
|
results = optimized.detailed_results
|
|
best_score = results.val_aggregate_scores[results.best_idx]
|
|
logger.info(f"Best validation score: {best_score:.3f}")
|
|
logger.info(f"Total candidates: {len(results.candidates)}")
|
|
logger.info(f"Metric calls: {results.total_metric_calls}")
|
|
|
|
# Save optimized pipeline
|
|
output_path = OUTPUT_DIR / f"heritage_rag_{timestamp}.json"
|
|
optimized.save(str(output_path))
|
|
logger.info(f"Saved optimized pipeline to: {output_path}")
|
|
|
|
# Also save as "latest"
|
|
latest_path = OUTPUT_DIR / "heritage_rag_latest.json"
|
|
optimized.save(str(latest_path))
|
|
logger.info(f"Saved as latest: {latest_path}")
|
|
|
|
# Save metadata
|
|
metadata = {
|
|
"timestamp": timestamp,
|
|
"budget": budget,
|
|
"train_size": train_size,
|
|
"val_size": val_size,
|
|
"student_model": model,
|
|
"reflection_model": reflection_model,
|
|
"elapsed_seconds": elapsed,
|
|
"best_score": best_score if hasattr(optimized, "detailed_results") else None,
|
|
}
|
|
|
|
metadata_path = OUTPUT_DIR / f"metadata_{timestamp}.json"
|
|
with open(metadata_path, "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
logger.info(f"Saved metadata to: {metadata_path}")
|
|
|
|
return optimized, output_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Optimization failed: {e}")
|
|
raise
|
|
|
|
|
|
def test_optimized_pipeline(model_path: str = None):
|
|
"""Test the optimized pipeline with sample queries."""
|
|
|
|
if model_path is None:
|
|
model_path = OUTPUT_DIR / "heritage_rag_latest.json"
|
|
|
|
if not Path(model_path).exists():
|
|
logger.error(f"No optimized model found at {model_path}")
|
|
return
|
|
|
|
logger.info(f"Loading optimized pipeline from {model_path}")
|
|
|
|
# Configure DSPy
|
|
lm = dspy.LM(model="openai/gpt-4o-mini", temperature=0.7, max_tokens=2000)
|
|
dspy.configure(lm=lm)
|
|
|
|
# Load pipeline
|
|
from backend.rag.dspy_heritage_rag import HeritageRAGPipeline
|
|
pipeline = HeritageRAGPipeline()
|
|
pipeline.load(str(model_path))
|
|
|
|
# Test queries
|
|
test_queries = [
|
|
("Hoeveel musea zijn er in Amsterdam?", "nl"),
|
|
("Where is the Rijksmuseum located?", "en"),
|
|
("List archives in Noord-Holland", "en"),
|
|
]
|
|
|
|
logger.info("Testing optimized pipeline...")
|
|
for question, lang in test_queries:
|
|
logger.info(f"\nQuery: {question} (lang={lang})")
|
|
result = pipeline(question=question, language=lang)
|
|
logger.info(f" Intent: {result.intent}")
|
|
logger.info(f" Sources: {result.sources_used}")
|
|
logger.info(f" Answer: {result.answer[:200]}...")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Run GEPA optimization for Heritage RAG")
|
|
parser.add_argument(
|
|
"--budget",
|
|
choices=["light", "medium", "heavy"],
|
|
default="light",
|
|
help="GEPA optimization budget"
|
|
)
|
|
parser.add_argument(
|
|
"--train-size",
|
|
type=int,
|
|
default=5,
|
|
help="Number of training examples (max 7)"
|
|
)
|
|
parser.add_argument(
|
|
"--val-size",
|
|
type=int,
|
|
default=3,
|
|
help="Number of validation examples (max 3)"
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
default="openai/gpt-4o-mini",
|
|
help="Student model for pipeline"
|
|
)
|
|
parser.add_argument(
|
|
"--reflection-model",
|
|
default="openai/gpt-4o",
|
|
help="Reflection model for GEPA"
|
|
)
|
|
parser.add_argument(
|
|
"--test-only",
|
|
action="store_true",
|
|
help="Only test existing optimized pipeline"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check environment
|
|
if not os.environ.get("OPENAI_API_KEY"):
|
|
logger.error("OPENAI_API_KEY not set. Run: source .env")
|
|
sys.exit(1)
|
|
|
|
if args.test_only:
|
|
test_optimized_pipeline()
|
|
else:
|
|
optimized, output_path = run_optimization(
|
|
budget=args.budget,
|
|
train_size=min(args.train_size, 7),
|
|
val_size=min(args.val_size, 3),
|
|
model=args.model,
|
|
reflection_model=args.reflection_model,
|
|
)
|
|
|
|
# Run quick test
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("Running quick test of optimized pipeline...")
|
|
test_optimized_pipeline(str(output_path))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|