diff --git a/apps/notification/webapi/middleware/__init__.py b/apps/notification/webapi/middleware/__init__.py new file mode 100644 index 0000000..2eb59b0 --- /dev/null +++ b/apps/notification/webapi/middleware/__init__.py @@ -0,0 +1,3 @@ +from .api_key_middleware import NotificationServiceMiddleware + +__all__ = ['NotificationServiceMiddleware'] \ No newline at end of file diff --git a/apps/notification/webapi/middleware/api_key_middleware.py b/apps/notification/webapi/middleware/api_key_middleware.py new file mode 100644 index 0000000..7eaae0d --- /dev/null +++ b/apps/notification/webapi/middleware/api_key_middleware.py @@ -0,0 +1,183 @@ +import httpx +import asyncio +import time +import contextvars +from datetime import datetime +from starlette.requests import Request +from fastapi import HTTPException, Response +from typing import Dict, Any, Optional +from common.log.module_logger import ModuleLogger + +from backend.models.models import UsageLogDoc +from backend.infra.api_key_introspect_handler import ApiKeyIntrospectHandler + +# Define context data class +class RequestContext: + def __init__(self, tenant_name: str = None, product_id: str = None, key_id: str = None): + self.tenant_name = tenant_name + self.product_id = product_id + self.key_id = key_id + + def __repr__(self): + return f"RequestContext(tenant_name='{self.tenant_name}', product_id='{self.product_id}', key_id='{self.key_id}')" + +# Create context variable, store RequestContext object +request_context_var = contextvars.ContextVar('request_context', default=RequestContext()) + +class NotificationServiceMiddleware: + """ + Notification service API Key middleware + """ + + def __init__(self, app): + self.app = app + self.api_key_introspect_handler = ApiKeyIntrospectHandler() + self.module_logger = ModuleLogger(sender_id=NotificationServiceMiddleware) + + async def __call__(self, scope, receive, send): + """ + Middleware main function, execute before and after request processing + """ + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive) + start_time = time.time() + validation_result = None + + try: + # 1. Skip paths that do not need validation + if self._should_skip_validation(request.url.path): + await self.app(scope, receive, send) + return + + # 2. Extract API Key from request header + api_key = request.headers.get("X-API-Key") + # if the API_KEY field is empty, the request can be processed directly without validation. + # for compatibility + if not api_key or api_key == "": + await self.app(scope, receive, send) + return + + # 3. Call freeleaps_auth to validate API Key + validation_result = await self.api_key_introspect_handler.api_key_introspect(api_key) + + # 4. Validate API Key status + if not validation_result.get("active"): + response = Response( + status_code=403, + content=f'{{"error": "{validation_result.get("error")}", "message": "{validation_result.get("message")}"}}', + media_type="application/json" + ) + await response(scope, receive, send) + return + + # 5. Store validation result in contextvars for later use + request_context = RequestContext( + tenant_name=validation_result.get("tenant_name"), + product_id=validation_result.get("product_id"), + key_id=validation_result.get("key_id") + ) + request_context_var.set(request_context) + + # 6. Process request and capture response + response_captured = None + + async def send_wrapper(message): + nonlocal response_captured + if message["type"] == "http.response.start": + # Convert bytes headers to string headers + headers = {} + for header_name, header_value in message.get("headers", []): + if isinstance(header_name, bytes): + header_name = header_name.decode('latin-1') + if isinstance(header_value, bytes): + header_value = header_value.decode('latin-1') + headers[header_name] = header_value + + response_captured = Response( + status_code=message["status"], + headers=headers, + media_type="application/json" + ) + await send(message) + + await self.app(scope, receive, send_wrapper) + + # 7. Record usage log after request processing + if validation_result and response_captured: + await self._log_usage(validation_result, request, response_captured, start_time) + + except Exception as e: + await self.module_logger.log_error(f"Middleware error: {str(e)}") + response = Response( + status_code=500, + content=f'{{"error": "Internal error", "message": "Failed to process request", "details": "{str(e)}"}}', + media_type="application/json" + ) + await response(scope, receive, send) + + def _should_skip_validation(self, path: str) -> bool: + """ + Check if the path should be skipped for validation + """ + skip_paths = [ + "/health", + "/metrics", + "/docs", + "/openapi.json", + "/favicon.ico" + ] + return any(path.startswith(skip_path) for skip_path in skip_paths) + + async def _log_usage(self, validation_result: Dict[str, Any], request: Request, + response: Response, start_time: float) -> None: + """ + Record API usage log + """ + try: + # calculate processing time + process_time = (time.time() - start_time) * 1000 + + # get request body size + try: + request_body = await request.body() + bytes_in = len(request_body) if request.method in ["POST", "PUT", "PATCH"] else 0 + except Exception: + bytes_in = 0 + + bytes_out = 0 + if hasattr(response, 'headers'): + content_length = response.headers.get('content-length') + if content_length: + bytes_out = int(content_length) + + # create usage log document + usage_log_doc = UsageLogDoc( + timestamp=datetime.utcnow(), + tenant_id=validation_result.get("tenant_name"), + operation=f"{request.method} {request.url.path}", + request_id=request.headers.get("X-Request-ID", "unknown"), + units=1, # TODO: adjust according to business logic + status="success" if response.status_code < 400 else "error", + latency_ms=int(process_time), + bytes_in=bytes_in, + bytes_out=bytes_out, + key_id=validation_result.get("key_id"), + extra={ + "tenant_name": request_context_var.get().tenant_name, + "product_id": request_context_var.get().product_id, + "scopes": validation_result.get("scopes"), + "user_agent": request.headers.get("User-Agent"), + "ip_address": request.client.host if request.client else "unknown", + "response_status": response.status_code + } + ) + + # save to database + await usage_log_doc.save() + await self.module_logger.log_info(f"API Usage logged: {usage_log_doc.operation} for tenant {usage_log_doc.tenant_id}") + + except Exception as e: + await self.module_logger.log_error(f"Failed to log usage: {str(e)}")