609 lines
20 KiB
Python
609 lines
20 KiB
Python
"""
|
|
Cost and Performance Tracking for Heritage RAG Pipeline
|
|
|
|
Provides:
|
|
1. LLM cost tracking with per-model pricing
|
|
2. Performance timing breakdowns for all pipeline stages
|
|
3. Accumulated session costs
|
|
4. Benchmark utilities for Qdrant and Oxigraph
|
|
|
|
Usage:
|
|
from backend.rag.cost_tracker import CostTracker, get_tracker
|
|
|
|
tracker = get_tracker()
|
|
|
|
# Track LLM usage
|
|
with tracker.track_llm_call("gpt-4o-mini"):
|
|
result = dspy_module(question=query)
|
|
|
|
# Track retrieval
|
|
with tracker.track_retrieval("qdrant"):
|
|
results = qdrant.search(query)
|
|
|
|
# Get accumulated costs
|
|
print(tracker.get_session_summary())
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from contextlib import contextmanager, asynccontextmanager
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, Generator, AsyncGenerator
|
|
|
|
import dspy
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# LLM Pricing (USD per 1M tokens as of Dec 2024)
|
|
# =============================================================================
|
|
|
|
class LLMProvider(str, Enum):
|
|
"""LLM providers."""
|
|
OPENAI = "openai"
|
|
ANTHROPIC = "anthropic"
|
|
GOOGLE = "google"
|
|
ZAI = "zai"
|
|
|
|
|
|
@dataclass
|
|
class ModelPricing:
|
|
"""Pricing for a specific model."""
|
|
input_per_1m: float # USD per 1M input tokens
|
|
output_per_1m: float # USD per 1M output tokens
|
|
cached_input_per_1m: float | None = None # Cached input discount
|
|
|
|
|
|
# Pricing data - update as needed
|
|
MODEL_PRICING: dict[str, ModelPricing] = {
|
|
# OpenAI models
|
|
"gpt-4o": ModelPricing(2.50, 10.00, 1.25),
|
|
"gpt-4o-mini": ModelPricing(0.15, 0.60, 0.075),
|
|
"gpt-4o-2024-11-20": ModelPricing(2.50, 10.00, 1.25),
|
|
"gpt-4-turbo": ModelPricing(10.00, 30.00),
|
|
"gpt-3.5-turbo": ModelPricing(0.50, 1.50),
|
|
"o1": ModelPricing(15.00, 60.00, 7.50),
|
|
"o1-mini": ModelPricing(3.00, 12.00, 1.50),
|
|
"o1-preview": ModelPricing(15.00, 60.00),
|
|
"text-embedding-3-small": ModelPricing(0.02, 0.0),
|
|
"text-embedding-3-large": ModelPricing(0.13, 0.0),
|
|
|
|
# Anthropic models
|
|
"claude-3-5-sonnet-20241022": ModelPricing(3.00, 15.00, 0.30),
|
|
"claude-3-5-haiku-20241022": ModelPricing(0.80, 4.00, 0.08),
|
|
"claude-sonnet-4-20250514": ModelPricing(3.00, 15.00, 0.30),
|
|
"claude-opus-4-5-20251101": ModelPricing(15.00, 75.00, 1.50),
|
|
|
|
# Google models
|
|
"gemini-2.0-flash": ModelPricing(0.10, 0.40),
|
|
"gemini-1.5-pro": ModelPricing(1.25, 5.00),
|
|
"gemini-1.5-flash": ModelPricing(0.075, 0.30),
|
|
|
|
# Z.AI GLM models (free tier)
|
|
"glm-4.5": ModelPricing(0.0, 0.0),
|
|
"glm-4.5-air": ModelPricing(0.0, 0.0),
|
|
"glm-4.5-flash": ModelPricing(0.0, 0.0),
|
|
"glm-4.6": ModelPricing(0.0, 0.0),
|
|
}
|
|
|
|
|
|
def get_model_pricing(model: str) -> ModelPricing:
|
|
"""Get pricing for a model, with fallback for unknown models."""
|
|
# Normalize model name
|
|
model_key = model.lower().replace("openai/", "").replace("anthropic/", "")
|
|
|
|
if model_key in MODEL_PRICING:
|
|
return MODEL_PRICING[model_key]
|
|
|
|
# Try partial matching
|
|
for key, pricing in MODEL_PRICING.items():
|
|
if key in model_key or model_key in key:
|
|
return pricing
|
|
|
|
# Default to gpt-4o-mini pricing as fallback
|
|
logger.warning(f"Unknown model '{model}', using gpt-4o-mini pricing as fallback")
|
|
return MODEL_PRICING["gpt-4o-mini"]
|
|
|
|
|
|
# =============================================================================
|
|
# Timing and Cost Data Structures
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class LLMUsage:
|
|
"""Token usage for an LLM call."""
|
|
model: str
|
|
input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
cached_tokens: int = 0
|
|
duration_ms: float = 0.0
|
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
return self.input_tokens + self.output_tokens
|
|
|
|
@property
|
|
def cost_usd(self) -> float:
|
|
"""Calculate cost in USD."""
|
|
pricing = get_model_pricing(self.model)
|
|
|
|
# Regular input tokens (minus cached)
|
|
regular_input = self.input_tokens - self.cached_tokens
|
|
|
|
cost = (regular_input / 1_000_000) * pricing.input_per_1m
|
|
cost += (self.output_tokens / 1_000_000) * pricing.output_per_1m
|
|
|
|
if pricing.cached_input_per_1m and self.cached_tokens > 0:
|
|
cost += (self.cached_tokens / 1_000_000) * pricing.cached_input_per_1m
|
|
|
|
return cost
|
|
|
|
|
|
@dataclass
|
|
class RetrievalTiming:
|
|
"""Timing for a retrieval operation."""
|
|
source: str # qdrant, oxigraph, typedb, postgis
|
|
duration_ms: float = 0.0 # Set by context manager in finally block
|
|
result_count: int = 0
|
|
query: str = ""
|
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
|
|
@dataclass
|
|
class PipelineTiming:
|
|
"""Timing breakdown for a full pipeline execution."""
|
|
total_ms: float = 0.0
|
|
query_routing_ms: float = 0.0
|
|
entity_extraction_ms: float = 0.0
|
|
retrieval_ms: float = 0.0
|
|
generation_ms: float = 0.0
|
|
visualization_ms: float = 0.0
|
|
|
|
# Sub-breakdowns
|
|
retrieval_breakdown: dict[str, float] = field(default_factory=dict)
|
|
llm_calls: list[LLMUsage] = field(default_factory=list)
|
|
retrievals: list[RetrievalTiming] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class SessionStats:
|
|
"""Accumulated session statistics."""
|
|
session_id: str
|
|
started_at: datetime
|
|
query_count: int = 0
|
|
total_input_tokens: int = 0
|
|
total_output_tokens: int = 0
|
|
total_cost_usd: float = 0.0
|
|
total_duration_ms: float = 0.0
|
|
|
|
# Per-source timing
|
|
qdrant_calls: int = 0
|
|
qdrant_total_ms: float = 0.0
|
|
oxigraph_calls: int = 0
|
|
oxigraph_total_ms: float = 0.0
|
|
typedb_calls: int = 0
|
|
typedb_total_ms: float = 0.0
|
|
|
|
# Per-model costs
|
|
costs_by_model: dict[str, float] = field(default_factory=dict)
|
|
tokens_by_model: dict[str, dict[str, int]] = field(default_factory=dict)
|
|
|
|
|
|
# =============================================================================
|
|
# Cost Tracker
|
|
# =============================================================================
|
|
|
|
class CostTracker:
|
|
"""
|
|
Tracks costs and performance for Heritage RAG pipeline.
|
|
|
|
Thread-safe session-based tracking with:
|
|
- LLM token usage and costs
|
|
- Retrieval timing breakdowns
|
|
- Accumulated session totals
|
|
"""
|
|
|
|
def __init__(self, session_id: str | None = None):
|
|
import uuid
|
|
self.session_id = session_id or str(uuid.uuid4())[:8]
|
|
self.started_at = datetime.now(timezone.utc)
|
|
self._stats = SessionStats(
|
|
session_id=self.session_id,
|
|
started_at=self.started_at,
|
|
)
|
|
self._current_pipeline: PipelineTiming | None = None
|
|
# Use try/except to safely check for running event loop (compatible with ThreadPoolExecutor)
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
self._lock = asyncio.Lock()
|
|
except RuntimeError:
|
|
# No event loop running in this thread
|
|
self._lock = None
|
|
|
|
def reset(self) -> None:
|
|
"""Reset session statistics."""
|
|
self._stats = SessionStats(
|
|
session_id=self.session_id,
|
|
started_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
@contextmanager
|
|
def track_pipeline(self) -> Generator[PipelineTiming, None, None]:
|
|
"""Context manager for tracking a full pipeline execution."""
|
|
timing = PipelineTiming()
|
|
self._current_pipeline = timing
|
|
start = time.perf_counter()
|
|
|
|
try:
|
|
yield timing
|
|
finally:
|
|
timing.total_ms = (time.perf_counter() - start) * 1000
|
|
self._stats.query_count += 1
|
|
self._stats.total_duration_ms += timing.total_ms
|
|
self._current_pipeline = None
|
|
|
|
@contextmanager
|
|
def track_stage(self, stage: str) -> Generator[None, None, None]:
|
|
"""Track timing for a pipeline stage."""
|
|
start = time.perf_counter()
|
|
try:
|
|
yield
|
|
finally:
|
|
duration_ms = (time.perf_counter() - start) * 1000
|
|
if self._current_pipeline:
|
|
setattr(self._current_pipeline, f"{stage}_ms", duration_ms)
|
|
|
|
@contextmanager
|
|
def track_llm_call(self, model: str) -> Generator[LLMUsage, None, None]:
|
|
"""
|
|
Context manager for tracking an LLM call.
|
|
|
|
Integrates with DSPy's track_usage for automatic token counting.
|
|
"""
|
|
usage = LLMUsage(model=model)
|
|
start = time.perf_counter()
|
|
|
|
try:
|
|
# Use DSPy's built-in usage tracking
|
|
with dspy.track_usage() as dspy_tracker:
|
|
yield usage
|
|
|
|
# Extract token counts from DSPy tracker
|
|
totals = dspy_tracker.get_total_tokens()
|
|
for lm_name, tokens in totals.items():
|
|
if isinstance(tokens, dict):
|
|
usage.input_tokens += tokens.get("prompt_tokens", 0) or tokens.get("input_tokens", 0)
|
|
usage.output_tokens += tokens.get("completion_tokens", 0) or tokens.get("output_tokens", 0)
|
|
usage.cached_tokens += tokens.get("cached_tokens", 0)
|
|
except Exception as e:
|
|
logger.warning(f"Token tracking failed: {e}")
|
|
finally:
|
|
usage.duration_ms = (time.perf_counter() - start) * 1000
|
|
|
|
# Update session stats
|
|
self._stats.total_input_tokens += usage.input_tokens
|
|
self._stats.total_output_tokens += usage.output_tokens
|
|
self._stats.total_cost_usd += usage.cost_usd
|
|
|
|
# Per-model tracking
|
|
if model not in self._stats.costs_by_model:
|
|
self._stats.costs_by_model[model] = 0.0
|
|
self._stats.tokens_by_model[model] = {"input": 0, "output": 0}
|
|
self._stats.costs_by_model[model] += usage.cost_usd
|
|
self._stats.tokens_by_model[model]["input"] += usage.input_tokens
|
|
self._stats.tokens_by_model[model]["output"] += usage.output_tokens
|
|
|
|
# Add to current pipeline if active
|
|
if self._current_pipeline:
|
|
self._current_pipeline.llm_calls.append(usage)
|
|
|
|
@contextmanager
|
|
def track_retrieval(self, source: str, query: str = "") -> Generator[RetrievalTiming, None, None]:
|
|
"""Track a retrieval operation."""
|
|
timing = RetrievalTiming(source=source, query=query)
|
|
start = time.perf_counter()
|
|
|
|
try:
|
|
yield timing
|
|
finally:
|
|
timing.duration_ms = (time.perf_counter() - start) * 1000
|
|
|
|
# Update session stats
|
|
if source == "qdrant":
|
|
self._stats.qdrant_calls += 1
|
|
self._stats.qdrant_total_ms += timing.duration_ms
|
|
elif source in ("oxigraph", "sparql"):
|
|
self._stats.oxigraph_calls += 1
|
|
self._stats.oxigraph_total_ms += timing.duration_ms
|
|
elif source == "typedb":
|
|
self._stats.typedb_calls += 1
|
|
self._stats.typedb_total_ms += timing.duration_ms
|
|
|
|
# Add to current pipeline if active
|
|
if self._current_pipeline:
|
|
self._current_pipeline.retrievals.append(timing)
|
|
self._current_pipeline.retrieval_breakdown[source] = timing.duration_ms
|
|
|
|
@asynccontextmanager
|
|
async def track_retrieval_async(self, source: str, query: str = "") -> AsyncGenerator[RetrievalTiming, None]:
|
|
"""Async version of track_retrieval."""
|
|
timing = RetrievalTiming(source=source, query=query)
|
|
start = time.perf_counter()
|
|
|
|
try:
|
|
yield timing
|
|
finally:
|
|
timing.duration_ms = (time.perf_counter() - start) * 1000
|
|
|
|
# Thread-safe update
|
|
if source == "qdrant":
|
|
self._stats.qdrant_calls += 1
|
|
self._stats.qdrant_total_ms += timing.duration_ms
|
|
elif source in ("oxigraph", "sparql"):
|
|
self._stats.oxigraph_calls += 1
|
|
self._stats.oxigraph_total_ms += timing.duration_ms
|
|
elif source == "typedb":
|
|
self._stats.typedb_calls += 1
|
|
self._stats.typedb_total_ms += timing.duration_ms
|
|
|
|
if self._current_pipeline:
|
|
self._current_pipeline.retrievals.append(timing)
|
|
self._current_pipeline.retrieval_breakdown[source] = timing.duration_ms
|
|
|
|
def get_session_summary(self) -> dict[str, Any]:
|
|
"""Get comprehensive session summary."""
|
|
elapsed_sec = (datetime.now(timezone.utc) - self._stats.started_at).total_seconds()
|
|
|
|
return {
|
|
"session_id": self.session_id,
|
|
"elapsed_seconds": round(elapsed_sec, 1),
|
|
"queries": self._stats.query_count,
|
|
"costs": {
|
|
"total_usd": round(self._stats.total_cost_usd, 6),
|
|
"by_model": {
|
|
model: round(cost, 6)
|
|
for model, cost in self._stats.costs_by_model.items()
|
|
},
|
|
},
|
|
"tokens": {
|
|
"total_input": self._stats.total_input_tokens,
|
|
"total_output": self._stats.total_output_tokens,
|
|
"by_model": self._stats.tokens_by_model,
|
|
},
|
|
"timing": {
|
|
"total_ms": round(self._stats.total_duration_ms, 1),
|
|
"avg_query_ms": round(
|
|
self._stats.total_duration_ms / self._stats.query_count, 1
|
|
) if self._stats.query_count > 0 else 0,
|
|
"retrievers": {
|
|
"qdrant": {
|
|
"calls": self._stats.qdrant_calls,
|
|
"total_ms": round(self._stats.qdrant_total_ms, 1),
|
|
"avg_ms": round(
|
|
self._stats.qdrant_total_ms / self._stats.qdrant_calls, 1
|
|
) if self._stats.qdrant_calls > 0 else 0,
|
|
},
|
|
"oxigraph": {
|
|
"calls": self._stats.oxigraph_calls,
|
|
"total_ms": round(self._stats.oxigraph_total_ms, 1),
|
|
"avg_ms": round(
|
|
self._stats.oxigraph_total_ms / self._stats.oxigraph_calls, 1
|
|
) if self._stats.oxigraph_calls > 0 else 0,
|
|
},
|
|
"typedb": {
|
|
"calls": self._stats.typedb_calls,
|
|
"total_ms": round(self._stats.typedb_total_ms, 1),
|
|
"avg_ms": round(
|
|
self._stats.typedb_total_ms / self._stats.typedb_calls, 1
|
|
) if self._stats.typedb_calls > 0 else 0,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
def format_cost_display(self) -> str:
|
|
"""Format cost for display in CLI/UI."""
|
|
cost = self._stats.total_cost_usd
|
|
|
|
if cost == 0:
|
|
return "$0.00"
|
|
elif cost < 0.01:
|
|
return f"${cost:.4f}"
|
|
elif cost < 1.00:
|
|
return f"${cost:.3f}"
|
|
else:
|
|
return f"${cost:.2f}"
|
|
|
|
def get_last_query_summary(self) -> dict[str, Any] | None:
|
|
"""Get summary of the last pipeline execution."""
|
|
if not self._current_pipeline:
|
|
return None
|
|
|
|
p = self._current_pipeline
|
|
llm_cost = sum(u.cost_usd for u in p.llm_calls)
|
|
|
|
return {
|
|
"total_ms": round(p.total_ms, 1),
|
|
"breakdown": {
|
|
"query_routing_ms": round(p.query_routing_ms, 1),
|
|
"entity_extraction_ms": round(p.entity_extraction_ms, 1),
|
|
"retrieval_ms": round(p.retrieval_ms, 1),
|
|
"generation_ms": round(p.generation_ms, 1),
|
|
},
|
|
"retrievals": {
|
|
source: round(ms, 1)
|
|
for source, ms in p.retrieval_breakdown.items()
|
|
},
|
|
"llm": {
|
|
"calls": len(p.llm_calls),
|
|
"total_tokens": sum(u.total_tokens for u in p.llm_calls),
|
|
"cost_usd": round(llm_cost, 6),
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Global Tracker Instance
|
|
# =============================================================================
|
|
|
|
_global_tracker: CostTracker | None = None
|
|
|
|
|
|
def get_tracker() -> CostTracker:
|
|
"""Get or create the global cost tracker."""
|
|
global _global_tracker
|
|
if _global_tracker is None:
|
|
_global_tracker = CostTracker()
|
|
return _global_tracker
|
|
|
|
|
|
def reset_tracker() -> CostTracker:
|
|
"""Reset and return a new global tracker."""
|
|
global _global_tracker
|
|
_global_tracker = CostTracker()
|
|
return _global_tracker
|
|
|
|
|
|
# =============================================================================
|
|
# Benchmark Utilities
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class BenchmarkResult:
|
|
"""Result of a benchmark run."""
|
|
name: str
|
|
iterations: int
|
|
total_ms: float
|
|
avg_ms: float
|
|
min_ms: float
|
|
max_ms: float
|
|
p50_ms: float
|
|
p95_ms: float
|
|
p99_ms: float
|
|
errors: int = 0
|
|
|
|
|
|
def run_benchmark(
|
|
name: str,
|
|
func: Any,
|
|
iterations: int = 10,
|
|
warmup: int = 2,
|
|
) -> BenchmarkResult:
|
|
"""
|
|
Run a synchronous benchmark.
|
|
|
|
Args:
|
|
name: Benchmark name
|
|
func: Function to benchmark (no args)
|
|
iterations: Number of iterations
|
|
warmup: Warmup iterations (not counted)
|
|
"""
|
|
# Warmup
|
|
for _ in range(warmup):
|
|
try:
|
|
func()
|
|
except Exception:
|
|
pass
|
|
|
|
# Benchmark
|
|
times: list[float] = []
|
|
errors = 0
|
|
|
|
for _ in range(iterations):
|
|
start = time.perf_counter()
|
|
try:
|
|
func()
|
|
except Exception:
|
|
errors += 1
|
|
times.append((time.perf_counter() - start) * 1000)
|
|
|
|
times.sort()
|
|
|
|
return BenchmarkResult(
|
|
name=name,
|
|
iterations=iterations,
|
|
total_ms=sum(times),
|
|
avg_ms=sum(times) / len(times),
|
|
min_ms=times[0],
|
|
max_ms=times[-1],
|
|
p50_ms=times[len(times) // 2],
|
|
p95_ms=times[int(len(times) * 0.95)],
|
|
p99_ms=times[int(len(times) * 0.99)],
|
|
errors=errors,
|
|
)
|
|
|
|
|
|
async def run_benchmark_async(
|
|
name: str,
|
|
func: Any,
|
|
iterations: int = 10,
|
|
warmup: int = 2,
|
|
) -> BenchmarkResult:
|
|
"""
|
|
Run an async benchmark.
|
|
|
|
Args:
|
|
name: Benchmark name
|
|
func: Async function to benchmark (no args)
|
|
iterations: Number of iterations
|
|
warmup: Warmup iterations (not counted)
|
|
"""
|
|
# Warmup
|
|
for _ in range(warmup):
|
|
try:
|
|
await func()
|
|
except Exception:
|
|
pass
|
|
|
|
# Benchmark
|
|
times: list[float] = []
|
|
errors = 0
|
|
|
|
for _ in range(iterations):
|
|
start = time.perf_counter()
|
|
try:
|
|
await func()
|
|
except Exception:
|
|
errors += 1
|
|
times.append((time.perf_counter() - start) * 1000)
|
|
|
|
times.sort()
|
|
|
|
return BenchmarkResult(
|
|
name=name,
|
|
iterations=iterations,
|
|
total_ms=sum(times),
|
|
avg_ms=sum(times) / len(times),
|
|
min_ms=times[0],
|
|
max_ms=times[-1],
|
|
p50_ms=times[len(times) // 2],
|
|
p95_ms=times[int(len(times) * 0.95)],
|
|
p99_ms=times[int(len(times) * 0.99)],
|
|
errors=errors,
|
|
)
|
|
|
|
|
|
def format_benchmark_report(results: list[BenchmarkResult]) -> str:
|
|
"""Format benchmark results as a table."""
|
|
lines = [
|
|
"=" * 80,
|
|
"BENCHMARK RESULTS",
|
|
"=" * 80,
|
|
f"{'Name':<30} {'Avg (ms)':<12} {'P50 (ms)':<12} {'P95 (ms)':<12} {'Errors':<8}",
|
|
"-" * 80,
|
|
]
|
|
|
|
for r in results:
|
|
lines.append(
|
|
f"{r.name:<30} {r.avg_ms:<12.1f} {r.p50_ms:<12.1f} {r.p95_ms:<12.1f} {r.errors:<8}"
|
|
)
|
|
|
|
lines.append("=" * 80)
|
|
return "\n".join(lines)
|