""" Session Manager for Multi-Turn Conversation State Persists ConversationState objects across HTTP requests to support elliptical follow-ups like "En in Enschede?" resolved via context. Architecture: - In-memory storage (primary) - fast access, no dependencies - Optional Redis/Valkey backend - for distributed deployments - TTL-based cleanup for inactive sessions Usage: session_manager = get_session_manager() # Get or create session state = await session_manager.get_or_create(session_id) # Update after query await session_manager.update(session_id, state) # Session auto-expires after TTL Author: OpenCode Created: 2025-01-06 """ from __future__ import annotations import asyncio import json import logging import uuid from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Literal, Optional from pydantic import BaseModel, Field # Import DSPy History for to_dspy_history() conversion try: from dspy import History DSPY_AVAILABLE = True except ImportError: DSPY_AVAILABLE = False History = None # type: ignore logger = logging.getLogger(__name__) # ============================================================================= # CONVERSATION STATE MODELS (always defined for type checking) # ============================================================================= class ConversationTurn(BaseModel): """A single turn in conversation history.""" role: Literal["user", "assistant"] = "user" content: str = "" resolved_question: Optional[str] = None template_id: Optional[str] = None slots: dict[str, str] = Field(default_factory=dict) results: list[dict[str, Any]] = Field(default_factory=list) class ConversationState(BaseModel): """State tracking across conversation turns.""" turns: list[ConversationTurn] = Field(default_factory=list) current_slots: dict[str, str] = Field(default_factory=dict) current_template_id: Optional[str] = None language: str = "nl" def add_turn(self, turn: ConversationTurn) -> None: """Add a turn and update current state.""" self.turns.append(turn) if turn.role == "user" and turn.slots: # Inherit slots from user turns self.current_slots.update(turn.slots) if turn.template_id: self.current_template_id = turn.template_id def get_previous_user_turn(self) -> Optional[ConversationTurn]: """Get the most recent user turn.""" for turn in reversed(self.turns): if turn.role == "user": return turn return None def to_dspy_history(self) -> Any: """Convert to DSPy History object for LLM context. This method is required by ConversationContextResolver in template_sparql.py to enable follow-up question resolution (e.g., "en in Limburg?" → "Welke archieven zijn er in Limburg?"). Returns: DSPy History object with last 6 turns for context window efficiency """ if not DSPY_AVAILABLE or History is None: # Return a minimal dict-like structure as fallback logger.warning("DSPy not available, returning empty history") return {"messages": []} messages = [] for turn in self.turns[-6:]: # Keep last 6 turns for context messages.append({ "role": turn.role, "content": turn.resolved_question or turn.content }) return History(messages=messages) # ============================================================================= # CONFIGURATION # ============================================================================= class SessionConfig(BaseModel): """Session management configuration.""" # Session TTL (seconds) - default 30 minutes session_ttl: int = 1800 # Cleanup interval (seconds) - how often to purge expired sessions cleanup_interval: int = 300 # 5 minutes # Maximum sessions to keep in memory (LRU eviction if exceeded) max_sessions: int = 10000 # Redis/Valkey settings (optional) redis_enabled: bool = False redis_url: str = "redis://localhost:6379" redis_prefix: str = "heritage:session:" redis_db: int = 1 # Use different DB than cache to avoid conflicts # ============================================================================= # SESSION STORAGE # ============================================================================= class SessionEntry: """A session with metadata.""" def __init__(self, state: ConversationState, session_id: str): self.session_id = session_id self.state = state self.created_at = datetime.now(timezone.utc) self.last_accessed = datetime.now(timezone.utc) self.access_count = 0 def touch(self) -> None: """Update last accessed time.""" self.last_accessed = datetime.now(timezone.utc) self.access_count += 1 def is_expired(self, ttl_seconds: int) -> bool: """Check if session has expired.""" age = (datetime.now(timezone.utc) - self.last_accessed).total_seconds() return age > ttl_seconds def to_dict(self) -> dict[str, Any]: """Serialize for Redis storage.""" return { "session_id": self.session_id, "state": self.state.model_dump(), "created_at": self.created_at.isoformat(), "last_accessed": self.last_accessed.isoformat(), "access_count": self.access_count, } @classmethod def from_dict(cls, data: dict[str, Any]) -> "SessionEntry": """Deserialize from Redis storage.""" state = ConversationState(**data["state"]) entry = cls(state, data["session_id"]) entry.created_at = datetime.fromisoformat(data["created_at"]) entry.last_accessed = datetime.fromisoformat(data["last_accessed"]) entry.access_count = data.get("access_count", 0) return entry class SessionManager: """Manages conversation sessions across requests. Features: - In-memory storage with LRU eviction - Optional Redis/Valkey backend for persistence - TTL-based automatic cleanup - Thread-safe async operations """ def __init__(self, config: Optional[SessionConfig] = None): self.config = config or SessionConfig() # In-memory storage self._sessions: dict[str, SessionEntry] = {} self._access_order: list[str] = [] # For LRU eviction self._lock = asyncio.Lock() # Redis client (lazy initialized) self._redis: Any = None self._redis_available = False # Cleanup task self._cleanup_task: Optional[asyncio.Task[None]] = None self._running = False # Stats self._stats = { "sessions_created": 0, "sessions_expired": 0, "sessions_evicted": 0, "cache_hits": 0, "cache_misses": 0, } async def start(self) -> None: """Start the session manager and cleanup task.""" if self._running: return self._running = True # Try to connect to Redis if enabled if self.config.redis_enabled: await self._connect_redis() # Start cleanup task self._cleanup_task = asyncio.create_task(self._cleanup_loop()) logger.info( f"SessionManager started (ttl={self.config.session_ttl}s, " f"max={self.config.max_sessions}, redis={self._redis_available})" ) async def stop(self) -> None: """Stop the session manager and cleanup.""" self._running = False if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass if self._redis: await self._redis.close() self._redis = None logger.info(f"SessionManager stopped (stats: {self._stats})") async def _connect_redis(self) -> None: """Try to connect to Redis/Valkey.""" try: import redis.asyncio as redis_async self._redis = await redis_async.from_url( self.config.redis_url, db=self.config.redis_db, decode_responses=True, ) await self._redis.ping() self._redis_available = True logger.info(f"SessionManager connected to Redis: {self.config.redis_url}") except Exception as e: logger.warning(f"Redis not available, using memory-only: {e}") self._redis = None self._redis_available = False def generate_session_id(self) -> str: """Generate a new unique session ID.""" return str(uuid.uuid4()) async def get(self, session_id: str) -> Optional[ConversationState]: """Get session state by ID. Args: session_id: The session identifier Returns: ConversationState if found and not expired, None otherwise """ # Try memory first async with self._lock: if session_id in self._sessions: entry = self._sessions[session_id] if not entry.is_expired(self.config.session_ttl): entry.touch() self._update_access_order(session_id) self._stats["cache_hits"] += 1 return entry.state else: # Expired, remove it del self._sessions[session_id] if session_id in self._access_order: self._access_order.remove(session_id) self._stats["sessions_expired"] += 1 # Try Redis if available if self._redis_available and self._redis: try: key = f"{self.config.redis_prefix}{session_id}" data = await self._redis.get(key) if data: entry = SessionEntry.from_dict(json.loads(data)) entry.touch() # Store in memory for faster access async with self._lock: self._sessions[session_id] = entry self._access_order.append(session_id) await self._evict_if_needed() # Update Redis TTL await self._redis.expire(key, self.config.session_ttl) self._stats["cache_hits"] += 1 return entry.state except Exception as e: logger.warning(f"Redis get failed: {e}") self._stats["cache_misses"] += 1 return None async def get_or_create(self, session_id: Optional[str] = None) -> tuple[str, ConversationState]: """Get existing session or create a new one. Args: session_id: Optional session ID. If None, generates a new one. Returns: Tuple of (session_id, ConversationState) """ if session_id: state = await self.get(session_id) if state: return session_id, state # Create new session new_id = session_id or self.generate_session_id() state = ConversationState() entry = SessionEntry(state, new_id) async with self._lock: self._sessions[new_id] = entry self._access_order.append(new_id) await self._evict_if_needed() # Store in Redis if available if self._redis_available and self._redis: try: key = f"{self.config.redis_prefix}{new_id}" await self._redis.setex( key, self.config.session_ttl, json.dumps(entry.to_dict()), ) except Exception as e: logger.warning(f"Redis set failed: {e}") self._stats["sessions_created"] += 1 logger.debug(f"Created new session: {new_id}") return new_id, state async def update(self, session_id: str, state: ConversationState) -> None: """Update session state. Args: session_id: The session identifier state: Updated conversation state """ async with self._lock: if session_id in self._sessions: self._sessions[session_id].state = state self._sessions[session_id].touch() self._update_access_order(session_id) else: # Create new entry entry = SessionEntry(state, session_id) self._sessions[session_id] = entry self._access_order.append(session_id) await self._evict_if_needed() # Update Redis if available if self._redis_available and self._redis: try: key = f"{self.config.redis_prefix}{session_id}" entry = self._sessions.get(session_id) if entry: await self._redis.setex( key, self.config.session_ttl, json.dumps(entry.to_dict()), ) except Exception as e: logger.warning(f"Redis update failed: {e}") async def delete(self, session_id: str) -> bool: """Delete a session. Args: session_id: The session identifier Returns: True if session was deleted, False if not found """ deleted = False async with self._lock: if session_id in self._sessions: del self._sessions[session_id] if session_id in self._access_order: self._access_order.remove(session_id) deleted = True if self._redis_available and self._redis: try: key = f"{self.config.redis_prefix}{session_id}" await self._redis.delete(key) deleted = True except Exception as e: logger.warning(f"Redis delete failed: {e}") return deleted def _update_access_order(self, session_id: str) -> None: """Move session to end of access order (most recently used).""" if session_id in self._access_order: self._access_order.remove(session_id) self._access_order.append(session_id) async def _evict_if_needed(self) -> None: """Evict oldest sessions if over limit (LRU).""" while len(self._sessions) > self.config.max_sessions: if not self._access_order: break oldest_id = self._access_order.pop(0) if oldest_id in self._sessions: del self._sessions[oldest_id] self._stats["sessions_evicted"] += 1 logger.debug(f"Evicted session (LRU): {oldest_id}") async def _cleanup_loop(self) -> None: """Background task to cleanup expired sessions.""" while self._running: try: await asyncio.sleep(self.config.cleanup_interval) await self._cleanup_expired() except asyncio.CancelledError: break except Exception as e: logger.exception(f"Cleanup error: {e}") async def _cleanup_expired(self) -> None: """Remove expired sessions from memory.""" expired_ids = [] async with self._lock: for session_id, entry in list(self._sessions.items()): if entry.is_expired(self.config.session_ttl): expired_ids.append(session_id) for session_id in expired_ids: del self._sessions[session_id] if session_id in self._access_order: self._access_order.remove(session_id) self._stats["sessions_expired"] += 1 if expired_ids: logger.debug(f"Cleaned up {len(expired_ids)} expired sessions") async def get_stats(self) -> dict[str, Any]: """Get session manager statistics.""" async with self._lock: active_count = len(self._sessions) return { **self._stats, "active_sessions": active_count, "redis_available": self._redis_available, } async def add_turn_to_session( self, session_id: str, question: str, answer: str, resolved_question: Optional[str] = None, template_id: Optional[str] = None, slots: Optional[dict[str, str]] = None, results: Optional[list[dict[str, Any]]] = None, ) -> None: """Add a conversation turn to a session. This is a convenience method that: 1. Gets the session state 2. Creates a user turn with the question 3. Creates an assistant turn with the answer 4. Updates the session Args: session_id: The session identifier question: User's question answer: Assistant's answer resolved_question: The resolved form of the question (if different) template_id: Template ID used for this query slots: Extracted slot values results: Query results for context """ state = await self.get(session_id) if not state: _, state = await self.get_or_create(session_id) # Add user turn user_turn = ConversationTurn( role="user", content=question, resolved_question=resolved_question, template_id=template_id, slots=slots or {}, results=results or [], ) state.add_turn(user_turn) # Add assistant turn assistant_turn = ConversationTurn( role="assistant", content=answer, ) state.add_turn(assistant_turn) await self.update(session_id, state) # ============================================================================= # SINGLETON INSTANCE # ============================================================================= _session_manager: Optional[SessionManager] = None _session_manager_lock = asyncio.Lock() async def get_session_manager(config: Optional[SessionConfig] = None) -> SessionManager: """Get or create the singleton session manager. Args: config: Optional configuration override Returns: SessionManager instance """ global _session_manager async with _session_manager_lock: if _session_manager is None: _session_manager = SessionManager(config) await _session_manager.start() return _session_manager async def shutdown_session_manager() -> None: """Shutdown the session manager (for graceful shutdown).""" global _session_manager async with _session_manager_lock: if _session_manager: await _session_manager.stop() _session_manager = None