glam/backend/typedb/main.py
kempersc 35066eb5eb fix(typedb): improve concept serialization for frontend compatibility
- Add 'type' alias alongside '_type' for frontend compatibility
- Handle both bytes and string IID formats from driver
- Add 'id' alias alongside '_iid' for consistency
2025-12-08 14:58:35 +01:00

493 lines
16 KiB
Python

"""
TypeDB REST API for Heritage Custodian Data
FastAPI backend providing TypeQL query interface for bronhouder.nl
Endpoints:
- GET / - Health check and statistics
- POST /query - Execute TypeQL query
- GET /databases - List all databases
- GET /schema - Get database schema
"""
import os
from datetime import datetime
from typing import Optional, List, Dict, Any
from contextlib import asynccontextmanager
import asyncio
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typedb.driver import TypeDB, SessionType, TransactionType
# ============================================================================
# Configuration
# ============================================================================
class Settings(BaseModel):
"""TypeDB server settings"""
host: str = os.getenv("TYPEDB_HOST", "localhost")
port: int = int(os.getenv("TYPEDB_PORT", "1729"))
database: str = os.getenv("TYPEDB_DATABASE", "glam")
# Server settings
api_host: str = os.getenv("API_HOST", "0.0.0.0")
api_port: int = int(os.getenv("API_PORT", "8002"))
settings = Settings()
# ============================================================================
# Pydantic Models
# ============================================================================
class QueryRequest(BaseModel):
"""TypeQL query request"""
query: str = Field(..., description="TypeQL query to execute")
database: Optional[str] = Field(None, description="Database name (defaults to 'glam')")
class QueryResponse(BaseModel):
"""TypeQL query response"""
results: List[Dict[str, Any]]
result_count: int
execution_time_ms: float
query_type: str
class DatabaseInfo(BaseModel):
"""Database metadata"""
name: str
class StatusResponse(BaseModel):
"""Server status response"""
status: str
databases: List[str]
default_database: str
uptime_seconds: float
typedb_version: str
# Fields for frontend compatibility
connected: bool = False
database: str = ""
version: str = ""
# ============================================================================
# Global State
# ============================================================================
_driver = None
_executor = ThreadPoolExecutor(max_workers=4)
_start_time: datetime = datetime.now()
def get_driver():
"""Get or create TypeDB driver"""
global _driver
if _driver is None:
_driver = TypeDB.core_driver(f"{settings.host}:{settings.port}")
return _driver
# ============================================================================
# Helper Functions
# ============================================================================
def serialize_concept(concept) -> Dict[str, Any]:
"""Convert TypeDB concept to JSON-serializable dict"""
result = {}
if hasattr(concept, 'get_type'):
concept_type = concept.get_type()
if hasattr(concept_type, 'get_label'):
label = concept_type.get_label()
result['_type'] = label.name if hasattr(label, 'name') else str(label)
result['type'] = result['_type'] # Alias for frontend compatibility
if hasattr(concept, 'get_iid'):
iid = concept.get_iid()
# Handle both old (bytes with .hex()) and new (string) driver formats
if iid is None:
result['_iid'] = None
result['id'] = None
elif hasattr(iid, 'hex'):
result['_iid'] = iid.hex()
result['id'] = iid.hex()
else:
# iid is already a string
result['_iid'] = str(iid)
result['id'] = str(iid)
if hasattr(concept, 'get_value'):
result['value'] = concept.get_value()
return result
def serialize_concept_map(concept_map) -> Dict[str, Any]:
"""Convert TypeDB concept map to JSON-serializable dict"""
result = {}
for var in concept_map.variables():
concept = concept_map.get(var)
if concept:
if hasattr(concept, 'get_value'):
# It's an attribute value - include type info for frontend
attr_type = None
if hasattr(concept, 'get_type'):
concept_type = concept.get_type()
if hasattr(concept_type, 'get_label'):
label = concept_type.get_label()
attr_type = label.name if hasattr(label, 'name') else str(label)
result[var] = {
'value': concept.get_value(),
'type': attr_type,
}
elif hasattr(concept, 'get_iid'):
# It's an entity or relation
result[var] = serialize_concept(concept)
else:
result[var] = str(concept)
return result
def execute_read_query(database: str, query: str) -> tuple:
"""Execute a read query in TypeDB (blocking)"""
driver = get_driver()
results = []
query_type = "unknown"
# Determine query type
query_stripped = query.strip().lower()
if query_stripped.startswith("match"):
query_type = "match"
elif query_stripped.startswith("define"):
query_type = "define"
elif query_stripped.startswith("insert"):
query_type = "insert"
with driver.session(database, SessionType.DATA) as session:
with session.transaction(TransactionType.READ) as tx:
# Execute match...get query
answer = tx.query.get(query)
for concept_map in answer:
results.append(serialize_concept_map(concept_map))
return results, query_type
def get_databases() -> List[str]:
"""Get list of databases"""
driver = get_driver()
return [db.name for db in driver.databases.all()]
def get_schema_types(database: str) -> Dict[str, Any]:
"""Get schema types from database"""
driver = get_driver()
schema = {"entity_types": [], "relation_types": [], "attribute_types": []}
try:
with driver.session(database, SessionType.SCHEMA) as session:
with session.transaction(TransactionType.READ) as tx:
# Get entity types
root_entity = tx.concepts.get_root_entity_type()
for entity_type in root_entity.get_subtypes(tx):
label = entity_type.get_label()
name = label.name if hasattr(label, 'name') else str(label)
if name != "entity":
schema["entity_types"].append(name)
# Get relation types
root_relation = tx.concepts.get_root_relation_type()
for rel_type in root_relation.get_subtypes(tx):
label = rel_type.get_label()
name = label.name if hasattr(label, 'name') else str(label)
if name != "relation":
schema["relation_types"].append(name)
# Get attribute types
root_attr = tx.concepts.get_root_attribute_type()
for attr_type in root_attr.get_subtypes(tx):
label = attr_type.get_label()
name = label.name if hasattr(label, 'name') else str(label)
if name != "attribute":
schema["attribute_types"].append(name)
except Exception as e:
schema["error"] = str(e)
return schema
# ============================================================================
# FastAPI App
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler"""
# Startup: Initialize driver
get_driver()
yield
# Shutdown: Close driver
global _driver
if _driver:
_driver.close()
_driver = None
_executor.shutdown(wait=True)
app = FastAPI(
title="TypeDB Heritage API",
description="REST API for heritage institution TypeQL queries",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================================
# API Endpoints
# ============================================================================
@app.get("/", response_model=StatusResponse)
async def get_root_status() -> StatusResponse:
"""Get server status and statistics (root endpoint)"""
return await get_status()
@app.get("/status")
async def get_status() -> StatusResponse:
"""Get server status and statistics"""
loop = asyncio.get_event_loop()
try:
databases = await loop.run_in_executor(_executor, get_databases)
except Exception as e:
databases = []
uptime = (datetime.now() - _start_time).total_seconds()
return StatusResponse(
status="healthy" if databases is not None else "error",
databases=databases,
default_database=settings.database,
uptime_seconds=uptime,
typedb_version="2.28.0",
connected=len(databases) > 0,
database=settings.database,
version="2.28.0",
)
@app.post("/query", response_model=QueryResponse)
async def execute_query(request: QueryRequest) -> QueryResponse:
"""Execute a TypeQL query (read-only)"""
loop = asyncio.get_event_loop()
database = request.database or settings.database
# Security: Only allow match queries for now
query_stripped = request.query.strip().lower()
if not query_stripped.startswith("match"):
raise HTTPException(
status_code=403,
detail="Only match queries are allowed for read operations."
)
start_time = datetime.now()
try:
results, query_type = await loop.run_in_executor(
_executor,
execute_read_query,
database,
request.query
)
execution_time = (datetime.now() - start_time).total_seconds() * 1000
return QueryResponse(
results=results,
result_count=len(results),
execution_time_ms=round(execution_time, 2),
query_type=query_type,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/databases", response_model=List[DatabaseInfo])
async def list_databases() -> List[DatabaseInfo]:
"""List all databases"""
loop = asyncio.get_event_loop()
try:
databases = await loop.run_in_executor(_executor, get_databases)
return [DatabaseInfo(name=db) for db in databases]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/schema")
async def get_schema(database: Optional[str] = None) -> Dict[str, Any]:
"""Get schema for a database"""
loop = asyncio.get_event_loop()
db = database or settings.database
try:
# Check database exists
databases = await loop.run_in_executor(_executor, get_databases)
if db not in databases:
raise HTTPException(status_code=404, detail=f"Database '{db}' not found")
schema = await loop.run_in_executor(_executor, get_schema_types, db)
return {"database": db, "schema": schema}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/databases/{name}")
async def create_database(name: str) -> Dict[str, str]:
"""Create a new database"""
loop = asyncio.get_event_loop()
def _create_db():
driver = get_driver()
driver.databases.create(name)
return name
try:
await loop.run_in_executor(_executor, _create_db)
return {"status": "created", "database": name}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/stats")
async def get_stats(database: Optional[str] = None) -> Dict[str, Any]:
"""Get database statistics including entity/relation/attribute counts"""
loop = asyncio.get_event_loop()
db = database or settings.database
def _get_stats():
driver = get_driver()
stats = {
"database": db,
"entity_types": [],
"relation_types": [],
"attribute_types": [],
"total_entities": 0,
"total_relations": 0,
}
# Get schema types
with driver.session(db, SessionType.SCHEMA) as session:
with session.transaction(TransactionType.READ) as tx:
# Entity types
root_entity = tx.concepts.get_root_entity_type()
for entity_type in root_entity.get_subtypes(tx):
label = entity_type.get_label()
name = label.name if hasattr(label, 'name') else str(label)
if name != "entity":
stats["entity_types"].append({"label": name, "abstract": False})
# Relation types
root_relation = tx.concepts.get_root_relation_type()
for rel_type in root_relation.get_subtypes(tx):
label = rel_type.get_label()
name = label.name if hasattr(label, 'name') else str(label)
if name != "relation":
# Get roles
roles = []
try:
for role in rel_type.get_relates(tx):
role_label = role.get_label()
roles.append(role_label.name if hasattr(role_label, 'name') else str(role_label))
except:
pass
stats["relation_types"].append({"label": name, "roles": roles})
# Attribute types
root_attr = tx.concepts.get_root_attribute_type()
for attr_type in root_attr.get_subtypes(tx):
label = attr_type.get_label()
name = label.name if hasattr(label, 'name') else str(label)
if name != "attribute":
value_type = "unknown"
try:
vt = attr_type.get_value_type()
value_type = str(vt.name) if vt else "unknown"
except:
pass
stats["attribute_types"].append({"label": name, "valueType": value_type})
# Count entities
with driver.session(db, SessionType.DATA) as session:
with session.transaction(TransactionType.READ) as tx:
for et in stats["entity_types"]:
try:
query = f"match $x isa {et['label']}; get $x;"
count = sum(1 for _ in tx.query.get(query))
et["count"] = count
stats["total_entities"] += count
except:
et["count"] = 0
for rt in stats["relation_types"]:
try:
query = f"match $x isa {rt['label']}; get $x;"
count = sum(1 for _ in tx.query.get(query))
rt["count"] = count
stats["total_relations"] += count
except:
rt["count"] = 0
for at in stats["attribute_types"]:
try:
query = f"match $x isa {at['label']}; get $x;"
count = sum(1 for _ in tx.query.get(query))
at["count"] = count
except:
at["count"] = 0
return stats
try:
stats = await loop.run_in_executor(_executor, _get_stats)
return stats
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# Main Entry Point
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=settings.api_host,
port=settings.api_port,
reload=True,
)