glam/backend/postgres/main.py
2025-12-06 19:50:04 +01:00

372 lines
11 KiB
Python

"""
PostgreSQL REST API for Heritage Custodian Data
FastAPI backend providing SQL query interface for bronhouder.nl
Endpoints:
- GET / - Health check and statistics
- POST /query - Execute SQL query
- GET /tables - List all tables with metadata
- GET /schema/:table - Get table schema
"""
import os
import json
from datetime import datetime
from typing import Optional, List, Dict, Any
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import asyncpg
# ============================================================================
# Configuration
# ============================================================================
class Settings(BaseModel):
"""PostgreSQL server settings"""
host: str = os.getenv("POSTGRES_HOST", "localhost")
port: int = int(os.getenv("POSTGRES_PORT", "5432"))
database: str = os.getenv("POSTGRES_DB", "glam")
user: str = os.getenv("POSTGRES_USER", "glam_api")
password: str = os.getenv("POSTGRES_PASSWORD", "glam_secret_2025")
# Server settings
api_host: str = os.getenv("API_HOST", "0.0.0.0")
api_port: int = int(os.getenv("API_PORT", "8001"))
settings = Settings()
# ============================================================================
# Pydantic Models
# ============================================================================
class QueryRequest(BaseModel):
"""SQL query request"""
sql: str = Field(..., description="SQL query to execute")
params: Optional[List[Any]] = Field(None, description="Query parameters")
class QueryResponse(BaseModel):
"""SQL query response"""
columns: List[str]
rows: List[List[Any]]
row_count: int
execution_time_ms: float
class TableInfo(BaseModel):
"""Table metadata"""
name: str
schema_name: str
row_count: int
column_count: int
size_bytes: Optional[int] = None
class ColumnInfo(BaseModel):
"""Column metadata"""
name: str
data_type: str
is_nullable: bool
default_value: Optional[str] = None
description: Optional[str] = None
class StatusResponse(BaseModel):
"""Server status response"""
status: str
database: str
tables: int
total_rows: int
uptime_seconds: float
postgres_version: str
# ============================================================================
# Global State
# ============================================================================
_pool: Optional[asyncpg.Pool] = None
_start_time: datetime = datetime.now()
async def get_pool() -> asyncpg.Pool:
"""Get or create connection pool"""
global _pool
if _pool is None:
_pool = await asyncpg.create_pool(
host=settings.host,
port=settings.port,
database=settings.database,
user=settings.user,
password=settings.password,
min_size=2,
max_size=10,
)
return _pool
# ============================================================================
# FastAPI App
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler"""
# Startup: Initialize connection pool
await get_pool()
yield
# Shutdown: Close pool
global _pool
if _pool:
await _pool.close()
_pool = None
app = FastAPI(
title="PostgreSQL Heritage API",
description="REST API for heritage institution SQL queries",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================================
# Helper Functions
# ============================================================================
def serialize_value(val: Any) -> Any:
"""Convert PostgreSQL values to JSON-serializable format"""
if val is None:
return None
elif isinstance(val, datetime):
return val.isoformat()
elif isinstance(val, (dict, list)):
return val
elif isinstance(val, bytes):
return val.decode('utf-8', errors='replace')
else:
return val
# ============================================================================
# API Endpoints
# ============================================================================
@app.get("/", response_model=StatusResponse)
async def get_status() -> StatusResponse:
"""Get server status and statistics"""
pool = await get_pool()
async with pool.acquire() as conn:
# Get PostgreSQL version
version = await conn.fetchval("SELECT version()")
# Get table count
tables = await conn.fetchval("""
SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'
""")
# Get total row count (approximate)
total_rows = await conn.fetchval("""
SELECT COALESCE(SUM(n_tup_ins - n_tup_del), 0)::bigint
FROM pg_stat_user_tables
""")
uptime = (datetime.now() - _start_time).total_seconds()
return StatusResponse(
status="healthy",
database=settings.database,
tables=tables or 0,
total_rows=total_rows or 0,
uptime_seconds=uptime,
postgres_version=version.split(',')[0] if version else "unknown",
)
@app.post("/query", response_model=QueryResponse)
async def execute_query(request: QueryRequest) -> QueryResponse:
"""Execute a SQL query (read-only)"""
pool = await get_pool()
# Security: Only allow SELECT queries for now
sql_upper = request.sql.strip().upper()
if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"):
raise HTTPException(
status_code=403,
detail="Only SELECT queries are allowed. Use WITH...SELECT for CTEs."
)
start_time = datetime.now()
try:
async with pool.acquire() as conn:
if request.params:
result = await conn.fetch(request.sql, *request.params)
else:
result = await conn.fetch(request.sql)
if result:
columns = list(result[0].keys())
rows = [[serialize_value(row[col]) for col in columns] for row in result]
else:
columns = []
rows = []
execution_time = (datetime.now() - start_time).total_seconds() * 1000
return QueryResponse(
columns=columns,
rows=rows,
row_count=len(rows),
execution_time_ms=round(execution_time, 2),
)
except asyncpg.PostgresError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/tables", response_model=List[TableInfo])
async def list_tables() -> List[TableInfo]:
"""List all tables with metadata"""
pool = await get_pool()
async with pool.acquire() as conn:
tables = await conn.fetch("""
SELECT
t.table_name,
t.table_schema,
(SELECT COUNT(*) FROM information_schema.columns c
WHERE c.table_name = t.table_name AND c.table_schema = t.table_schema) as column_count,
COALESCE(s.n_tup_ins - s.n_tup_del, 0) as row_count,
pg_total_relation_size(quote_ident(t.table_schema) || '.' || quote_ident(t.table_name)) as size_bytes
FROM information_schema.tables t
LEFT JOIN pg_stat_user_tables s
ON s.schemaname = t.table_schema AND s.relname = t.table_name
WHERE t.table_schema = 'public'
AND t.table_type = 'BASE TABLE'
ORDER BY t.table_name
""")
return [
TableInfo(
name=row['table_name'],
schema_name=row['table_schema'],
column_count=row['column_count'],
row_count=row['row_count'] or 0,
size_bytes=row['size_bytes'],
)
for row in tables
]
@app.get("/schema/{table_name}", response_model=List[ColumnInfo])
async def get_table_schema(table_name: str) -> List[ColumnInfo]:
"""Get schema for a specific table"""
pool = await get_pool()
async with pool.acquire() as conn:
# Check table exists
exists = await conn.fetchval("""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = $1
)
""", table_name)
if not exists:
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found")
columns = await conn.fetch("""
SELECT
column_name,
data_type,
is_nullable,
column_default,
col_description(
(quote_ident(table_schema) || '.' || quote_ident(table_name))::regclass,
ordinal_position
) as description
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position
""", table_name)
return [
ColumnInfo(
name=col['column_name'],
data_type=col['data_type'],
is_nullable=col['is_nullable'] == 'YES',
default_value=col['column_default'],
description=col['description'],
)
for col in columns
]
@app.get("/stats")
async def get_database_stats() -> Dict[str, Any]:
"""Get detailed database statistics"""
pool = await get_pool()
async with pool.acquire() as conn:
# Database size
db_size = await conn.fetchval("""
SELECT pg_size_pretty(pg_database_size($1))
""", settings.database)
# Table sizes
table_sizes = await conn.fetch("""
SELECT
relname as table_name,
pg_size_pretty(pg_total_relation_size(relid)) as total_size,
n_live_tup as row_count
FROM pg_stat_user_tables
ORDER BY pg_total_relation_size(relid) DESC
LIMIT 10
""")
return {
"database": settings.database,
"size": db_size,
"largest_tables": [
{
"name": t['table_name'],
"size": t['total_size'],
"rows": t['row_count']
}
for t in table_sizes
]
}
# ============================================================================
# Main Entry Point
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=settings.api_host,
port=settings.api_port,
reload=True,
)