""" Cost and Performance Tracking for Heritage RAG Pipeline Provides: 1. LLM cost tracking with per-model pricing 2. Performance timing breakdowns for all pipeline stages 3. Accumulated session costs 4. Benchmark utilities for Qdrant and Oxigraph Usage: from backend.rag.cost_tracker import CostTracker, get_tracker tracker = get_tracker() # Track LLM usage with tracker.track_llm_call("gpt-4o-mini"): result = dspy_module(question=query) # Track retrieval with tracker.track_retrieval("qdrant"): results = qdrant.search(query) # Get accumulated costs print(tracker.get_session_summary()) """ from __future__ import annotations import asyncio import logging import time from contextlib import contextmanager, asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum from typing import Any, Generator, AsyncGenerator import dspy logger = logging.getLogger(__name__) # ============================================================================= # LLM Pricing (USD per 1M tokens as of Dec 2024) # ============================================================================= class LLMProvider(str, Enum): """LLM providers.""" OPENAI = "openai" ANTHROPIC = "anthropic" GOOGLE = "google" ZAI = "zai" @dataclass class ModelPricing: """Pricing for a specific model.""" input_per_1m: float # USD per 1M input tokens output_per_1m: float # USD per 1M output tokens cached_input_per_1m: float | None = None # Cached input discount # Pricing data - update as needed MODEL_PRICING: dict[str, ModelPricing] = { # OpenAI models "gpt-4o": ModelPricing(2.50, 10.00, 1.25), "gpt-4o-mini": ModelPricing(0.15, 0.60, 0.075), "gpt-4o-2024-11-20": ModelPricing(2.50, 10.00, 1.25), "gpt-4-turbo": ModelPricing(10.00, 30.00), "gpt-3.5-turbo": ModelPricing(0.50, 1.50), "o1": ModelPricing(15.00, 60.00, 7.50), "o1-mini": ModelPricing(3.00, 12.00, 1.50), "o1-preview": ModelPricing(15.00, 60.00), "text-embedding-3-small": ModelPricing(0.02, 0.0), "text-embedding-3-large": ModelPricing(0.13, 0.0), # Anthropic models "claude-3-5-sonnet-20241022": ModelPricing(3.00, 15.00, 0.30), "claude-3-5-haiku-20241022": ModelPricing(0.80, 4.00, 0.08), "claude-sonnet-4-20250514": ModelPricing(3.00, 15.00, 0.30), "claude-opus-4-5-20251101": ModelPricing(15.00, 75.00, 1.50), # Google models "gemini-2.0-flash": ModelPricing(0.10, 0.40), "gemini-1.5-pro": ModelPricing(1.25, 5.00), "gemini-1.5-flash": ModelPricing(0.075, 0.30), # Z.AI GLM models (free tier) "glm-4.5": ModelPricing(0.0, 0.0), "glm-4.5-air": ModelPricing(0.0, 0.0), "glm-4.5-flash": ModelPricing(0.0, 0.0), "glm-4.6": ModelPricing(0.0, 0.0), } def get_model_pricing(model: str) -> ModelPricing: """Get pricing for a model, with fallback for unknown models.""" # Normalize model name model_key = model.lower().replace("openai/", "").replace("anthropic/", "") if model_key in MODEL_PRICING: return MODEL_PRICING[model_key] # Try partial matching for key, pricing in MODEL_PRICING.items(): if key in model_key or model_key in key: return pricing # Default to gpt-4o-mini pricing as fallback logger.warning(f"Unknown model '{model}', using gpt-4o-mini pricing as fallback") return MODEL_PRICING["gpt-4o-mini"] # ============================================================================= # Timing and Cost Data Structures # ============================================================================= @dataclass class LLMUsage: """Token usage for an LLM call.""" model: str input_tokens: int = 0 output_tokens: int = 0 cached_tokens: int = 0 duration_ms: float = 0.0 timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @property def total_tokens(self) -> int: return self.input_tokens + self.output_tokens @property def cost_usd(self) -> float: """Calculate cost in USD.""" pricing = get_model_pricing(self.model) # Regular input tokens (minus cached) regular_input = self.input_tokens - self.cached_tokens cost = (regular_input / 1_000_000) * pricing.input_per_1m cost += (self.output_tokens / 1_000_000) * pricing.output_per_1m if pricing.cached_input_per_1m and self.cached_tokens > 0: cost += (self.cached_tokens / 1_000_000) * pricing.cached_input_per_1m return cost @dataclass class RetrievalTiming: """Timing for a retrieval operation.""" source: str # qdrant, oxigraph, typedb, postgis duration_ms: float = 0.0 # Set by context manager in finally block result_count: int = 0 query: str = "" timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @dataclass class PipelineTiming: """Timing breakdown for a full pipeline execution.""" total_ms: float = 0.0 query_routing_ms: float = 0.0 entity_extraction_ms: float = 0.0 retrieval_ms: float = 0.0 generation_ms: float = 0.0 visualization_ms: float = 0.0 # Sub-breakdowns retrieval_breakdown: dict[str, float] = field(default_factory=dict) llm_calls: list[LLMUsage] = field(default_factory=list) retrievals: list[RetrievalTiming] = field(default_factory=list) @dataclass class SessionStats: """Accumulated session statistics.""" session_id: str started_at: datetime query_count: int = 0 total_input_tokens: int = 0 total_output_tokens: int = 0 total_cost_usd: float = 0.0 total_duration_ms: float = 0.0 # Per-source timing qdrant_calls: int = 0 qdrant_total_ms: float = 0.0 oxigraph_calls: int = 0 oxigraph_total_ms: float = 0.0 typedb_calls: int = 0 typedb_total_ms: float = 0.0 # Per-model costs costs_by_model: dict[str, float] = field(default_factory=dict) tokens_by_model: dict[str, dict[str, int]] = field(default_factory=dict) # ============================================================================= # Cost Tracker # ============================================================================= class CostTracker: """ Tracks costs and performance for Heritage RAG pipeline. Thread-safe session-based tracking with: - LLM token usage and costs - Retrieval timing breakdowns - Accumulated session totals """ def __init__(self, session_id: str | None = None): import uuid self.session_id = session_id or str(uuid.uuid4())[:8] self.started_at = datetime.now(timezone.utc) self._stats = SessionStats( session_id=self.session_id, started_at=self.started_at, ) self._current_pipeline: PipelineTiming | None = None # Use try/except to safely check for running event loop (compatible with ThreadPoolExecutor) try: loop = asyncio.get_running_loop() self._lock = asyncio.Lock() except RuntimeError: # No event loop running in this thread self._lock = None def reset(self) -> None: """Reset session statistics.""" self._stats = SessionStats( session_id=self.session_id, started_at=datetime.now(timezone.utc), ) @contextmanager def track_pipeline(self) -> Generator[PipelineTiming, None, None]: """Context manager for tracking a full pipeline execution.""" timing = PipelineTiming() self._current_pipeline = timing start = time.perf_counter() try: yield timing finally: timing.total_ms = (time.perf_counter() - start) * 1000 self._stats.query_count += 1 self._stats.total_duration_ms += timing.total_ms self._current_pipeline = None @contextmanager def track_stage(self, stage: str) -> Generator[None, None, None]: """Track timing for a pipeline stage.""" start = time.perf_counter() try: yield finally: duration_ms = (time.perf_counter() - start) * 1000 if self._current_pipeline: setattr(self._current_pipeline, f"{stage}_ms", duration_ms) @contextmanager def track_llm_call(self, model: str) -> Generator[LLMUsage, None, None]: """ Context manager for tracking an LLM call. Integrates with DSPy's track_usage for automatic token counting. """ usage = LLMUsage(model=model) start = time.perf_counter() try: # Use DSPy's built-in usage tracking with dspy.track_usage() as dspy_tracker: yield usage # Extract token counts from DSPy tracker totals = dspy_tracker.get_total_tokens() for lm_name, tokens in totals.items(): if isinstance(tokens, dict): usage.input_tokens += tokens.get("prompt_tokens", 0) or tokens.get("input_tokens", 0) usage.output_tokens += tokens.get("completion_tokens", 0) or tokens.get("output_tokens", 0) usage.cached_tokens += tokens.get("cached_tokens", 0) except Exception as e: logger.warning(f"Token tracking failed: {e}") finally: usage.duration_ms = (time.perf_counter() - start) * 1000 # Update session stats self._stats.total_input_tokens += usage.input_tokens self._stats.total_output_tokens += usage.output_tokens self._stats.total_cost_usd += usage.cost_usd # Per-model tracking if model not in self._stats.costs_by_model: self._stats.costs_by_model[model] = 0.0 self._stats.tokens_by_model[model] = {"input": 0, "output": 0} self._stats.costs_by_model[model] += usage.cost_usd self._stats.tokens_by_model[model]["input"] += usage.input_tokens self._stats.tokens_by_model[model]["output"] += usage.output_tokens # Add to current pipeline if active if self._current_pipeline: self._current_pipeline.llm_calls.append(usage) @contextmanager def track_retrieval(self, source: str, query: str = "") -> Generator[RetrievalTiming, None, None]: """Track a retrieval operation.""" timing = RetrievalTiming(source=source, query=query) start = time.perf_counter() try: yield timing finally: timing.duration_ms = (time.perf_counter() - start) * 1000 # Update session stats if source == "qdrant": self._stats.qdrant_calls += 1 self._stats.qdrant_total_ms += timing.duration_ms elif source in ("oxigraph", "sparql"): self._stats.oxigraph_calls += 1 self._stats.oxigraph_total_ms += timing.duration_ms elif source == "typedb": self._stats.typedb_calls += 1 self._stats.typedb_total_ms += timing.duration_ms # Add to current pipeline if active if self._current_pipeline: self._current_pipeline.retrievals.append(timing) self._current_pipeline.retrieval_breakdown[source] = timing.duration_ms @asynccontextmanager async def track_retrieval_async(self, source: str, query: str = "") -> AsyncGenerator[RetrievalTiming, None]: """Async version of track_retrieval.""" timing = RetrievalTiming(source=source, query=query) start = time.perf_counter() try: yield timing finally: timing.duration_ms = (time.perf_counter() - start) * 1000 # Thread-safe update if source == "qdrant": self._stats.qdrant_calls += 1 self._stats.qdrant_total_ms += timing.duration_ms elif source in ("oxigraph", "sparql"): self._stats.oxigraph_calls += 1 self._stats.oxigraph_total_ms += timing.duration_ms elif source == "typedb": self._stats.typedb_calls += 1 self._stats.typedb_total_ms += timing.duration_ms if self._current_pipeline: self._current_pipeline.retrievals.append(timing) self._current_pipeline.retrieval_breakdown[source] = timing.duration_ms def get_session_summary(self) -> dict[str, Any]: """Get comprehensive session summary.""" elapsed_sec = (datetime.now(timezone.utc) - self._stats.started_at).total_seconds() return { "session_id": self.session_id, "elapsed_seconds": round(elapsed_sec, 1), "queries": self._stats.query_count, "costs": { "total_usd": round(self._stats.total_cost_usd, 6), "by_model": { model: round(cost, 6) for model, cost in self._stats.costs_by_model.items() }, }, "tokens": { "total_input": self._stats.total_input_tokens, "total_output": self._stats.total_output_tokens, "by_model": self._stats.tokens_by_model, }, "timing": { "total_ms": round(self._stats.total_duration_ms, 1), "avg_query_ms": round( self._stats.total_duration_ms / self._stats.query_count, 1 ) if self._stats.query_count > 0 else 0, "retrievers": { "qdrant": { "calls": self._stats.qdrant_calls, "total_ms": round(self._stats.qdrant_total_ms, 1), "avg_ms": round( self._stats.qdrant_total_ms / self._stats.qdrant_calls, 1 ) if self._stats.qdrant_calls > 0 else 0, }, "oxigraph": { "calls": self._stats.oxigraph_calls, "total_ms": round(self._stats.oxigraph_total_ms, 1), "avg_ms": round( self._stats.oxigraph_total_ms / self._stats.oxigraph_calls, 1 ) if self._stats.oxigraph_calls > 0 else 0, }, "typedb": { "calls": self._stats.typedb_calls, "total_ms": round(self._stats.typedb_total_ms, 1), "avg_ms": round( self._stats.typedb_total_ms / self._stats.typedb_calls, 1 ) if self._stats.typedb_calls > 0 else 0, }, }, }, } def format_cost_display(self) -> str: """Format cost for display in CLI/UI.""" cost = self._stats.total_cost_usd if cost == 0: return "$0.00" elif cost < 0.01: return f"${cost:.4f}" elif cost < 1.00: return f"${cost:.3f}" else: return f"${cost:.2f}" def get_last_query_summary(self) -> dict[str, Any] | None: """Get summary of the last pipeline execution.""" if not self._current_pipeline: return None p = self._current_pipeline llm_cost = sum(u.cost_usd for u in p.llm_calls) return { "total_ms": round(p.total_ms, 1), "breakdown": { "query_routing_ms": round(p.query_routing_ms, 1), "entity_extraction_ms": round(p.entity_extraction_ms, 1), "retrieval_ms": round(p.retrieval_ms, 1), "generation_ms": round(p.generation_ms, 1), }, "retrievals": { source: round(ms, 1) for source, ms in p.retrieval_breakdown.items() }, "llm": { "calls": len(p.llm_calls), "total_tokens": sum(u.total_tokens for u in p.llm_calls), "cost_usd": round(llm_cost, 6), }, } # ============================================================================= # Global Tracker Instance # ============================================================================= _global_tracker: CostTracker | None = None def get_tracker() -> CostTracker: """Get or create the global cost tracker.""" global _global_tracker if _global_tracker is None: _global_tracker = CostTracker() return _global_tracker def reset_tracker() -> CostTracker: """Reset and return a new global tracker.""" global _global_tracker _global_tracker = CostTracker() return _global_tracker # ============================================================================= # Benchmark Utilities # ============================================================================= @dataclass class BenchmarkResult: """Result of a benchmark run.""" name: str iterations: int total_ms: float avg_ms: float min_ms: float max_ms: float p50_ms: float p95_ms: float p99_ms: float errors: int = 0 def run_benchmark( name: str, func: Any, iterations: int = 10, warmup: int = 2, ) -> BenchmarkResult: """ Run a synchronous benchmark. Args: name: Benchmark name func: Function to benchmark (no args) iterations: Number of iterations warmup: Warmup iterations (not counted) """ # Warmup for _ in range(warmup): try: func() except Exception: pass # Benchmark times: list[float] = [] errors = 0 for _ in range(iterations): start = time.perf_counter() try: func() except Exception: errors += 1 times.append((time.perf_counter() - start) * 1000) times.sort() return BenchmarkResult( name=name, iterations=iterations, total_ms=sum(times), avg_ms=sum(times) / len(times), min_ms=times[0], max_ms=times[-1], p50_ms=times[len(times) // 2], p95_ms=times[int(len(times) * 0.95)], p99_ms=times[int(len(times) * 0.99)], errors=errors, ) async def run_benchmark_async( name: str, func: Any, iterations: int = 10, warmup: int = 2, ) -> BenchmarkResult: """ Run an async benchmark. Args: name: Benchmark name func: Async function to benchmark (no args) iterations: Number of iterations warmup: Warmup iterations (not counted) """ # Warmup for _ in range(warmup): try: await func() except Exception: pass # Benchmark times: list[float] = [] errors = 0 for _ in range(iterations): start = time.perf_counter() try: await func() except Exception: errors += 1 times.append((time.perf_counter() - start) * 1000) times.sort() return BenchmarkResult( name=name, iterations=iterations, total_ms=sum(times), avg_ms=sum(times) / len(times), min_ms=times[0], max_ms=times[-1], p50_ms=times[len(times) // 2], p95_ms=times[int(len(times) * 0.95)], p99_ms=times[int(len(times) * 0.99)], errors=errors, ) def format_benchmark_report(results: list[BenchmarkResult]) -> str: """Format benchmark results as a table.""" lines = [ "=" * 80, "BENCHMARK RESULTS", "=" * 80, f"{'Name':<30} {'Avg (ms)':<12} {'P50 (ms)':<12} {'P95 (ms)':<12} {'Errors':<8}", "-" * 80, ] for r in results: lines.append( f"{r.name:<30} {r.avg_ms:<12.1f} {r.p50_ms:<12.1f} {r.p95_ms:<12.1f} {r.errors:<8}" ) lines.append("=" * 80) return "\n".join(lines)