glam/backend/rag/cost_tracker.py
2026-01-14 09:05:54 +01:00

609 lines
20 KiB
Python

"""
Cost and Performance Tracking for Heritage RAG Pipeline
Provides:
1. LLM cost tracking with per-model pricing
2. Performance timing breakdowns for all pipeline stages
3. Accumulated session costs
4. Benchmark utilities for Qdrant and Oxigraph
Usage:
from backend.rag.cost_tracker import CostTracker, get_tracker
tracker = get_tracker()
# Track LLM usage
with tracker.track_llm_call("gpt-4o-mini"):
result = dspy_module(question=query)
# Track retrieval
with tracker.track_retrieval("qdrant"):
results = qdrant.search(query)
# Get accumulated costs
print(tracker.get_session_summary())
"""
from __future__ import annotations
import asyncio
import logging
import time
from contextlib import contextmanager, asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Generator, AsyncGenerator
import dspy
logger = logging.getLogger(__name__)
# =============================================================================
# LLM Pricing (USD per 1M tokens as of Dec 2024)
# =============================================================================
class LLMProvider(str, Enum):
"""LLM providers."""
OPENAI = "openai"
ANTHROPIC = "anthropic"
GOOGLE = "google"
ZAI = "zai"
@dataclass
class ModelPricing:
"""Pricing for a specific model."""
input_per_1m: float # USD per 1M input tokens
output_per_1m: float # USD per 1M output tokens
cached_input_per_1m: float | None = None # Cached input discount
# Pricing data - update as needed
MODEL_PRICING: dict[str, ModelPricing] = {
# OpenAI models
"gpt-4o": ModelPricing(2.50, 10.00, 1.25),
"gpt-4o-mini": ModelPricing(0.15, 0.60, 0.075),
"gpt-4o-2024-11-20": ModelPricing(2.50, 10.00, 1.25),
"gpt-4-turbo": ModelPricing(10.00, 30.00),
"gpt-3.5-turbo": ModelPricing(0.50, 1.50),
"o1": ModelPricing(15.00, 60.00, 7.50),
"o1-mini": ModelPricing(3.00, 12.00, 1.50),
"o1-preview": ModelPricing(15.00, 60.00),
"text-embedding-3-small": ModelPricing(0.02, 0.0),
"text-embedding-3-large": ModelPricing(0.13, 0.0),
# Anthropic models
"claude-3-5-sonnet-20241022": ModelPricing(3.00, 15.00, 0.30),
"claude-3-5-haiku-20241022": ModelPricing(0.80, 4.00, 0.08),
"claude-sonnet-4-20250514": ModelPricing(3.00, 15.00, 0.30),
"claude-opus-4-5-20251101": ModelPricing(15.00, 75.00, 1.50),
# Google models
"gemini-2.0-flash": ModelPricing(0.10, 0.40),
"gemini-1.5-pro": ModelPricing(1.25, 5.00),
"gemini-1.5-flash": ModelPricing(0.075, 0.30),
# Z.AI GLM models (free tier)
"glm-4.5": ModelPricing(0.0, 0.0),
"glm-4.5-air": ModelPricing(0.0, 0.0),
"glm-4.5-flash": ModelPricing(0.0, 0.0),
"glm-4.6": ModelPricing(0.0, 0.0),
}
def get_model_pricing(model: str) -> ModelPricing:
"""Get pricing for a model, with fallback for unknown models."""
# Normalize model name
model_key = model.lower().replace("openai/", "").replace("anthropic/", "")
if model_key in MODEL_PRICING:
return MODEL_PRICING[model_key]
# Try partial matching
for key, pricing in MODEL_PRICING.items():
if key in model_key or model_key in key:
return pricing
# Default to gpt-4o-mini pricing as fallback
logger.warning(f"Unknown model '{model}', using gpt-4o-mini pricing as fallback")
return MODEL_PRICING["gpt-4o-mini"]
# =============================================================================
# Timing and Cost Data Structures
# =============================================================================
@dataclass
class LLMUsage:
"""Token usage for an LLM call."""
model: str
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
duration_ms: float = 0.0
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@property
def total_tokens(self) -> int:
return self.input_tokens + self.output_tokens
@property
def cost_usd(self) -> float:
"""Calculate cost in USD."""
pricing = get_model_pricing(self.model)
# Regular input tokens (minus cached)
regular_input = self.input_tokens - self.cached_tokens
cost = (regular_input / 1_000_000) * pricing.input_per_1m
cost += (self.output_tokens / 1_000_000) * pricing.output_per_1m
if pricing.cached_input_per_1m and self.cached_tokens > 0:
cost += (self.cached_tokens / 1_000_000) * pricing.cached_input_per_1m
return cost
@dataclass
class RetrievalTiming:
"""Timing for a retrieval operation."""
source: str # qdrant, oxigraph, typedb, postgis
duration_ms: float = 0.0 # Set by context manager in finally block
result_count: int = 0
query: str = ""
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class PipelineTiming:
"""Timing breakdown for a full pipeline execution."""
total_ms: float = 0.0
query_routing_ms: float = 0.0
entity_extraction_ms: float = 0.0
retrieval_ms: float = 0.0
generation_ms: float = 0.0
visualization_ms: float = 0.0
# Sub-breakdowns
retrieval_breakdown: dict[str, float] = field(default_factory=dict)
llm_calls: list[LLMUsage] = field(default_factory=list)
retrievals: list[RetrievalTiming] = field(default_factory=list)
@dataclass
class SessionStats:
"""Accumulated session statistics."""
session_id: str
started_at: datetime
query_count: int = 0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cost_usd: float = 0.0
total_duration_ms: float = 0.0
# Per-source timing
qdrant_calls: int = 0
qdrant_total_ms: float = 0.0
oxigraph_calls: int = 0
oxigraph_total_ms: float = 0.0
typedb_calls: int = 0
typedb_total_ms: float = 0.0
# Per-model costs
costs_by_model: dict[str, float] = field(default_factory=dict)
tokens_by_model: dict[str, dict[str, int]] = field(default_factory=dict)
# =============================================================================
# Cost Tracker
# =============================================================================
class CostTracker:
"""
Tracks costs and performance for Heritage RAG pipeline.
Thread-safe session-based tracking with:
- LLM token usage and costs
- Retrieval timing breakdowns
- Accumulated session totals
"""
def __init__(self, session_id: str | None = None):
import uuid
self.session_id = session_id or str(uuid.uuid4())[:8]
self.started_at = datetime.now(timezone.utc)
self._stats = SessionStats(
session_id=self.session_id,
started_at=self.started_at,
)
self._current_pipeline: PipelineTiming | None = None
# Use try/except to safely check for running event loop (compatible with ThreadPoolExecutor)
try:
loop = asyncio.get_running_loop()
self._lock = asyncio.Lock()
except RuntimeError:
# No event loop running in this thread
self._lock = None
def reset(self) -> None:
"""Reset session statistics."""
self._stats = SessionStats(
session_id=self.session_id,
started_at=datetime.now(timezone.utc),
)
@contextmanager
def track_pipeline(self) -> Generator[PipelineTiming, None, None]:
"""Context manager for tracking a full pipeline execution."""
timing = PipelineTiming()
self._current_pipeline = timing
start = time.perf_counter()
try:
yield timing
finally:
timing.total_ms = (time.perf_counter() - start) * 1000
self._stats.query_count += 1
self._stats.total_duration_ms += timing.total_ms
self._current_pipeline = None
@contextmanager
def track_stage(self, stage: str) -> Generator[None, None, None]:
"""Track timing for a pipeline stage."""
start = time.perf_counter()
try:
yield
finally:
duration_ms = (time.perf_counter() - start) * 1000
if self._current_pipeline:
setattr(self._current_pipeline, f"{stage}_ms", duration_ms)
@contextmanager
def track_llm_call(self, model: str) -> Generator[LLMUsage, None, None]:
"""
Context manager for tracking an LLM call.
Integrates with DSPy's track_usage for automatic token counting.
"""
usage = LLMUsage(model=model)
start = time.perf_counter()
try:
# Use DSPy's built-in usage tracking
with dspy.track_usage() as dspy_tracker:
yield usage
# Extract token counts from DSPy tracker
totals = dspy_tracker.get_total_tokens()
for lm_name, tokens in totals.items():
if isinstance(tokens, dict):
usage.input_tokens += tokens.get("prompt_tokens", 0) or tokens.get("input_tokens", 0)
usage.output_tokens += tokens.get("completion_tokens", 0) or tokens.get("output_tokens", 0)
usage.cached_tokens += tokens.get("cached_tokens", 0)
except Exception as e:
logger.warning(f"Token tracking failed: {e}")
finally:
usage.duration_ms = (time.perf_counter() - start) * 1000
# Update session stats
self._stats.total_input_tokens += usage.input_tokens
self._stats.total_output_tokens += usage.output_tokens
self._stats.total_cost_usd += usage.cost_usd
# Per-model tracking
if model not in self._stats.costs_by_model:
self._stats.costs_by_model[model] = 0.0
self._stats.tokens_by_model[model] = {"input": 0, "output": 0}
self._stats.costs_by_model[model] += usage.cost_usd
self._stats.tokens_by_model[model]["input"] += usage.input_tokens
self._stats.tokens_by_model[model]["output"] += usage.output_tokens
# Add to current pipeline if active
if self._current_pipeline:
self._current_pipeline.llm_calls.append(usage)
@contextmanager
def track_retrieval(self, source: str, query: str = "") -> Generator[RetrievalTiming, None, None]:
"""Track a retrieval operation."""
timing = RetrievalTiming(source=source, query=query)
start = time.perf_counter()
try:
yield timing
finally:
timing.duration_ms = (time.perf_counter() - start) * 1000
# Update session stats
if source == "qdrant":
self._stats.qdrant_calls += 1
self._stats.qdrant_total_ms += timing.duration_ms
elif source in ("oxigraph", "sparql"):
self._stats.oxigraph_calls += 1
self._stats.oxigraph_total_ms += timing.duration_ms
elif source == "typedb":
self._stats.typedb_calls += 1
self._stats.typedb_total_ms += timing.duration_ms
# Add to current pipeline if active
if self._current_pipeline:
self._current_pipeline.retrievals.append(timing)
self._current_pipeline.retrieval_breakdown[source] = timing.duration_ms
@asynccontextmanager
async def track_retrieval_async(self, source: str, query: str = "") -> AsyncGenerator[RetrievalTiming, None]:
"""Async version of track_retrieval."""
timing = RetrievalTiming(source=source, query=query)
start = time.perf_counter()
try:
yield timing
finally:
timing.duration_ms = (time.perf_counter() - start) * 1000
# Thread-safe update
if source == "qdrant":
self._stats.qdrant_calls += 1
self._stats.qdrant_total_ms += timing.duration_ms
elif source in ("oxigraph", "sparql"):
self._stats.oxigraph_calls += 1
self._stats.oxigraph_total_ms += timing.duration_ms
elif source == "typedb":
self._stats.typedb_calls += 1
self._stats.typedb_total_ms += timing.duration_ms
if self._current_pipeline:
self._current_pipeline.retrievals.append(timing)
self._current_pipeline.retrieval_breakdown[source] = timing.duration_ms
def get_session_summary(self) -> dict[str, Any]:
"""Get comprehensive session summary."""
elapsed_sec = (datetime.now(timezone.utc) - self._stats.started_at).total_seconds()
return {
"session_id": self.session_id,
"elapsed_seconds": round(elapsed_sec, 1),
"queries": self._stats.query_count,
"costs": {
"total_usd": round(self._stats.total_cost_usd, 6),
"by_model": {
model: round(cost, 6)
for model, cost in self._stats.costs_by_model.items()
},
},
"tokens": {
"total_input": self._stats.total_input_tokens,
"total_output": self._stats.total_output_tokens,
"by_model": self._stats.tokens_by_model,
},
"timing": {
"total_ms": round(self._stats.total_duration_ms, 1),
"avg_query_ms": round(
self._stats.total_duration_ms / self._stats.query_count, 1
) if self._stats.query_count > 0 else 0,
"retrievers": {
"qdrant": {
"calls": self._stats.qdrant_calls,
"total_ms": round(self._stats.qdrant_total_ms, 1),
"avg_ms": round(
self._stats.qdrant_total_ms / self._stats.qdrant_calls, 1
) if self._stats.qdrant_calls > 0 else 0,
},
"oxigraph": {
"calls": self._stats.oxigraph_calls,
"total_ms": round(self._stats.oxigraph_total_ms, 1),
"avg_ms": round(
self._stats.oxigraph_total_ms / self._stats.oxigraph_calls, 1
) if self._stats.oxigraph_calls > 0 else 0,
},
"typedb": {
"calls": self._stats.typedb_calls,
"total_ms": round(self._stats.typedb_total_ms, 1),
"avg_ms": round(
self._stats.typedb_total_ms / self._stats.typedb_calls, 1
) if self._stats.typedb_calls > 0 else 0,
},
},
},
}
def format_cost_display(self) -> str:
"""Format cost for display in CLI/UI."""
cost = self._stats.total_cost_usd
if cost == 0:
return "$0.00"
elif cost < 0.01:
return f"${cost:.4f}"
elif cost < 1.00:
return f"${cost:.3f}"
else:
return f"${cost:.2f}"
def get_last_query_summary(self) -> dict[str, Any] | None:
"""Get summary of the last pipeline execution."""
if not self._current_pipeline:
return None
p = self._current_pipeline
llm_cost = sum(u.cost_usd for u in p.llm_calls)
return {
"total_ms": round(p.total_ms, 1),
"breakdown": {
"query_routing_ms": round(p.query_routing_ms, 1),
"entity_extraction_ms": round(p.entity_extraction_ms, 1),
"retrieval_ms": round(p.retrieval_ms, 1),
"generation_ms": round(p.generation_ms, 1),
},
"retrievals": {
source: round(ms, 1)
for source, ms in p.retrieval_breakdown.items()
},
"llm": {
"calls": len(p.llm_calls),
"total_tokens": sum(u.total_tokens for u in p.llm_calls),
"cost_usd": round(llm_cost, 6),
},
}
# =============================================================================
# Global Tracker Instance
# =============================================================================
_global_tracker: CostTracker | None = None
def get_tracker() -> CostTracker:
"""Get or create the global cost tracker."""
global _global_tracker
if _global_tracker is None:
_global_tracker = CostTracker()
return _global_tracker
def reset_tracker() -> CostTracker:
"""Reset and return a new global tracker."""
global _global_tracker
_global_tracker = CostTracker()
return _global_tracker
# =============================================================================
# Benchmark Utilities
# =============================================================================
@dataclass
class BenchmarkResult:
"""Result of a benchmark run."""
name: str
iterations: int
total_ms: float
avg_ms: float
min_ms: float
max_ms: float
p50_ms: float
p95_ms: float
p99_ms: float
errors: int = 0
def run_benchmark(
name: str,
func: Any,
iterations: int = 10,
warmup: int = 2,
) -> BenchmarkResult:
"""
Run a synchronous benchmark.
Args:
name: Benchmark name
func: Function to benchmark (no args)
iterations: Number of iterations
warmup: Warmup iterations (not counted)
"""
# Warmup
for _ in range(warmup):
try:
func()
except Exception:
pass
# Benchmark
times: list[float] = []
errors = 0
for _ in range(iterations):
start = time.perf_counter()
try:
func()
except Exception:
errors += 1
times.append((time.perf_counter() - start) * 1000)
times.sort()
return BenchmarkResult(
name=name,
iterations=iterations,
total_ms=sum(times),
avg_ms=sum(times) / len(times),
min_ms=times[0],
max_ms=times[-1],
p50_ms=times[len(times) // 2],
p95_ms=times[int(len(times) * 0.95)],
p99_ms=times[int(len(times) * 0.99)],
errors=errors,
)
async def run_benchmark_async(
name: str,
func: Any,
iterations: int = 10,
warmup: int = 2,
) -> BenchmarkResult:
"""
Run an async benchmark.
Args:
name: Benchmark name
func: Async function to benchmark (no args)
iterations: Number of iterations
warmup: Warmup iterations (not counted)
"""
# Warmup
for _ in range(warmup):
try:
await func()
except Exception:
pass
# Benchmark
times: list[float] = []
errors = 0
for _ in range(iterations):
start = time.perf_counter()
try:
await func()
except Exception:
errors += 1
times.append((time.perf_counter() - start) * 1000)
times.sort()
return BenchmarkResult(
name=name,
iterations=iterations,
total_ms=sum(times),
avg_ms=sum(times) / len(times),
min_ms=times[0],
max_ms=times[-1],
p50_ms=times[len(times) // 2],
p95_ms=times[int(len(times) * 0.95)],
p99_ms=times[int(len(times) * 0.99)],
errors=errors,
)
def format_benchmark_report(results: list[BenchmarkResult]) -> str:
"""Format benchmark results as a table."""
lines = [
"=" * 80,
"BENCHMARK RESULTS",
"=" * 80,
f"{'Name':<30} {'Avg (ms)':<12} {'P50 (ms)':<12} {'P95 (ms)':<12} {'Errors':<8}",
"-" * 80,
]
for r in results:
lines.append(
f"{r.name:<30} {r.avg_ms:<12.1f} {r.p50_ms:<12.1f} {r.p95_ms:<12.1f} {r.errors:<8}"
)
lines.append("=" * 80)
return "\n".join(lines)