Merge pull request 'fix: change notification service model layer to not using Beanie' (#94) from fix/auth-model into dev

Reviewed-on: freeleaps/freeleaps-service-hub#94
Reviewed-by: jingyao1991 <jingyao1991@noreply.gitea.freeleaps.mathmast.com>
This commit is contained in:
icecheng 2025-10-27 01:52:52 +00:00
commit d90df6bf83
4 changed files with 491 additions and 60 deletions

View File

@ -0,0 +1,415 @@
"""
BaseDoc - A custom document class that provides Beanie-like interface using direct MongoDB operations
"""
import asyncio
from datetime import datetime, timezone
from typing import Optional, List, Dict, Any, Type, Union
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from pydantic import BaseModel
from pydantic._internal._model_construction import ModelMetaclass
from common.config.app_settings import app_settings
class QueryExpression:
"""Query expression for field comparisons"""
def __init__(self, field_name: str):
self.field_name = field_name
def __eq__(self, other: Any) -> Dict[str, Any]:
"""Handle field == value comparisons"""
return {self.field_name: other}
def __ne__(self, other: Any) -> Dict[str, Any]:
"""Handle field != value comparisons"""
return {self.field_name: {"$ne": other}}
def __gt__(self, other: Any) -> Dict[str, Any]:
"""Handle field > value comparisons"""
return {self.field_name: {"$gt": other}}
def __lt__(self, other: Any) -> Dict[str, Any]:
"""Handle field < value comparisons"""
return {self.field_name: {"$lt": other}}
def __ge__(self, other: Any) -> Dict[str, Any]:
"""Handle field >= value comparisons"""
return {self.field_name: {"$gte": other}}
def __le__(self, other: Any) -> Dict[str, Any]:
"""Handle field <= value comparisons"""
return {self.field_name: {"$lte": other}}
class FieldDescriptor:
"""Descriptor for field access like Beanie's field == value pattern"""
def __init__(self, field_name: str, field_type: type):
self.field_name = field_name
self.field_type = field_type
def __get__(self, instance: Any, owner: type) -> Any:
"""
- Class access (instance is None): return QueryExpression for building queries
- Instance access (instance is not None): return the actual field value
"""
if instance is None:
return QueryExpression(self.field_name)
return instance.__dict__.get(self.field_name)
def __set__(self, instance: Any, value: Any) -> None:
"""Set instance field value with type validation (compatible with Pydantic validation)"""
if not isinstance(value, self.field_type):
raise TypeError(f"Field {self.field_name} must be {self.field_type}")
instance.__dict__[self.field_name] = value
class FieldCondition:
"""Represents a field condition for MongoDB queries"""
def __init__(self, field_name: str, value: Any, operator: str = "$eq"):
self.field_name = field_name
self.value = value
self.operator = operator
self.left = self # For compatibility with existing condition parsing
self.right = value
# Module-level variables for database connection
_db: Optional[AsyncIOMotorDatabase] = None
_client: Optional[AsyncIOMotorClient] = None
# Context variable for tenant database
import contextvars
_tenant_db_context: contextvars.ContextVar[Optional[AsyncIOMotorDatabase]] = contextvars.ContextVar('tenant_db', default=None)
class QueryModelMeta(ModelMetaclass):
"""Metaclass: automatically create FieldDescriptor for model fields"""
def __new__(cls, name: str, bases: tuple, namespace: dict):
# Get model field annotations (like name: str -> "name" and str)
annotations = namespace.get("__annotations__", {})
# Create the class first using Pydantic's metaclass
new_class = super().__new__(cls, name, bases, namespace)
# After Pydantic processes the fields, add the descriptors as class attributes
for field_name, field_type in annotations.items():
if field_name != 'id': # Skip the id field as it's handled specially
# Add the descriptor as a class attribute
setattr(new_class, field_name, FieldDescriptor(field_name, field_type))
return new_class
def __getattr__(cls, name: str):
"""Handle field access like Doc.field_name for query building"""
# Check if this is a field that exists in the model
if hasattr(cls, 'model_fields') and name in cls.model_fields:
return QueryExpression(name)
raise AttributeError(f"'{cls.__name__}' object has no attribute '{name}'")
class BaseDoc(BaseModel, metaclass=QueryModelMeta):
"""
Base document class that provides Beanie-like interface using direct MongoDB operations.
All model classes should inherit from this instead of Beanie's Document.
"""
id: Optional[str] = None # MongoDB _id field
def model_dump(self, **kwargs):
"""Override model_dump to exclude field descriptors"""
# Get the default model_dump result
result = super().model_dump(**kwargs)
# Remove any field descriptors that might have been included
filtered_result = {}
for key, value in result.items():
if not isinstance(value, FieldDescriptor):
filtered_result[key] = value
return filtered_result
@classmethod
def field(cls, field_name: str) -> QueryExpression:
"""Get a field expression for query building"""
return QueryExpression(field_name)
@classmethod
async def _get_database(cls) -> AsyncIOMotorDatabase:
"""Get database connection using pure AsyncIOMotorClient"""
# Try to get tenant database from context first
tenant_db = _tenant_db_context.get()
if tenant_db is not None:
return tenant_db
# Fallback to global database connection
global _db, _client
if _db is None:
_client = AsyncIOMotorClient(app_settings.MONGODB_URI)
_db = _client[app_settings.MONGODB_NAME]
return _db
@classmethod
def set_tenant_database(cls, db: AsyncIOMotorDatabase):
"""Set the tenant database for this context"""
_tenant_db_context.set(db)
@classmethod
def _get_collection_name(cls) -> str:
"""Get collection name from Settings or class name"""
if hasattr(cls, 'Settings') and hasattr(cls.Settings, 'name'):
return cls.Settings.name
else:
# Convert class name to snake_case for collection name
import re
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', cls.__name__)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
@classmethod
def find(cls, *conditions) -> 'QueryBuilder':
"""Find documents matching conditions - returns QueryBuilder for chaining"""
return QueryBuilder(cls, conditions)
@classmethod
async def find_one(cls, *conditions) -> Optional['BaseDoc']:
"""Find one document matching conditions"""
db = await cls._get_database()
collection_name = cls._get_collection_name()
collection = db[collection_name]
# Convert Beanie-style conditions to MongoDB query
query = cls._convert_conditions_to_query(conditions)
doc = await collection.find_one(query)
if doc:
# Extract MongoDB _id and convert to string
mongo_id = doc.pop('_id', None)
# Filter doc to only include fields defined in the model
model_fields = set(cls.model_fields.keys())
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
# Add the id field
if mongo_id:
filtered_doc['id'] = str(mongo_id)
return cls(**filtered_doc)
return None
@classmethod
async def get(cls, doc_id: str) -> Optional['BaseDoc']:
"""Get document by ID"""
from bson import ObjectId
try:
object_id = ObjectId(doc_id)
except:
return None
db = await cls._get_database()
collection_name = cls._get_collection_name()
collection = db[collection_name]
doc = await collection.find_one({"_id": object_id})
if doc:
# Extract MongoDB _id and convert to string
mongo_id = doc.pop('_id', None)
# Filter doc to only include fields defined in the model
model_fields = set(cls.model_fields.keys())
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
# Add the id field
if mongo_id:
filtered_doc['id'] = str(mongo_id)
return cls(**filtered_doc)
return None
@classmethod
def _convert_conditions_to_query(cls, conditions) -> Dict[str, Any]:
"""Convert Beanie-style conditions to MongoDB query"""
if not conditions:
return {}
query = {}
for condition in conditions:
if isinstance(condition, dict):
# Handle QueryExpression results (dictionaries) and direct dictionary queries
query.update(condition)
elif isinstance(condition, FieldCondition):
# Handle legacy FieldCondition objects
if condition.operator == "$eq":
query[condition.field_name] = condition.value
else:
query[condition.field_name] = {condition.operator: condition.value}
elif hasattr(condition, 'left') and hasattr(condition, 'right'):
# Handle field == value conditions
field_name = condition.left.name
value = condition.right
query[field_name] = value
elif hasattr(condition, '__dict__'):
# Handle complex conditions like FLID.identity == value
if hasattr(condition, 'left') and hasattr(condition, 'right'):
left = condition.left
if hasattr(left, 'name') and hasattr(left, 'left'):
# Nested field access like FLID.identity
field_name = f"{left.left.name}.{left.name}"
value = condition.right
query[field_name] = value
else:
field_name = left.name
value = condition.right
query[field_name] = value
return query
def _convert_decimals_to_float(self, obj):
"""Convert Decimal objects to float for MongoDB compatibility"""
from decimal import Decimal
if isinstance(obj, Decimal):
return float(obj)
elif isinstance(obj, dict):
return {key: self._convert_decimals_to_float(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [self._convert_decimals_to_float(item) for item in obj]
else:
return obj
async def create(self) -> 'BaseDoc':
"""Create this document in the database"""
db = await self._get_database()
collection_name = self._get_collection_name()
collection = db[collection_name]
# Convert to dict and insert, excluding field descriptors
doc_dict = self.model_dump(exclude={'id'})
# Convert Decimal objects to float for MongoDB compatibility
doc_dict = self._convert_decimals_to_float(doc_dict)
result = await collection.insert_one(doc_dict)
# Set the id field from the inserted document
if result.inserted_id:
self.id = str(result.inserted_id)
# Return the created document
return self
async def save(self) -> 'BaseDoc':
"""Save this document to the database (update if exists, create if not)"""
db = await self._get_database()
collection_name = self._get_collection_name()
collection = db[collection_name]
# Convert to dict, excluding field descriptors
doc_dict = self.model_dump(exclude={'id'})
# Convert Decimal objects to float for MongoDB compatibility
doc_dict = self._convert_decimals_to_float(doc_dict)
# Try to find existing document by user_id or other unique fields
query = {}
if hasattr(self, 'user_id'):
query['user_id'] = self.user_id
elif hasattr(self, 'email'):
query['email'] = self.email
elif hasattr(self, 'mobile'):
query['mobile'] = self.mobile
elif hasattr(self, 'auth_code'):
query['auth_code'] = self.auth_code
if query:
# Update existing document
result = await collection.update_one(query, {"$set": doc_dict}, upsert=True)
# If it was an insert, set the id field
if result.upserted_id:
self.id = str(result.upserted_id)
else:
# Insert new document
result = await collection.insert_one(doc_dict)
if result.inserted_id:
self.id = str(result.inserted_id)
return self
async def delete(self) -> bool:
"""Delete this document from the database"""
db = await self._get_database()
collection_name = self._get_collection_name()
collection = db[collection_name]
# Try to find existing document by user_id or other unique fields
query = {}
if hasattr(self, 'user_id'):
query['user_id'] = self.user_id
elif hasattr(self, 'email'):
query['email'] = self.email
elif hasattr(self, 'mobile'):
query['mobile'] = self.mobile
elif hasattr(self, 'auth_code'):
query['auth_code'] = self.auth_code
if query:
result = await collection.delete_one(query)
return result.deleted_count > 0
return False
class QueryBuilder:
"""Query builder for chaining operations like Beanie's QueryBuilder"""
def __init__(self, model_class: Type[BaseDoc], conditions: tuple):
self.model_class = model_class
self.conditions = conditions
self._limit_value: Optional[int] = None
self._skip_value: Optional[int] = None
def limit(self, n: int) -> 'QueryBuilder':
"""Limit number of results"""
self._limit_value = n
return self
def skip(self, n: int) -> 'QueryBuilder':
"""Skip number of results"""
self._skip_value = n
return self
async def to_list(self) -> List[BaseDoc]:
"""Convert query to list of documents"""
db = await self.model_class._get_database()
collection_name = self.model_class._get_collection_name()
collection = db[collection_name]
# Convert conditions to MongoDB query
query = self.model_class._convert_conditions_to_query(self.conditions)
# Build cursor
cursor = collection.find(query)
if self._skip_value:
cursor = cursor.skip(self._skip_value)
if self._limit_value:
cursor = cursor.limit(self._limit_value)
# Execute query and convert to model instances
docs = await cursor.to_list(length=None)
results = []
for doc in docs:
# Extract MongoDB _id and convert to string
mongo_id = doc.pop('_id', None)
# Filter doc to only include fields defined in the model
model_fields = set(self.model_class.model_fields.keys())
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
# Add the id field
if mongo_id:
filtered_doc['id'] = str(mongo_id)
results.append(self.model_class(**filtered_doc))
return results
async def first_or_none(self) -> Optional[BaseDoc]:
"""Get first result or None"""
results = await self.limit(1).to_list()
return results[0] if results else None
async def count(self) -> int:
"""Count number of matching documents"""
db = await self.model_class._get_database()
collection_name = self.model_class._get_collection_name()
collection = db[collection_name]
query = self.model_class._convert_conditions_to_query(self.conditions)
return await collection.count_documents(query)

View File

@ -1,11 +1,11 @@
from beanie import Document
from datetime import datetime from datetime import datetime
from typing import Optional, List from typing import Optional, List
from common.constants.region import UserRegion from common.constants.region import UserRegion
from common.constants.email import EmailSendStatus, BounceType from common.constants.email import EmailSendStatus, BounceType
from backend.models.base_doc import BaseDoc
class MessageTemplateDoc(Document): class MessageTemplateDoc(BaseDoc):
template_id: str template_id: str
tenant_id: Optional[str] = None tenant_id: Optional[str] = None
region: UserRegion region: UserRegion
@ -23,7 +23,7 @@ class MessageTemplateDoc(Document):
"region" "region"
] ]
class EmailSenderDoc(Document): class EmailSenderDoc(BaseDoc):
tenant_id: str tenant_id: str
email_sender: Optional[str] = None email_sender: Optional[str] = None
is_active: bool = True is_active: bool = True
@ -32,7 +32,7 @@ class EmailSenderDoc(Document):
name = "email_sender_doc" name = "email_sender_doc"
indexes = ["tenant_id"] indexes = ["tenant_id"]
class EmailSendStatusDoc(Document): class EmailSendStatusDoc(BaseDoc):
email_id: str email_id: str
tenant_id: str tenant_id: str
email_sender: Optional[str] = None email_sender: Optional[str] = None
@ -57,7 +57,7 @@ class EmailSendStatusDoc(Document):
"tenant_id" "tenant_id"
] ]
class EmailTrackingDoc(Document): class EmailTrackingDoc(BaseDoc):
email_id: str email_id: str
tenant_id: str tenant_id: str
recipient_email: str recipient_email: str
@ -85,7 +85,7 @@ class EmailTrackingDoc(Document):
"tenant_id" "tenant_id"
] ]
class EmailBounceDoc(Document): class EmailBounceDoc(BaseDoc):
email: str email: str
tenant_id: str tenant_id: str
email_id: Optional[str] = None email_id: Optional[str] = None
@ -106,7 +106,7 @@ class EmailBounceDoc(Document):
"tenant_id" "tenant_id"
] ]
class UsageLogDoc(Document): class UsageLogDoc(BaseDoc):
timestamp: datetime = datetime.utcnow() # timestamp timestamp: datetime = datetime.utcnow() # timestamp
tenant_id: str # tenant id tenant_id: str # tenant id
operation: str # operation type operation: str # operation type

View File

@ -39,32 +39,33 @@ class DatabaseMiddleware:
return await response(scope, receive, send) return await response(scope, receive, send)
if not product_id: if not product_id:
# Compatibility / public routes: use main database with tenant models initialized # Compatibility / public routes: use main database with BaseDoc context set
await self.module_logger.log_info(f"No product_id - using main database for path: {request.url.path}") await self.module_logger.log_info(f"No product_id - using main database for path: {request.url.path}")
# Get main database with Beanie initialized for tenant models # Get main database with BaseDoc context set for all models
main_db_initialized = await tenant_cache.get_main_db_initialized() main_db_initialized = await tenant_cache.get_main_db_initialized()
request.state.db = main_db_initialized request.state.db = main_db_initialized
request.state.product_id = None request.state.product_id = None
await self.module_logger.log_info(f"Successfully initialized main database with tenant models") await self.module_logger.log_info(f"Successfully set BaseDoc context for main database")
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
try: try:
# Get tenant-specific database with Beanie already initialized (cached) # Get tenant-specific database with BaseDoc context set (cached)
await self.module_logger.log_info(f"Attempting to get tenant database for product_id: {product_id}") await self.module_logger.log_info(f"Attempting to get tenant database for product_id: {product_id}")
tenant_db = await tenant_cache.get_initialized_db(product_id) tenant_db = await tenant_cache.get_initialized_db(product_id)
request.state.db = tenant_db request.state.db = tenant_db
request.state.product_id = product_id request.state.product_id = product_id
await self.module_logger.log_info(f"Successfully retrieved cached tenant database with Beanie for product_id: {product_id}") await self.module_logger.log_info(f"Successfully retrieved cached tenant database with BaseDoc context for product_id: {product_id}")
return await self.app(scope, receive, send)
except HTTPException as e: except ValueError as e:
# Handle tenant not found or inactive (HTTPException from TenantDBCache) # Handle tenant not found or inactive (ValueError from TenantDBCache)
await self.module_logger.log_error(f"Tenant error for {product_id}: [{e.status_code}] {e.detail}") await self.module_logger.log_error(f"Tenant error for {product_id}: {str(e)}")
response = JSONResponse( response = JSONResponse(
status_code=e.status_code, status_code=status.HTTP_404_NOT_FOUND,
content={"detail": e.detail} content={"detail": str(e)}
) )
return await response(scope, receive, send) return await response(scope, receive, send)

View File

@ -1,5 +1,4 @@
from webapi.config.site_settings import site_settings from webapi.config.site_settings import site_settings
from beanie import init_beanie
from fastapi import HTTPException from fastapi import HTTPException
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from backend.models.models import MessageTemplateDoc, EmailSenderDoc, EmailSendStatusDoc, EmailTrackingDoc, EmailBounceDoc, UsageLogDoc from backend.models.models import MessageTemplateDoc, EmailSenderDoc, EmailSendStatusDoc, EmailTrackingDoc, EmailBounceDoc, UsageLogDoc
@ -47,7 +46,7 @@ class TenantDBCache:
self.module_logger = ModuleLogger(sender_id="TenantDBCache") self.module_logger = ModuleLogger(sender_id="TenantDBCache")
async def get_initialized_db(self, product_id: str) -> AsyncIOMotorDatabase: async def get_initialized_db(self, product_id: str) -> AsyncIOMotorDatabase:
"""Get tenant database with Beanie already initialized""" """Get tenant database with BaseDoc context set"""
# fast-path: check if client is cached # fast-path: check if client is cached
cached_client = self._cache.get(product_id) cached_client = self._cache.get(product_id)
@ -58,9 +57,11 @@ class TenantDBCache:
# Get fresh database instance from cached client # Get fresh database instance from cached client
db = cached_client.get_default_database() db = cached_client.get_default_database()
if db is not None: if db is not None:
# Initialize Beanie for this fresh database instance # Set tenant database context for BaseDoc
await init_beanie(database=db, document_models=tenant_document_models) MessageTemplateDoc.set_tenant_database(db)
await self.module_logger.log_info(f"Beanie initialization completed for {product_id} using cached client") EmailSenderDoc.set_tenant_database(db)
EmailSendStatusDoc.set_tenant_database(db)
await self.module_logger.log_info(f"BaseDoc tenant context set for {product_id} using cached client")
return db return db
else: else:
await self.module_logger.log_error(f"No default database found for cached client {product_id}") await self.module_logger.log_error(f"No default database found for cached client {product_id}")
@ -78,9 +79,11 @@ class TenantDBCache:
# Get fresh database instance from cached client # Get fresh database instance from cached client
db = cached_client.get_default_database() db = cached_client.get_default_database()
if db is not None: if db is not None:
# Initialize Beanie for this fresh database instance # Set tenant database context for BaseDoc
await init_beanie(database=db, document_models=tenant_document_models) MessageTemplateDoc.set_tenant_database(db)
await self.module_logger.log_info(f"Beanie initialization completed for {product_id} using cached client (double-check)") EmailSenderDoc.set_tenant_database(db)
EmailSendStatusDoc.set_tenant_database(db)
await self.module_logger.log_info(f"BaseDoc tenant context set for {product_id} using cached client (double-check)")
return db return db
else: else:
await self.module_logger.log_error(f"No default database found for cached client {product_id}") await self.module_logger.log_error(f"No default database found for cached client {product_id}")
@ -129,9 +132,11 @@ class TenantDBCache:
headers={"X-Error-Message": f"No default database found for tenant {product_id}"} headers={"X-Error-Message": f"No default database found for tenant {product_id}"}
) )
# Initialize Beanie for this tenant database # Set tenant database context for BaseDoc
await init_beanie(database=db, document_models=tenant_document_models) MessageTemplateDoc.set_tenant_database(db)
await self.module_logger.log_info(f"Beanie initialization completed for new tenant database {product_id}") EmailSenderDoc.set_tenant_database(db)
EmailSendStatusDoc.set_tenant_database(db)
await self.module_logger.log_info(f"BaseDoc tenant context set for new tenant database {product_id}")
# Cache only the client # Cache only the client
await self._lru_put(product_id, client) await self._lru_put(product_id, client)
@ -139,10 +144,15 @@ class TenantDBCache:
return db return db
async def get_main_db_initialized(self) -> AsyncIOMotorDatabase: async def get_main_db_initialized(self) -> AsyncIOMotorDatabase:
"""Get main database with Beanie initialized for tenant models""" """Get main database with BaseDoc context set for all models"""
# Re-initialize Beanie for main database with business models # Set main database context for all BaseDoc models
await init_beanie(database=self.main_db, document_models=document_models) MessageTemplateDoc.set_tenant_database(self.main_db)
await self.module_logger.log_info("Beanie initialization completed for main database") EmailSenderDoc.set_tenant_database(self.main_db)
EmailSendStatusDoc.set_tenant_database(self.main_db)
EmailTrackingDoc.set_tenant_database(self.main_db)
EmailBounceDoc.set_tenant_database(self.main_db)
UsageLogDoc.set_tenant_database(self.main_db)
await self.module_logger.log_info("BaseDoc context set for main database")
return self.main_db return self.main_db
async def _lru_put(self, key: str, client: AsyncIOMotorClient): async def _lru_put(self, key: str, client: AsyncIOMotorClient):
@ -196,8 +206,13 @@ async def initiate_database(app):
MAIN_CLIENT = AsyncIOMotorClient(app_settings.MONGODB_URI) MAIN_CLIENT = AsyncIOMotorClient(app_settings.MONGODB_URI)
main_db = MAIN_CLIENT[app_settings.MONGODB_NAME] main_db = MAIN_CLIENT[app_settings.MONGODB_NAME]
# 2) Initialize Beanie for main DB with business document models # 2) Set BaseDoc context for main database
await init_beanie(database=main_db, document_models=document_models) MessageTemplateDoc.set_tenant_database(main_db)
EmailSenderDoc.set_tenant_database(main_db)
EmailSendStatusDoc.set_tenant_database(main_db)
EmailTrackingDoc.set_tenant_database(main_db)
EmailBounceDoc.set_tenant_database(main_db)
UsageLogDoc.set_tenant_database(main_db)
# 3) Create tenant cache that uses main_db lookups to resolve product_id -> tenant db # 3) Create tenant cache that uses main_db lookups to resolve product_id -> tenant db
max_cache_size = getattr(app_settings, 'TENANT_CACHE_MAX', 64) max_cache_size = getattr(app_settings, 'TENANT_CACHE_MAX', 64)