import httpx import asyncio import time import contextvars from datetime import datetime from starlette.requests import Request from fastapi import HTTPException, Response from typing import Dict, Any, Optional from common.log.module_logger import ModuleLogger from backend.models.models import UsageLogDoc from backend.infra.api_key_introspect_handler import ApiKeyIntrospectHandler # Define context data class class RequestContext: def __init__(self, tenant_name: str = None, product_id: str = None, key_id: str = None): self.tenant_name = tenant_name self.product_id = product_id self.key_id = key_id def __repr__(self): return f"RequestContext(tenant_name='{self.tenant_name}', product_id='{self.product_id}', key_id='{self.key_id}')" # Create context variable, store RequestContext object request_context_var = contextvars.ContextVar('request_context', default=RequestContext()) class NotificationServiceMiddleware: """ Notification service API Key middleware """ def __init__(self, app): self.app = app self.api_key_introspect_handler = ApiKeyIntrospectHandler() self.module_logger = ModuleLogger(sender_id=NotificationServiceMiddleware) async def __call__(self, scope, receive, send): """ Middleware main function, execute before and after request processing """ if scope["type"] != "http": await self.app(scope, receive, send) return request = Request(scope, receive) start_time = time.time() validation_result = None try: # 1. Skip paths that do not need validation if self._should_skip_validation(request.url.path): await self.app(scope, receive, send) return # 2. Extract API Key from request header api_key = request.headers.get("X-API-KEY") # if the API_KEY field is empty, the request can be processed directly without validation. # for compatibility if not api_key or api_key == "": await self.app(scope, receive, send) return # 3. Call freeleaps_auth to validate API Key validation_result = await self.api_key_introspect_handler.api_key_introspect(api_key) # 4. Validate API Key status if not validation_result.get("active"): response = Response( status_code=403, content=f'{{"error": "{validation_result.get("error")}", "message": "{validation_result.get("message")}"}}', media_type="application/json" ) await response(scope, receive, send) return # 5. Store validation result in contextvars for later use request_context = RequestContext( tenant_name=validation_result.get("tenant_name"), product_id=validation_result.get("product_id"), key_id=validation_result.get("key_id") ) request_context_var.set(request_context) # 6. Process request and capture response response_captured = None async def send_wrapper(message): nonlocal response_captured if message["type"] == "http.response.start": # Convert bytes headers to string headers headers = {} for header_name, header_value in message.get("headers", []): if isinstance(header_name, bytes): header_name = header_name.decode('latin-1') if isinstance(header_value, bytes): header_value = header_value.decode('latin-1') headers[header_name] = header_value response_captured = Response( status_code=message["status"], headers=headers, media_type="application/json" ) await send(message) await self.app(scope, receive, send_wrapper) # 7. Record usage log after request processing if response_captured: await self._log_usage(validation_result, request, response_captured, start_time) except Exception as e: await self.module_logger.log_error(f"Middleware error: {str(e)}") response = Response( status_code=500, content=f'{{"error": "Internal error", "message": "Failed to process request", "details": "{str(e)}"}}', media_type="application/json" ) await response(scope, receive, send) def _should_skip_validation(self, path: str) -> bool: """ Check if the path should be skipped for validation """ skip_paths = [ "/health", "/metrics", "/docs", "/openapi.json", "/favicon.ico" ] return any(path.startswith(skip_path) for skip_path in skip_paths) async def _log_usage(self, validation_result: Dict[str, Any], request: Request, response: Response, start_time: float) -> None: """ Record API usage log """ try: # calculate processing time process_time = (time.time() - start_time) * 1000 # get request body size try: request_body = await request.body() bytes_in = len(request_body) if request.method in ["POST", "PUT", "PATCH"] else 0 except Exception: bytes_in = 0 bytes_out = 0 if hasattr(response, 'headers'): content_length = response.headers.get('content-length') if content_length: bytes_out = int(content_length) # create usage log document usage_log_doc = UsageLogDoc( timestamp=datetime.utcnow(), tenant_id=validation_result.get("tenant_name"), operation=f"{request.method} {request.url.path}", request_id=request.headers.get("X-Request-ID", "unknown"), units=1, # TODO: adjust according to business logic status="success" if response.status_code < 400 else "error", latency_ms=int(process_time), bytes_in=bytes_in, bytes_out=bytes_out, key_id=validation_result.get("key_id"), extra={ "tenant_name": request_context_var.get().tenant_name, "product_id": request_context_var.get().product_id, "scopes": validation_result.get("scopes"), "user_agent": request.headers.get("User-Agent"), "ip_address": request.client.host if request.client else "unknown", "response_status": response.status_code } ) # save to database await usage_log_doc.save() await self.module_logger.log_info(f"API Usage logged: {usage_log_doc.operation} for tenant {usage_log_doc.tenant_id}") except Exception as e: await self.module_logger.log_error(f"Failed to log usage: {str(e)}")