glam/backend/rag/session_manager.py
2026-01-02 02:11:04 +01:00

566 lines
19 KiB
Python

"""
Session Manager for Multi-Turn Conversation State
Persists ConversationState objects across HTTP requests to support
elliptical follow-ups like "En in Enschede?" resolved via context.
Architecture:
- In-memory storage (primary) - fast access, no dependencies
- Optional Redis/Valkey backend - for distributed deployments
- TTL-based cleanup for inactive sessions
Usage:
session_manager = get_session_manager()
# Get or create session
state = await session_manager.get_or_create(session_id)
# Update after query
await session_manager.update(session_id, state)
# Session auto-expires after TTL
Author: OpenCode
Created: 2025-01-06
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Literal, Optional
from pydantic import BaseModel, Field
# Import DSPy History for to_dspy_history() conversion
try:
from dspy import History
DSPY_AVAILABLE = True
except ImportError:
DSPY_AVAILABLE = False
History = None # type: ignore
logger = logging.getLogger(__name__)
# =============================================================================
# CONVERSATION STATE MODELS (always defined for type checking)
# =============================================================================
class ConversationTurn(BaseModel):
"""A single turn in conversation history."""
role: Literal["user", "assistant"] = "user"
content: str = ""
resolved_question: Optional[str] = None
template_id: Optional[str] = None
slots: dict[str, str] = Field(default_factory=dict)
results: list[dict[str, Any]] = Field(default_factory=list)
class ConversationState(BaseModel):
"""State tracking across conversation turns."""
turns: list[ConversationTurn] = Field(default_factory=list)
current_slots: dict[str, str] = Field(default_factory=dict)
current_template_id: Optional[str] = None
language: str = "nl"
def add_turn(self, turn: ConversationTurn) -> None:
"""Add a turn and update current state."""
self.turns.append(turn)
if turn.role == "user" and turn.slots:
# Inherit slots from user turns
self.current_slots.update(turn.slots)
if turn.template_id:
self.current_template_id = turn.template_id
def get_previous_user_turn(self) -> Optional[ConversationTurn]:
"""Get the most recent user turn."""
for turn in reversed(self.turns):
if turn.role == "user":
return turn
return None
def to_dspy_history(self) -> Any:
"""Convert to DSPy History object for LLM context.
This method is required by ConversationContextResolver in template_sparql.py
to enable follow-up question resolution (e.g., "en in Limburg?""Welke archieven
zijn er in Limburg?").
Returns:
DSPy History object with last 6 turns for context window efficiency
"""
if not DSPY_AVAILABLE or History is None:
# Return a minimal dict-like structure as fallback
logger.warning("DSPy not available, returning empty history")
return {"messages": []}
messages = []
for turn in self.turns[-6:]: # Keep last 6 turns for context
messages.append({
"role": turn.role,
"content": turn.resolved_question or turn.content
})
return History(messages=messages)
# =============================================================================
# CONFIGURATION
# =============================================================================
class SessionConfig(BaseModel):
"""Session management configuration."""
# Session TTL (seconds) - default 30 minutes
session_ttl: int = 1800
# Cleanup interval (seconds) - how often to purge expired sessions
cleanup_interval: int = 300 # 5 minutes
# Maximum sessions to keep in memory (LRU eviction if exceeded)
max_sessions: int = 10000
# Redis/Valkey settings (optional)
redis_enabled: bool = False
redis_url: str = "redis://localhost:6379"
redis_prefix: str = "heritage:session:"
redis_db: int = 1 # Use different DB than cache to avoid conflicts
# =============================================================================
# SESSION STORAGE
# =============================================================================
class SessionEntry:
"""A session with metadata."""
def __init__(self, state: ConversationState, session_id: str):
self.session_id = session_id
self.state = state
self.created_at = datetime.now(timezone.utc)
self.last_accessed = datetime.now(timezone.utc)
self.access_count = 0
def touch(self) -> None:
"""Update last accessed time."""
self.last_accessed = datetime.now(timezone.utc)
self.access_count += 1
def is_expired(self, ttl_seconds: int) -> bool:
"""Check if session has expired."""
age = (datetime.now(timezone.utc) - self.last_accessed).total_seconds()
return age > ttl_seconds
def to_dict(self) -> dict[str, Any]:
"""Serialize for Redis storage."""
return {
"session_id": self.session_id,
"state": self.state.model_dump(),
"created_at": self.created_at.isoformat(),
"last_accessed": self.last_accessed.isoformat(),
"access_count": self.access_count,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SessionEntry":
"""Deserialize from Redis storage."""
state = ConversationState(**data["state"])
entry = cls(state, data["session_id"])
entry.created_at = datetime.fromisoformat(data["created_at"])
entry.last_accessed = datetime.fromisoformat(data["last_accessed"])
entry.access_count = data.get("access_count", 0)
return entry
class SessionManager:
"""Manages conversation sessions across requests.
Features:
- In-memory storage with LRU eviction
- Optional Redis/Valkey backend for persistence
- TTL-based automatic cleanup
- Thread-safe async operations
"""
def __init__(self, config: Optional[SessionConfig] = None):
self.config = config or SessionConfig()
# In-memory storage
self._sessions: dict[str, SessionEntry] = {}
self._access_order: list[str] = [] # For LRU eviction
self._lock = asyncio.Lock()
# Redis client (lazy initialized)
self._redis: Any = None
self._redis_available = False
# Cleanup task
self._cleanup_task: Optional[asyncio.Task[None]] = None
self._running = False
# Stats
self._stats = {
"sessions_created": 0,
"sessions_expired": 0,
"sessions_evicted": 0,
"cache_hits": 0,
"cache_misses": 0,
}
async def start(self) -> None:
"""Start the session manager and cleanup task."""
if self._running:
return
self._running = True
# Try to connect to Redis if enabled
if self.config.redis_enabled:
await self._connect_redis()
# Start cleanup task
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info(
f"SessionManager started (ttl={self.config.session_ttl}s, "
f"max={self.config.max_sessions}, redis={self._redis_available})"
)
async def stop(self) -> None:
"""Stop the session manager and cleanup."""
self._running = False
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
if self._redis:
await self._redis.close()
self._redis = None
logger.info(f"SessionManager stopped (stats: {self._stats})")
async def _connect_redis(self) -> None:
"""Try to connect to Redis/Valkey."""
try:
import redis.asyncio as redis_async
self._redis = await redis_async.from_url(
self.config.redis_url,
db=self.config.redis_db,
decode_responses=True,
)
await self._redis.ping()
self._redis_available = True
logger.info(f"SessionManager connected to Redis: {self.config.redis_url}")
except Exception as e:
logger.warning(f"Redis not available, using memory-only: {e}")
self._redis = None
self._redis_available = False
def generate_session_id(self) -> str:
"""Generate a new unique session ID."""
return str(uuid.uuid4())
async def get(self, session_id: str) -> Optional[ConversationState]:
"""Get session state by ID.
Args:
session_id: The session identifier
Returns:
ConversationState if found and not expired, None otherwise
"""
# Try memory first
async with self._lock:
if session_id in self._sessions:
entry = self._sessions[session_id]
if not entry.is_expired(self.config.session_ttl):
entry.touch()
self._update_access_order(session_id)
self._stats["cache_hits"] += 1
return entry.state
else:
# Expired, remove it
del self._sessions[session_id]
if session_id in self._access_order:
self._access_order.remove(session_id)
self._stats["sessions_expired"] += 1
# Try Redis if available
if self._redis_available and self._redis:
try:
key = f"{self.config.redis_prefix}{session_id}"
data = await self._redis.get(key)
if data:
entry = SessionEntry.from_dict(json.loads(data))
entry.touch()
# Store in memory for faster access
async with self._lock:
self._sessions[session_id] = entry
self._access_order.append(session_id)
await self._evict_if_needed()
# Update Redis TTL
await self._redis.expire(key, self.config.session_ttl)
self._stats["cache_hits"] += 1
return entry.state
except Exception as e:
logger.warning(f"Redis get failed: {e}")
self._stats["cache_misses"] += 1
return None
async def get_or_create(self, session_id: Optional[str] = None) -> tuple[str, ConversationState]:
"""Get existing session or create a new one.
Args:
session_id: Optional session ID. If None, generates a new one.
Returns:
Tuple of (session_id, ConversationState)
"""
if session_id:
state = await self.get(session_id)
if state:
return session_id, state
# Create new session
new_id = session_id or self.generate_session_id()
state = ConversationState()
entry = SessionEntry(state, new_id)
async with self._lock:
self._sessions[new_id] = entry
self._access_order.append(new_id)
await self._evict_if_needed()
# Store in Redis if available
if self._redis_available and self._redis:
try:
key = f"{self.config.redis_prefix}{new_id}"
await self._redis.setex(
key,
self.config.session_ttl,
json.dumps(entry.to_dict()),
)
except Exception as e:
logger.warning(f"Redis set failed: {e}")
self._stats["sessions_created"] += 1
logger.debug(f"Created new session: {new_id}")
return new_id, state
async def update(self, session_id: str, state: ConversationState) -> None:
"""Update session state.
Args:
session_id: The session identifier
state: Updated conversation state
"""
async with self._lock:
if session_id in self._sessions:
self._sessions[session_id].state = state
self._sessions[session_id].touch()
self._update_access_order(session_id)
else:
# Create new entry
entry = SessionEntry(state, session_id)
self._sessions[session_id] = entry
self._access_order.append(session_id)
await self._evict_if_needed()
# Update Redis if available
if self._redis_available and self._redis:
try:
key = f"{self.config.redis_prefix}{session_id}"
entry = self._sessions.get(session_id)
if entry:
await self._redis.setex(
key,
self.config.session_ttl,
json.dumps(entry.to_dict()),
)
except Exception as e:
logger.warning(f"Redis update failed: {e}")
async def delete(self, session_id: str) -> bool:
"""Delete a session.
Args:
session_id: The session identifier
Returns:
True if session was deleted, False if not found
"""
deleted = False
async with self._lock:
if session_id in self._sessions:
del self._sessions[session_id]
if session_id in self._access_order:
self._access_order.remove(session_id)
deleted = True
if self._redis_available and self._redis:
try:
key = f"{self.config.redis_prefix}{session_id}"
await self._redis.delete(key)
deleted = True
except Exception as e:
logger.warning(f"Redis delete failed: {e}")
return deleted
def _update_access_order(self, session_id: str) -> None:
"""Move session to end of access order (most recently used)."""
if session_id in self._access_order:
self._access_order.remove(session_id)
self._access_order.append(session_id)
async def _evict_if_needed(self) -> None:
"""Evict oldest sessions if over limit (LRU)."""
while len(self._sessions) > self.config.max_sessions:
if not self._access_order:
break
oldest_id = self._access_order.pop(0)
if oldest_id in self._sessions:
del self._sessions[oldest_id]
self._stats["sessions_evicted"] += 1
logger.debug(f"Evicted session (LRU): {oldest_id}")
async def _cleanup_loop(self) -> None:
"""Background task to cleanup expired sessions."""
while self._running:
try:
await asyncio.sleep(self.config.cleanup_interval)
await self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception as e:
logger.exception(f"Cleanup error: {e}")
async def _cleanup_expired(self) -> None:
"""Remove expired sessions from memory."""
expired_ids = []
async with self._lock:
for session_id, entry in list(self._sessions.items()):
if entry.is_expired(self.config.session_ttl):
expired_ids.append(session_id)
for session_id in expired_ids:
del self._sessions[session_id]
if session_id in self._access_order:
self._access_order.remove(session_id)
self._stats["sessions_expired"] += 1
if expired_ids:
logger.debug(f"Cleaned up {len(expired_ids)} expired sessions")
async def get_stats(self) -> dict[str, Any]:
"""Get session manager statistics."""
async with self._lock:
active_count = len(self._sessions)
return {
**self._stats,
"active_sessions": active_count,
"redis_available": self._redis_available,
}
async def add_turn_to_session(
self,
session_id: str,
question: str,
answer: str,
resolved_question: Optional[str] = None,
template_id: Optional[str] = None,
slots: Optional[dict[str, str]] = None,
results: Optional[list[dict[str, Any]]] = None,
) -> None:
"""Add a conversation turn to a session.
This is a convenience method that:
1. Gets the session state
2. Creates a user turn with the question
3. Creates an assistant turn with the answer
4. Updates the session
Args:
session_id: The session identifier
question: User's question
answer: Assistant's answer
resolved_question: The resolved form of the question (if different)
template_id: Template ID used for this query
slots: Extracted slot values
results: Query results for context
"""
state = await self.get(session_id)
if not state:
_, state = await self.get_or_create(session_id)
# Add user turn
user_turn = ConversationTurn(
role="user",
content=question,
resolved_question=resolved_question,
template_id=template_id,
slots=slots or {},
results=results or [],
)
state.add_turn(user_turn)
# Add assistant turn
assistant_turn = ConversationTurn(
role="assistant",
content=answer,
)
state.add_turn(assistant_turn)
await self.update(session_id, state)
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
_session_manager: Optional[SessionManager] = None
_session_manager_lock = asyncio.Lock()
async def get_session_manager(config: Optional[SessionConfig] = None) -> SessionManager:
"""Get or create the singleton session manager.
Args:
config: Optional configuration override
Returns:
SessionManager instance
"""
global _session_manager
async with _session_manager_lock:
if _session_manager is None:
_session_manager = SessionManager(config)
await _session_manager.start()
return _session_manager
async def shutdown_session_manager() -> None:
"""Shutdown the session manager (for graceful shutdown)."""
global _session_manager
async with _session_manager_lock:
if _session_manager:
await _session_manager.stop()
_session_manager = None