glam/backend/rag/run_mipro_optimization.py
2025-12-21 00:01:54 +01:00

650 lines
24 KiB
Python

#!/usr/bin/env python3
"""
MIPROv2 Optimization for Heritage RAG Pipeline
MIPROv2 is DSPy 3.0's recommended optimizer for instruction optimization.
It generates optimized instructions AND selects optimal few-shot demonstrations.
Benefits over BootstrapFewShot:
- Generates optimized instructions (not just few-shot examples)
- Uses Bayesian optimization to search instruction space
- Better at finding optimal prompts for complex tasks
Usage:
cd /Users/kempersc/apps/glam
source .venv/bin/activate && source .env
python backend/rag/run_mipro_optimization.py
Requirements:
- SSH tunnel active: ssh -f -N -L 7878:localhost:7878 root@91.98.224.44
- Environment loaded with OPENAI_API_KEY
"""
import json
import logging
import os
import socket
import sys
from datetime import datetime, timezone
from pathlib import Path
import dspy
from dspy import Example
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# =============================================================================
# TRAINING DATA
# =============================================================================
def create_heritage_training_data():
"""Create comprehensive training data for heritage RAG optimization.
Covers all query intents and both languages (Dutch/English).
"""
trainset = [
# Statistical queries
Example(
question="Hoeveel musea zijn er in Amsterdam?",
language="nl",
expected_intent="statistical",
expected_entities=["amsterdam", "musea"],
expected_sources=["sparql", "qdrant"],
answer_contains=["musea", "Amsterdam"],
).with_inputs("question", "language"),
Example(
question="How many archives are in the Netherlands?",
language="en",
expected_intent="statistical",
expected_entities=["netherlands", "archives"],
expected_sources=["sparql"],
answer_contains=["archive", "Netherlands"],
).with_inputs("question", "language"),
Example(
question="What is the total number of libraries in Utrecht?",
language="en",
expected_intent="statistical",
expected_entities=["utrecht", "libraries"],
expected_sources=["sparql", "qdrant"],
answer_contains=["library", "Utrecht"],
).with_inputs("question", "language"),
# Entity lookup queries
Example(
question="Where is the Rijksmuseum located?",
language="en",
expected_intent="entity_lookup",
expected_entities=["rijksmuseum"],
expected_sources=["sparql", "qdrant"],
answer_contains=["Rijksmuseum", "Amsterdam"],
).with_inputs("question", "language"),
Example(
question="Wat is het Nationaal Archief?",
language="nl",
expected_intent="entity_lookup",
expected_entities=["nationaal archief"],
expected_sources=["sparql", "qdrant"],
answer_contains=["Nationaal Archief", "Den Haag"],
).with_inputs("question", "language"),
Example(
question="Tell me about the Van Gogh Museum",
language="en",
expected_intent="entity_lookup",
expected_entities=["van gogh museum"],
expected_sources=["sparql", "qdrant"],
answer_contains=["Van Gogh", "museum", "Amsterdam"],
).with_inputs("question", "language"),
# Exploration queries
Example(
question="Show me archives related to World War II",
language="en",
expected_intent="exploration",
expected_entities=["world war ii", "archives"],
expected_sources=["qdrant", "sparql"],
answer_contains=["archive", "war"],
).with_inputs("question", "language"),
Example(
question="What heritage institutions focus on maritime history?",
language="en",
expected_intent="exploration",
expected_entities=["maritime history", "heritage"],
expected_sources=["qdrant"],
answer_contains=["maritime", "museum"],
).with_inputs("question", "language"),
# Geographic queries
Example(
question="Which museums are in Noord-Holland province?",
language="en",
expected_intent="geographic",
expected_entities=["noord-holland", "museums"],
expected_sources=["sparql", "qdrant"],
answer_contains=["museum", "Noord-Holland"],
).with_inputs("question", "language"),
Example(
question="Welke archieven zijn er in Limburg?",
language="nl",
expected_intent="geographic",
expected_entities=["limburg", "archieven"],
expected_sources=["sparql", "qdrant"],
answer_contains=["archief", "Limburg"],
).with_inputs("question", "language"),
Example(
question="List all libraries in Rotterdam",
language="en",
expected_intent="geographic",
expected_entities=["rotterdam", "libraries"],
expected_sources=["sparql", "qdrant"],
answer_contains=["library", "Rotterdam"],
).with_inputs("question", "language"),
# Temporal queries
Example(
question="Welke bibliotheken zijn gefuseerd sinds 2000?",
language="nl",
expected_intent="temporal",
expected_entities=["bibliotheken", "2000"],
expected_sources=["typedb", "sparql"],
answer_contains=["bibliotheek"],
).with_inputs("question", "language"),
Example(
question="Which museums opened after 2010?",
language="en",
expected_intent="temporal",
expected_entities=["museums", "2010"],
expected_sources=["sparql", "typedb"],
answer_contains=["museum", "opened"],
).with_inputs("question", "language"),
# Relational queries
Example(
question="What collections does the Nationaal Archief manage?",
language="en",
expected_intent="relational",
expected_entities=["nationaal archief", "collections"],
expected_sources=["typedb", "sparql"],
answer_contains=["Nationaal Archief", "collection"],
).with_inputs("question", "language"),
Example(
question="Which institutions are part of the KNAW?",
language="en",
expected_intent="relational",
expected_entities=["knaw", "institutions"],
expected_sources=["typedb", "sparql"],
answer_contains=["KNAW"],
).with_inputs("question", "language"),
# Comparative queries
Example(
question="Compare visitor numbers of Rijksmuseum and Van Gogh Museum",
language="en",
expected_intent="comparative",
expected_entities=["rijksmuseum", "van gogh museum"],
expected_sources=["sparql", "qdrant"],
answer_contains=["visitor", "museum"],
).with_inputs("question", "language"),
# Person queries (role-based)
Example(
question="Who are the curators at the Mauritshuis?",
language="en",
expected_intent="person_lookup",
expected_entities=["mauritshuis", "curators"],
expected_sources=["qdrant"],
answer_contains=["Mauritshuis", "curator"],
).with_inputs("question", "language"),
Example(
question="Wie is de directeur van het Rijksmuseum?",
language="nl",
expected_intent="person_lookup",
expected_entities=["rijksmuseum", "directeur"],
expected_sources=["qdrant"],
answer_contains=["Rijksmuseum", "directeur"],
).with_inputs("question", "language"),
Example(
question="Find archivists working at provincial archives",
language="en",
expected_intent="person_lookup",
expected_entities=["archivists", "provincial archives"],
expected_sources=["qdrant"],
answer_contains=["archivist", "archive"],
).with_inputs("question", "language"),
# Complex multi-hop queries
Example(
question="Which archives in Amsterdam have digitized their collections?",
language="en",
expected_intent="exploration",
expected_entities=["amsterdam", "archives", "digitized"],
expected_sources=["qdrant", "sparql"],
answer_contains=["archive", "digital"],
).with_inputs("question", "language"),
]
valset = [
Example(
question="List all libraries in Utrecht",
language="en",
expected_intent="geographic",
expected_entities=["libraries", "utrecht"],
expected_sources=["sparql", "qdrant"],
answer_contains=["library", "Utrecht"],
).with_inputs("question", "language"),
Example(
question="Wat is de geschiedenis van het Anne Frank Huis?",
language="nl",
expected_intent="entity_lookup",
expected_entities=["anne frank huis"],
expected_sources=["sparql", "qdrant"],
answer_contains=["Anne Frank"],
).with_inputs("question", "language"),
Example(
question="How many heritage institutions are in the Netherlands?",
language="en",
expected_intent="statistical",
expected_entities=["heritage institutions", "netherlands"],
expected_sources=["sparql"],
answer_contains=["institution", "Netherlands"],
).with_inputs("question", "language"),
Example(
question="Which museums have paintings by Rembrandt?",
language="en",
expected_intent="exploration",
expected_entities=["museums", "rembrandt"],
expected_sources=["qdrant"],
answer_contains=["Rembrandt", "museum"],
).with_inputs("question", "language"),
Example(
question="Wie werkt er bij het Stadsarchief Amsterdam?",
language="nl",
expected_intent="person_lookup",
expected_entities=["stadsarchief amsterdam"],
expected_sources=["qdrant"],
answer_contains=["Stadsarchief", "Amsterdam"],
).with_inputs("question", "language"),
]
return trainset, valset
# =============================================================================
# METRIC FUNCTION
# =============================================================================
def heritage_metric(example: Example, prediction, trace=None) -> float:
"""Evaluate Heritage RAG predictions.
Scores based on:
- Intent classification accuracy (35%)
- Entity extraction (25%)
- Answer relevance (30%)
- Has non-empty answer (10%)
Returns:
Score between 0 and 1
"""
score = 0.0
# 1. Intent match (35%)
expected_intent = getattr(example, 'expected_intent', None)
pred_intent = getattr(prediction, 'intent', None)
if expected_intent and pred_intent:
if expected_intent.lower() == pred_intent.lower():
score += 0.35
else:
# Partial credit for related intents
related_intents = {
('geographic', 'entity_lookup'): 0.15,
('statistical', 'comparative'): 0.15,
('exploration', 'entity_lookup'): 0.15,
('temporal', 'entity_lookup'): 0.15,
('person_lookup', 'entity_lookup'): 0.20,
('relational', 'exploration'): 0.15,
}
pair = (expected_intent.lower(), pred_intent.lower())
reverse_pair = (pred_intent.lower(), expected_intent.lower())
score += related_intents.get(pair, related_intents.get(reverse_pair, 0))
# 2. Entity extraction (25%)
expected_entities = getattr(example, 'expected_entities', [])
pred_entities = getattr(prediction, 'entities', [])
if expected_entities and pred_entities:
expected_lower = {e.lower() for e in expected_entities}
pred_lower = {str(e).lower() for e in pred_entities}
# Check overlap
overlap = expected_lower & pred_lower
if expected_lower:
entity_score = len(overlap) / len(expected_lower)
score += 0.25 * entity_score
elif expected_entities:
# Check if entities appear in the answer
answer = getattr(prediction, 'answer', '') or ''
answer_lower = answer.lower()
matches = sum(1 for e in expected_entities if e.lower() in answer_lower)
if expected_entities:
score += 0.25 * (matches / len(expected_entities))
# 3. Answer contains expected keywords (30%)
answer_contains = getattr(example, 'answer_contains', [])
answer = getattr(prediction, 'answer', '') or ''
if answer_contains and answer:
answer_lower = answer.lower()
matches = sum(1 for kw in answer_contains if kw.lower() in answer_lower)
keyword_score = matches / len(answer_contains)
score += 0.30 * keyword_score
# 4. Has non-empty, substantive answer (10%)
if answer and len(answer.strip()) > 50:
score += 0.10
elif answer and len(answer.strip()) > 20:
score += 0.05
return score
def heritage_metric_strict(example: Example, prediction, trace=None) -> bool:
"""Strict boolean metric for MIPROv2.
Returns True only if score >= 0.7
"""
return heritage_metric(example, prediction) >= 0.7
# =============================================================================
# MIPRO OPTIMIZATION
# =============================================================================
def run_mipro_optimization(
auto_setting: str = "light",
model: str = "openai/gpt-4o-mini",
teacher_model: str = "openai/gpt-4o",
):
"""Run MIPROv2 optimization on Heritage RAG pipeline.
Args:
auto_setting: MIPROv2 auto setting (light, medium, heavy)
model: Student model for pipeline
teacher_model: Teacher model for instruction generation
"""
logger.info("=" * 60)
logger.info("Heritage RAG - MIPROv2 Optimization")
logger.info("=" * 60)
logger.info(f"Auto setting: {auto_setting}")
logger.info(f"Student model: {model}")
logger.info(f"Teacher model: {teacher_model}")
# Check SSH tunnel for Qdrant
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', 7878))
sock.close()
if result != 0:
logger.warning("SSH tunnel not active on port 7878 - Qdrant queries may fail")
logger.warning("Run: ssh -f -N -L 7878:localhost:7878 root@91.98.224.44")
else:
logger.info("✓ SSH tunnel active (port 7878)")
# Configure DSPy
logger.info("Configuring DSPy...")
lm = dspy.LM(model, temperature=0.3, max_tokens=1500)
dspy.configure(lm=lm)
# Load training data
logger.info("Loading training data...")
trainset, valset = create_heritage_training_data()
logger.info(f" Training examples: {len(trainset)}")
logger.info(f" Validation examples: {len(valset)}")
# Create baseline pipeline
logger.info("Creating baseline HeritageRAGPipeline...")
from backend.rag.dspy_heritage_rag import HeritageRAGPipeline
pipeline = HeritageRAGPipeline()
# Evaluate baseline
logger.info("Evaluating baseline on validation set...")
baseline_scores = []
for ex in valset[:3]: # Quick eval on 3 examples
try:
pred = pipeline(question=ex.question, language=ex.language)
score = heritage_metric(ex, pred)
baseline_scores.append(score)
logger.info(f" Q: {ex.question[:50]}...")
logger.info(f" Intent: {pred.intent} (expected: {ex.expected_intent}) → Score: {score:.2f}")
except Exception as e:
logger.warning(f" Error: {e}")
baseline_scores.append(0.0)
baseline_avg = sum(baseline_scores) / len(baseline_scores) if baseline_scores else 0
logger.info(f"Baseline average score: {baseline_avg:.3f}")
# Create MIPROv2 optimizer
logger.info("-" * 60)
logger.info(f"Starting MIPROv2 optimization (auto={auto_setting})...")
logger.info("This will optimize instructions AND select few-shot examples")
# Create teacher LM for instruction generation
teacher_lm = dspy.LM(teacher_model, temperature=0.7, max_tokens=3000)
optimizer = dspy.MIPROv2(
metric=heritage_metric,
auto=auto_setting,
prompt_model=teacher_lm, # Use stronger model for prompt generation
task_model=lm, # Use standard model for task execution
num_threads=4,
# MIPROv2 specific settings
init_temperature=0.7,
track_stats=True,
)
# Compile optimized pipeline
logger.info("Compiling optimized pipeline...")
start_time = datetime.now(timezone.utc)
try:
optimized_pipeline = optimizer.compile(
pipeline,
trainset=trainset,
valset=valset,
requires_permission_to_run=False, # Don't prompt for confirmation
)
except Exception as e:
logger.error(f"Optimization failed: {e}")
import traceback
traceback.print_exc()
raise
elapsed = (datetime.now(timezone.utc) - start_time).total_seconds()
logger.info(f"Optimization completed in {elapsed:.1f} seconds")
# Evaluate optimized pipeline
logger.info("-" * 60)
logger.info("Evaluating optimized pipeline on full validation set...")
optimized_scores = []
for ex in valset:
try:
pred = optimized_pipeline(question=ex.question, language=ex.language)
score = heritage_metric(ex, pred)
optimized_scores.append(score)
logger.info(f" Q: {ex.question[:50]}...")
logger.info(f" Intent: {pred.intent} (expected: {ex.expected_intent}) → Score: {score:.2f}")
except Exception as e:
logger.warning(f" Error: {e}")
optimized_scores.append(0.0)
optimized_avg = sum(optimized_scores) / len(optimized_scores) if optimized_scores else 0
improvement = optimized_avg - baseline_avg
improvement_pct = (improvement / baseline_avg * 100) if baseline_avg > 0 else 0
logger.info(f"Optimized average score: {optimized_avg:.3f}")
logger.info(f"Improvement: {improvement:.3f} ({improvement_pct:.1f}%)")
# Save optimized model
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = Path(__file__).parent / "optimized_models"
model_dir.mkdir(exist_ok=True)
model_path = model_dir / f"heritage_rag_mipro_{timestamp}.json"
latest_path = model_dir / "heritage_rag_mipro_latest.json"
optimized_pipeline.save(str(model_path))
optimized_pipeline.save(str(latest_path))
logger.info(f"Saved optimized model to: {model_path}")
# Save metadata
metadata = {
"timestamp": timestamp,
"optimizer": "MIPROv2",
"auto_setting": auto_setting,
"student_model": model,
"teacher_model": teacher_model,
"training_examples": len(trainset),
"validation_examples": len(valset),
"baseline_score": baseline_avg,
"optimized_score": optimized_avg,
"improvement": improvement,
"improvement_pct": improvement_pct,
"optimization_seconds": elapsed,
}
metadata_path = model_dir / f"metadata_mipro_{timestamp}.json"
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Saved metadata to: {metadata_path}")
# Summary
logger.info("=" * 60)
logger.info("MIPRO OPTIMIZATION SUMMARY")
logger.info("=" * 60)
logger.info(f"Baseline Score: {baseline_avg:.3f}")
logger.info(f"Optimized Score: {optimized_avg:.3f}")
logger.info(f"Improvement: {improvement:.3f} ({improvement_pct:.1f}%)")
logger.info(f"Time: {elapsed:.1f}s")
logger.info(f"Model saved to: {latest_path}")
logger.info("=" * 60)
return optimized_pipeline, metadata
# =============================================================================
# COMPARISON WITH EXISTING OPTIMIZERS
# =============================================================================
def compare_optimizers():
"""Compare all optimized models against baseline."""
logger.info("=" * 60)
logger.info("Heritage RAG - Optimizer Comparison")
logger.info("=" * 60)
model_dir = Path(__file__).parent / "optimized_models"
# Load all metadata files
results = []
for meta_file in model_dir.glob("metadata_*.json"):
with open(meta_file) as f:
meta = json.load(f)
results.append(meta)
if not results:
logger.info("No optimization results found")
return
# Sort by timestamp
results.sort(key=lambda x: x.get('timestamp', ''))
logger.info(f"Found {len(results)} optimization runs:\n")
for r in results:
optimizer = r.get('optimizer', 'Unknown')
baseline = r.get('baseline_score', 0)
optimized = r.get('optimized_score', 0)
improvement = r.get('improvement', 0)
improvement_pct = r.get('improvement_pct', improvement / baseline * 100 if baseline > 0 else 0)
elapsed = r.get('optimization_seconds', 0)
timestamp = r.get('timestamp', 'Unknown')
logger.info(f" {optimizer} ({timestamp}):")
logger.info(f" Baseline: {baseline:.3f}")
logger.info(f" Optimized: {optimized:.3f}")
logger.info(f" Improvement: {improvement:.3f} ({improvement_pct:.1f}%)")
logger.info(f" Time: {elapsed:.1f}s")
logger.info("")
# Find best
best = max(results, key=lambda x: x.get('optimized_score', 0))
logger.info(f"Best optimizer: {best.get('optimizer')} with score {best.get('optimized_score', 0):.3f}")
# =============================================================================
# MAIN
# =============================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Heritage RAG MIPROv2 Optimization")
parser.add_argument(
"--auto",
choices=["light", "medium", "heavy"],
default="light",
help="MIPROv2 optimization intensity"
)
parser.add_argument(
"--model",
default="openai/gpt-4o-mini",
help="Student model for pipeline"
)
parser.add_argument(
"--teacher",
default="openai/gpt-4o",
help="Teacher model for instruction generation"
)
parser.add_argument(
"--compare",
action="store_true",
help="Compare all optimized models"
)
args = parser.parse_args()
# Check environment
if not os.environ.get("OPENAI_API_KEY"):
logger.error("OPENAI_API_KEY not set. Run: source .env")
sys.exit(1)
if args.compare:
compare_optimizers()
else:
run_mipro_optimization(
auto_setting=args.auto,
model=args.model,
teacher_model=args.teacher,
)