freeleaps-service-hub/apps/notification/webapi/middleware/freeleaps_auth_middleware.py
2025-09-17 17:56:24 +08:00

193 lines
8.1 KiB
Python

import httpx
import asyncio
import time
import contextvars
from datetime import datetime
from starlette.requests import Request
from fastapi import HTTPException, Response
from typing import Dict, Any, Optional
from common.log.module_logger import ModuleLogger
from backend.models.models import UsageLogDoc
from backend.infra.api_key_introspect_handler import ApiKeyIntrospectHandler
# Define context data class
class RequestContext:
def __init__(self, tenant_name: str = None, product_id: str = None, key_id: str = None):
self.tenant_name = tenant_name
self.product_id = product_id
self.key_id = key_id
def __repr__(self):
return f"RequestContext(tenant_name='{self.tenant_name}', product_id='{self.product_id}', key_id='{self.key_id}')"
# Create context variable, store RequestContext object
request_context_var = contextvars.ContextVar('request_context', default=RequestContext())
class FreeleapsAuthMiddleware:
"""
Notification service API Key middleware
"""
def __init__(self, app):
self.app = app
self.api_key_introspect_handler = ApiKeyIntrospectHandler()
self.module_logger = ModuleLogger(sender_id=FreeleapsAuthMiddleware)
async def __call__(self, scope, receive, send):
"""
Middleware main function, execute before and after request processing
"""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
start_time = time.time()
validation_result = None
try:
# 1. Skip paths that do not need validation
if self._should_skip_validation(request.url.path):
await self.module_logger.log_info(f"Path skipped validation: {request.url.path}")
await self.app(scope, receive, send)
return
# 2. Extract API Key from request header
api_key = request.headers.get("X-API-KEY")
# if the API_KEY field is empty, the request can be processed directly without validation.
# for compatibility
if not api_key or api_key == "":
await self.module_logger.log_info(f"API Key is empty: {request.url.path}")
await self.app(scope, receive, send)
return
# 3. Call freeleaps_auth to validate API Key
validation_result = await self.api_key_introspect_handler.api_key_introspect(api_key)
# 4. Store validation result in contextvars for later use
request_context = RequestContext(
tenant_name=validation_result.get("tenant_name"),
product_id=validation_result.get("product_id"),
key_id=validation_result.get("key_id")
)
request_context_var.set(request_context)
# 6. Process request and capture response
response_captured = None
async def send_wrapper(message):
nonlocal response_captured
if message["type"] == "http.response.start":
# Convert bytes headers to string headers
headers = {}
for header_name, header_value in message.get("headers", []):
if isinstance(header_name, bytes):
header_name = header_name.decode('latin-1')
if isinstance(header_value, bytes):
header_value = header_value.decode('latin-1')
headers[header_name] = header_value
response_captured = Response(
status_code=message["status"],
headers=headers,
media_type="application/json"
)
await send(message)
await self.app(scope, receive, send_wrapper)
# 7. Record usage log after request processing
if response_captured:
await self._log_usage(validation_result, request, response_captured, start_time)
except HTTPException as http_exc:
# Pass through HTTP exceptions (401, 403, etc.) from auth service
await self.module_logger.log_info(f"API Key validation failed: {http_exc.status_code} - {http_exc.detail}")
response = Response(
status_code=http_exc.status_code,
content=f'{{"error": "Authentication failed", "message": "{str(http_exc.detail)}"}}',
media_type="application/json"
)
await response(scope, receive, send)
except Exception as e:
await self.module_logger.log_error(f"Middleware error: {str(e)}")
response = Response(
status_code=500,
content=f'{{"error": "Internal error", "message": "Failed to process request", "details": "{str(e)}"}}',
media_type="application/json"
)
await response(scope, receive, send)
def _should_skip_validation(self, path: str) -> bool:
"""
Check if the path should be skipped for validation
"""
skip_paths = [
"/api/_/healthz", # Health check endpoint
"/api/_/readyz", # Readiness check endpoint
"/api/_/livez", # Liveness check endpoint
"/metrics", # Metrics endpoint
"/docs", # API documentation
"/openapi.json", # OpenAPI specification
"/favicon.ico" # Website icon
]
# Check exact match for root path
if path == "/":
return True
# Check startswith for other paths
return any(path.startswith(skip_path) for skip_path in skip_paths)
async def _log_usage(self, validation_result: Dict[str, Any], request: Request,
response: Response, start_time: float) -> None:
"""
Record API usage log
"""
try:
# calculate processing time
process_time = (time.time() - start_time) * 1000
# get request body size
try:
request_body = await request.body()
bytes_in = len(request_body) if request.method in ["POST", "PUT", "PATCH"] else 0
except Exception:
bytes_in = 0
bytes_out = 0
if hasattr(response, 'headers'):
content_length = response.headers.get('content-length')
if content_length:
bytes_out = int(content_length)
# create usage log document
usage_log_doc = UsageLogDoc(
timestamp=datetime.utcnow(),
tenant_id=validation_result.get("tenant_name"),
operation=f"{request.method} {request.url.path}",
request_id=request.headers.get("X-Request-ID", "unknown"),
units=1, # TODO: adjust according to business logic
status="success" if response.status_code < 400 else "error",
latency_ms=int(process_time),
bytes_in=bytes_in,
bytes_out=bytes_out,
key_id=validation_result.get("key_id"),
extra={
"tenant_name": request_context_var.get().tenant_name,
"product_id": request_context_var.get().product_id,
"scopes": validation_result.get("scopes"),
"user_agent": request.headers.get("User-Agent"),
"ip_address": request.client.host if request.client else "unknown",
"response_status": response.status_code
}
)
# save to database
await usage_log_doc.save()
await self.module_logger.log_info(f"API Usage logged: {usage_log_doc.operation} for tenant {usage_log_doc.tenant_id}")
except Exception as e:
await self.module_logger.log_error(f"Failed to log usage: {str(e)}")