566 lines
19 KiB
Python
566 lines
19 KiB
Python
"""
|
|
Session Manager for Multi-Turn Conversation State
|
|
|
|
Persists ConversationState objects across HTTP requests to support
|
|
elliptical follow-ups like "En in Enschede?" resolved via context.
|
|
|
|
Architecture:
|
|
- In-memory storage (primary) - fast access, no dependencies
|
|
- Optional Redis/Valkey backend - for distributed deployments
|
|
- TTL-based cleanup for inactive sessions
|
|
|
|
Usage:
|
|
session_manager = get_session_manager()
|
|
|
|
# Get or create session
|
|
state = await session_manager.get_or_create(session_id)
|
|
|
|
# Update after query
|
|
await session_manager.update(session_id, state)
|
|
|
|
# Session auto-expires after TTL
|
|
|
|
Author: OpenCode
|
|
Created: 2025-01-06
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import TYPE_CHECKING, Any, Literal, Optional
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Import DSPy History for to_dspy_history() conversion
|
|
try:
|
|
from dspy import History
|
|
DSPY_AVAILABLE = True
|
|
except ImportError:
|
|
DSPY_AVAILABLE = False
|
|
History = None # type: ignore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# CONVERSATION STATE MODELS (always defined for type checking)
|
|
# =============================================================================
|
|
|
|
class ConversationTurn(BaseModel):
|
|
"""A single turn in conversation history."""
|
|
role: Literal["user", "assistant"] = "user"
|
|
content: str = ""
|
|
resolved_question: Optional[str] = None
|
|
template_id: Optional[str] = None
|
|
slots: dict[str, str] = Field(default_factory=dict)
|
|
results: list[dict[str, Any]] = Field(default_factory=list)
|
|
|
|
|
|
class ConversationState(BaseModel):
|
|
"""State tracking across conversation turns."""
|
|
turns: list[ConversationTurn] = Field(default_factory=list)
|
|
current_slots: dict[str, str] = Field(default_factory=dict)
|
|
current_template_id: Optional[str] = None
|
|
language: str = "nl"
|
|
|
|
def add_turn(self, turn: ConversationTurn) -> None:
|
|
"""Add a turn and update current state."""
|
|
self.turns.append(turn)
|
|
if turn.role == "user" and turn.slots:
|
|
# Inherit slots from user turns
|
|
self.current_slots.update(turn.slots)
|
|
if turn.template_id:
|
|
self.current_template_id = turn.template_id
|
|
|
|
def get_previous_user_turn(self) -> Optional[ConversationTurn]:
|
|
"""Get the most recent user turn."""
|
|
for turn in reversed(self.turns):
|
|
if turn.role == "user":
|
|
return turn
|
|
return None
|
|
|
|
def to_dspy_history(self) -> Any:
|
|
"""Convert to DSPy History object for LLM context.
|
|
|
|
This method is required by ConversationContextResolver in template_sparql.py
|
|
to enable follow-up question resolution (e.g., "en in Limburg?" → "Welke archieven
|
|
zijn er in Limburg?").
|
|
|
|
Returns:
|
|
DSPy History object with last 6 turns for context window efficiency
|
|
"""
|
|
if not DSPY_AVAILABLE or History is None:
|
|
# Return a minimal dict-like structure as fallback
|
|
logger.warning("DSPy not available, returning empty history")
|
|
return {"messages": []}
|
|
|
|
messages = []
|
|
for turn in self.turns[-6:]: # Keep last 6 turns for context
|
|
messages.append({
|
|
"role": turn.role,
|
|
"content": turn.resolved_question or turn.content
|
|
})
|
|
return History(messages=messages)
|
|
|
|
|
|
# =============================================================================
|
|
# CONFIGURATION
|
|
# =============================================================================
|
|
|
|
class SessionConfig(BaseModel):
|
|
"""Session management configuration."""
|
|
|
|
# Session TTL (seconds) - default 30 minutes
|
|
session_ttl: int = 1800
|
|
|
|
# Cleanup interval (seconds) - how often to purge expired sessions
|
|
cleanup_interval: int = 300 # 5 minutes
|
|
|
|
# Maximum sessions to keep in memory (LRU eviction if exceeded)
|
|
max_sessions: int = 10000
|
|
|
|
# Redis/Valkey settings (optional)
|
|
redis_enabled: bool = False
|
|
redis_url: str = "redis://localhost:6379"
|
|
redis_prefix: str = "heritage:session:"
|
|
redis_db: int = 1 # Use different DB than cache to avoid conflicts
|
|
|
|
|
|
# =============================================================================
|
|
# SESSION STORAGE
|
|
# =============================================================================
|
|
|
|
class SessionEntry:
|
|
"""A session with metadata."""
|
|
|
|
def __init__(self, state: ConversationState, session_id: str):
|
|
self.session_id = session_id
|
|
self.state = state
|
|
self.created_at = datetime.now(timezone.utc)
|
|
self.last_accessed = datetime.now(timezone.utc)
|
|
self.access_count = 0
|
|
|
|
def touch(self) -> None:
|
|
"""Update last accessed time."""
|
|
self.last_accessed = datetime.now(timezone.utc)
|
|
self.access_count += 1
|
|
|
|
def is_expired(self, ttl_seconds: int) -> bool:
|
|
"""Check if session has expired."""
|
|
age = (datetime.now(timezone.utc) - self.last_accessed).total_seconds()
|
|
return age > ttl_seconds
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Serialize for Redis storage."""
|
|
return {
|
|
"session_id": self.session_id,
|
|
"state": self.state.model_dump(),
|
|
"created_at": self.created_at.isoformat(),
|
|
"last_accessed": self.last_accessed.isoformat(),
|
|
"access_count": self.access_count,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> "SessionEntry":
|
|
"""Deserialize from Redis storage."""
|
|
state = ConversationState(**data["state"])
|
|
entry = cls(state, data["session_id"])
|
|
entry.created_at = datetime.fromisoformat(data["created_at"])
|
|
entry.last_accessed = datetime.fromisoformat(data["last_accessed"])
|
|
entry.access_count = data.get("access_count", 0)
|
|
return entry
|
|
|
|
|
|
class SessionManager:
|
|
"""Manages conversation sessions across requests.
|
|
|
|
Features:
|
|
- In-memory storage with LRU eviction
|
|
- Optional Redis/Valkey backend for persistence
|
|
- TTL-based automatic cleanup
|
|
- Thread-safe async operations
|
|
"""
|
|
|
|
def __init__(self, config: Optional[SessionConfig] = None):
|
|
self.config = config or SessionConfig()
|
|
|
|
# In-memory storage
|
|
self._sessions: dict[str, SessionEntry] = {}
|
|
self._access_order: list[str] = [] # For LRU eviction
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Redis client (lazy initialized)
|
|
self._redis: Any = None
|
|
self._redis_available = False
|
|
|
|
# Cleanup task
|
|
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
|
self._running = False
|
|
|
|
# Stats
|
|
self._stats = {
|
|
"sessions_created": 0,
|
|
"sessions_expired": 0,
|
|
"sessions_evicted": 0,
|
|
"cache_hits": 0,
|
|
"cache_misses": 0,
|
|
}
|
|
|
|
async def start(self) -> None:
|
|
"""Start the session manager and cleanup task."""
|
|
if self._running:
|
|
return
|
|
|
|
self._running = True
|
|
|
|
# Try to connect to Redis if enabled
|
|
if self.config.redis_enabled:
|
|
await self._connect_redis()
|
|
|
|
# Start cleanup task
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
logger.info(
|
|
f"SessionManager started (ttl={self.config.session_ttl}s, "
|
|
f"max={self.config.max_sessions}, redis={self._redis_available})"
|
|
)
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the session manager and cleanup."""
|
|
self._running = False
|
|
|
|
if self._cleanup_task:
|
|
self._cleanup_task.cancel()
|
|
try:
|
|
await self._cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
if self._redis:
|
|
await self._redis.close()
|
|
self._redis = None
|
|
|
|
logger.info(f"SessionManager stopped (stats: {self._stats})")
|
|
|
|
async def _connect_redis(self) -> None:
|
|
"""Try to connect to Redis/Valkey."""
|
|
try:
|
|
import redis.asyncio as redis_async
|
|
|
|
self._redis = await redis_async.from_url(
|
|
self.config.redis_url,
|
|
db=self.config.redis_db,
|
|
decode_responses=True,
|
|
)
|
|
await self._redis.ping()
|
|
self._redis_available = True
|
|
logger.info(f"SessionManager connected to Redis: {self.config.redis_url}")
|
|
except Exception as e:
|
|
logger.warning(f"Redis not available, using memory-only: {e}")
|
|
self._redis = None
|
|
self._redis_available = False
|
|
|
|
def generate_session_id(self) -> str:
|
|
"""Generate a new unique session ID."""
|
|
return str(uuid.uuid4())
|
|
|
|
async def get(self, session_id: str) -> Optional[ConversationState]:
|
|
"""Get session state by ID.
|
|
|
|
Args:
|
|
session_id: The session identifier
|
|
|
|
Returns:
|
|
ConversationState if found and not expired, None otherwise
|
|
"""
|
|
# Try memory first
|
|
async with self._lock:
|
|
if session_id in self._sessions:
|
|
entry = self._sessions[session_id]
|
|
if not entry.is_expired(self.config.session_ttl):
|
|
entry.touch()
|
|
self._update_access_order(session_id)
|
|
self._stats["cache_hits"] += 1
|
|
return entry.state
|
|
else:
|
|
# Expired, remove it
|
|
del self._sessions[session_id]
|
|
if session_id in self._access_order:
|
|
self._access_order.remove(session_id)
|
|
self._stats["sessions_expired"] += 1
|
|
|
|
# Try Redis if available
|
|
if self._redis_available and self._redis:
|
|
try:
|
|
key = f"{self.config.redis_prefix}{session_id}"
|
|
data = await self._redis.get(key)
|
|
if data:
|
|
entry = SessionEntry.from_dict(json.loads(data))
|
|
entry.touch()
|
|
|
|
# Store in memory for faster access
|
|
async with self._lock:
|
|
self._sessions[session_id] = entry
|
|
self._access_order.append(session_id)
|
|
await self._evict_if_needed()
|
|
|
|
# Update Redis TTL
|
|
await self._redis.expire(key, self.config.session_ttl)
|
|
|
|
self._stats["cache_hits"] += 1
|
|
return entry.state
|
|
except Exception as e:
|
|
logger.warning(f"Redis get failed: {e}")
|
|
|
|
self._stats["cache_misses"] += 1
|
|
return None
|
|
|
|
async def get_or_create(self, session_id: Optional[str] = None) -> tuple[str, ConversationState]:
|
|
"""Get existing session or create a new one.
|
|
|
|
Args:
|
|
session_id: Optional session ID. If None, generates a new one.
|
|
|
|
Returns:
|
|
Tuple of (session_id, ConversationState)
|
|
"""
|
|
if session_id:
|
|
state = await self.get(session_id)
|
|
if state:
|
|
return session_id, state
|
|
|
|
# Create new session
|
|
new_id = session_id or self.generate_session_id()
|
|
state = ConversationState()
|
|
|
|
entry = SessionEntry(state, new_id)
|
|
|
|
async with self._lock:
|
|
self._sessions[new_id] = entry
|
|
self._access_order.append(new_id)
|
|
await self._evict_if_needed()
|
|
|
|
# Store in Redis if available
|
|
if self._redis_available and self._redis:
|
|
try:
|
|
key = f"{self.config.redis_prefix}{new_id}"
|
|
await self._redis.setex(
|
|
key,
|
|
self.config.session_ttl,
|
|
json.dumps(entry.to_dict()),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Redis set failed: {e}")
|
|
|
|
self._stats["sessions_created"] += 1
|
|
logger.debug(f"Created new session: {new_id}")
|
|
|
|
return new_id, state
|
|
|
|
async def update(self, session_id: str, state: ConversationState) -> None:
|
|
"""Update session state.
|
|
|
|
Args:
|
|
session_id: The session identifier
|
|
state: Updated conversation state
|
|
"""
|
|
async with self._lock:
|
|
if session_id in self._sessions:
|
|
self._sessions[session_id].state = state
|
|
self._sessions[session_id].touch()
|
|
self._update_access_order(session_id)
|
|
else:
|
|
# Create new entry
|
|
entry = SessionEntry(state, session_id)
|
|
self._sessions[session_id] = entry
|
|
self._access_order.append(session_id)
|
|
await self._evict_if_needed()
|
|
|
|
# Update Redis if available
|
|
if self._redis_available and self._redis:
|
|
try:
|
|
key = f"{self.config.redis_prefix}{session_id}"
|
|
entry = self._sessions.get(session_id)
|
|
if entry:
|
|
await self._redis.setex(
|
|
key,
|
|
self.config.session_ttl,
|
|
json.dumps(entry.to_dict()),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Redis update failed: {e}")
|
|
|
|
async def delete(self, session_id: str) -> bool:
|
|
"""Delete a session.
|
|
|
|
Args:
|
|
session_id: The session identifier
|
|
|
|
Returns:
|
|
True if session was deleted, False if not found
|
|
"""
|
|
deleted = False
|
|
|
|
async with self._lock:
|
|
if session_id in self._sessions:
|
|
del self._sessions[session_id]
|
|
if session_id in self._access_order:
|
|
self._access_order.remove(session_id)
|
|
deleted = True
|
|
|
|
if self._redis_available and self._redis:
|
|
try:
|
|
key = f"{self.config.redis_prefix}{session_id}"
|
|
await self._redis.delete(key)
|
|
deleted = True
|
|
except Exception as e:
|
|
logger.warning(f"Redis delete failed: {e}")
|
|
|
|
return deleted
|
|
|
|
def _update_access_order(self, session_id: str) -> None:
|
|
"""Move session to end of access order (most recently used)."""
|
|
if session_id in self._access_order:
|
|
self._access_order.remove(session_id)
|
|
self._access_order.append(session_id)
|
|
|
|
async def _evict_if_needed(self) -> None:
|
|
"""Evict oldest sessions if over limit (LRU)."""
|
|
while len(self._sessions) > self.config.max_sessions:
|
|
if not self._access_order:
|
|
break
|
|
oldest_id = self._access_order.pop(0)
|
|
if oldest_id in self._sessions:
|
|
del self._sessions[oldest_id]
|
|
self._stats["sessions_evicted"] += 1
|
|
logger.debug(f"Evicted session (LRU): {oldest_id}")
|
|
|
|
async def _cleanup_loop(self) -> None:
|
|
"""Background task to cleanup expired sessions."""
|
|
while self._running:
|
|
try:
|
|
await asyncio.sleep(self.config.cleanup_interval)
|
|
await self._cleanup_expired()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.exception(f"Cleanup error: {e}")
|
|
|
|
async def _cleanup_expired(self) -> None:
|
|
"""Remove expired sessions from memory."""
|
|
expired_ids = []
|
|
|
|
async with self._lock:
|
|
for session_id, entry in list(self._sessions.items()):
|
|
if entry.is_expired(self.config.session_ttl):
|
|
expired_ids.append(session_id)
|
|
|
|
for session_id in expired_ids:
|
|
del self._sessions[session_id]
|
|
if session_id in self._access_order:
|
|
self._access_order.remove(session_id)
|
|
self._stats["sessions_expired"] += 1
|
|
|
|
if expired_ids:
|
|
logger.debug(f"Cleaned up {len(expired_ids)} expired sessions")
|
|
|
|
async def get_stats(self) -> dict[str, Any]:
|
|
"""Get session manager statistics."""
|
|
async with self._lock:
|
|
active_count = len(self._sessions)
|
|
|
|
return {
|
|
**self._stats,
|
|
"active_sessions": active_count,
|
|
"redis_available": self._redis_available,
|
|
}
|
|
|
|
async def add_turn_to_session(
|
|
self,
|
|
session_id: str,
|
|
question: str,
|
|
answer: str,
|
|
resolved_question: Optional[str] = None,
|
|
template_id: Optional[str] = None,
|
|
slots: Optional[dict[str, str]] = None,
|
|
results: Optional[list[dict[str, Any]]] = None,
|
|
) -> None:
|
|
"""Add a conversation turn to a session.
|
|
|
|
This is a convenience method that:
|
|
1. Gets the session state
|
|
2. Creates a user turn with the question
|
|
3. Creates an assistant turn with the answer
|
|
4. Updates the session
|
|
|
|
Args:
|
|
session_id: The session identifier
|
|
question: User's question
|
|
answer: Assistant's answer
|
|
resolved_question: The resolved form of the question (if different)
|
|
template_id: Template ID used for this query
|
|
slots: Extracted slot values
|
|
results: Query results for context
|
|
"""
|
|
state = await self.get(session_id)
|
|
if not state:
|
|
_, state = await self.get_or_create(session_id)
|
|
|
|
# Add user turn
|
|
user_turn = ConversationTurn(
|
|
role="user",
|
|
content=question,
|
|
resolved_question=resolved_question,
|
|
template_id=template_id,
|
|
slots=slots or {},
|
|
results=results or [],
|
|
)
|
|
state.add_turn(user_turn)
|
|
|
|
# Add assistant turn
|
|
assistant_turn = ConversationTurn(
|
|
role="assistant",
|
|
content=answer,
|
|
)
|
|
state.add_turn(assistant_turn)
|
|
|
|
await self.update(session_id, state)
|
|
|
|
|
|
# =============================================================================
|
|
# SINGLETON INSTANCE
|
|
# =============================================================================
|
|
|
|
_session_manager: Optional[SessionManager] = None
|
|
_session_manager_lock = asyncio.Lock()
|
|
|
|
|
|
async def get_session_manager(config: Optional[SessionConfig] = None) -> SessionManager:
|
|
"""Get or create the singleton session manager.
|
|
|
|
Args:
|
|
config: Optional configuration override
|
|
|
|
Returns:
|
|
SessionManager instance
|
|
"""
|
|
global _session_manager
|
|
|
|
async with _session_manager_lock:
|
|
if _session_manager is None:
|
|
_session_manager = SessionManager(config)
|
|
await _session_manager.start()
|
|
return _session_manager
|
|
|
|
|
|
async def shutdown_session_manager() -> None:
|
|
"""Shutdown the session manager (for graceful shutdown)."""
|
|
global _session_manager
|
|
|
|
async with _session_manager_lock:
|
|
if _session_manager:
|
|
await _session_manager.stop()
|
|
_session_manager = None
|