feat(middleware): add the middleware for auth service
This commit is contained in:
parent
bf1e476c0b
commit
6256b3377d
3
apps/notification/webapi/middleware/__init__.py
Normal file
3
apps/notification/webapi/middleware/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .api_key_middleware import NotificationServiceMiddleware
|
||||||
|
|
||||||
|
__all__ = ['NotificationServiceMiddleware']
|
||||||
183
apps/notification/webapi/middleware/api_key_middleware.py
Normal file
183
apps/notification/webapi/middleware/api_key_middleware.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
import httpx
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import contextvars
|
||||||
|
from datetime import datetime
|
||||||
|
from starlette.requests import Request
|
||||||
|
from fastapi import HTTPException, Response
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from common.log.module_logger import ModuleLogger
|
||||||
|
|
||||||
|
from backend.models.models import UsageLogDoc
|
||||||
|
from backend.infra.api_key_introspect_handler import ApiKeyIntrospectHandler
|
||||||
|
|
||||||
|
# Define context data class
|
||||||
|
class RequestContext:
|
||||||
|
def __init__(self, tenant_name: str = None, product_id: str = None, key_id: str = None):
|
||||||
|
self.tenant_name = tenant_name
|
||||||
|
self.product_id = product_id
|
||||||
|
self.key_id = key_id
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"RequestContext(tenant_name='{self.tenant_name}', product_id='{self.product_id}', key_id='{self.key_id}')"
|
||||||
|
|
||||||
|
# Create context variable, store RequestContext object
|
||||||
|
request_context_var = contextvars.ContextVar('request_context', default=RequestContext())
|
||||||
|
|
||||||
|
class NotificationServiceMiddleware:
|
||||||
|
"""
|
||||||
|
Notification service API Key middleware
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
self.api_key_introspect_handler = ApiKeyIntrospectHandler()
|
||||||
|
self.module_logger = ModuleLogger(sender_id=NotificationServiceMiddleware)
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
"""
|
||||||
|
Middleware main function, execute before and after request processing
|
||||||
|
"""
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
request = Request(scope, receive)
|
||||||
|
start_time = time.time()
|
||||||
|
validation_result = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Skip paths that do not need validation
|
||||||
|
if self._should_skip_validation(request.url.path):
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. Extract API Key from request header
|
||||||
|
api_key = request.headers.get("X-API-Key")
|
||||||
|
# if the API_KEY field is empty, the request can be processed directly without validation.
|
||||||
|
# for compatibility
|
||||||
|
if not api_key or api_key == "":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. Call freeleaps_auth to validate API Key
|
||||||
|
validation_result = await self.api_key_introspect_handler.api_key_introspect(api_key)
|
||||||
|
|
||||||
|
# 4. Validate API Key status
|
||||||
|
if not validation_result.get("active"):
|
||||||
|
response = Response(
|
||||||
|
status_code=403,
|
||||||
|
content=f'{{"error": "{validation_result.get("error")}", "message": "{validation_result.get("message")}"}}',
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 5. Store validation result in contextvars for later use
|
||||||
|
request_context = RequestContext(
|
||||||
|
tenant_name=validation_result.get("tenant_name"),
|
||||||
|
product_id=validation_result.get("product_id"),
|
||||||
|
key_id=validation_result.get("key_id")
|
||||||
|
)
|
||||||
|
request_context_var.set(request_context)
|
||||||
|
|
||||||
|
# 6. Process request and capture response
|
||||||
|
response_captured = None
|
||||||
|
|
||||||
|
async def send_wrapper(message):
|
||||||
|
nonlocal response_captured
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
# Convert bytes headers to string headers
|
||||||
|
headers = {}
|
||||||
|
for header_name, header_value in message.get("headers", []):
|
||||||
|
if isinstance(header_name, bytes):
|
||||||
|
header_name = header_name.decode('latin-1')
|
||||||
|
if isinstance(header_value, bytes):
|
||||||
|
header_value = header_value.decode('latin-1')
|
||||||
|
headers[header_name] = header_value
|
||||||
|
|
||||||
|
response_captured = Response(
|
||||||
|
status_code=message["status"],
|
||||||
|
headers=headers,
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
await self.app(scope, receive, send_wrapper)
|
||||||
|
|
||||||
|
# 7. Record usage log after request processing
|
||||||
|
if validation_result and response_captured:
|
||||||
|
await self._log_usage(validation_result, request, response_captured, start_time)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self.module_logger.log_error(f"Middleware error: {str(e)}")
|
||||||
|
response = Response(
|
||||||
|
status_code=500,
|
||||||
|
content=f'{{"error": "Internal error", "message": "Failed to process request", "details": "{str(e)}"}}',
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
|
||||||
|
def _should_skip_validation(self, path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the path should be skipped for validation
|
||||||
|
"""
|
||||||
|
skip_paths = [
|
||||||
|
"/health",
|
||||||
|
"/metrics",
|
||||||
|
"/docs",
|
||||||
|
"/openapi.json",
|
||||||
|
"/favicon.ico"
|
||||||
|
]
|
||||||
|
return any(path.startswith(skip_path) for skip_path in skip_paths)
|
||||||
|
|
||||||
|
async def _log_usage(self, validation_result: Dict[str, Any], request: Request,
|
||||||
|
response: Response, start_time: float) -> None:
|
||||||
|
"""
|
||||||
|
Record API usage log
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# calculate processing time
|
||||||
|
process_time = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
# get request body size
|
||||||
|
try:
|
||||||
|
request_body = await request.body()
|
||||||
|
bytes_in = len(request_body) if request.method in ["POST", "PUT", "PATCH"] else 0
|
||||||
|
except Exception:
|
||||||
|
bytes_in = 0
|
||||||
|
|
||||||
|
bytes_out = 0
|
||||||
|
if hasattr(response, 'headers'):
|
||||||
|
content_length = response.headers.get('content-length')
|
||||||
|
if content_length:
|
||||||
|
bytes_out = int(content_length)
|
||||||
|
|
||||||
|
# create usage log document
|
||||||
|
usage_log_doc = UsageLogDoc(
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
tenant_id=validation_result.get("tenant_name"),
|
||||||
|
operation=f"{request.method} {request.url.path}",
|
||||||
|
request_id=request.headers.get("X-Request-ID", "unknown"),
|
||||||
|
units=1, # TODO: adjust according to business logic
|
||||||
|
status="success" if response.status_code < 400 else "error",
|
||||||
|
latency_ms=int(process_time),
|
||||||
|
bytes_in=bytes_in,
|
||||||
|
bytes_out=bytes_out,
|
||||||
|
key_id=validation_result.get("key_id"),
|
||||||
|
extra={
|
||||||
|
"tenant_name": request_context_var.get().tenant_name,
|
||||||
|
"product_id": request_context_var.get().product_id,
|
||||||
|
"scopes": validation_result.get("scopes"),
|
||||||
|
"user_agent": request.headers.get("User-Agent"),
|
||||||
|
"ip_address": request.client.host if request.client else "unknown",
|
||||||
|
"response_status": response.status_code
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# save to database
|
||||||
|
await usage_log_doc.save()
|
||||||
|
await self.module_logger.log_info(f"API Usage logged: {usage_log_doc.operation} for tenant {usage_log_doc.tenant_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self.module_logger.log_error(f"Failed to log usage: {str(e)}")
|
||||||
Loading…
Reference in New Issue
Block a user