445 lines
12 KiB
Python
445 lines
12 KiB
Python
"""
|
|
de Aa Archiefassistent - Authentication Backend
|
|
FastAPI-based JWT authentication service with SQLite persistence
|
|
"""
|
|
|
|
import os
|
|
import secrets
|
|
import sqlite3
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from pydantic import BaseModel, EmailStr
|
|
from jose import JWTError, jwt
|
|
import bcrypt
|
|
import uvicorn
|
|
|
|
# Configuration
|
|
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
|
|
REFRESH_SECRET_KEY = os.getenv("JWT_REFRESH_SECRET_KEY", secrets.token_urlsafe(32))
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
|
# Database path - store in same directory as main.py
|
|
DB_PATH = Path(__file__).parent / "users.db"
|
|
|
|
# Security
|
|
security = HTTPBearer()
|
|
|
|
|
|
def get_db_connection() -> sqlite3.Connection:
|
|
"""Get a database connection."""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
|
|
def init_database():
|
|
"""Initialize the SQLite database with users table."""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# Create users table if not exists
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
email TEXT UNIQUE NOT NULL,
|
|
name TEXT NOT NULL,
|
|
password_hash TEXT NOT NULL,
|
|
role TEXT NOT NULL DEFAULT 'user',
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)
|
|
""")
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
print(f"✓ Database initialized: {DB_PATH}")
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""Hash a password using bcrypt."""
|
|
# Truncate to 72 bytes (bcrypt limit)
|
|
password_bytes = password.encode('utf-8')[:72]
|
|
salt = bcrypt.gensalt()
|
|
return bcrypt.hashpw(password_bytes, salt).decode('utf-8')
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash."""
|
|
# Truncate to 72 bytes (bcrypt limit)
|
|
password_bytes = plain_password.encode('utf-8')[:72]
|
|
hashed_bytes = hashed_password.encode('utf-8')
|
|
try:
|
|
return bcrypt.checkpw(password_bytes, hashed_bytes)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def load_users_from_env():
|
|
"""Load users from environment variables into SQLite database.
|
|
|
|
Only inserts users that don't already exist (preserves password changes).
|
|
"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
|
|
# Format: DE_AA_USER_1=email:password:name:role
|
|
i = 1
|
|
new_users = 0
|
|
existing_users = 0
|
|
|
|
while True:
|
|
user_env = os.getenv(f"DE_AA_USER_{i}")
|
|
if not user_env:
|
|
break
|
|
parts = user_env.split(":")
|
|
if len(parts) >= 3:
|
|
email = parts[0]
|
|
password = parts[1]
|
|
name = parts[2]
|
|
role = parts[3] if len(parts) > 3 else "user"
|
|
user_id = f"user_{i}"
|
|
|
|
# Check if user already exists
|
|
cursor.execute("SELECT id FROM users WHERE email = ?", (email,))
|
|
if cursor.fetchone() is None:
|
|
# Insert new user
|
|
password_hash = hash_password(password)
|
|
cursor.execute("""
|
|
INSERT INTO users (id, email, name, password_hash, role, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""", (user_id, email, name, password_hash, role, now, now))
|
|
print(f"✓ Created user: {email} ({role})")
|
|
new_users += 1
|
|
else:
|
|
print(f"○ User exists: {email} (password preserved)")
|
|
existing_users += 1
|
|
i += 1
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
print(f"Users: {new_users} new, {existing_users} existing")
|
|
|
|
|
|
def get_user(email: str) -> Optional[dict]:
|
|
"""Get user from database by email."""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM users WHERE email = ?", (email,))
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
|
|
if row:
|
|
return dict(row)
|
|
return None
|
|
|
|
|
|
def get_user_count() -> int:
|
|
"""Get total number of users in database."""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT COUNT(*) FROM users")
|
|
count = cursor.fetchone()[0]
|
|
conn.close()
|
|
return count
|
|
|
|
|
|
def update_user_password(email: str, new_password_hash: str) -> bool:
|
|
"""Update user password in database."""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
|
|
cursor.execute("""
|
|
UPDATE users SET password_hash = ?, updated_at = ? WHERE email = ?
|
|
""", (new_password_hash, now, email))
|
|
|
|
success = cursor.rowcount > 0
|
|
conn.commit()
|
|
conn.close()
|
|
return success
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Initialize database and load users on startup."""
|
|
print("Initializing database...")
|
|
init_database()
|
|
print("Loading users from environment...")
|
|
load_users_from_env()
|
|
|
|
user_count = get_user_count()
|
|
if user_count == 0:
|
|
print("⚠️ WARNING: No users configured! Add DE_AA_USER_1=email:password:name:role to .env")
|
|
else:
|
|
print(f"Total users in database: {user_count}")
|
|
|
|
yield
|
|
print("Shutting down...")
|
|
|
|
|
|
# App
|
|
app = FastAPI(
|
|
title="de Aa Authentication API",
|
|
description="JWT authentication service for de Aa Archiefassistent",
|
|
version="1.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=[
|
|
"https://archief.support",
|
|
"http://localhost:5173", # Vite dev server
|
|
"http://localhost:4173", # Vite preview
|
|
],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# Models
|
|
class UserCreate(BaseModel):
|
|
email: EmailStr
|
|
password: str
|
|
name: str
|
|
|
|
|
|
class UserLogin(BaseModel):
|
|
email: EmailStr
|
|
password: str
|
|
|
|
|
|
class UserResponse(BaseModel):
|
|
id: str
|
|
email: str
|
|
name: str
|
|
role: str
|
|
|
|
|
|
class TokenResponse(BaseModel):
|
|
accessToken: str
|
|
refreshToken: str
|
|
expiresIn: int
|
|
|
|
|
|
class LoginResponse(BaseModel):
|
|
user: UserResponse
|
|
tokens: TokenResponse
|
|
|
|
|
|
class RefreshRequest(BaseModel):
|
|
refreshToken: str
|
|
|
|
|
|
class ChangePasswordRequest(BaseModel):
|
|
currentPassword: str
|
|
newPassword: str
|
|
|
|
|
|
def authenticate_user(email: str, password: str) -> Optional[dict]:
|
|
user = get_user(email)
|
|
if not user:
|
|
return None
|
|
if not verify_password(password, user["password_hash"]):
|
|
return None
|
|
return user
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
to_encode = data.copy()
|
|
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
|
to_encode.update({"exp": expire, "type": "access"})
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(data: dict) -> str:
|
|
to_encode = data.copy()
|
|
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
to_encode.update({"exp": expire, "type": "refresh"})
|
|
return jwt.encode(to_encode, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def verify_access_token(token: str) -> Optional[dict]:
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
if payload.get("type") != "access":
|
|
return None
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
def verify_refresh_token(token: str) -> Optional[dict]:
|
|
try:
|
|
payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM])
|
|
if payload.get("type") != "refresh":
|
|
return None
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict:
|
|
token = credentials.credentials
|
|
payload = verify_access_token(token)
|
|
if not payload:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
email = payload.get("sub")
|
|
if not email:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token",
|
|
)
|
|
user = get_user(str(email))
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found",
|
|
)
|
|
return user
|
|
|
|
|
|
# Routes
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy", "service": "de-aa-auth", "users_loaded": get_user_count()}
|
|
|
|
|
|
@app.post("/auth/login", response_model=LoginResponse)
|
|
async def login(credentials: UserLogin):
|
|
user = authenticate_user(credentials.email, credentials.password)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Ongeldige inloggegevens",
|
|
)
|
|
|
|
access_token = create_access_token(data={"sub": user["email"]})
|
|
refresh_token = create_refresh_token(data={"sub": user["email"]})
|
|
|
|
return LoginResponse(
|
|
user=UserResponse(
|
|
id=user["id"],
|
|
email=user["email"],
|
|
name=user["name"],
|
|
role=user["role"],
|
|
),
|
|
tokens=TokenResponse(
|
|
accessToken=access_token,
|
|
refreshToken=refresh_token,
|
|
expiresIn=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
|
),
|
|
)
|
|
|
|
|
|
@app.post("/auth/refresh", response_model=TokenResponse)
|
|
async def refresh_tokens(request: RefreshRequest):
|
|
payload = verify_refresh_token(request.refreshToken)
|
|
if not payload:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired refresh token",
|
|
)
|
|
|
|
email = payload.get("sub")
|
|
if not email:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid refresh token",
|
|
)
|
|
user = get_user(str(email))
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found",
|
|
)
|
|
|
|
access_token = create_access_token(data={"sub": str(email)})
|
|
refresh_token = create_refresh_token(data={"sub": str(email)})
|
|
|
|
return TokenResponse(
|
|
accessToken=access_token,
|
|
refreshToken=refresh_token,
|
|
expiresIn=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
|
)
|
|
|
|
|
|
@app.post("/auth/logout")
|
|
async def logout(current_user: dict = Depends(get_current_user)):
|
|
# In a production system, you would invalidate the refresh token here
|
|
# by storing it in a blacklist or removing it from a whitelist
|
|
return {"message": "Logged out successfully"}
|
|
|
|
|
|
@app.get("/auth/me", response_model=UserResponse)
|
|
async def get_me(current_user: dict = Depends(get_current_user)):
|
|
return UserResponse(
|
|
id=current_user["id"],
|
|
email=current_user["email"],
|
|
name=current_user["name"],
|
|
role=current_user["role"],
|
|
)
|
|
|
|
|
|
@app.post("/auth/change-password")
|
|
async def change_password(
|
|
request: ChangePasswordRequest,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
"""Change user password. Requires current password verification.
|
|
|
|
Password is stored persistently in SQLite database.
|
|
"""
|
|
# Verify current password
|
|
if not verify_password(request.currentPassword, current_user["password_hash"]):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Huidig wachtwoord is onjuist",
|
|
)
|
|
|
|
# Validate new password length
|
|
if len(request.newPassword) < 8:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Nieuw wachtwoord moet minimaal 8 tekens bevatten",
|
|
)
|
|
|
|
# Update password in database
|
|
email = current_user["email"]
|
|
new_password_hash = hash_password(request.newPassword)
|
|
|
|
if update_user_password(email, new_password_hash):
|
|
print(f"✓ Password changed for user: {email}")
|
|
return {"message": "Wachtwoord succesvol gewijzigd"}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Fout bij het opslaan van het wachtwoord",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(
|
|
"main:app",
|
|
host="0.0.0.0",
|
|
port=8080,
|
|
reload=True,
|
|
)
|