480 lines
18 KiB
Python
480 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Benchmark OpenAI Prompt Caching for Heritage RAG Pipeline
|
|
|
|
This script measures:
|
|
1. Token counts of schema-aware DSPy signatures (need 1024+ for caching)
|
|
2. Actual cache hit rates via OpenAI API
|
|
3. Latency improvements from prompt caching
|
|
|
|
OpenAI Prompt Caching Requirements:
|
|
- Minimum 1024 tokens in prompt for caching
|
|
- Cache lifetime: 5-10 minutes active, up to 24h extended
|
|
- 50% discount on cached input tokens
|
|
- Up to 80% latency reduction
|
|
|
|
Date: 2025-12-20
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
import sys
|
|
from pathlib import Path
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
# Token counting
|
|
try:
|
|
import tiktoken
|
|
TIKTOKEN_AVAILABLE = True
|
|
except ImportError:
|
|
TIKTOKEN_AVAILABLE = False
|
|
print("Warning: tiktoken not installed. Install with: pip install tiktoken")
|
|
|
|
# DSPy
|
|
try:
|
|
import dspy
|
|
DSPY_AVAILABLE = True
|
|
except ImportError:
|
|
DSPY_AVAILABLE = False
|
|
print("Warning: dspy not installed")
|
|
|
|
# OpenAI
|
|
try:
|
|
from openai import OpenAI
|
|
OPENAI_AVAILABLE = True
|
|
except ImportError:
|
|
OPENAI_AVAILABLE = False
|
|
print("Warning: openai not installed")
|
|
|
|
|
|
@dataclass
|
|
class SignatureTokenStats:
|
|
"""Statistics for a DSPy signature's token counts."""
|
|
name: str
|
|
docstring_tokens: int = 0
|
|
input_fields_tokens: int = 0
|
|
output_fields_tokens: int = 0
|
|
total_tokens: int = 0
|
|
cacheable: bool = False # True if >= 1024 tokens
|
|
notes: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class CacheTestResult:
|
|
"""Result of a cache hit test."""
|
|
request_num: int
|
|
latency_ms: float
|
|
prompt_tokens: int
|
|
cached_tokens: int
|
|
completion_tokens: int
|
|
cache_hit_rate: float # 0.0 to 1.0
|
|
|
|
|
|
def count_tokens(text: str, model: str = "gpt-4o") -> int:
|
|
"""Count tokens in text using tiktoken."""
|
|
if not TIKTOKEN_AVAILABLE:
|
|
# Rough estimate: ~4 characters per token
|
|
return len(text) // 4
|
|
|
|
try:
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
return len(encoding.encode(text))
|
|
except KeyError:
|
|
# Fallback to cl100k_base for newer models
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
return len(encoding.encode(text))
|
|
|
|
|
|
def analyze_signature_tokens(signature_class, name: str) -> SignatureTokenStats:
|
|
"""Analyze token counts for a DSPy signature class."""
|
|
stats = SignatureTokenStats(name=name)
|
|
|
|
# Get docstring
|
|
docstring = signature_class.__doc__ or ""
|
|
stats.docstring_tokens = count_tokens(docstring)
|
|
|
|
# Analyze input fields
|
|
input_desc = ""
|
|
output_desc = ""
|
|
|
|
# Get field annotations from signature
|
|
if hasattr(signature_class, 'model_fields'):
|
|
for field_name, field_info in signature_class.model_fields.items():
|
|
desc = str(field_info.default) if hasattr(field_info, 'default') else ""
|
|
if 'input' in str(type(field_info)).lower() or 'Input' in str(field_info.default):
|
|
input_desc += f"{field_name}: {desc}\n"
|
|
else:
|
|
output_desc += f"{field_name}: {desc}\n"
|
|
|
|
# For DSPy signatures, also check __fields__
|
|
if hasattr(signature_class, '__fields__'):
|
|
for field_name, field_info in signature_class.__fields__.items():
|
|
field_obj = getattr(signature_class, field_name, None)
|
|
if field_obj:
|
|
desc = getattr(field_obj, 'desc', str(field_obj))
|
|
if isinstance(field_obj, dspy.InputField):
|
|
input_desc += f"{field_name}: {desc}\n"
|
|
else:
|
|
output_desc += f"{field_name}: {desc}\n"
|
|
|
|
stats.input_fields_tokens = count_tokens(input_desc)
|
|
stats.output_fields_tokens = count_tokens(output_desc)
|
|
|
|
# Total tokens in the full prompt template
|
|
stats.total_tokens = stats.docstring_tokens + stats.input_fields_tokens + stats.output_fields_tokens
|
|
|
|
# Check if cacheable (>= 1024 tokens)
|
|
stats.cacheable = stats.total_tokens >= 1024
|
|
|
|
if stats.total_tokens < 1024:
|
|
stats.notes.append(f"Need {1024 - stats.total_tokens} more tokens for caching")
|
|
else:
|
|
stats.notes.append(f"Cacheable! {stats.total_tokens - 1024} tokens over threshold")
|
|
|
|
return stats
|
|
|
|
|
|
def test_openai_caching(num_requests: int = 5, model: str = "gpt-4o") -> list[CacheTestResult]:
|
|
"""Test OpenAI prompt caching with identical requests.
|
|
|
|
Makes multiple identical requests and measures cache hit rates.
|
|
"""
|
|
if not OPENAI_AVAILABLE:
|
|
print("OpenAI not available, skipping cache test")
|
|
return []
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
print("OPENAI_API_KEY not set, skipping cache test")
|
|
return []
|
|
|
|
client = OpenAI(api_key=api_key)
|
|
results = []
|
|
|
|
# Create a long system prompt (>1024 tokens) to test caching
|
|
# This simulates the schema-aware context
|
|
system_prompt = """You are an expert in GLAM (Galleries, Libraries, Archives, Museums)
|
|
heritage institutions. Your role is to classify user queries and route them to appropriate data sources.
|
|
|
|
HERITAGE CUSTODIAN ONTOLOGY CONTEXT
|
|
============================================================
|
|
|
|
Hub Architecture:
|
|
- Custodian (crm:E39_Actor): Central hub entity
|
|
- CustodianObservation: Evidence from sources
|
|
- CustodianName: Standardized emic names
|
|
- CustodianLegalStatus: Formal legal entity
|
|
- CustodianPlace: Geographic location
|
|
- CustodianCollection: Heritage collections
|
|
|
|
Heritage Custodian Types (GLAMORCUBESFIXPHDNT taxonomy):
|
|
- GALLERY: Art gallery or exhibition space for visual arts
|
|
- LIBRARY: Institution maintaining collections of books, periodicals, and other media
|
|
- ARCHIVE: Institution maintaining historical records and documents
|
|
- MUSEUM: Institution collecting, preserving, and displaying objects of cultural or scientific significance
|
|
- OFFICIAL_INSTITUTION: Government-operated heritage institution or registry
|
|
- RESEARCH_CENTER: Academic or research institution focused on heritage studies
|
|
- COMMERCIAL: For-profit organization involved in heritage sector
|
|
- UNSPECIFIED: Heritage institution whose specific type is not yet classified
|
|
- BIO_CUSTODIAN: Institution focused on botanical, zoological, or natural heritage
|
|
- EDUCATION_PROVIDER: Educational institution with significant heritage focus
|
|
- HERITAGE_SOCIETY: Non-profit society or association focused on heritage preservation
|
|
- FEATURE_CUSTODIAN: Institution responsible for physical landscape features
|
|
- INTANGIBLE_HERITAGE_GROUP: Organization preserving intangible cultural heritage
|
|
- MIXED: Institution combining multiple heritage types
|
|
- PERSONAL_COLLECTION: Private collection maintained by an individual
|
|
- HOLY_SACRED_SITE: Religious or sacred site with heritage significance
|
|
- DIGITAL_PLATFORM: Online platform or digital repository for heritage content
|
|
- NON_PROFIT: Non-governmental organization in heritage sector
|
|
- TASTE_SCENT_HERITAGE: Institution preserving culinary, olfactory, or gustatory heritage
|
|
|
|
Key Properties:
|
|
- hc:hc_id: Global Heritage Custodian Identifier (GHCID)
|
|
- hc:preferred_label: Primary name of the institution
|
|
- hc:custodian_type: Type classification from GLAMORCUBESFIXPHDNT taxonomy
|
|
- hc:legal_status: Legal form and registration status
|
|
- hc:place_designation: Geographic location and address
|
|
- hc:has_collection: Collections maintained by the institution
|
|
- hc:identifiers: External identifiers (ISIL, Wikidata, VIAF)
|
|
- hc:organizational_structure: Internal organization and departments
|
|
- hc:encompassing_body: Parent organization or governing body
|
|
|
|
Staff Role Categories (13 categories):
|
|
- CURATORIAL: Curator, Collections Manager, Registrar
|
|
- ARCHIVAL: Archivist, Records Manager, Digital Archivist
|
|
- LIBRARY: Librarian, Cataloger, Reference Specialist
|
|
- CONSERVATION: Conservator, Restorer, Conservation Scientist
|
|
- DIGITAL: Data Engineer, Digital Curator, Software Developer
|
|
- EDUCATION: Museum Educator, Public Programs Manager
|
|
- RESEARCH: Researcher, Historian, Archaeologist
|
|
- ADMINISTRATIVE: Director, Manager, Coordinator
|
|
- VISITOR_SERVICES: Front Desk, Tour Guide, Security
|
|
- DEVELOPMENT: Fundraiser, Grant Writer, Donor Relations
|
|
- COMMUNICATIONS: Marketing, Public Relations, Social Media
|
|
- FACILITIES: Building Manager, Exhibition Installer
|
|
- VOLUNTEER: Volunteer Coordinator, Docent
|
|
|
|
Key Ontology Prefixes:
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
PREFIX crm: <http://www.cidoc-crm.org/cidoc-crm/>
|
|
PREFIX prov: <http://www.w3.org/ns/prov#>
|
|
PREFIX schema: <http://schema.org/>
|
|
PREFIX cpov: <http://data.europa.eu/m8g/>
|
|
PREFIX rico: <https://www.ica.org/standards/RiC/ontology#>
|
|
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
|
PREFIX tooi: <https://identifier.overheid.nl/tooi/def/ont/>
|
|
PREFIX org: <http://www.w3.org/ns/org#>
|
|
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
|
|
PREFIX dcterms: <http://purl.org/dc/terms/>
|
|
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
|
|
PREFIX wikidata: <http://www.wikidata.org/entity/>
|
|
PREFIX geo: <http://www.opengis.net/ont/geosparql#>
|
|
PREFIX ghcid: <https://nde.nl/ontology/hc/>
|
|
|
|
MULTILINGUAL SYNONYMS:
|
|
- MUSEUM: "museum", "musea", "museo", "musée", "музей" (ru), "博物馆" (zh)
|
|
- LIBRARY: "library", "bibliotheek", "bibliothèque", "biblioteca", "библиотека" (ru)
|
|
- ARCHIVE: "archive", "archief", "archiv", "archivo", "архив" (ru), "档案馆" (zh)
|
|
- GALLERY: "gallery", "galerie", "galería", "galleria", "галерея" (ru)
|
|
|
|
When classifying queries:
|
|
1. Identify the primary intent (geographic, statistical, relational, temporal, entity_lookup, comparative, exploration)
|
|
2. Extract named entities (institution names, places, dates)
|
|
3. Recommend data sources (qdrant, sparql, typedb, postgis)
|
|
4. Classify entity type (person, institution, or both)
|
|
5. If person-related, identify the role category and specific role
|
|
6. If institution-related, identify the custodian type
|
|
|
|
============================================================
|
|
"""
|
|
|
|
# Count tokens in system prompt
|
|
system_tokens = count_tokens(system_prompt)
|
|
print(f"\nSystem prompt tokens: {system_tokens}")
|
|
print(f"Cacheable: {system_tokens >= 1024}")
|
|
|
|
# Make identical requests
|
|
user_message = "How many museums are there in Amsterdam?"
|
|
|
|
print(f"\nMaking {num_requests} identical requests to test caching...")
|
|
print("-" * 60)
|
|
|
|
for i in range(num_requests):
|
|
start_time = time.time()
|
|
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_message}
|
|
],
|
|
max_tokens=150,
|
|
)
|
|
|
|
latency_ms = (time.time() - start_time) * 1000
|
|
|
|
# Extract usage stats
|
|
usage = response.usage
|
|
prompt_tokens = usage.prompt_tokens
|
|
completion_tokens = usage.completion_tokens
|
|
|
|
# Check for cached tokens (new OpenAI feature)
|
|
cached_tokens = 0
|
|
if hasattr(usage, 'prompt_tokens_details') and usage.prompt_tokens_details:
|
|
cached_tokens = getattr(usage.prompt_tokens_details, 'cached_tokens', 0) or 0
|
|
|
|
cache_hit_rate = cached_tokens / prompt_tokens if prompt_tokens > 0 else 0
|
|
|
|
result = CacheTestResult(
|
|
request_num=i + 1,
|
|
latency_ms=latency_ms,
|
|
prompt_tokens=prompt_tokens,
|
|
cached_tokens=cached_tokens,
|
|
completion_tokens=completion_tokens,
|
|
cache_hit_rate=cache_hit_rate,
|
|
)
|
|
results.append(result)
|
|
|
|
print(f"Request {i+1}: {latency_ms:.0f}ms | "
|
|
f"Prompt: {prompt_tokens} | "
|
|
f"Cached: {cached_tokens} ({cache_hit_rate:.1%}) | "
|
|
f"Completion: {completion_tokens}")
|
|
|
|
# Small delay between requests
|
|
if i < num_requests - 1:
|
|
time.sleep(0.5)
|
|
|
|
except Exception as e:
|
|
print(f"Request {i+1} failed: {e}")
|
|
|
|
return results
|
|
|
|
|
|
def analyze_heritage_rag_signatures():
|
|
"""Analyze all DSPy signatures in the Heritage RAG pipeline."""
|
|
print("\n" + "=" * 70)
|
|
print("HERITAGE RAG SIGNATURE TOKEN ANALYSIS")
|
|
print("=" * 70)
|
|
|
|
try:
|
|
from backend.rag.dspy_heritage_rag import (
|
|
get_schema_aware_query_intent_signature,
|
|
HeritageQueryIntent,
|
|
HeritageEntityExtraction,
|
|
HeritageAnswerGeneration,
|
|
)
|
|
from backend.rag.schema_loader import (
|
|
create_schema_aware_sparql_docstring,
|
|
create_schema_aware_entity_docstring,
|
|
get_ontology_context,
|
|
)
|
|
except ImportError as e:
|
|
print(f"Could not import heritage RAG modules: {e}")
|
|
return
|
|
|
|
print("\n1. SCHEMA-AWARE CONTEXT TOKEN COUNTS")
|
|
print("-" * 50)
|
|
|
|
# Analyze ontology context
|
|
try:
|
|
ontology_context = get_ontology_context()
|
|
ont_tokens = count_tokens(ontology_context)
|
|
print(f"Ontology Context: {ont_tokens:,} tokens {'[CACHEABLE]' if ont_tokens >= 1024 else '[TOO SHORT]'}")
|
|
except Exception as e:
|
|
print(f"Could not get ontology context: {e}")
|
|
|
|
# Analyze SPARQL docstring
|
|
try:
|
|
sparql_doc = create_schema_aware_sparql_docstring()
|
|
sparql_tokens = count_tokens(sparql_doc)
|
|
print(f"SPARQL Docstring: {sparql_tokens:,} tokens {'[CACHEABLE]' if sparql_tokens >= 1024 else '[TOO SHORT]'}")
|
|
except Exception as e:
|
|
print(f"Could not get SPARQL docstring: {e}")
|
|
|
|
# Analyze entity docstring
|
|
try:
|
|
entity_doc = create_schema_aware_entity_docstring()
|
|
entity_tokens = count_tokens(entity_doc)
|
|
print(f"Entity Extractor Doc: {entity_tokens:,} tokens {'[CACHEABLE]' if entity_tokens >= 1024 else '[TOO SHORT]'}")
|
|
except Exception as e:
|
|
print(f"Could not get entity docstring: {e}")
|
|
|
|
print("\n2. DSPY SIGNATURE TOKEN COUNTS")
|
|
print("-" * 50)
|
|
|
|
# Analyze each signature
|
|
signatures_to_analyze = [
|
|
("HeritageQueryIntent (base)", HeritageQueryIntent),
|
|
]
|
|
|
|
try:
|
|
schema_aware_sig = get_schema_aware_query_intent_signature()
|
|
signatures_to_analyze.append(("SchemaAwareQueryIntent", schema_aware_sig))
|
|
except Exception as e:
|
|
print(f"Could not get schema-aware signature: {e}")
|
|
|
|
signatures_to_analyze.extend([
|
|
("HeritageEntityExtraction", HeritageEntityExtraction),
|
|
("HeritageAnswerGeneration", HeritageAnswerGeneration),
|
|
])
|
|
|
|
for name, sig_class in signatures_to_analyze:
|
|
try:
|
|
stats = analyze_signature_tokens(sig_class, name)
|
|
status = "[CACHEABLE]" if stats.cacheable else "[TOO SHORT]"
|
|
print(f"{name}: {stats.total_tokens:,} tokens {status}")
|
|
print(f" Docstring: {stats.docstring_tokens}, Fields: {stats.input_fields_tokens + stats.output_fields_tokens}")
|
|
for note in stats.notes:
|
|
print(f" Note: {note}")
|
|
except Exception as e:
|
|
print(f"{name}: Error - {e}")
|
|
|
|
print("\n3. RECOMMENDATIONS")
|
|
print("-" * 50)
|
|
|
|
recommendations = [
|
|
"1. Restructure prompts with STATIC content first (ontology context)",
|
|
" - Move schema definitions, type lists, prefix declarations to prompt start",
|
|
" - Put user query and dynamic content at the end",
|
|
"",
|
|
"2. Consider consolidating short signatures:",
|
|
" - If docstrings are under 1024 tokens, merge related signatures",
|
|
" - Or add more static ontology context to reach caching threshold",
|
|
"",
|
|
"3. Prompt structure for caching:",
|
|
" [SYSTEM - STATIC] Ontology context, type definitions (~1500 tokens)",
|
|
" [SYSTEM - STATIC] Role categories, properties (~500 tokens)",
|
|
" [USER - DYNAMIC] Actual query (variable)",
|
|
"",
|
|
"4. Cache hit optimization:",
|
|
" - Use identical system prompts across requests",
|
|
" - Vary only user message content",
|
|
" - Keep prompt prefix consistent for 1024+ tokens",
|
|
]
|
|
|
|
for rec in recommendations:
|
|
print(rec)
|
|
|
|
|
|
def main():
|
|
"""Run all prompt caching benchmarks."""
|
|
print("\n" + "=" * 70)
|
|
print("OPENAI PROMPT CACHING BENCHMARK FOR HERITAGE RAG")
|
|
print("=" * 70)
|
|
print(f"Tiktoken available: {TIKTOKEN_AVAILABLE}")
|
|
print(f"DSPy available: {DSPY_AVAILABLE}")
|
|
print(f"OpenAI available: {OPENAI_AVAILABLE}")
|
|
|
|
# 1. Analyze signature token counts
|
|
analyze_heritage_rag_signatures()
|
|
|
|
# 2. Test actual OpenAI caching
|
|
print("\n" + "=" * 70)
|
|
print("OPENAI CACHE HIT TEST")
|
|
print("=" * 70)
|
|
|
|
results = test_openai_caching(num_requests=5)
|
|
|
|
if results:
|
|
print("\n4. CACHE TEST SUMMARY")
|
|
print("-" * 50)
|
|
|
|
avg_latency = sum(r.latency_ms for r in results) / len(results)
|
|
avg_cache_rate = sum(r.cache_hit_rate for r in results) / len(results)
|
|
|
|
# Compare first vs subsequent requests
|
|
first_latency = results[0].latency_ms
|
|
subsequent_latencies = [r.latency_ms for r in results[1:]]
|
|
avg_subsequent = sum(subsequent_latencies) / len(subsequent_latencies) if subsequent_latencies else 0
|
|
|
|
print(f"Average latency: {avg_latency:.0f}ms")
|
|
print(f"First request: {first_latency:.0f}ms")
|
|
print(f"Subsequent avg: {avg_subsequent:.0f}ms")
|
|
if first_latency > 0:
|
|
improvement = (first_latency - avg_subsequent) / first_latency * 100
|
|
print(f"Latency improvement: {improvement:.1f}%")
|
|
print(f"Average cache hit rate: {avg_cache_rate:.1%}")
|
|
|
|
# Calculate potential savings
|
|
total_prompt_tokens = sum(r.prompt_tokens for r in results)
|
|
total_cached_tokens = sum(r.cached_tokens for r in results)
|
|
|
|
print(f"\nTotal prompt tokens: {total_prompt_tokens:,}")
|
|
print(f"Total cached tokens: {total_cached_tokens:,}")
|
|
if total_prompt_tokens > 0:
|
|
savings_pct = total_cached_tokens / total_prompt_tokens * 50 # 50% discount on cached
|
|
print(f"Estimated cost savings: {savings_pct:.1f}%")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|