glam/backend/typedb/main.py
2025-12-11 22:32:09 +01:00

774 lines
29 KiB
Python

"""
TypeDB REST API for Heritage Custodian Data
FastAPI backend providing TypeQL query interface for bronhouder.nl
Endpoints:
- GET / - Health check and statistics
- GET /status - Server status
- POST /query - Execute TypeQL match query (read-only)
- GET /databases - List all databases
- POST /databases/{n} - Create new database
- GET /schema - Get database schema types
- POST /schema/load - Load TypeQL schema from file
- POST /data/insert - Execute insert query
- POST /database/reset/{n} - Reset (delete/recreate) database
- GET /stats - Get detailed statistics
Updated for TypeDB 2.x driver API (with sessions)
"""
import os
from datetime import datetime
from typing import Optional, List, Dict, Any, Union
from contextlib import asynccontextmanager
import asyncio
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typedb.driver import TypeDB, SessionType, TransactionType
# ============================================================================
# Configuration
# ============================================================================
class Settings(BaseModel):
"""TypeDB server settings"""
host: str = os.getenv("TYPEDB_HOST", "localhost")
port: int = int(os.getenv("TYPEDB_PORT", "1729"))
database: str = os.getenv("TYPEDB_DATABASE", "glam")
# Server settings
api_host: str = os.getenv("API_HOST", "0.0.0.0")
api_port: int = int(os.getenv("API_PORT", "8003"))
settings = Settings()
# ============================================================================
# Pydantic Models
# ============================================================================
class QueryRequest(BaseModel):
"""TypeQL query request"""
query: str = Field(..., description="TypeQL query to execute")
database: Optional[str] = Field(None, description="Database name (defaults to 'glam')")
class QueryResponse(BaseModel):
"""TypeQL query response"""
results: List[Dict[str, Any]]
result_count: int
execution_time_ms: float
query_type: str
class DatabaseInfo(BaseModel):
"""Database metadata"""
name: str
class StatusResponse(BaseModel):
"""Server status response"""
status: str
databases: List[str]
default_database: str
uptime_seconds: float
typedb_version: str
# Fields for frontend compatibility
connected: bool = False
database: str = ""
version: str = ""
# ============================================================================
# Global State
# ============================================================================
_driver: Any = None
_executor = ThreadPoolExecutor(max_workers=4)
_start_time: datetime = datetime.now()
def get_driver() -> Any:
"""Get or create TypeDB driver"""
global _driver
if _driver is None:
# TypeDB 2.x: TypeDB.core_driver() returns a Driver instance
_driver = TypeDB.core_driver(f"{settings.host}:{settings.port}")
return _driver
# ============================================================================
# Helper Functions
# ============================================================================
def serialize_concept(concept: Any) -> Dict[str, Any]:
"""Convert TypeDB concept to JSON-serializable dict"""
result: Dict[str, Any] = {}
# Get type label
if hasattr(concept, 'get_type'):
concept_type = concept.get_type()
if hasattr(concept_type, 'get_label'):
label = concept_type.get_label()
result['_type'] = label.name if hasattr(label, 'name') else str(label)
result['type'] = result['_type']
# Handle IID
if hasattr(concept, 'get_iid'):
iid = concept.get_iid()
if iid is not None:
result['_iid'] = iid.hex() if hasattr(iid, 'hex') else str(iid)
result['id'] = result['_iid']
# Handle value (for attributes)
if hasattr(concept, 'get_value'):
result['value'] = concept.get_value()
# Handle entity/relation - get attributes
if hasattr(concept, 'get_has'):
try:
attrs = list(concept.get_has())
if attrs:
for attr in attrs:
attr_type = attr.get_type().get_label()
attr_name = attr_type.name if hasattr(attr_type, 'name') else str(attr_type)
result[attr_name] = attr.get_value()
except Exception:
pass
return result
def serialize_concept_map(concept_map: Any) -> Dict[str, Any]:
"""Convert TypeDB ConceptMap to JSON-serializable dict"""
result: Dict[str, Any] = {}
# TypeDB 2.x uses ConceptMap with variables()
if hasattr(concept_map, 'variables'):
for var in concept_map.variables():
concept = concept_map.get(var)
if concept:
result[var] = serialize_concept(concept)
return result
def execute_read_query(database: str, query: str) -> tuple:
"""Execute a read query in TypeDB 2.x (blocking)"""
driver = get_driver()
results: List[Dict[str, Any]] = []
query_type = "unknown"
# Determine query type
query_stripped = query.strip().lower()
if query_stripped.startswith("match"):
query_type = "match"
elif query_stripped.startswith("define"):
query_type = "define"
elif query_stripped.startswith("insert"):
query_type = "insert"
# TypeDB 2.x: Session -> Transaction -> Query
with driver.session(database, SessionType.DATA) as session:
with session.transaction(TransactionType.READ) as tx:
# Execute query using get() for match queries
# TypeDB 2.x uses tx.query.get() for "match ... get ..." queries
answer = tx.query.get(query)
# Iterate results
for concept_map in answer:
results.append(serialize_concept_map(concept_map))
return results, query_type
def get_databases() -> List[str]:
"""Get list of databases"""
driver = get_driver()
return [db.name for db in driver.databases.all()]
def get_schema_types(database: str) -> Dict[str, Any]:
"""Get schema types from database using TypeDB 2.x concepts API"""
driver = get_driver()
schema: Dict[str, Any] = {"entity_types": [], "relation_types": [], "attribute_types": []}
try:
# TypeDB 2.x: Use SCHEMA session type with concepts API
with driver.session(database, SessionType.SCHEMA) as session:
with session.transaction(TransactionType.READ) as tx:
# Get entity types using concepts API
try:
entity_type = tx.concepts.get_root_entity_type()
if entity_type:
for sub in entity_type.get_subtypes(tx):
label = sub.get_label()
label_str = label.name if hasattr(label, 'name') else str(label)
if label_str != "entity":
schema["entity_types"].append(label_str)
except Exception:
pass
# Get relation types using concepts API
try:
relation_type = tx.concepts.get_root_relation_type()
if relation_type:
for sub in relation_type.get_subtypes(tx):
label = sub.get_label()
label_str = label.name if hasattr(label, 'name') else str(label)
if label_str != "relation":
schema["relation_types"].append(label_str)
except Exception:
pass
# Get attribute types using concepts API
try:
attr_type = tx.concepts.get_root_attribute_type()
if attr_type:
for sub in attr_type.get_subtypes(tx):
label = sub.get_label()
label_str = label.name if hasattr(label, 'name') else str(label)
if label_str != "attribute":
schema["attribute_types"].append(label_str)
except Exception:
pass
except Exception as e:
schema["error"] = str(e)
return schema
# ============================================================================
# FastAPI App
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler"""
# Startup: Initialize driver
get_driver()
yield
# Shutdown: Close driver
global _driver
if _driver:
_driver.close()
_driver = None
_executor.shutdown(wait=True)
app = FastAPI(
title="TypeDB Heritage API",
description="REST API for heritage institution TypeQL queries",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================================
# Schema and Data Loading Functions
# ============================================================================
def load_schema_from_file(database: str, filepath: str) -> Dict[str, Any]:
"""Load TypeQL schema from a .tql file"""
driver = get_driver()
# Read the schema file
with open(filepath, 'r', encoding='utf-8') as f:
schema_content = f.read()
# TypeDB 2.x: Schema operations use SCHEMA session
with driver.session(database, SessionType.SCHEMA) as session:
with session.transaction(TransactionType.WRITE) as tx:
tx.query.define(schema_content)
tx.commit()
return {"status": "success", "file": filepath, "database": database}
def execute_write_query(database: str, query: str) -> Dict[str, Any]:
"""Execute a write query (insert/delete) in TypeDB 2.x"""
driver = get_driver()
result: Dict[str, Any] = {"status": "success", "inserted": 0}
# TypeDB 2.x: Data operations use DATA session
with driver.session(database, SessionType.DATA) as session:
with session.transaction(TransactionType.WRITE) as tx:
query_stripped = query.strip().lower()
if query_stripped.startswith("insert"):
answer = tx.query.insert(query)
# Count inserted concepts
count = sum(1 for _ in answer)
result["inserted"] = count
elif query_stripped.startswith("match") and "insert" in query_stripped:
# Match-insert pattern
answer = tx.query.insert(query)
count = sum(1 for _ in answer)
result["inserted"] = count
elif query_stripped.startswith("delete"):
tx.query.delete(query)
result["deleted"] = True
else:
raise ValueError(f"Unsupported write query type: {query[:50]}...")
tx.commit()
return result
def reset_database(database: str) -> Dict[str, Any]:
"""Delete and recreate a database"""
driver = get_driver()
# Delete if exists
try:
if driver.databases.contains(database):
driver.databases.get(database).delete()
except Exception:
pass
# Create new
driver.databases.create(database)
return {"status": "success", "database": database, "action": "reset"}
# ============================================================================
# API Endpoints
# ============================================================================
@app.get("/", response_model=StatusResponse)
async def get_root_status() -> StatusResponse:
"""Get server status and statistics (root endpoint)"""
return await get_status()
@app.get("/status")
async def get_status() -> StatusResponse:
"""Get server status and statistics"""
loop = asyncio.get_event_loop()
try:
databases = await loop.run_in_executor(_executor, get_databases)
except Exception as e:
databases = []
uptime = (datetime.now() - _start_time).total_seconds()
return StatusResponse(
status="healthy" if databases is not None else "error",
databases=databases,
default_database=settings.database,
uptime_seconds=uptime,
typedb_version="2.28.0",
connected=len(databases) > 0,
database=settings.database,
version="2.28.0",
)
@app.post("/query", response_model=QueryResponse)
async def execute_query(request: QueryRequest) -> QueryResponse:
"""Execute a TypeQL query (read-only)"""
loop = asyncio.get_event_loop()
database = request.database or settings.database
# Security: Only allow match queries for now
query_stripped = request.query.strip().lower()
if not query_stripped.startswith("match"):
raise HTTPException(
status_code=403,
detail="Only match queries are allowed for read operations."
)
start_time = datetime.now()
try:
results, query_type = await loop.run_in_executor(
_executor,
execute_read_query,
database,
request.query
)
execution_time = (datetime.now() - start_time).total_seconds() * 1000
return QueryResponse(
results=results,
result_count=len(results),
execution_time_ms=round(execution_time, 2),
query_type=query_type,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/databases", response_model=List[DatabaseInfo])
async def list_databases() -> List[DatabaseInfo]:
"""List all databases"""
loop = asyncio.get_event_loop()
try:
databases = await loop.run_in_executor(_executor, get_databases)
return [DatabaseInfo(name=db) for db in databases]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/schema")
async def get_schema(database: Optional[str] = None) -> Dict[str, Any]:
"""Get schema for a database"""
loop = asyncio.get_event_loop()
db = database or settings.database
try:
# Check database exists
databases = await loop.run_in_executor(_executor, get_databases)
if db not in databases:
raise HTTPException(status_code=404, detail=f"Database '{db}' not found")
schema = await loop.run_in_executor(_executor, get_schema_types, db)
return {"database": db, "schema": schema}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/databases/{name}")
async def create_database(name: str) -> Dict[str, str]:
"""Create a new database"""
loop = asyncio.get_event_loop()
def _create_db() -> str:
driver = get_driver()
driver.databases.create(name)
return name
try:
await loop.run_in_executor(_executor, _create_db)
return {"status": "created", "database": name}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/schema/load")
async def load_schema(filepath: str = "/var/www/backend/typedb/01_custodian_name.tql", database: Optional[str] = None) -> Dict[str, Any]:
"""Load TypeQL schema from a file on the server"""
loop = asyncio.get_event_loop()
db = database or settings.database
# Security: Only allow files in the backend directory
if not filepath.startswith("/var/www/backend/typedb/"):
raise HTTPException(status_code=403, detail="Schema files must be in /var/www/backend/typedb/")
try:
result = await loop.run_in_executor(_executor, load_schema_from_file, db, filepath)
return result
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Schema file not found: {filepath}")
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/data/insert")
async def insert_data(request: QueryRequest) -> Dict[str, Any]:
"""Execute an insert query to add data"""
loop = asyncio.get_event_loop()
database = request.database or settings.database
# Security: Only allow insert queries
query_stripped = request.query.strip().lower()
if not (query_stripped.startswith("insert") or (query_stripped.startswith("match") and "insert" in query_stripped)):
raise HTTPException(
status_code=403,
detail="Only insert queries are allowed for this endpoint."
)
start_time = datetime.now()
try:
result = await loop.run_in_executor(_executor, execute_write_query, database, request.query)
execution_time = (datetime.now() - start_time).total_seconds() * 1000
result["execution_time_ms"] = round(execution_time, 2)
return result
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/database/reset/{name}")
async def reset_db(name: str) -> Dict[str, Any]:
"""Reset (delete and recreate) a database"""
loop = asyncio.get_event_loop()
try:
result = await loop.run_in_executor(_executor, reset_database, name)
return result
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/stats")
async def get_stats(database: Optional[str] = None) -> Dict[str, Any]:
"""Get database statistics including entity/relation/attribute counts"""
loop = asyncio.get_event_loop()
db = database or settings.database
def _get_stats() -> Dict[str, Any]:
driver = get_driver()
stats: Dict[str, Any] = {
"database": db,
"entity_types": [],
"relation_types": [],
"attribute_types": [],
"total_entities": 0,
"total_relations": 0,
}
# Get schema types using SCHEMA session with concepts API
with driver.session(db, SessionType.SCHEMA) as session:
with session.transaction(TransactionType.READ) as tx:
# Entity types using concepts API
try:
entity_type = tx.concepts.get_root_entity_type()
if entity_type:
for sub in entity_type.get_subtypes(tx):
label = sub.get_label()
label_str = label.name if hasattr(label, 'name') else str(label)
if label_str != "entity":
is_abstract = sub.is_abstract()
stats["entity_types"].append({
"label": label_str,
"abstract": is_abstract
})
except Exception:
pass
# Relation types using concepts API
try:
relation_type = tx.concepts.get_root_relation_type()
if relation_type:
for sub in relation_type.get_subtypes(tx):
label = sub.get_label()
label_str = label.name if hasattr(label, 'name') else str(label)
if label_str != "relation":
stats["relation_types"].append({
"label": label_str,
"roles": []
})
except Exception:
pass
# Attribute types using concepts API
try:
attr_type = tx.concepts.get_root_attribute_type()
if attr_type:
for sub in attr_type.get_subtypes(tx):
label = sub.get_label()
label_str = label.name if hasattr(label, 'name') else str(label)
if label_str != "attribute":
# Get value type
value_type = sub.get_value_type()
vt_str = value_type.name if value_type else "unknown"
stats["attribute_types"].append({
"label": label_str,
"valueType": vt_str
})
except Exception:
pass
# Count instances using DATA session
# Note: TypeDB 2.x returns Promises that need .resolve()
with driver.session(db, SessionType.DATA) as session:
with session.transaction(TransactionType.READ) as tx:
for et in stats["entity_types"]:
if et.get("abstract"):
et["count"] = 0
continue
try:
# TypeDB 2.x: get_entity_type returns Promise, need .resolve()
entity_type = tx.concepts.get_entity_type(et["label"]).resolve()
if entity_type:
instances = list(entity_type.get_instances(tx))
et["count"] = len(instances)
stats["total_entities"] += len(instances)
else:
et["count"] = 0
except Exception:
et["count"] = 0
for rt in stats["relation_types"]:
try:
# TypeDB 2.x: get_relation_type returns Promise, need .resolve()
relation_type = tx.concepts.get_relation_type(rt["label"]).resolve()
if relation_type:
instances = list(relation_type.get_instances(tx))
rt["count"] = len(instances)
stats["total_relations"] += len(instances)
else:
rt["count"] = 0
except Exception:
rt["count"] = 0
for at in stats["attribute_types"]:
try:
# TypeDB 2.x: get_attribute_type returns Promise, need .resolve()
attr_type = tx.concepts.get_attribute_type(at["label"]).resolve()
if attr_type:
instances = list(attr_type.get_instances(tx))
at["count"] = len(instances)
else:
at["count"] = 0
except Exception:
at["count"] = 0
return stats
try:
stats = await loop.run_in_executor(_executor, _get_stats)
return stats
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/graph/{entity_type}")
async def get_graph_data(entity_type: str, limit: int = 100, database: Optional[str] = None) -> Dict[str, Any]:
"""Get graph data for visualization (nodes + edges) for a specific entity type"""
loop = asyncio.get_event_loop()
db = database or settings.database
def _get_graph() -> Dict[str, Any]:
driver = get_driver()
nodes: List[Dict[str, Any]] = []
edges: List[Dict[str, Any]] = []
node_ids: set = set()
with driver.session(db, SessionType.DATA) as session:
with session.transaction(TransactionType.READ) as tx:
# Get entity instances with their attributes
entity_type_obj = tx.concepts.get_entity_type(entity_type).resolve()
if not entity_type_obj:
return {"nodes": [], "edges": [], "nodeCount": 0, "edgeCount": 0}
instances = list(entity_type_obj.get_instances(tx))[:limit]
for instance in instances:
iid = instance.get_iid()
node_id = iid.hex() if hasattr(iid, 'hex') else str(iid)
# Get attributes for this instance
attributes: Dict[str, Any] = {}
label = node_id[:8] # Default label
try:
for attr in instance.get_has(tx):
attr_type = attr.get_type().get_label()
attr_name = attr_type.name if hasattr(attr_type, 'name') else str(attr_type)
attr_value = attr.get_value()
attributes[attr_name] = attr_value
# Use certain attributes as label
if attr_name in ('name', 'label', 'observed-name', 'legal-name', 'id'):
label = str(attr_value)
except Exception:
pass
node_ids.add(node_id)
nodes.append({
"id": node_id,
"label": label,
"type": "entity",
"entityType": entity_type,
"attributes": attributes,
})
# Get relations involving these entities
# Query relations where these entities participate
try:
relation_root = tx.concepts.get_root_relation_type()
for rel_type in relation_root.get_subtypes(tx):
rel_label = rel_type.get_label()
rel_name = rel_label.name if hasattr(rel_label, 'name') else str(rel_label)
if rel_name == "relation":
continue
for rel_instance in rel_type.get_instances(tx):
rel_iid = rel_instance.get_iid()
rel_id = rel_iid.hex() if hasattr(rel_iid, 'hex') else str(rel_iid)
# Get role players
players = []
try:
# TypeDB 2.x: get_players(tx) returns a dict {role_type: [player_list]}
role_players_dict = rel_instance.get_players(tx)
for role_type, player_list in role_players_dict.items():
role_label = role_type.get_label()
role_name = role_label.name if hasattr(role_label, 'name') else str(role_label)
for player in player_list:
player_iid = player.get_iid()
player_id = player_iid.hex() if hasattr(player_iid, 'hex') else str(player_iid)
players.append({"role": role_name, "id": player_id})
except Exception as e:
print(f"Error getting role players for relation {rel_name}: {e}")
# Create edges if we have players from our node set
matching_players = [p for p in players if p["id"] in node_ids]
if len(matching_players) >= 1 and len(players) >= 2:
# Create edge between first two players
edges.append({
"id": rel_id,
"source": players[0]["id"],
"target": players[1]["id"] if len(players) > 1 else players[0]["id"],
"relationType": rel_name,
"role": f"{players[0]['role']} -> {players[1]['role'] if len(players) > 1 else players[0]['role']}",
"attributes": {},
})
except Exception as e:
print(f"Error getting relations: {e}")
return {
"nodes": nodes,
"edges": edges,
"nodeCount": len(nodes),
"edgeCount": len(edges),
}
try:
result = await loop.run_in_executor(_executor, _get_graph)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# Main Entry Point
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=settings.api_host,
port=settings.api_port,
reload=True,
)