339 lines
12 KiB
Python
339 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
BootstrapFewShot Optimization for Heritage RAG Pipeline
|
|
|
|
This script uses DSPy's BootstrapFewShot optimizer - a faster alternative to GEPA
|
|
that works by creating few-shot demonstrations from the training data.
|
|
|
|
Benefits over GEPA:
|
|
- Much faster (minutes vs hours)
|
|
- Lower LLM API usage
|
|
- Good for initial optimization before fine-tuning
|
|
|
|
Usage:
|
|
python run_bootstrap_optimization.py
|
|
|
|
Requirements:
|
|
- SSH tunnel active: ssh -f -N -L 7878:localhost:7878 root@91.98.224.44
|
|
- Environment loaded: source .venv/bin/activate && source .env
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import dspy
|
|
from dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearch
|
|
|
|
# Add parent to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from backend.rag.dspy_heritage_rag import HeritageRAGPipeline
|
|
from backend.rag.gepa_training_extended import get_extended_training_data
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# METRIC FUNCTION
|
|
# =============================================================================
|
|
|
|
def heritage_metric(example: dspy.Example, prediction: dspy.Prediction, trace=None) -> float:
|
|
"""Evaluate Heritage RAG predictions.
|
|
|
|
Scores based on:
|
|
- Intent classification accuracy (40%)
|
|
- Answer relevance - contains expected keywords (40%)
|
|
- Has non-empty answer (20%)
|
|
|
|
Returns:
|
|
Score between 0 and 1
|
|
"""
|
|
score = 0.0
|
|
|
|
# 1. Intent match (40%)
|
|
if hasattr(example, 'expected_intent') and hasattr(prediction, 'intent'):
|
|
if example.expected_intent.lower() == prediction.intent.lower():
|
|
score += 0.4
|
|
else:
|
|
# Partial credit for related intents
|
|
related_intents = {
|
|
('geographic', 'entity_lookup'): 0.2,
|
|
('statistical', 'comparative'): 0.2,
|
|
('exploration', 'entity_lookup'): 0.2,
|
|
('temporal', 'entity_lookup'): 0.2,
|
|
}
|
|
pair = (example.expected_intent.lower(), prediction.intent.lower())
|
|
reverse_pair = (prediction.intent.lower(), example.expected_intent.lower())
|
|
score += related_intents.get(pair, related_intents.get(reverse_pair, 0))
|
|
|
|
# 2. Answer contains expected keywords (40%)
|
|
if hasattr(example, 'answer_contains') and hasattr(prediction, 'answer'):
|
|
answer_lower = prediction.answer.lower() if prediction.answer else ""
|
|
keywords = example.answer_contains
|
|
if keywords:
|
|
matches = sum(1 for kw in keywords if kw.lower() in answer_lower)
|
|
keyword_score = matches / len(keywords)
|
|
score += 0.4 * keyword_score
|
|
|
|
# 3. Has non-empty answer (20%)
|
|
if hasattr(prediction, 'answer') and prediction.answer:
|
|
if len(prediction.answer.strip()) > 20:
|
|
score += 0.2
|
|
elif len(prediction.answer.strip()) > 0:
|
|
score += 0.1
|
|
|
|
return score
|
|
|
|
|
|
def heritage_metric_strict(example: dspy.Example, prediction: dspy.Prediction, trace=None) -> bool:
|
|
"""Strict boolean metric for BootstrapFewShot.
|
|
|
|
Returns True only if score >= 0.6
|
|
"""
|
|
return heritage_metric(example, prediction) >= 0.6
|
|
|
|
|
|
# =============================================================================
|
|
# OPTIMIZATION RUNNER
|
|
# =============================================================================
|
|
|
|
def run_bootstrap_optimization(
|
|
use_random_search: bool = False,
|
|
max_bootstrapped_demos: int = 4,
|
|
max_labeled_demos: int = 8,
|
|
num_candidates: int = 10,
|
|
):
|
|
"""Run BootstrapFewShot optimization on Heritage RAG pipeline.
|
|
|
|
Args:
|
|
use_random_search: If True, use BootstrapFewShotWithRandomSearch
|
|
max_bootstrapped_demos: Max demonstrations to bootstrap
|
|
max_labeled_demos: Max labeled examples to include
|
|
num_candidates: Number of candidate programs for random search
|
|
"""
|
|
logger.info("=" * 60)
|
|
logger.info("Heritage RAG - BootstrapFewShot Optimization")
|
|
logger.info("=" * 60)
|
|
|
|
# Check SSH tunnel
|
|
import socket
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
result = sock.connect_ex(('localhost', 7878))
|
|
sock.close()
|
|
if result != 0:
|
|
logger.error("SSH tunnel not active! Run: ssh -f -N -L 7878:localhost:7878 root@91.98.224.44")
|
|
return None
|
|
logger.info("✓ SSH tunnel active (port 7878)")
|
|
|
|
# Configure DSPy
|
|
logger.info("Configuring DSPy with GPT-4o-mini...")
|
|
lm = dspy.LM('openai/gpt-4o-mini', temperature=0.3, max_tokens=1000)
|
|
dspy.configure(lm=lm)
|
|
|
|
# Load training data
|
|
logger.info("Loading training data...")
|
|
trainset, valset = get_extended_training_data()
|
|
logger.info(f" Training examples: {len(trainset)}")
|
|
logger.info(f" Validation examples: {len(valset)}")
|
|
|
|
# Create baseline pipeline
|
|
logger.info("Creating baseline HeritageRAGPipeline...")
|
|
pipeline = HeritageRAGPipeline()
|
|
|
|
# Evaluate baseline
|
|
logger.info("Evaluating baseline on validation set...")
|
|
baseline_scores = []
|
|
for ex in valset[:5]: # Quick eval on 5 examples
|
|
try:
|
|
pred = pipeline(question=ex.question, language=ex.language)
|
|
score = heritage_metric(ex, pred)
|
|
baseline_scores.append(score)
|
|
logger.info(f" Q: {ex.question[:50]}... → Score: {score:.2f}")
|
|
except Exception as e:
|
|
logger.warning(f" Error on example: {e}")
|
|
baseline_scores.append(0.0)
|
|
|
|
baseline_avg = sum(baseline_scores) / len(baseline_scores) if baseline_scores else 0
|
|
logger.info(f"Baseline average score: {baseline_avg:.3f}")
|
|
|
|
# Run optimization
|
|
logger.info("-" * 60)
|
|
if use_random_search:
|
|
logger.info(f"Starting BootstrapFewShotWithRandomSearch optimization...")
|
|
logger.info(f" max_bootstrapped_demos: {max_bootstrapped_demos}")
|
|
logger.info(f" max_labeled_demos: {max_labeled_demos}")
|
|
logger.info(f" num_candidates: {num_candidates}")
|
|
|
|
optimizer = BootstrapFewShotWithRandomSearch(
|
|
metric=heritage_metric_strict,
|
|
max_bootstrapped_demos=max_bootstrapped_demos,
|
|
max_labeled_demos=max_labeled_demos,
|
|
num_candidate_programs=num_candidates,
|
|
num_threads=4,
|
|
)
|
|
else:
|
|
logger.info(f"Starting BootstrapFewShot optimization...")
|
|
logger.info(f" max_bootstrapped_demos: {max_bootstrapped_demos}")
|
|
logger.info(f" max_labeled_demos: {max_labeled_demos}")
|
|
|
|
optimizer = BootstrapFewShot(
|
|
metric=heritage_metric_strict,
|
|
max_bootstrapped_demos=max_bootstrapped_demos,
|
|
max_labeled_demos=max_labeled_demos,
|
|
)
|
|
|
|
# Compile optimized pipeline
|
|
logger.info("Compiling optimized pipeline (this may take a few minutes)...")
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
try:
|
|
optimized_pipeline = optimizer.compile(
|
|
pipeline,
|
|
trainset=trainset,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Optimization failed: {e}")
|
|
raise
|
|
|
|
elapsed = (datetime.now(timezone.utc) - start_time).total_seconds()
|
|
logger.info(f"Optimization completed in {elapsed:.1f} seconds")
|
|
|
|
# Evaluate optimized pipeline
|
|
logger.info("-" * 60)
|
|
logger.info("Evaluating optimized pipeline on validation set...")
|
|
optimized_scores = []
|
|
for ex in valset:
|
|
try:
|
|
pred = optimized_pipeline(question=ex.question, language=ex.language)
|
|
score = heritage_metric(ex, pred)
|
|
optimized_scores.append(score)
|
|
logger.info(f" Q: {ex.question[:50]}... → Score: {score:.2f}, Intent: {pred.intent}")
|
|
except Exception as e:
|
|
logger.warning(f" Error on example: {e}")
|
|
optimized_scores.append(0.0)
|
|
|
|
optimized_avg = sum(optimized_scores) / len(optimized_scores) if optimized_scores else 0
|
|
logger.info(f"Optimized average score: {optimized_avg:.3f}")
|
|
logger.info(f"Improvement: {(optimized_avg - baseline_avg):.3f} ({(optimized_avg - baseline_avg) / baseline_avg * 100:.1f}%)")
|
|
|
|
# Save optimized model
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
model_dir = Path(__file__).parent / "optimized_models"
|
|
model_dir.mkdir(exist_ok=True)
|
|
|
|
model_path = model_dir / f"heritage_rag_bootstrap_{timestamp}.json"
|
|
latest_path = model_dir / "heritage_rag_bootstrap_latest.json"
|
|
|
|
optimized_pipeline.save(str(model_path))
|
|
optimized_pipeline.save(str(latest_path))
|
|
logger.info(f"Saved optimized model to: {model_path}")
|
|
|
|
# Save metadata
|
|
metadata = {
|
|
"timestamp": timestamp,
|
|
"optimizer": "BootstrapFewShotWithRandomSearch" if use_random_search else "BootstrapFewShot",
|
|
"max_bootstrapped_demos": max_bootstrapped_demos,
|
|
"max_labeled_demos": max_labeled_demos,
|
|
"training_examples": len(trainset),
|
|
"validation_examples": len(valset),
|
|
"baseline_score": baseline_avg,
|
|
"optimized_score": optimized_avg,
|
|
"improvement": optimized_avg - baseline_avg,
|
|
"optimization_seconds": elapsed,
|
|
}
|
|
|
|
metadata_path = model_dir / f"metadata_bootstrap_{timestamp}.json"
|
|
with open(metadata_path, "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
logger.info(f"Saved metadata to: {metadata_path}")
|
|
|
|
# Summary
|
|
logger.info("=" * 60)
|
|
logger.info("OPTIMIZATION SUMMARY")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Baseline Score: {baseline_avg:.3f}")
|
|
logger.info(f"Optimized Score: {optimized_avg:.3f}")
|
|
logger.info(f"Improvement: {(optimized_avg - baseline_avg):.3f}")
|
|
logger.info(f"Time: {elapsed:.1f}s")
|
|
logger.info(f"Model saved to: {latest_path}")
|
|
logger.info("=" * 60)
|
|
|
|
return optimized_pipeline, metadata
|
|
|
|
|
|
# =============================================================================
|
|
# QUICK EVALUATION
|
|
# =============================================================================
|
|
|
|
def evaluate_saved_model(model_path: str = None):
|
|
"""Load and evaluate a saved model."""
|
|
if model_path is None:
|
|
model_path = Path(__file__).parent / "optimized_models" / "heritage_rag_bootstrap_latest.json"
|
|
|
|
logger.info(f"Loading model from: {model_path}")
|
|
|
|
# Configure DSPy
|
|
lm = dspy.LM('openai/gpt-4o-mini', temperature=0.3, max_tokens=1000)
|
|
dspy.configure(lm=lm)
|
|
|
|
# Load pipeline
|
|
pipeline = HeritageRAGPipeline()
|
|
pipeline.load(str(model_path))
|
|
|
|
# Load validation data
|
|
_, valset = get_extended_training_data()
|
|
|
|
# Evaluate
|
|
scores = []
|
|
for ex in valset:
|
|
try:
|
|
pred = pipeline(question=ex.question, language=ex.language)
|
|
score = heritage_metric(ex, pred)
|
|
scores.append(score)
|
|
logger.info(f"Q: {ex.question[:50]}...")
|
|
logger.info(f" Intent: {pred.intent} (expected: {ex.expected_intent})")
|
|
logger.info(f" Score: {score:.2f}")
|
|
except Exception as e:
|
|
logger.warning(f"Error: {e}")
|
|
scores.append(0.0)
|
|
|
|
avg_score = sum(scores) / len(scores) if scores else 0
|
|
logger.info(f"\nAverage score: {avg_score:.3f}")
|
|
return avg_score
|
|
|
|
|
|
# =============================================================================
|
|
# MAIN
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Heritage RAG Bootstrap Optimization")
|
|
parser.add_argument("--random-search", action="store_true", help="Use random search variant")
|
|
parser.add_argument("--evaluate", action="store_true", help="Just evaluate saved model")
|
|
parser.add_argument("--model-path", type=str, help="Model path for evaluation")
|
|
parser.add_argument("--demos", type=int, default=4, help="Max bootstrapped demos")
|
|
parser.add_argument("--labeled", type=int, default=8, help="Max labeled demos")
|
|
parser.add_argument("--candidates", type=int, default=10, help="Num candidates for random search")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.evaluate:
|
|
evaluate_saved_model(args.model_path)
|
|
else:
|
|
run_bootstrap_optimization(
|
|
use_random_search=args.random_search,
|
|
max_bootstrapped_demos=args.demos,
|
|
max_labeled_demos=args.labeled,
|
|
num_candidates=args.candidates,
|
|
)
|