feat(middleware): add middleware for authentication

This commit is contained in:
YuehuCao 2025-10-11 09:51:02 +08:00
parent 7a832371c7
commit 80a6beb1ed
5 changed files with 290 additions and 0 deletions

View File

@ -11,6 +11,7 @@ from webapi.providers import metrics
# from webapi.providers import scheduler
from webapi.providers import exception_handler
from webapi.providers import middleware
from .freeleaps_app import FreeleapsApp
from common.config.app_settings import app_settings
@ -20,6 +21,10 @@ def create_app() -> FastAPI:
app = FreeleapsApp()
register_logger()
# 1. Register middleware firstly
register(app, middleware)
# 2. Register other providers
register(app, exception_handler)
register(app, database)
register(app, router)

View File

@ -0,0 +1,4 @@
from .freeleaps_auth_middleware import FreeleapsAuthMiddleware
from .database_middleware import DatabaseMiddleware
__all__ = ['FreeleapsAuthMiddleware', 'DatabaseMiddleware']

View File

@ -0,0 +1,78 @@
from fastapi import Request, status
from fastapi.responses import JSONResponse
from webapi.middleware.freeleaps_auth_middleware import request_context_var
from common.log.module_logger import ModuleLogger
class DatabaseMiddleware:
def __init__(self, app):
self.app = app
self.module_logger = ModuleLogger(sender_id=DatabaseMiddleware)
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope, receive)
# Get tenant id from auth context (set by FreeleapsAuthMiddleware)
product_id = None
try:
ctx = request_context_var.get()
product_id = getattr(ctx, "product_id", None)
await self.module_logger.log_info(f"Retrieved product_id from auth context: {product_id}")
except Exception as e:
await self.module_logger.log_error(f"Failed to get auth context: {str(e)}")
product_id = None
# Get tenant cache and main database from app state
try:
tenant_cache = request.app.state.tenant_cache
main_db = request.app.state.main_db
await self.module_logger.log_info(f"Retrieved app state - tenant_cache: {'success' if tenant_cache is not None else 'fail'}, main_db: {'success' if main_db is not None else 'fail'}")
except Exception as e:
await self.module_logger.log_error(f"Failed to get app state: {str(e)}")
response = JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Database not properly initialized"}
)
return await response(scope, receive, send)
if not product_id:
# Compatibility / public routes: use main database with tenant models initialized
await self.module_logger.log_info(f"No product_id - using main database for path: {request.url.path}")
# Get main database with Beanie initialized for tenant models
main_db_initialized = await tenant_cache.get_main_db_initialized()
request.state.db = main_db_initialized
request.state.product_id = None
await self.module_logger.log_info(f"Successfully initialized main database with tenant models")
return await self.app(scope, receive, send)
try:
# Get tenant-specific database with Beanie already initialized (cached)
await self.module_logger.log_info(f"Attempting to get tenant database for product_id: {product_id}")
tenant_db = await tenant_cache.get_initialized_db(product_id)
request.state.db = tenant_db
request.state.product_id = product_id
await self.module_logger.log_info(f"Successfully retrieved cached tenant database with Beanie for product_id: {product_id}")
return await self.app(scope, receive, send)
except ValueError as e:
# Handle tenant not found or inactive (ValueError from TenantDBCache)
await self.module_logger.log_error(f"Tenant error for {product_id}: {str(e)}")
response = JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"detail": str(e)}
)
return await response(scope, receive, send)
except Exception as e:
await self.module_logger.log_error(f"Database error for tenant {product_id}: {str(e)}")
response = JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Database connection error"}
)
return await response(scope, receive, send)

View File

@ -0,0 +1,192 @@
import httpx
import asyncio
import time
import contextvars
from datetime import datetime
from starlette.requests import Request
from fastapi import HTTPException, Response
from typing import Dict, Any, Optional
from common.log.module_logger import ModuleLogger
from backend.models.user.models import UsageLogDoc
from backend.infra.api_key_introspect_handler import ApiKeyIntrospectHandler
# Define context data class
class RequestContext:
def __init__(self, tenant_name: str = None, product_id: str = None, key_id: str = None):
self.tenant_name = tenant_name
self.product_id = product_id
self.key_id = key_id
def __repr__(self):
return f"RequestContext(tenant_name='{self.tenant_name}', product_id='{self.product_id}', key_id='{self.key_id}')"
# Create context variable, store RequestContext object
request_context_var = contextvars.ContextVar('request_context', default=RequestContext())
class FreeleapsAuthMiddleware:
"""
Notification service API Key middleware
"""
def __init__(self, app):
self.app = app
self.api_key_introspect_handler = ApiKeyIntrospectHandler()
self.module_logger = ModuleLogger(sender_id=FreeleapsAuthMiddleware)
async def __call__(self, scope, receive, send):
"""
Middleware main function, execute before and after request processing
"""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
start_time = time.time()
validation_result = None
try:
# 1. Skip paths that do not need validation
if self._should_skip_validation(request.url.path):
await self.module_logger.log_info(f"Path skipped validation: {request.url.path}")
await self.app(scope, receive, send)
return
# 2. Extract API Key from request header
api_key = request.headers.get("X-API-KEY")
# if the API_KEY field is empty, the request can be processed directly without validation.
# for compatibility
if not api_key or api_key == "":
await self.module_logger.log_info(f"API Key is empty: {request.url.path}")
await self.app(scope, receive, send)
return
# 3. Call freeleaps_auth to validate API Key
validation_result = await self.api_key_introspect_handler.api_key_introspect(api_key)
# 4. Store validation result in contextvars for later use
request_context = RequestContext(
tenant_name=validation_result.get("tenant_name"),
product_id=validation_result.get("product_id"),
key_id=validation_result.get("key_id")
)
request_context_var.set(request_context)
# 6. Process request and capture response
response_captured = None
async def send_wrapper(message):
nonlocal response_captured
if message["type"] == "http.response.start":
# Convert bytes headers to string headers
headers = {}
for header_name, header_value in message.get("headers", []):
if isinstance(header_name, bytes):
header_name = header_name.decode('latin-1')
if isinstance(header_value, bytes):
header_value = header_value.decode('latin-1')
headers[header_name] = header_value
response_captured = Response(
status_code=message["status"],
headers=headers,
media_type="application/json"
)
await send(message)
await self.app(scope, receive, send_wrapper)
# 7. Record usage log after request processing
if response_captured:
await self._log_usage(validation_result, request, response_captured, start_time)
except HTTPException as http_exc:
# Pass through HTTP exceptions (401, 403, etc.) from auth service
await self.module_logger.log_info(f"API Key validation failed: {http_exc.status_code} - {http_exc.detail}")
response = Response(
status_code=http_exc.status_code,
content=f'{{"error": "Authentication failed", "message": "{str(http_exc.detail)}"}}',
media_type="application/json"
)
await response(scope, receive, send)
except Exception as e:
await self.module_logger.log_error(f"Middleware error: {str(e)}")
response = Response(
status_code=500,
content=f'{{"error": "Internal error", "message": "Failed to process request", "details": "{str(e)}"}}',
media_type="application/json"
)
await response(scope, receive, send)
def _should_skip_validation(self, path: str) -> bool:
"""
Check if the path should be skipped for validation
"""
skip_paths = [
"/api/_/healthz", # Health check endpoint
"/api/_/readyz", # Readiness check endpoint
"/api/_/livez", # Liveness check endpoint
"/metrics", # Metrics endpoint
"/docs", # API documentation
"/openapi.json", # OpenAPI specification
"/favicon.ico" # Website icon
]
# Check exact match for root path
if path == "/":
return True
# Check startswith for other paths
return any(path.startswith(skip_path) for skip_path in skip_paths)
async def _log_usage(self, validation_result: Dict[str, Any], request: Request,
response: Response, start_time: float) -> None:
"""
Record API usage log
"""
try:
# calculate processing time
process_time = (time.time() - start_time) * 1000
# get request body size
try:
request_body = await request.body()
bytes_in = len(request_body) if request.method in ["POST", "PUT", "PATCH"] else 0
except Exception:
bytes_in = 0
bytes_out = 0
if hasattr(response, 'headers'):
content_length = response.headers.get('content-length')
if content_length:
bytes_out = int(content_length)
# create usage log document
usage_log_doc = UsageLogDoc(
timestamp=datetime.utcnow(),
tenant_id=validation_result.get("tenant_name"),
operation=f"{request.method} {request.url.path}",
request_id=request.headers.get("X-Request-ID", "unknown"),
units=1, # TODO: adjust according to business logic
status="success" if response.status_code < 400 else "error",
latency_ms=int(process_time),
bytes_in=bytes_in,
bytes_out=bytes_out,
key_id=validation_result.get("key_id"),
extra={
"tenant_name": request_context_var.get().tenant_name,
"product_id": request_context_var.get().product_id,
"scopes": validation_result.get("scopes"),
"user_agent": request.headers.get("User-Agent"),
"ip_address": request.client.host if request.client else "unknown",
"response_status": response.status_code
}
)
# save to database
await usage_log_doc.save()
await self.module_logger.log_info(f"API Usage logged: {usage_log_doc.operation} for tenant {usage_log_doc.tenant_id}")
except Exception as e:
await self.module_logger.log_error(f"Failed to log usage: {str(e)}")

View File

@ -0,0 +1,11 @@
from webapi.middleware.freeleaps_auth_middleware import FreeleapsAuthMiddleware
from webapi.middleware.database_middleware import DatabaseMiddleware
def register(app):
"""
Register middleware to FastAPI application
"""
# Register middlewares
app.add_middleware(DatabaseMiddleware)
app.add_middleware(FreeleapsAuthMiddleware)