650 lines
24 KiB
Python
650 lines
24 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
MIPROv2 Optimization for Heritage RAG Pipeline
|
|
|
|
MIPROv2 is DSPy 3.0's recommended optimizer for instruction optimization.
|
|
It generates optimized instructions AND selects optimal few-shot demonstrations.
|
|
|
|
Benefits over BootstrapFewShot:
|
|
- Generates optimized instructions (not just few-shot examples)
|
|
- Uses Bayesian optimization to search instruction space
|
|
- Better at finding optimal prompts for complex tasks
|
|
|
|
Usage:
|
|
cd /Users/kempersc/apps/glam
|
|
source .venv/bin/activate && source .env
|
|
python backend/rag/run_mipro_optimization.py
|
|
|
|
Requirements:
|
|
- SSH tunnel active: ssh -f -N -L 7878:localhost:7878 root@91.98.224.44
|
|
- Environment loaded with OPENAI_API_KEY
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import socket
|
|
import sys
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import dspy
|
|
from dspy import Example
|
|
|
|
# Add parent to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# TRAINING DATA
|
|
# =============================================================================
|
|
|
|
def create_heritage_training_data():
|
|
"""Create comprehensive training data for heritage RAG optimization.
|
|
|
|
Covers all query intents and both languages (Dutch/English).
|
|
"""
|
|
|
|
trainset = [
|
|
# Statistical queries
|
|
Example(
|
|
question="Hoeveel musea zijn er in Amsterdam?",
|
|
language="nl",
|
|
expected_intent="statistical",
|
|
expected_entities=["amsterdam", "musea"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["musea", "Amsterdam"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="How many archives are in the Netherlands?",
|
|
language="en",
|
|
expected_intent="statistical",
|
|
expected_entities=["netherlands", "archives"],
|
|
expected_sources=["sparql"],
|
|
answer_contains=["archive", "Netherlands"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="What is the total number of libraries in Utrecht?",
|
|
language="en",
|
|
expected_intent="statistical",
|
|
expected_entities=["utrecht", "libraries"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["library", "Utrecht"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Entity lookup queries
|
|
Example(
|
|
question="Where is the Rijksmuseum located?",
|
|
language="en",
|
|
expected_intent="entity_lookup",
|
|
expected_entities=["rijksmuseum"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["Rijksmuseum", "Amsterdam"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Wat is het Nationaal Archief?",
|
|
language="nl",
|
|
expected_intent="entity_lookup",
|
|
expected_entities=["nationaal archief"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["Nationaal Archief", "Den Haag"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Tell me about the Van Gogh Museum",
|
|
language="en",
|
|
expected_intent="entity_lookup",
|
|
expected_entities=["van gogh museum"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["Van Gogh", "museum", "Amsterdam"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Exploration queries
|
|
Example(
|
|
question="Show me archives related to World War II",
|
|
language="en",
|
|
expected_intent="exploration",
|
|
expected_entities=["world war ii", "archives"],
|
|
expected_sources=["qdrant", "sparql"],
|
|
answer_contains=["archive", "war"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="What heritage institutions focus on maritime history?",
|
|
language="en",
|
|
expected_intent="exploration",
|
|
expected_entities=["maritime history", "heritage"],
|
|
expected_sources=["qdrant"],
|
|
answer_contains=["maritime", "museum"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Geographic queries
|
|
Example(
|
|
question="Which museums are in Noord-Holland province?",
|
|
language="en",
|
|
expected_intent="geographic",
|
|
expected_entities=["noord-holland", "museums"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["museum", "Noord-Holland"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Welke archieven zijn er in Limburg?",
|
|
language="nl",
|
|
expected_intent="geographic",
|
|
expected_entities=["limburg", "archieven"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["archief", "Limburg"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="List all libraries in Rotterdam",
|
|
language="en",
|
|
expected_intent="geographic",
|
|
expected_entities=["rotterdam", "libraries"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["library", "Rotterdam"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Temporal queries
|
|
Example(
|
|
question="Welke bibliotheken zijn gefuseerd sinds 2000?",
|
|
language="nl",
|
|
expected_intent="temporal",
|
|
expected_entities=["bibliotheken", "2000"],
|
|
expected_sources=["typedb", "sparql"],
|
|
answer_contains=["bibliotheek"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Which museums opened after 2010?",
|
|
language="en",
|
|
expected_intent="temporal",
|
|
expected_entities=["museums", "2010"],
|
|
expected_sources=["sparql", "typedb"],
|
|
answer_contains=["museum", "opened"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Relational queries
|
|
Example(
|
|
question="What collections does the Nationaal Archief manage?",
|
|
language="en",
|
|
expected_intent="relational",
|
|
expected_entities=["nationaal archief", "collections"],
|
|
expected_sources=["typedb", "sparql"],
|
|
answer_contains=["Nationaal Archief", "collection"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Which institutions are part of the KNAW?",
|
|
language="en",
|
|
expected_intent="relational",
|
|
expected_entities=["knaw", "institutions"],
|
|
expected_sources=["typedb", "sparql"],
|
|
answer_contains=["KNAW"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Comparative queries
|
|
Example(
|
|
question="Compare visitor numbers of Rijksmuseum and Van Gogh Museum",
|
|
language="en",
|
|
expected_intent="comparative",
|
|
expected_entities=["rijksmuseum", "van gogh museum"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["visitor", "museum"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Person queries (role-based)
|
|
Example(
|
|
question="Who are the curators at the Mauritshuis?",
|
|
language="en",
|
|
expected_intent="person_lookup",
|
|
expected_entities=["mauritshuis", "curators"],
|
|
expected_sources=["qdrant"],
|
|
answer_contains=["Mauritshuis", "curator"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Wie is de directeur van het Rijksmuseum?",
|
|
language="nl",
|
|
expected_intent="person_lookup",
|
|
expected_entities=["rijksmuseum", "directeur"],
|
|
expected_sources=["qdrant"],
|
|
answer_contains=["Rijksmuseum", "directeur"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Find archivists working at provincial archives",
|
|
language="en",
|
|
expected_intent="person_lookup",
|
|
expected_entities=["archivists", "provincial archives"],
|
|
expected_sources=["qdrant"],
|
|
answer_contains=["archivist", "archive"],
|
|
).with_inputs("question", "language"),
|
|
|
|
# Complex multi-hop queries
|
|
Example(
|
|
question="Which archives in Amsterdam have digitized their collections?",
|
|
language="en",
|
|
expected_intent="exploration",
|
|
expected_entities=["amsterdam", "archives", "digitized"],
|
|
expected_sources=["qdrant", "sparql"],
|
|
answer_contains=["archive", "digital"],
|
|
).with_inputs("question", "language"),
|
|
]
|
|
|
|
valset = [
|
|
Example(
|
|
question="List all libraries in Utrecht",
|
|
language="en",
|
|
expected_intent="geographic",
|
|
expected_entities=["libraries", "utrecht"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["library", "Utrecht"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Wat is de geschiedenis van het Anne Frank Huis?",
|
|
language="nl",
|
|
expected_intent="entity_lookup",
|
|
expected_entities=["anne frank huis"],
|
|
expected_sources=["sparql", "qdrant"],
|
|
answer_contains=["Anne Frank"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="How many heritage institutions are in the Netherlands?",
|
|
language="en",
|
|
expected_intent="statistical",
|
|
expected_entities=["heritage institutions", "netherlands"],
|
|
expected_sources=["sparql"],
|
|
answer_contains=["institution", "Netherlands"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Which museums have paintings by Rembrandt?",
|
|
language="en",
|
|
expected_intent="exploration",
|
|
expected_entities=["museums", "rembrandt"],
|
|
expected_sources=["qdrant"],
|
|
answer_contains=["Rembrandt", "museum"],
|
|
).with_inputs("question", "language"),
|
|
|
|
Example(
|
|
question="Wie werkt er bij het Stadsarchief Amsterdam?",
|
|
language="nl",
|
|
expected_intent="person_lookup",
|
|
expected_entities=["stadsarchief amsterdam"],
|
|
expected_sources=["qdrant"],
|
|
answer_contains=["Stadsarchief", "Amsterdam"],
|
|
).with_inputs("question", "language"),
|
|
]
|
|
|
|
return trainset, valset
|
|
|
|
|
|
# =============================================================================
|
|
# METRIC FUNCTION
|
|
# =============================================================================
|
|
|
|
def heritage_metric(example: Example, prediction, trace=None) -> float:
|
|
"""Evaluate Heritage RAG predictions.
|
|
|
|
Scores based on:
|
|
- Intent classification accuracy (35%)
|
|
- Entity extraction (25%)
|
|
- Answer relevance (30%)
|
|
- Has non-empty answer (10%)
|
|
|
|
Returns:
|
|
Score between 0 and 1
|
|
"""
|
|
score = 0.0
|
|
|
|
# 1. Intent match (35%)
|
|
expected_intent = getattr(example, 'expected_intent', None)
|
|
pred_intent = getattr(prediction, 'intent', None)
|
|
|
|
if expected_intent and pred_intent:
|
|
if expected_intent.lower() == pred_intent.lower():
|
|
score += 0.35
|
|
else:
|
|
# Partial credit for related intents
|
|
related_intents = {
|
|
('geographic', 'entity_lookup'): 0.15,
|
|
('statistical', 'comparative'): 0.15,
|
|
('exploration', 'entity_lookup'): 0.15,
|
|
('temporal', 'entity_lookup'): 0.15,
|
|
('person_lookup', 'entity_lookup'): 0.20,
|
|
('relational', 'exploration'): 0.15,
|
|
}
|
|
pair = (expected_intent.lower(), pred_intent.lower())
|
|
reverse_pair = (pred_intent.lower(), expected_intent.lower())
|
|
score += related_intents.get(pair, related_intents.get(reverse_pair, 0))
|
|
|
|
# 2. Entity extraction (25%)
|
|
expected_entities = getattr(example, 'expected_entities', [])
|
|
pred_entities = getattr(prediction, 'entities', [])
|
|
|
|
if expected_entities and pred_entities:
|
|
expected_lower = {e.lower() for e in expected_entities}
|
|
pred_lower = {str(e).lower() for e in pred_entities}
|
|
|
|
# Check overlap
|
|
overlap = expected_lower & pred_lower
|
|
if expected_lower:
|
|
entity_score = len(overlap) / len(expected_lower)
|
|
score += 0.25 * entity_score
|
|
elif expected_entities:
|
|
# Check if entities appear in the answer
|
|
answer = getattr(prediction, 'answer', '') or ''
|
|
answer_lower = answer.lower()
|
|
matches = sum(1 for e in expected_entities if e.lower() in answer_lower)
|
|
if expected_entities:
|
|
score += 0.25 * (matches / len(expected_entities))
|
|
|
|
# 3. Answer contains expected keywords (30%)
|
|
answer_contains = getattr(example, 'answer_contains', [])
|
|
answer = getattr(prediction, 'answer', '') or ''
|
|
|
|
if answer_contains and answer:
|
|
answer_lower = answer.lower()
|
|
matches = sum(1 for kw in answer_contains if kw.lower() in answer_lower)
|
|
keyword_score = matches / len(answer_contains)
|
|
score += 0.30 * keyword_score
|
|
|
|
# 4. Has non-empty, substantive answer (10%)
|
|
if answer and len(answer.strip()) > 50:
|
|
score += 0.10
|
|
elif answer and len(answer.strip()) > 20:
|
|
score += 0.05
|
|
|
|
return score
|
|
|
|
|
|
def heritage_metric_strict(example: Example, prediction, trace=None) -> bool:
|
|
"""Strict boolean metric for MIPROv2.
|
|
|
|
Returns True only if score >= 0.7
|
|
"""
|
|
return heritage_metric(example, prediction) >= 0.7
|
|
|
|
|
|
# =============================================================================
|
|
# MIPRO OPTIMIZATION
|
|
# =============================================================================
|
|
|
|
def run_mipro_optimization(
|
|
auto_setting: str = "light",
|
|
model: str = "openai/gpt-4o-mini",
|
|
teacher_model: str = "openai/gpt-4o",
|
|
):
|
|
"""Run MIPROv2 optimization on Heritage RAG pipeline.
|
|
|
|
Args:
|
|
auto_setting: MIPROv2 auto setting (light, medium, heavy)
|
|
model: Student model for pipeline
|
|
teacher_model: Teacher model for instruction generation
|
|
"""
|
|
logger.info("=" * 60)
|
|
logger.info("Heritage RAG - MIPROv2 Optimization")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Auto setting: {auto_setting}")
|
|
logger.info(f"Student model: {model}")
|
|
logger.info(f"Teacher model: {teacher_model}")
|
|
|
|
# Check SSH tunnel for Qdrant
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
result = sock.connect_ex(('localhost', 7878))
|
|
sock.close()
|
|
if result != 0:
|
|
logger.warning("SSH tunnel not active on port 7878 - Qdrant queries may fail")
|
|
logger.warning("Run: ssh -f -N -L 7878:localhost:7878 root@91.98.224.44")
|
|
else:
|
|
logger.info("✓ SSH tunnel active (port 7878)")
|
|
|
|
# Configure DSPy
|
|
logger.info("Configuring DSPy...")
|
|
lm = dspy.LM(model, temperature=0.3, max_tokens=1500)
|
|
dspy.configure(lm=lm)
|
|
|
|
# Load training data
|
|
logger.info("Loading training data...")
|
|
trainset, valset = create_heritage_training_data()
|
|
logger.info(f" Training examples: {len(trainset)}")
|
|
logger.info(f" Validation examples: {len(valset)}")
|
|
|
|
# Create baseline pipeline
|
|
logger.info("Creating baseline HeritageRAGPipeline...")
|
|
from backend.rag.dspy_heritage_rag import HeritageRAGPipeline
|
|
pipeline = HeritageRAGPipeline()
|
|
|
|
# Evaluate baseline
|
|
logger.info("Evaluating baseline on validation set...")
|
|
baseline_scores = []
|
|
for ex in valset[:3]: # Quick eval on 3 examples
|
|
try:
|
|
pred = pipeline(question=ex.question, language=ex.language)
|
|
score = heritage_metric(ex, pred)
|
|
baseline_scores.append(score)
|
|
logger.info(f" Q: {ex.question[:50]}...")
|
|
logger.info(f" Intent: {pred.intent} (expected: {ex.expected_intent}) → Score: {score:.2f}")
|
|
except Exception as e:
|
|
logger.warning(f" Error: {e}")
|
|
baseline_scores.append(0.0)
|
|
|
|
baseline_avg = sum(baseline_scores) / len(baseline_scores) if baseline_scores else 0
|
|
logger.info(f"Baseline average score: {baseline_avg:.3f}")
|
|
|
|
# Create MIPROv2 optimizer
|
|
logger.info("-" * 60)
|
|
logger.info(f"Starting MIPROv2 optimization (auto={auto_setting})...")
|
|
logger.info("This will optimize instructions AND select few-shot examples")
|
|
|
|
# Create teacher LM for instruction generation
|
|
teacher_lm = dspy.LM(teacher_model, temperature=0.7, max_tokens=3000)
|
|
|
|
optimizer = dspy.MIPROv2(
|
|
metric=heritage_metric,
|
|
auto=auto_setting,
|
|
prompt_model=teacher_lm, # Use stronger model for prompt generation
|
|
task_model=lm, # Use standard model for task execution
|
|
num_threads=4,
|
|
# MIPROv2 specific settings
|
|
init_temperature=0.7,
|
|
track_stats=True,
|
|
)
|
|
|
|
# Compile optimized pipeline
|
|
logger.info("Compiling optimized pipeline...")
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
try:
|
|
optimized_pipeline = optimizer.compile(
|
|
pipeline,
|
|
trainset=trainset,
|
|
valset=valset,
|
|
requires_permission_to_run=False, # Don't prompt for confirmation
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Optimization failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
elapsed = (datetime.now(timezone.utc) - start_time).total_seconds()
|
|
logger.info(f"Optimization completed in {elapsed:.1f} seconds")
|
|
|
|
# Evaluate optimized pipeline
|
|
logger.info("-" * 60)
|
|
logger.info("Evaluating optimized pipeline on full validation set...")
|
|
optimized_scores = []
|
|
for ex in valset:
|
|
try:
|
|
pred = optimized_pipeline(question=ex.question, language=ex.language)
|
|
score = heritage_metric(ex, pred)
|
|
optimized_scores.append(score)
|
|
logger.info(f" Q: {ex.question[:50]}...")
|
|
logger.info(f" Intent: {pred.intent} (expected: {ex.expected_intent}) → Score: {score:.2f}")
|
|
except Exception as e:
|
|
logger.warning(f" Error: {e}")
|
|
optimized_scores.append(0.0)
|
|
|
|
optimized_avg = sum(optimized_scores) / len(optimized_scores) if optimized_scores else 0
|
|
improvement = optimized_avg - baseline_avg
|
|
improvement_pct = (improvement / baseline_avg * 100) if baseline_avg > 0 else 0
|
|
|
|
logger.info(f"Optimized average score: {optimized_avg:.3f}")
|
|
logger.info(f"Improvement: {improvement:.3f} ({improvement_pct:.1f}%)")
|
|
|
|
# Save optimized model
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
model_dir = Path(__file__).parent / "optimized_models"
|
|
model_dir.mkdir(exist_ok=True)
|
|
|
|
model_path = model_dir / f"heritage_rag_mipro_{timestamp}.json"
|
|
latest_path = model_dir / "heritage_rag_mipro_latest.json"
|
|
|
|
optimized_pipeline.save(str(model_path))
|
|
optimized_pipeline.save(str(latest_path))
|
|
logger.info(f"Saved optimized model to: {model_path}")
|
|
|
|
# Save metadata
|
|
metadata = {
|
|
"timestamp": timestamp,
|
|
"optimizer": "MIPROv2",
|
|
"auto_setting": auto_setting,
|
|
"student_model": model,
|
|
"teacher_model": teacher_model,
|
|
"training_examples": len(trainset),
|
|
"validation_examples": len(valset),
|
|
"baseline_score": baseline_avg,
|
|
"optimized_score": optimized_avg,
|
|
"improvement": improvement,
|
|
"improvement_pct": improvement_pct,
|
|
"optimization_seconds": elapsed,
|
|
}
|
|
|
|
metadata_path = model_dir / f"metadata_mipro_{timestamp}.json"
|
|
with open(metadata_path, "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
logger.info(f"Saved metadata to: {metadata_path}")
|
|
|
|
# Summary
|
|
logger.info("=" * 60)
|
|
logger.info("MIPRO OPTIMIZATION SUMMARY")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Baseline Score: {baseline_avg:.3f}")
|
|
logger.info(f"Optimized Score: {optimized_avg:.3f}")
|
|
logger.info(f"Improvement: {improvement:.3f} ({improvement_pct:.1f}%)")
|
|
logger.info(f"Time: {elapsed:.1f}s")
|
|
logger.info(f"Model saved to: {latest_path}")
|
|
logger.info("=" * 60)
|
|
|
|
return optimized_pipeline, metadata
|
|
|
|
|
|
# =============================================================================
|
|
# COMPARISON WITH EXISTING OPTIMIZERS
|
|
# =============================================================================
|
|
|
|
def compare_optimizers():
|
|
"""Compare all optimized models against baseline."""
|
|
logger.info("=" * 60)
|
|
logger.info("Heritage RAG - Optimizer Comparison")
|
|
logger.info("=" * 60)
|
|
|
|
model_dir = Path(__file__).parent / "optimized_models"
|
|
|
|
# Load all metadata files
|
|
results = []
|
|
for meta_file in model_dir.glob("metadata_*.json"):
|
|
with open(meta_file) as f:
|
|
meta = json.load(f)
|
|
results.append(meta)
|
|
|
|
if not results:
|
|
logger.info("No optimization results found")
|
|
return
|
|
|
|
# Sort by timestamp
|
|
results.sort(key=lambda x: x.get('timestamp', ''))
|
|
|
|
logger.info(f"Found {len(results)} optimization runs:\n")
|
|
|
|
for r in results:
|
|
optimizer = r.get('optimizer', 'Unknown')
|
|
baseline = r.get('baseline_score', 0)
|
|
optimized = r.get('optimized_score', 0)
|
|
improvement = r.get('improvement', 0)
|
|
improvement_pct = r.get('improvement_pct', improvement / baseline * 100 if baseline > 0 else 0)
|
|
elapsed = r.get('optimization_seconds', 0)
|
|
timestamp = r.get('timestamp', 'Unknown')
|
|
|
|
logger.info(f" {optimizer} ({timestamp}):")
|
|
logger.info(f" Baseline: {baseline:.3f}")
|
|
logger.info(f" Optimized: {optimized:.3f}")
|
|
logger.info(f" Improvement: {improvement:.3f} ({improvement_pct:.1f}%)")
|
|
logger.info(f" Time: {elapsed:.1f}s")
|
|
logger.info("")
|
|
|
|
# Find best
|
|
best = max(results, key=lambda x: x.get('optimized_score', 0))
|
|
logger.info(f"Best optimizer: {best.get('optimizer')} with score {best.get('optimized_score', 0):.3f}")
|
|
|
|
|
|
# =============================================================================
|
|
# MAIN
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Heritage RAG MIPROv2 Optimization")
|
|
parser.add_argument(
|
|
"--auto",
|
|
choices=["light", "medium", "heavy"],
|
|
default="light",
|
|
help="MIPROv2 optimization intensity"
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
default="openai/gpt-4o-mini",
|
|
help="Student model for pipeline"
|
|
)
|
|
parser.add_argument(
|
|
"--teacher",
|
|
default="openai/gpt-4o",
|
|
help="Teacher model for instruction generation"
|
|
)
|
|
parser.add_argument(
|
|
"--compare",
|
|
action="store_true",
|
|
help="Compare all optimized models"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check environment
|
|
if not os.environ.get("OPENAI_API_KEY"):
|
|
logger.error("OPENAI_API_KEY not set. Run: source .env")
|
|
sys.exit(1)
|
|
|
|
if args.compare:
|
|
compare_optimizers()
|
|
else:
|
|
run_mipro_optimization(
|
|
auto_setting=args.auto,
|
|
model=args.model,
|
|
teacher_model=args.teacher,
|
|
)
|