glam/backend/rag/run_bootstrap_optimization.py
2025-12-11 22:32:09 +01:00

339 lines
12 KiB
Python

#!/usr/bin/env python3
"""
BootstrapFewShot Optimization for Heritage RAG Pipeline
This script uses DSPy's BootstrapFewShot optimizer - a faster alternative to GEPA
that works by creating few-shot demonstrations from the training data.
Benefits over GEPA:
- Much faster (minutes vs hours)
- Lower LLM API usage
- Good for initial optimization before fine-tuning
Usage:
python run_bootstrap_optimization.py
Requirements:
- SSH tunnel active: ssh -f -N -L 7878:localhost:7878 root@91.98.224.44
- Environment loaded: source .venv/bin/activate && source .env
"""
import json
import logging
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
import dspy
from dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearch
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from backend.rag.dspy_heritage_rag import HeritageRAGPipeline
from backend.rag.gepa_training_extended import get_extended_training_data
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# =============================================================================
# METRIC FUNCTION
# =============================================================================
def heritage_metric(example: dspy.Example, prediction: dspy.Prediction, trace=None) -> float:
"""Evaluate Heritage RAG predictions.
Scores based on:
- Intent classification accuracy (40%)
- Answer relevance - contains expected keywords (40%)
- Has non-empty answer (20%)
Returns:
Score between 0 and 1
"""
score = 0.0
# 1. Intent match (40%)
if hasattr(example, 'expected_intent') and hasattr(prediction, 'intent'):
if example.expected_intent.lower() == prediction.intent.lower():
score += 0.4
else:
# Partial credit for related intents
related_intents = {
('geographic', 'entity_lookup'): 0.2,
('statistical', 'comparative'): 0.2,
('exploration', 'entity_lookup'): 0.2,
('temporal', 'entity_lookup'): 0.2,
}
pair = (example.expected_intent.lower(), prediction.intent.lower())
reverse_pair = (prediction.intent.lower(), example.expected_intent.lower())
score += related_intents.get(pair, related_intents.get(reverse_pair, 0))
# 2. Answer contains expected keywords (40%)
if hasattr(example, 'answer_contains') and hasattr(prediction, 'answer'):
answer_lower = prediction.answer.lower() if prediction.answer else ""
keywords = example.answer_contains
if keywords:
matches = sum(1 for kw in keywords if kw.lower() in answer_lower)
keyword_score = matches / len(keywords)
score += 0.4 * keyword_score
# 3. Has non-empty answer (20%)
if hasattr(prediction, 'answer') and prediction.answer:
if len(prediction.answer.strip()) > 20:
score += 0.2
elif len(prediction.answer.strip()) > 0:
score += 0.1
return score
def heritage_metric_strict(example: dspy.Example, prediction: dspy.Prediction, trace=None) -> bool:
"""Strict boolean metric for BootstrapFewShot.
Returns True only if score >= 0.6
"""
return heritage_metric(example, prediction) >= 0.6
# =============================================================================
# OPTIMIZATION RUNNER
# =============================================================================
def run_bootstrap_optimization(
use_random_search: bool = False,
max_bootstrapped_demos: int = 4,
max_labeled_demos: int = 8,
num_candidates: int = 10,
):
"""Run BootstrapFewShot optimization on Heritage RAG pipeline.
Args:
use_random_search: If True, use BootstrapFewShotWithRandomSearch
max_bootstrapped_demos: Max demonstrations to bootstrap
max_labeled_demos: Max labeled examples to include
num_candidates: Number of candidate programs for random search
"""
logger.info("=" * 60)
logger.info("Heritage RAG - BootstrapFewShot Optimization")
logger.info("=" * 60)
# Check SSH tunnel
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', 7878))
sock.close()
if result != 0:
logger.error("SSH tunnel not active! Run: ssh -f -N -L 7878:localhost:7878 root@91.98.224.44")
return None
logger.info("✓ SSH tunnel active (port 7878)")
# Configure DSPy
logger.info("Configuring DSPy with GPT-4o-mini...")
lm = dspy.LM('openai/gpt-4o-mini', temperature=0.3, max_tokens=1000)
dspy.configure(lm=lm)
# Load training data
logger.info("Loading training data...")
trainset, valset = get_extended_training_data()
logger.info(f" Training examples: {len(trainset)}")
logger.info(f" Validation examples: {len(valset)}")
# Create baseline pipeline
logger.info("Creating baseline HeritageRAGPipeline...")
pipeline = HeritageRAGPipeline()
# Evaluate baseline
logger.info("Evaluating baseline on validation set...")
baseline_scores = []
for ex in valset[:5]: # Quick eval on 5 examples
try:
pred = pipeline(question=ex.question, language=ex.language)
score = heritage_metric(ex, pred)
baseline_scores.append(score)
logger.info(f" Q: {ex.question[:50]}... → Score: {score:.2f}")
except Exception as e:
logger.warning(f" Error on example: {e}")
baseline_scores.append(0.0)
baseline_avg = sum(baseline_scores) / len(baseline_scores) if baseline_scores else 0
logger.info(f"Baseline average score: {baseline_avg:.3f}")
# Run optimization
logger.info("-" * 60)
if use_random_search:
logger.info(f"Starting BootstrapFewShotWithRandomSearch optimization...")
logger.info(f" max_bootstrapped_demos: {max_bootstrapped_demos}")
logger.info(f" max_labeled_demos: {max_labeled_demos}")
logger.info(f" num_candidates: {num_candidates}")
optimizer = BootstrapFewShotWithRandomSearch(
metric=heritage_metric_strict,
max_bootstrapped_demos=max_bootstrapped_demos,
max_labeled_demos=max_labeled_demos,
num_candidate_programs=num_candidates,
num_threads=4,
)
else:
logger.info(f"Starting BootstrapFewShot optimization...")
logger.info(f" max_bootstrapped_demos: {max_bootstrapped_demos}")
logger.info(f" max_labeled_demos: {max_labeled_demos}")
optimizer = BootstrapFewShot(
metric=heritage_metric_strict,
max_bootstrapped_demos=max_bootstrapped_demos,
max_labeled_demos=max_labeled_demos,
)
# Compile optimized pipeline
logger.info("Compiling optimized pipeline (this may take a few minutes)...")
start_time = datetime.now(timezone.utc)
try:
optimized_pipeline = optimizer.compile(
pipeline,
trainset=trainset,
)
except Exception as e:
logger.error(f"Optimization failed: {e}")
raise
elapsed = (datetime.now(timezone.utc) - start_time).total_seconds()
logger.info(f"Optimization completed in {elapsed:.1f} seconds")
# Evaluate optimized pipeline
logger.info("-" * 60)
logger.info("Evaluating optimized pipeline on validation set...")
optimized_scores = []
for ex in valset:
try:
pred = optimized_pipeline(question=ex.question, language=ex.language)
score = heritage_metric(ex, pred)
optimized_scores.append(score)
logger.info(f" Q: {ex.question[:50]}... → Score: {score:.2f}, Intent: {pred.intent}")
except Exception as e:
logger.warning(f" Error on example: {e}")
optimized_scores.append(0.0)
optimized_avg = sum(optimized_scores) / len(optimized_scores) if optimized_scores else 0
logger.info(f"Optimized average score: {optimized_avg:.3f}")
logger.info(f"Improvement: {(optimized_avg - baseline_avg):.3f} ({(optimized_avg - baseline_avg) / baseline_avg * 100:.1f}%)")
# Save optimized model
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = Path(__file__).parent / "optimized_models"
model_dir.mkdir(exist_ok=True)
model_path = model_dir / f"heritage_rag_bootstrap_{timestamp}.json"
latest_path = model_dir / "heritage_rag_bootstrap_latest.json"
optimized_pipeline.save(str(model_path))
optimized_pipeline.save(str(latest_path))
logger.info(f"Saved optimized model to: {model_path}")
# Save metadata
metadata = {
"timestamp": timestamp,
"optimizer": "BootstrapFewShotWithRandomSearch" if use_random_search else "BootstrapFewShot",
"max_bootstrapped_demos": max_bootstrapped_demos,
"max_labeled_demos": max_labeled_demos,
"training_examples": len(trainset),
"validation_examples": len(valset),
"baseline_score": baseline_avg,
"optimized_score": optimized_avg,
"improvement": optimized_avg - baseline_avg,
"optimization_seconds": elapsed,
}
metadata_path = model_dir / f"metadata_bootstrap_{timestamp}.json"
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Saved metadata to: {metadata_path}")
# Summary
logger.info("=" * 60)
logger.info("OPTIMIZATION SUMMARY")
logger.info("=" * 60)
logger.info(f"Baseline Score: {baseline_avg:.3f}")
logger.info(f"Optimized Score: {optimized_avg:.3f}")
logger.info(f"Improvement: {(optimized_avg - baseline_avg):.3f}")
logger.info(f"Time: {elapsed:.1f}s")
logger.info(f"Model saved to: {latest_path}")
logger.info("=" * 60)
return optimized_pipeline, metadata
# =============================================================================
# QUICK EVALUATION
# =============================================================================
def evaluate_saved_model(model_path: str = None):
"""Load and evaluate a saved model."""
if model_path is None:
model_path = Path(__file__).parent / "optimized_models" / "heritage_rag_bootstrap_latest.json"
logger.info(f"Loading model from: {model_path}")
# Configure DSPy
lm = dspy.LM('openai/gpt-4o-mini', temperature=0.3, max_tokens=1000)
dspy.configure(lm=lm)
# Load pipeline
pipeline = HeritageRAGPipeline()
pipeline.load(str(model_path))
# Load validation data
_, valset = get_extended_training_data()
# Evaluate
scores = []
for ex in valset:
try:
pred = pipeline(question=ex.question, language=ex.language)
score = heritage_metric(ex, pred)
scores.append(score)
logger.info(f"Q: {ex.question[:50]}...")
logger.info(f" Intent: {pred.intent} (expected: {ex.expected_intent})")
logger.info(f" Score: {score:.2f}")
except Exception as e:
logger.warning(f"Error: {e}")
scores.append(0.0)
avg_score = sum(scores) / len(scores) if scores else 0
logger.info(f"\nAverage score: {avg_score:.3f}")
return avg_score
# =============================================================================
# MAIN
# =============================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Heritage RAG Bootstrap Optimization")
parser.add_argument("--random-search", action="store_true", help="Use random search variant")
parser.add_argument("--evaluate", action="store_true", help="Just evaluate saved model")
parser.add_argument("--model-path", type=str, help="Model path for evaluation")
parser.add_argument("--demos", type=int, default=4, help="Max bootstrapped demos")
parser.add_argument("--labeled", type=int, default=8, help="Max labeled demos")
parser.add_argument("--candidates", type=int, default=10, help="Num candidates for random search")
args = parser.parse_args()
if args.evaluate:
evaluate_saved_model(args.model_path)
else:
run_bootstrap_optimization(
use_random_search=args.random_search,
max_bootstrapped_demos=args.demos,
max_labeled_demos=args.labeled,
num_candidates=args.candidates,
)