372 lines
11 KiB
Python
372 lines
11 KiB
Python
"""
|
|
PostgreSQL REST API for Heritage Custodian Data
|
|
FastAPI backend providing SQL query interface for bronhouder.nl
|
|
|
|
Endpoints:
|
|
- GET / - Health check and statistics
|
|
- POST /query - Execute SQL query
|
|
- GET /tables - List all tables with metadata
|
|
- GET /schema/:table - Get table schema
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
from datetime import datetime
|
|
from typing import Optional, List, Dict, Any
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
import asyncpg
|
|
|
|
|
|
# ============================================================================
|
|
# Configuration
|
|
# ============================================================================
|
|
|
|
class Settings(BaseModel):
|
|
"""PostgreSQL server settings"""
|
|
host: str = os.getenv("POSTGRES_HOST", "localhost")
|
|
port: int = int(os.getenv("POSTGRES_PORT", "5432"))
|
|
database: str = os.getenv("POSTGRES_DB", "glam")
|
|
user: str = os.getenv("POSTGRES_USER", "glam_api")
|
|
password: str = os.getenv("POSTGRES_PASSWORD", "glam_secret_2025")
|
|
|
|
# Server settings
|
|
api_host: str = os.getenv("API_HOST", "0.0.0.0")
|
|
api_port: int = int(os.getenv("API_PORT", "8001"))
|
|
|
|
|
|
settings = Settings()
|
|
|
|
|
|
# ============================================================================
|
|
# Pydantic Models
|
|
# ============================================================================
|
|
|
|
class QueryRequest(BaseModel):
|
|
"""SQL query request"""
|
|
sql: str = Field(..., description="SQL query to execute")
|
|
params: Optional[List[Any]] = Field(None, description="Query parameters")
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
"""SQL query response"""
|
|
columns: List[str]
|
|
rows: List[List[Any]]
|
|
row_count: int
|
|
execution_time_ms: float
|
|
|
|
|
|
class TableInfo(BaseModel):
|
|
"""Table metadata"""
|
|
name: str
|
|
schema_name: str
|
|
row_count: int
|
|
column_count: int
|
|
size_bytes: Optional[int] = None
|
|
|
|
|
|
class ColumnInfo(BaseModel):
|
|
"""Column metadata"""
|
|
name: str
|
|
data_type: str
|
|
is_nullable: bool
|
|
default_value: Optional[str] = None
|
|
description: Optional[str] = None
|
|
|
|
|
|
class StatusResponse(BaseModel):
|
|
"""Server status response"""
|
|
status: str
|
|
database: str
|
|
tables: int
|
|
total_rows: int
|
|
uptime_seconds: float
|
|
postgres_version: str
|
|
|
|
|
|
# ============================================================================
|
|
# Global State
|
|
# ============================================================================
|
|
|
|
_pool: Optional[asyncpg.Pool] = None
|
|
_start_time: datetime = datetime.now()
|
|
|
|
|
|
async def get_pool() -> asyncpg.Pool:
|
|
"""Get or create connection pool"""
|
|
global _pool
|
|
|
|
if _pool is None:
|
|
_pool = await asyncpg.create_pool(
|
|
host=settings.host,
|
|
port=settings.port,
|
|
database=settings.database,
|
|
user=settings.user,
|
|
password=settings.password,
|
|
min_size=2,
|
|
max_size=10,
|
|
)
|
|
|
|
return _pool
|
|
|
|
|
|
# ============================================================================
|
|
# FastAPI App
|
|
# ============================================================================
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Application lifespan handler"""
|
|
# Startup: Initialize connection pool
|
|
await get_pool()
|
|
yield
|
|
# Shutdown: Close pool
|
|
global _pool
|
|
if _pool:
|
|
await _pool.close()
|
|
_pool = None
|
|
|
|
|
|
app = FastAPI(
|
|
title="PostgreSQL Heritage API",
|
|
description="REST API for heritage institution SQL queries",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Configure for production
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Helper Functions
|
|
# ============================================================================
|
|
|
|
def serialize_value(val: Any) -> Any:
|
|
"""Convert PostgreSQL values to JSON-serializable format"""
|
|
if val is None:
|
|
return None
|
|
elif isinstance(val, datetime):
|
|
return val.isoformat()
|
|
elif isinstance(val, (dict, list)):
|
|
return val
|
|
elif isinstance(val, bytes):
|
|
return val.decode('utf-8', errors='replace')
|
|
else:
|
|
return val
|
|
|
|
|
|
# ============================================================================
|
|
# API Endpoints
|
|
# ============================================================================
|
|
|
|
@app.get("/", response_model=StatusResponse)
|
|
async def get_status() -> StatusResponse:
|
|
"""Get server status and statistics"""
|
|
pool = await get_pool()
|
|
|
|
async with pool.acquire() as conn:
|
|
# Get PostgreSQL version
|
|
version = await conn.fetchval("SELECT version()")
|
|
|
|
# Get table count
|
|
tables = await conn.fetchval("""
|
|
SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'
|
|
""")
|
|
|
|
# Get total row count (approximate)
|
|
total_rows = await conn.fetchval("""
|
|
SELECT COALESCE(SUM(n_tup_ins - n_tup_del), 0)::bigint
|
|
FROM pg_stat_user_tables
|
|
""")
|
|
|
|
uptime = (datetime.now() - _start_time).total_seconds()
|
|
|
|
return StatusResponse(
|
|
status="healthy",
|
|
database=settings.database,
|
|
tables=tables or 0,
|
|
total_rows=total_rows or 0,
|
|
uptime_seconds=uptime,
|
|
postgres_version=version.split(',')[0] if version else "unknown",
|
|
)
|
|
|
|
|
|
@app.post("/query", response_model=QueryResponse)
|
|
async def execute_query(request: QueryRequest) -> QueryResponse:
|
|
"""Execute a SQL query (read-only)"""
|
|
pool = await get_pool()
|
|
|
|
# Security: Only allow SELECT queries for now
|
|
sql_upper = request.sql.strip().upper()
|
|
if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Only SELECT queries are allowed. Use WITH...SELECT for CTEs."
|
|
)
|
|
|
|
start_time = datetime.now()
|
|
|
|
try:
|
|
async with pool.acquire() as conn:
|
|
if request.params:
|
|
result = await conn.fetch(request.sql, *request.params)
|
|
else:
|
|
result = await conn.fetch(request.sql)
|
|
|
|
if result:
|
|
columns = list(result[0].keys())
|
|
rows = [[serialize_value(row[col]) for col in columns] for row in result]
|
|
else:
|
|
columns = []
|
|
rows = []
|
|
|
|
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
|
|
|
return QueryResponse(
|
|
columns=columns,
|
|
rows=rows,
|
|
row_count=len(rows),
|
|
execution_time_ms=round(execution_time, 2),
|
|
)
|
|
|
|
except asyncpg.PostgresError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.get("/tables", response_model=List[TableInfo])
|
|
async def list_tables() -> List[TableInfo]:
|
|
"""List all tables with metadata"""
|
|
pool = await get_pool()
|
|
|
|
async with pool.acquire() as conn:
|
|
tables = await conn.fetch("""
|
|
SELECT
|
|
t.table_name,
|
|
t.table_schema,
|
|
(SELECT COUNT(*) FROM information_schema.columns c
|
|
WHERE c.table_name = t.table_name AND c.table_schema = t.table_schema) as column_count,
|
|
COALESCE(s.n_tup_ins - s.n_tup_del, 0) as row_count,
|
|
pg_total_relation_size(quote_ident(t.table_schema) || '.' || quote_ident(t.table_name)) as size_bytes
|
|
FROM information_schema.tables t
|
|
LEFT JOIN pg_stat_user_tables s
|
|
ON s.schemaname = t.table_schema AND s.relname = t.table_name
|
|
WHERE t.table_schema = 'public'
|
|
AND t.table_type = 'BASE TABLE'
|
|
ORDER BY t.table_name
|
|
""")
|
|
|
|
return [
|
|
TableInfo(
|
|
name=row['table_name'],
|
|
schema_name=row['table_schema'],
|
|
column_count=row['column_count'],
|
|
row_count=row['row_count'] or 0,
|
|
size_bytes=row['size_bytes'],
|
|
)
|
|
for row in tables
|
|
]
|
|
|
|
|
|
@app.get("/schema/{table_name}", response_model=List[ColumnInfo])
|
|
async def get_table_schema(table_name: str) -> List[ColumnInfo]:
|
|
"""Get schema for a specific table"""
|
|
pool = await get_pool()
|
|
|
|
async with pool.acquire() as conn:
|
|
# Check table exists
|
|
exists = await conn.fetchval("""
|
|
SELECT EXISTS (
|
|
SELECT 1 FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = $1
|
|
)
|
|
""", table_name)
|
|
|
|
if not exists:
|
|
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found")
|
|
|
|
columns = await conn.fetch("""
|
|
SELECT
|
|
column_name,
|
|
data_type,
|
|
is_nullable,
|
|
column_default,
|
|
col_description(
|
|
(quote_ident(table_schema) || '.' || quote_ident(table_name))::regclass,
|
|
ordinal_position
|
|
) as description
|
|
FROM information_schema.columns
|
|
WHERE table_schema = 'public' AND table_name = $1
|
|
ORDER BY ordinal_position
|
|
""", table_name)
|
|
|
|
return [
|
|
ColumnInfo(
|
|
name=col['column_name'],
|
|
data_type=col['data_type'],
|
|
is_nullable=col['is_nullable'] == 'YES',
|
|
default_value=col['column_default'],
|
|
description=col['description'],
|
|
)
|
|
for col in columns
|
|
]
|
|
|
|
|
|
@app.get("/stats")
|
|
async def get_database_stats() -> Dict[str, Any]:
|
|
"""Get detailed database statistics"""
|
|
pool = await get_pool()
|
|
|
|
async with pool.acquire() as conn:
|
|
# Database size
|
|
db_size = await conn.fetchval("""
|
|
SELECT pg_size_pretty(pg_database_size($1))
|
|
""", settings.database)
|
|
|
|
# Table sizes
|
|
table_sizes = await conn.fetch("""
|
|
SELECT
|
|
relname as table_name,
|
|
pg_size_pretty(pg_total_relation_size(relid)) as total_size,
|
|
n_live_tup as row_count
|
|
FROM pg_stat_user_tables
|
|
ORDER BY pg_total_relation_size(relid) DESC
|
|
LIMIT 10
|
|
""")
|
|
|
|
return {
|
|
"database": settings.database,
|
|
"size": db_size,
|
|
"largest_tables": [
|
|
{
|
|
"name": t['table_name'],
|
|
"size": t['total_size'],
|
|
"rows": t['row_count']
|
|
}
|
|
for t in table_sizes
|
|
]
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# Main Entry Point
|
|
# ============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=settings.api_host,
|
|
port=settings.api_port,
|
|
reload=True,
|
|
)
|