Merge pull request 'fix: change notification service model layer to not using Beanie' (#94) from fix/auth-model into dev
Reviewed-on: freeleaps/freeleaps-service-hub#94 Reviewed-by: jingyao1991 <jingyao1991@noreply.gitea.freeleaps.mathmast.com>
This commit is contained in:
commit
d90df6bf83
415
apps/notification/backend/models/base_doc.py
Normal file
415
apps/notification/backend/models/base_doc.py
Normal file
@ -0,0 +1,415 @@
|
|||||||
|
"""
|
||||||
|
BaseDoc - A custom document class that provides Beanie-like interface using direct MongoDB operations
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional, List, Dict, Any, Type, Union
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pydantic._internal._model_construction import ModelMetaclass
|
||||||
|
from common.config.app_settings import app_settings
|
||||||
|
|
||||||
|
|
||||||
|
class QueryExpression:
|
||||||
|
"""Query expression for field comparisons"""
|
||||||
|
def __init__(self, field_name: str):
|
||||||
|
self.field_name = field_name
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> Dict[str, Any]:
|
||||||
|
"""Handle field == value comparisons"""
|
||||||
|
return {self.field_name: other}
|
||||||
|
|
||||||
|
def __ne__(self, other: Any) -> Dict[str, Any]:
|
||||||
|
"""Handle field != value comparisons"""
|
||||||
|
return {self.field_name: {"$ne": other}}
|
||||||
|
|
||||||
|
def __gt__(self, other: Any) -> Dict[str, Any]:
|
||||||
|
"""Handle field > value comparisons"""
|
||||||
|
return {self.field_name: {"$gt": other}}
|
||||||
|
|
||||||
|
def __lt__(self, other: Any) -> Dict[str, Any]:
|
||||||
|
"""Handle field < value comparisons"""
|
||||||
|
return {self.field_name: {"$lt": other}}
|
||||||
|
|
||||||
|
def __ge__(self, other: Any) -> Dict[str, Any]:
|
||||||
|
"""Handle field >= value comparisons"""
|
||||||
|
return {self.field_name: {"$gte": other}}
|
||||||
|
|
||||||
|
def __le__(self, other: Any) -> Dict[str, Any]:
|
||||||
|
"""Handle field <= value comparisons"""
|
||||||
|
return {self.field_name: {"$lte": other}}
|
||||||
|
|
||||||
|
|
||||||
|
class FieldDescriptor:
|
||||||
|
"""Descriptor for field access like Beanie's field == value pattern"""
|
||||||
|
def __init__(self, field_name: str, field_type: type):
|
||||||
|
self.field_name = field_name
|
||||||
|
self.field_type = field_type
|
||||||
|
|
||||||
|
def __get__(self, instance: Any, owner: type) -> Any:
|
||||||
|
"""
|
||||||
|
- Class access (instance is None): return QueryExpression for building queries
|
||||||
|
- Instance access (instance is not None): return the actual field value
|
||||||
|
"""
|
||||||
|
if instance is None:
|
||||||
|
return QueryExpression(self.field_name)
|
||||||
|
return instance.__dict__.get(self.field_name)
|
||||||
|
|
||||||
|
def __set__(self, instance: Any, value: Any) -> None:
|
||||||
|
"""Set instance field value with type validation (compatible with Pydantic validation)"""
|
||||||
|
if not isinstance(value, self.field_type):
|
||||||
|
raise TypeError(f"Field {self.field_name} must be {self.field_type}")
|
||||||
|
instance.__dict__[self.field_name] = value
|
||||||
|
|
||||||
|
|
||||||
|
class FieldCondition:
|
||||||
|
"""Represents a field condition for MongoDB queries"""
|
||||||
|
def __init__(self, field_name: str, value: Any, operator: str = "$eq"):
|
||||||
|
self.field_name = field_name
|
||||||
|
self.value = value
|
||||||
|
self.operator = operator
|
||||||
|
self.left = self # For compatibility with existing condition parsing
|
||||||
|
self.right = value
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level variables for database connection
|
||||||
|
_db: Optional[AsyncIOMotorDatabase] = None
|
||||||
|
_client: Optional[AsyncIOMotorClient] = None
|
||||||
|
|
||||||
|
# Context variable for tenant database
|
||||||
|
import contextvars
|
||||||
|
_tenant_db_context: contextvars.ContextVar[Optional[AsyncIOMotorDatabase]] = contextvars.ContextVar('tenant_db', default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryModelMeta(ModelMetaclass):
|
||||||
|
"""Metaclass: automatically create FieldDescriptor for model fields"""
|
||||||
|
def __new__(cls, name: str, bases: tuple, namespace: dict):
|
||||||
|
# Get model field annotations (like name: str -> "name" and str)
|
||||||
|
annotations = namespace.get("__annotations__", {})
|
||||||
|
|
||||||
|
# Create the class first using Pydantic's metaclass
|
||||||
|
new_class = super().__new__(cls, name, bases, namespace)
|
||||||
|
|
||||||
|
# After Pydantic processes the fields, add the descriptors as class attributes
|
||||||
|
for field_name, field_type in annotations.items():
|
||||||
|
if field_name != 'id': # Skip the id field as it's handled specially
|
||||||
|
# Add the descriptor as a class attribute
|
||||||
|
setattr(new_class, field_name, FieldDescriptor(field_name, field_type))
|
||||||
|
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
def __getattr__(cls, name: str):
|
||||||
|
"""Handle field access like Doc.field_name for query building"""
|
||||||
|
# Check if this is a field that exists in the model
|
||||||
|
if hasattr(cls, 'model_fields') and name in cls.model_fields:
|
||||||
|
return QueryExpression(name)
|
||||||
|
raise AttributeError(f"'{cls.__name__}' object has no attribute '{name}'")
|
||||||
|
|
||||||
|
class BaseDoc(BaseModel, metaclass=QueryModelMeta):
|
||||||
|
"""
|
||||||
|
Base document class that provides Beanie-like interface using direct MongoDB operations.
|
||||||
|
All model classes should inherit from this instead of Beanie's Document.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: Optional[str] = None # MongoDB _id field
|
||||||
|
|
||||||
|
def model_dump(self, **kwargs):
|
||||||
|
"""Override model_dump to exclude field descriptors"""
|
||||||
|
# Get the default model_dump result
|
||||||
|
result = super().model_dump(**kwargs)
|
||||||
|
|
||||||
|
# Remove any field descriptors that might have been included
|
||||||
|
filtered_result = {}
|
||||||
|
for key, value in result.items():
|
||||||
|
if not isinstance(value, FieldDescriptor):
|
||||||
|
filtered_result[key] = value
|
||||||
|
|
||||||
|
return filtered_result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def field(cls, field_name: str) -> QueryExpression:
|
||||||
|
"""Get a field expression for query building"""
|
||||||
|
return QueryExpression(field_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _get_database(cls) -> AsyncIOMotorDatabase:
|
||||||
|
"""Get database connection using pure AsyncIOMotorClient"""
|
||||||
|
# Try to get tenant database from context first
|
||||||
|
tenant_db = _tenant_db_context.get()
|
||||||
|
if tenant_db is not None:
|
||||||
|
return tenant_db
|
||||||
|
|
||||||
|
# Fallback to global database connection
|
||||||
|
global _db, _client
|
||||||
|
if _db is None:
|
||||||
|
_client = AsyncIOMotorClient(app_settings.MONGODB_URI)
|
||||||
|
_db = _client[app_settings.MONGODB_NAME]
|
||||||
|
return _db
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_tenant_database(cls, db: AsyncIOMotorDatabase):
|
||||||
|
"""Set the tenant database for this context"""
|
||||||
|
_tenant_db_context.set(db)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_collection_name(cls) -> str:
|
||||||
|
"""Get collection name from Settings or class name"""
|
||||||
|
if hasattr(cls, 'Settings') and hasattr(cls.Settings, 'name'):
|
||||||
|
return cls.Settings.name
|
||||||
|
else:
|
||||||
|
# Convert class name to snake_case for collection name
|
||||||
|
import re
|
||||||
|
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', cls.__name__)
|
||||||
|
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find(cls, *conditions) -> 'QueryBuilder':
|
||||||
|
"""Find documents matching conditions - returns QueryBuilder for chaining"""
|
||||||
|
return QueryBuilder(cls, conditions)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def find_one(cls, *conditions) -> Optional['BaseDoc']:
|
||||||
|
"""Find one document matching conditions"""
|
||||||
|
db = await cls._get_database()
|
||||||
|
collection_name = cls._get_collection_name()
|
||||||
|
collection = db[collection_name]
|
||||||
|
|
||||||
|
# Convert Beanie-style conditions to MongoDB query
|
||||||
|
query = cls._convert_conditions_to_query(conditions)
|
||||||
|
|
||||||
|
doc = await collection.find_one(query)
|
||||||
|
if doc:
|
||||||
|
# Extract MongoDB _id and convert to string
|
||||||
|
mongo_id = doc.pop('_id', None)
|
||||||
|
# Filter doc to only include fields defined in the model
|
||||||
|
model_fields = set(cls.model_fields.keys())
|
||||||
|
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
|
||||||
|
# Add the id field
|
||||||
|
if mongo_id:
|
||||||
|
filtered_doc['id'] = str(mongo_id)
|
||||||
|
return cls(**filtered_doc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get(cls, doc_id: str) -> Optional['BaseDoc']:
|
||||||
|
"""Get document by ID"""
|
||||||
|
from bson import ObjectId
|
||||||
|
try:
|
||||||
|
object_id = ObjectId(doc_id)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
db = await cls._get_database()
|
||||||
|
collection_name = cls._get_collection_name()
|
||||||
|
collection = db[collection_name]
|
||||||
|
|
||||||
|
doc = await collection.find_one({"_id": object_id})
|
||||||
|
if doc:
|
||||||
|
# Extract MongoDB _id and convert to string
|
||||||
|
mongo_id = doc.pop('_id', None)
|
||||||
|
# Filter doc to only include fields defined in the model
|
||||||
|
model_fields = set(cls.model_fields.keys())
|
||||||
|
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
|
||||||
|
# Add the id field
|
||||||
|
if mongo_id:
|
||||||
|
filtered_doc['id'] = str(mongo_id)
|
||||||
|
return cls(**filtered_doc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _convert_conditions_to_query(cls, conditions) -> Dict[str, Any]:
|
||||||
|
"""Convert Beanie-style conditions to MongoDB query"""
|
||||||
|
if not conditions:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
query = {}
|
||||||
|
for condition in conditions:
|
||||||
|
if isinstance(condition, dict):
|
||||||
|
# Handle QueryExpression results (dictionaries) and direct dictionary queries
|
||||||
|
query.update(condition)
|
||||||
|
elif isinstance(condition, FieldCondition):
|
||||||
|
# Handle legacy FieldCondition objects
|
||||||
|
if condition.operator == "$eq":
|
||||||
|
query[condition.field_name] = condition.value
|
||||||
|
else:
|
||||||
|
query[condition.field_name] = {condition.operator: condition.value}
|
||||||
|
elif hasattr(condition, 'left') and hasattr(condition, 'right'):
|
||||||
|
# Handle field == value conditions
|
||||||
|
field_name = condition.left.name
|
||||||
|
value = condition.right
|
||||||
|
query[field_name] = value
|
||||||
|
elif hasattr(condition, '__dict__'):
|
||||||
|
# Handle complex conditions like FLID.identity == value
|
||||||
|
if hasattr(condition, 'left') and hasattr(condition, 'right'):
|
||||||
|
left = condition.left
|
||||||
|
if hasattr(left, 'name') and hasattr(left, 'left'):
|
||||||
|
# Nested field access like FLID.identity
|
||||||
|
field_name = f"{left.left.name}.{left.name}"
|
||||||
|
value = condition.right
|
||||||
|
query[field_name] = value
|
||||||
|
else:
|
||||||
|
field_name = left.name
|
||||||
|
value = condition.right
|
||||||
|
query[field_name] = value
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
def _convert_decimals_to_float(self, obj):
|
||||||
|
"""Convert Decimal objects to float for MongoDB compatibility"""
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
if isinstance(obj, Decimal):
|
||||||
|
return float(obj)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {key: self._convert_decimals_to_float(value) for key, value in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [self._convert_decimals_to_float(item) for item in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def create(self) -> 'BaseDoc':
|
||||||
|
"""Create this document in the database"""
|
||||||
|
db = await self._get_database()
|
||||||
|
collection_name = self._get_collection_name()
|
||||||
|
collection = db[collection_name]
|
||||||
|
|
||||||
|
# Convert to dict and insert, excluding field descriptors
|
||||||
|
doc_dict = self.model_dump(exclude={'id'})
|
||||||
|
|
||||||
|
# Convert Decimal objects to float for MongoDB compatibility
|
||||||
|
doc_dict = self._convert_decimals_to_float(doc_dict)
|
||||||
|
|
||||||
|
result = await collection.insert_one(doc_dict)
|
||||||
|
|
||||||
|
# Set the id field from the inserted document
|
||||||
|
if result.inserted_id:
|
||||||
|
self.id = str(result.inserted_id)
|
||||||
|
|
||||||
|
# Return the created document
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def save(self) -> 'BaseDoc':
|
||||||
|
"""Save this document to the database (update if exists, create if not)"""
|
||||||
|
db = await self._get_database()
|
||||||
|
collection_name = self._get_collection_name()
|
||||||
|
collection = db[collection_name]
|
||||||
|
|
||||||
|
# Convert to dict, excluding field descriptors
|
||||||
|
doc_dict = self.model_dump(exclude={'id'})
|
||||||
|
|
||||||
|
# Convert Decimal objects to float for MongoDB compatibility
|
||||||
|
doc_dict = self._convert_decimals_to_float(doc_dict)
|
||||||
|
|
||||||
|
# Try to find existing document by user_id or other unique fields
|
||||||
|
query = {}
|
||||||
|
if hasattr(self, 'user_id'):
|
||||||
|
query['user_id'] = self.user_id
|
||||||
|
elif hasattr(self, 'email'):
|
||||||
|
query['email'] = self.email
|
||||||
|
elif hasattr(self, 'mobile'):
|
||||||
|
query['mobile'] = self.mobile
|
||||||
|
elif hasattr(self, 'auth_code'):
|
||||||
|
query['auth_code'] = self.auth_code
|
||||||
|
|
||||||
|
if query:
|
||||||
|
# Update existing document
|
||||||
|
result = await collection.update_one(query, {"$set": doc_dict}, upsert=True)
|
||||||
|
# If it was an insert, set the id field
|
||||||
|
if result.upserted_id:
|
||||||
|
self.id = str(result.upserted_id)
|
||||||
|
else:
|
||||||
|
# Insert new document
|
||||||
|
result = await collection.insert_one(doc_dict)
|
||||||
|
if result.inserted_id:
|
||||||
|
self.id = str(result.inserted_id)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def delete(self) -> bool:
|
||||||
|
"""Delete this document from the database"""
|
||||||
|
db = await self._get_database()
|
||||||
|
collection_name = self._get_collection_name()
|
||||||
|
collection = db[collection_name]
|
||||||
|
|
||||||
|
# Try to find existing document by user_id or other unique fields
|
||||||
|
query = {}
|
||||||
|
if hasattr(self, 'user_id'):
|
||||||
|
query['user_id'] = self.user_id
|
||||||
|
elif hasattr(self, 'email'):
|
||||||
|
query['email'] = self.email
|
||||||
|
elif hasattr(self, 'mobile'):
|
||||||
|
query['mobile'] = self.mobile
|
||||||
|
elif hasattr(self, 'auth_code'):
|
||||||
|
query['auth_code'] = self.auth_code
|
||||||
|
|
||||||
|
if query:
|
||||||
|
result = await collection.delete_one(query)
|
||||||
|
return result.deleted_count > 0
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class QueryBuilder:
|
||||||
|
"""Query builder for chaining operations like Beanie's QueryBuilder"""
|
||||||
|
|
||||||
|
def __init__(self, model_class: Type[BaseDoc], conditions: tuple):
|
||||||
|
self.model_class = model_class
|
||||||
|
self.conditions = conditions
|
||||||
|
self._limit_value: Optional[int] = None
|
||||||
|
self._skip_value: Optional[int] = None
|
||||||
|
|
||||||
|
def limit(self, n: int) -> 'QueryBuilder':
|
||||||
|
"""Limit number of results"""
|
||||||
|
self._limit_value = n
|
||||||
|
return self
|
||||||
|
|
||||||
|
def skip(self, n: int) -> 'QueryBuilder':
|
||||||
|
"""Skip number of results"""
|
||||||
|
self._skip_value = n
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def to_list(self) -> List[BaseDoc]:
|
||||||
|
"""Convert query to list of documents"""
|
||||||
|
db = await self.model_class._get_database()
|
||||||
|
collection_name = self.model_class._get_collection_name()
|
||||||
|
collection = db[collection_name]
|
||||||
|
|
||||||
|
# Convert conditions to MongoDB query
|
||||||
|
query = self.model_class._convert_conditions_to_query(self.conditions)
|
||||||
|
|
||||||
|
# Build cursor
|
||||||
|
cursor = collection.find(query)
|
||||||
|
|
||||||
|
if self._skip_value:
|
||||||
|
cursor = cursor.skip(self._skip_value)
|
||||||
|
if self._limit_value:
|
||||||
|
cursor = cursor.limit(self._limit_value)
|
||||||
|
|
||||||
|
# Execute query and convert to model instances
|
||||||
|
docs = await cursor.to_list(length=None)
|
||||||
|
results = []
|
||||||
|
for doc in docs:
|
||||||
|
# Extract MongoDB _id and convert to string
|
||||||
|
mongo_id = doc.pop('_id', None)
|
||||||
|
# Filter doc to only include fields defined in the model
|
||||||
|
model_fields = set(self.model_class.model_fields.keys())
|
||||||
|
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
|
||||||
|
# Add the id field
|
||||||
|
if mongo_id:
|
||||||
|
filtered_doc['id'] = str(mongo_id)
|
||||||
|
results.append(self.model_class(**filtered_doc))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def first_or_none(self) -> Optional[BaseDoc]:
|
||||||
|
"""Get first result or None"""
|
||||||
|
results = await self.limit(1).to_list()
|
||||||
|
return results[0] if results else None
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
"""Count number of matching documents"""
|
||||||
|
db = await self.model_class._get_database()
|
||||||
|
collection_name = self.model_class._get_collection_name()
|
||||||
|
collection = db[collection_name]
|
||||||
|
|
||||||
|
query = self.model_class._convert_conditions_to_query(self.conditions)
|
||||||
|
return await collection.count_documents(query)
|
||||||
@ -1,11 +1,11 @@
|
|||||||
from beanie import Document
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from common.constants.region import UserRegion
|
from common.constants.region import UserRegion
|
||||||
from common.constants.email import EmailSendStatus, BounceType
|
from common.constants.email import EmailSendStatus, BounceType
|
||||||
|
from backend.models.base_doc import BaseDoc
|
||||||
|
|
||||||
class MessageTemplateDoc(Document):
|
class MessageTemplateDoc(BaseDoc):
|
||||||
template_id: str
|
template_id: str
|
||||||
tenant_id: Optional[str] = None
|
tenant_id: Optional[str] = None
|
||||||
region: UserRegion
|
region: UserRegion
|
||||||
@ -19,20 +19,20 @@ class MessageTemplateDoc(Document):
|
|||||||
name = "message_templates_doc"
|
name = "message_templates_doc"
|
||||||
indexes = [
|
indexes = [
|
||||||
"template_id",
|
"template_id",
|
||||||
"tenant_id",
|
"tenant_id",
|
||||||
"region"
|
"region"
|
||||||
]
|
]
|
||||||
|
|
||||||
class EmailSenderDoc(Document):
|
class EmailSenderDoc(BaseDoc):
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
email_sender: Optional[str] = None
|
email_sender: Optional[str] = None
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
name = "email_sender_doc"
|
name = "email_sender_doc"
|
||||||
indexes = ["tenant_id"]
|
indexes = ["tenant_id"]
|
||||||
|
|
||||||
class EmailSendStatusDoc(Document):
|
class EmailSendStatusDoc(BaseDoc):
|
||||||
email_id: str
|
email_id: str
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
email_sender: Optional[str] = None
|
email_sender: Optional[str] = None
|
||||||
@ -57,7 +57,7 @@ class EmailSendStatusDoc(Document):
|
|||||||
"tenant_id"
|
"tenant_id"
|
||||||
]
|
]
|
||||||
|
|
||||||
class EmailTrackingDoc(Document):
|
class EmailTrackingDoc(BaseDoc):
|
||||||
email_id: str
|
email_id: str
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
recipient_email: str
|
recipient_email: str
|
||||||
@ -85,10 +85,10 @@ class EmailTrackingDoc(Document):
|
|||||||
"tenant_id"
|
"tenant_id"
|
||||||
]
|
]
|
||||||
|
|
||||||
class EmailBounceDoc(Document):
|
class EmailBounceDoc(BaseDoc):
|
||||||
email: str
|
email: str
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
email_id: Optional[str] = None
|
email_id: Optional[str] = None
|
||||||
template_id: Optional[str] = None
|
template_id: Optional[str] = None
|
||||||
bounce_type: BounceType
|
bounce_type: BounceType
|
||||||
reason: str
|
reason: str
|
||||||
@ -105,8 +105,8 @@ class EmailBounceDoc(Document):
|
|||||||
"email",
|
"email",
|
||||||
"tenant_id"
|
"tenant_id"
|
||||||
]
|
]
|
||||||
|
|
||||||
class UsageLogDoc(Document):
|
class UsageLogDoc(BaseDoc):
|
||||||
timestamp: datetime = datetime.utcnow() # timestamp
|
timestamp: datetime = datetime.utcnow() # timestamp
|
||||||
tenant_id: str # tenant id
|
tenant_id: str # tenant id
|
||||||
operation: str # operation type
|
operation: str # operation type
|
||||||
@ -117,7 +117,7 @@ class UsageLogDoc(Document):
|
|||||||
bytes_out: int # output bytes
|
bytes_out: int # output bytes
|
||||||
key_id: Optional[str] = None # API Key ID
|
key_id: Optional[str] = None # API Key ID
|
||||||
extra: dict = {} # extra information
|
extra: dict = {} # extra information
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
name = "usage_log_doc"
|
name = "usage_log_doc"
|
||||||
indexes = [
|
indexes = [
|
||||||
|
|||||||
@ -39,41 +39,42 @@ class DatabaseMiddleware:
|
|||||||
return await response(scope, receive, send)
|
return await response(scope, receive, send)
|
||||||
|
|
||||||
if not product_id:
|
if not product_id:
|
||||||
# Compatibility / public routes: use main database with tenant models initialized
|
# Compatibility / public routes: use main database with BaseDoc context set
|
||||||
await self.module_logger.log_info(f"No product_id - using main database for path: {request.url.path}")
|
await self.module_logger.log_info(f"No product_id - using main database for path: {request.url.path}")
|
||||||
|
|
||||||
# Get main database with Beanie initialized for tenant models
|
# Get main database with BaseDoc context set for all models
|
||||||
main_db_initialized = await tenant_cache.get_main_db_initialized()
|
main_db_initialized = await tenant_cache.get_main_db_initialized()
|
||||||
|
|
||||||
request.state.db = main_db_initialized
|
request.state.db = main_db_initialized
|
||||||
request.state.product_id = None
|
request.state.product_id = None
|
||||||
await self.module_logger.log_info(f"Successfully initialized main database with tenant models")
|
await self.module_logger.log_info(f"Successfully set BaseDoc context for main database")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get tenant-specific database with Beanie already initialized (cached)
|
# Get tenant-specific database with BaseDoc context set (cached)
|
||||||
await self.module_logger.log_info(f"Attempting to get tenant database for product_id: {product_id}")
|
await self.module_logger.log_info(f"Attempting to get tenant database for product_id: {product_id}")
|
||||||
tenant_db = await tenant_cache.get_initialized_db(product_id)
|
tenant_db = await tenant_cache.get_initialized_db(product_id)
|
||||||
|
|
||||||
request.state.db = tenant_db
|
request.state.db = tenant_db
|
||||||
request.state.product_id = product_id
|
request.state.product_id = product_id
|
||||||
await self.module_logger.log_info(f"Successfully retrieved cached tenant database with Beanie for product_id: {product_id}")
|
await self.module_logger.log_info(f"Successfully retrieved cached tenant database with BaseDoc context for product_id: {product_id}")
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
except HTTPException as e:
|
except ValueError as e:
|
||||||
# Handle tenant not found or inactive (HTTPException from TenantDBCache)
|
# Handle tenant not found or inactive (ValueError from TenantDBCache)
|
||||||
await self.module_logger.log_error(f"Tenant error for {product_id}: [{e.status_code}] {e.detail}")
|
await self.module_logger.log_error(f"Tenant error for {product_id}: {str(e)}")
|
||||||
response = JSONResponse(
|
response = JSONResponse(
|
||||||
status_code=e.status_code,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
content={"detail": e.detail}
|
content={"detail": str(e)}
|
||||||
)
|
)
|
||||||
return await response(scope, receive, send)
|
return await response(scope, receive, send)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.module_logger.log_error(f"Database error for tenant {product_id}: {str(e)}")
|
await self.module_logger.log_error(f"Database error for tenant {product_id}: {str(e)}")
|
||||||
response = JSONResponse(
|
response = JSONResponse(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content={"detail": "Database connection error"}
|
content={"detail": "Database connection error"}
|
||||||
)
|
)
|
||||||
return await response(scope, receive, send)
|
return await response(scope, receive, send)
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
@ -1,5 +1,4 @@
|
|||||||
from webapi.config.site_settings import site_settings
|
from webapi.config.site_settings import site_settings
|
||||||
from beanie import init_beanie
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||||
from backend.models.models import MessageTemplateDoc, EmailSenderDoc, EmailSendStatusDoc, EmailTrackingDoc, EmailBounceDoc, UsageLogDoc
|
from backend.models.models import MessageTemplateDoc, EmailSenderDoc, EmailSendStatusDoc, EmailTrackingDoc, EmailBounceDoc, UsageLogDoc
|
||||||
@ -15,7 +14,7 @@ import os
|
|||||||
MAIN_CLIENT: Optional[AsyncIOMotorClient] = None
|
MAIN_CLIENT: Optional[AsyncIOMotorClient] = None
|
||||||
TENANT_CACHE: Optional['TenantDBCache'] = None
|
TENANT_CACHE: Optional['TenantDBCache'] = None
|
||||||
|
|
||||||
# Define document models
|
# Define document models
|
||||||
document_models = [
|
document_models = [
|
||||||
MessageTemplateDoc,
|
MessageTemplateDoc,
|
||||||
EmailSenderDoc,
|
EmailSenderDoc,
|
||||||
@ -37,7 +36,7 @@ class TenantDBCache:
|
|||||||
Uses main_db.tenant_doc to resolve mongodb_uri; caches clients with LRU.
|
Uses main_db.tenant_doc to resolve mongodb_uri; caches clients with LRU.
|
||||||
Database instances are created fresh each time from cached clients.
|
Database instances are created fresh each time from cached clients.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, main_db: AsyncIOMotorDatabase, max_size: int = 64):
|
def __init__(self, main_db: AsyncIOMotorDatabase, max_size: int = 64):
|
||||||
self.main_db = main_db
|
self.main_db = main_db
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
@ -47,20 +46,22 @@ class TenantDBCache:
|
|||||||
self.module_logger = ModuleLogger(sender_id="TenantDBCache")
|
self.module_logger = ModuleLogger(sender_id="TenantDBCache")
|
||||||
|
|
||||||
async def get_initialized_db(self, product_id: str) -> AsyncIOMotorDatabase:
|
async def get_initialized_db(self, product_id: str) -> AsyncIOMotorDatabase:
|
||||||
"""Get tenant database with Beanie already initialized"""
|
"""Get tenant database with BaseDoc context set"""
|
||||||
|
|
||||||
# fast-path: check if client is cached
|
# fast-path: check if client is cached
|
||||||
cached_client = self._cache.get(product_id)
|
cached_client = self._cache.get(product_id)
|
||||||
if cached_client:
|
if cached_client:
|
||||||
await self.module_logger.log_info(f"Found cached client for {product_id}")
|
await self.module_logger.log_info(f"Found cached client for {product_id}")
|
||||||
self._cache.move_to_end(product_id)
|
self._cache.move_to_end(product_id)
|
||||||
|
|
||||||
# Get fresh database instance from cached client
|
# Get fresh database instance from cached client
|
||||||
db = cached_client.get_default_database()
|
db = cached_client.get_default_database()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
# Initialize Beanie for this fresh database instance
|
# Set tenant database context for BaseDoc
|
||||||
await init_beanie(database=db, document_models=tenant_document_models)
|
MessageTemplateDoc.set_tenant_database(db)
|
||||||
await self.module_logger.log_info(f"Beanie initialization completed for {product_id} using cached client")
|
EmailSenderDoc.set_tenant_database(db)
|
||||||
|
EmailSendStatusDoc.set_tenant_database(db)
|
||||||
|
await self.module_logger.log_info(f"BaseDoc tenant context set for {product_id} using cached client")
|
||||||
return db
|
return db
|
||||||
else:
|
else:
|
||||||
await self.module_logger.log_error(f"No default database found for cached client {product_id}")
|
await self.module_logger.log_error(f"No default database found for cached client {product_id}")
|
||||||
@ -74,13 +75,15 @@ class TenantDBCache:
|
|||||||
if cached_client:
|
if cached_client:
|
||||||
await self.module_logger.log_info(f"Double-check found cached client for {product_id}")
|
await self.module_logger.log_info(f"Double-check found cached client for {product_id}")
|
||||||
self._cache.move_to_end(product_id)
|
self._cache.move_to_end(product_id)
|
||||||
|
|
||||||
# Get fresh database instance from cached client
|
# Get fresh database instance from cached client
|
||||||
db = cached_client.get_default_database()
|
db = cached_client.get_default_database()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
# Initialize Beanie for this fresh database instance
|
# Set tenant database context for BaseDoc
|
||||||
await init_beanie(database=db, document_models=tenant_document_models)
|
MessageTemplateDoc.set_tenant_database(db)
|
||||||
await self.module_logger.log_info(f"Beanie initialization completed for {product_id} using cached client (double-check)")
|
EmailSenderDoc.set_tenant_database(db)
|
||||||
|
EmailSendStatusDoc.set_tenant_database(db)
|
||||||
|
await self.module_logger.log_info(f"BaseDoc tenant context set for {product_id} using cached client (double-check)")
|
||||||
return db
|
return db
|
||||||
else:
|
else:
|
||||||
await self.module_logger.log_error(f"No default database found for cached client {product_id}")
|
await self.module_logger.log_error(f"No default database found for cached client {product_id}")
|
||||||
@ -128,10 +131,12 @@ class TenantDBCache:
|
|||||||
detail=f"No default database found for tenant {product_id}",
|
detail=f"No default database found for tenant {product_id}",
|
||||||
headers={"X-Error-Message": f"No default database found for tenant {product_id}"}
|
headers={"X-Error-Message": f"No default database found for tenant {product_id}"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Beanie for this tenant database
|
# Set tenant database context for BaseDoc
|
||||||
await init_beanie(database=db, document_models=tenant_document_models)
|
MessageTemplateDoc.set_tenant_database(db)
|
||||||
await self.module_logger.log_info(f"Beanie initialization completed for new tenant database {product_id}")
|
EmailSenderDoc.set_tenant_database(db)
|
||||||
|
EmailSendStatusDoc.set_tenant_database(db)
|
||||||
|
await self.module_logger.log_info(f"BaseDoc tenant context set for new tenant database {product_id}")
|
||||||
|
|
||||||
# Cache only the client
|
# Cache only the client
|
||||||
await self._lru_put(product_id, client)
|
await self._lru_put(product_id, client)
|
||||||
@ -139,10 +144,15 @@ class TenantDBCache:
|
|||||||
return db
|
return db
|
||||||
|
|
||||||
async def get_main_db_initialized(self) -> AsyncIOMotorDatabase:
|
async def get_main_db_initialized(self) -> AsyncIOMotorDatabase:
|
||||||
"""Get main database with Beanie initialized for tenant models"""
|
"""Get main database with BaseDoc context set for all models"""
|
||||||
# Re-initialize Beanie for main database with business models
|
# Set main database context for all BaseDoc models
|
||||||
await init_beanie(database=self.main_db, document_models=document_models)
|
MessageTemplateDoc.set_tenant_database(self.main_db)
|
||||||
await self.module_logger.log_info("Beanie initialization completed for main database")
|
EmailSenderDoc.set_tenant_database(self.main_db)
|
||||||
|
EmailSendStatusDoc.set_tenant_database(self.main_db)
|
||||||
|
EmailTrackingDoc.set_tenant_database(self.main_db)
|
||||||
|
EmailBounceDoc.set_tenant_database(self.main_db)
|
||||||
|
UsageLogDoc.set_tenant_database(self.main_db)
|
||||||
|
await self.module_logger.log_info("BaseDoc context set for main database")
|
||||||
return self.main_db
|
return self.main_db
|
||||||
|
|
||||||
async def _lru_put(self, key: str, client: AsyncIOMotorClient):
|
async def _lru_put(self, key: str, client: AsyncIOMotorClient):
|
||||||
@ -180,7 +190,7 @@ def register(app):
|
|||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def start_database():
|
async def start_database():
|
||||||
await initiate_database(app)
|
await initiate_database(app)
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def shutdown_database():
|
async def shutdown_database():
|
||||||
await cleanup_database()
|
await cleanup_database()
|
||||||
@ -189,15 +199,20 @@ def register(app):
|
|||||||
async def initiate_database(app):
|
async def initiate_database(app):
|
||||||
"""Initialize main database and tenant cache"""
|
"""Initialize main database and tenant cache"""
|
||||||
global MAIN_CLIENT, TENANT_CACHE
|
global MAIN_CLIENT, TENANT_CACHE
|
||||||
|
|
||||||
module_logger = ModuleLogger(sender_id="DatabaseInit")
|
module_logger = ModuleLogger(sender_id="DatabaseInit")
|
||||||
|
|
||||||
# 1) Create main/catalog client + DB
|
# 1) Create main/catalog client + DB
|
||||||
MAIN_CLIENT = AsyncIOMotorClient(app_settings.MONGODB_URI)
|
MAIN_CLIENT = AsyncIOMotorClient(app_settings.MONGODB_URI)
|
||||||
main_db = MAIN_CLIENT[app_settings.MONGODB_NAME]
|
main_db = MAIN_CLIENT[app_settings.MONGODB_NAME]
|
||||||
|
|
||||||
# 2) Initialize Beanie for main DB with business document models
|
# 2) Set BaseDoc context for main database
|
||||||
await init_beanie(database=main_db, document_models=document_models)
|
MessageTemplateDoc.set_tenant_database(main_db)
|
||||||
|
EmailSenderDoc.set_tenant_database(main_db)
|
||||||
|
EmailSendStatusDoc.set_tenant_database(main_db)
|
||||||
|
EmailTrackingDoc.set_tenant_database(main_db)
|
||||||
|
EmailBounceDoc.set_tenant_database(main_db)
|
||||||
|
UsageLogDoc.set_tenant_database(main_db)
|
||||||
|
|
||||||
# 3) Create tenant cache that uses main_db lookups to resolve product_id -> tenant db
|
# 3) Create tenant cache that uses main_db lookups to resolve product_id -> tenant db
|
||||||
max_cache_size = getattr(app_settings, 'TENANT_CACHE_MAX', 64)
|
max_cache_size = getattr(app_settings, 'TENANT_CACHE_MAX', 64)
|
||||||
@ -206,20 +221,20 @@ async def initiate_database(app):
|
|||||||
# 4) Store on app state for middleware to access
|
# 4) Store on app state for middleware to access
|
||||||
app.state.main_db = main_db
|
app.state.main_db = main_db
|
||||||
app.state.tenant_cache = TENANT_CACHE
|
app.state.tenant_cache = TENANT_CACHE
|
||||||
|
|
||||||
await module_logger.log_info("Database and tenant cache initialized successfully")
|
await module_logger.log_info("Database and tenant cache initialized successfully")
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_database():
|
async def cleanup_database():
|
||||||
"""Cleanup database connections and cache"""
|
"""Cleanup database connections and cache"""
|
||||||
global MAIN_CLIENT, TENANT_CACHE
|
global MAIN_CLIENT, TENANT_CACHE
|
||||||
|
|
||||||
module_logger = ModuleLogger(sender_id="DatabaseCleanup")
|
module_logger = ModuleLogger(sender_id="DatabaseCleanup")
|
||||||
|
|
||||||
if TENANT_CACHE:
|
if TENANT_CACHE:
|
||||||
await TENANT_CACHE.aclose()
|
await TENANT_CACHE.aclose()
|
||||||
|
|
||||||
if MAIN_CLIENT:
|
if MAIN_CLIENT:
|
||||||
MAIN_CLIENT.close()
|
MAIN_CLIENT.close()
|
||||||
|
|
||||||
await module_logger.log_info("Database connections closed successfully")
|
await module_logger.log_info("Database connections closed successfully")
|
||||||
Loading…
Reference in New Issue
Block a user