""" BaseDoc - A custom document class that provides Beanie-like interface using direct MongoDB operations """ import asyncio from datetime import datetime, timezone from typing import Optional, List, Dict, Any, Type, Union from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase from pydantic import BaseModel from common.config.app_settings import app_settings class QueryExpression: """Query expression for field comparisons""" def __init__(self, field_name: str): self.field_name = field_name def __eq__(self, other) -> dict: """Handle field == value comparisons""" return {self.field_name: other} def __ne__(self, other) -> dict: """Handle field != value comparisons""" return {self.field_name: {"$ne": other}} def __gt__(self, other) -> dict: """Handle field > value comparisons""" return {self.field_name: {"$gt": other}} def __lt__(self, other) -> dict: """Handle field < value comparisons""" return {self.field_name: {"$lt": other}} def __ge__(self, other) -> dict: """Handle field >= value comparisons""" return {self.field_name: {"$gte": other}} def __le__(self, other) -> dict: """Handle field <= value comparisons""" return {self.field_name: {"$lte": other}} class FieldDescriptor: """Descriptor for field access like Beanie's field == value pattern""" def __init__(self, field_name: str, field_type: type): self.field_name = field_name self.field_type = field_type def __get__(self, instance: Any, owner: type) -> Any: """ - Class access (instance is None): return QueryExpression for building queries - Instance access (instance is not None): return the actual field value """ if instance is None: return QueryExpression(self.field_name) return instance.__dict__.get(self.field_name) def __set__(self, instance: Any, value: Any) -> None: """Set instance field value with type validation (compatible with Pydantic validation)""" if not isinstance(value, self.field_type): raise TypeError(f"Field {self.field_name} must be {self.field_type}") instance.__dict__[self.field_name] = value class FieldCondition: """Represents a field condition for MongoDB queries""" def __init__(self, field_name: str, value: Any, operator: str = "$eq"): self.field_name = field_name self.value = value self.operator = operator self.left = self # For compatibility with existing condition parsing self.right = value # Module-level variables for database connection _db: Optional[AsyncIOMotorDatabase] = None _client: Optional[AsyncIOMotorClient] = None # Context variable for tenant database import contextvars _tenant_db_context: contextvars.ContextVar[Optional[AsyncIOMotorDatabase]] = contextvars.ContextVar('tenant_db', default=None) class QueryModelMeta(type(BaseModel)): """Metaclass: automatically create FieldDescriptor for model fields""" def __new__(cls, name: str, bases: tuple, namespace: dict): # Get model field annotations (like name: str -> "name" and str) annotations = namespace.get("__annotations__", {}) # Create the class first new_class = super().__new__(cls, name, bases, namespace) # After Pydantic processes the fields, add the descriptors as class attributes for field_name, field_type in annotations.items(): if field_name != 'id': # Skip the id field as it's handled specially # Add the descriptor as a class attribute setattr(new_class, field_name, FieldDescriptor(field_name, field_type)) return new_class class BaseDoc(BaseModel, metaclass=QueryModelMeta): """ Base document class that provides Beanie-like interface using direct MongoDB operations. All model classes should inherit from this instead of Beanie's Document. """ id: Optional[str] = None # MongoDB _id field def model_dump(self, **kwargs): """Override model_dump to exclude field descriptors""" # Get the default model_dump result result = super().model_dump(**kwargs) # Remove any field descriptors that might have been included filtered_result = {} for key, value in result.items(): if not isinstance(value, FieldDescriptor): filtered_result[key] = value return filtered_result @classmethod async def _get_database(cls) -> AsyncIOMotorDatabase: """Get database connection using pure AsyncIOMotorClient""" # Try to get tenant database from context first tenant_db = _tenant_db_context.get() if tenant_db is not None: return tenant_db # Fallback to global database connection global _db, _client if _db is None: _client = AsyncIOMotorClient(app_settings.MONGODB_URI) _db = _client[app_settings.MONGODB_NAME] return _db @classmethod def set_tenant_database(cls, db: AsyncIOMotorDatabase): """Set the tenant database for this context""" _tenant_db_context.set(db) @classmethod def _get_collection_name(cls) -> str: """Get collection name from Settings or class name""" if hasattr(cls, 'Settings') and hasattr(cls.Settings, 'name'): return cls.Settings.name else: # Convert class name to snake_case for collection name import re name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', cls.__name__) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower() @classmethod def find(cls, *conditions) -> 'QueryBuilder': """Find documents matching conditions - returns QueryBuilder for chaining""" return QueryBuilder(cls, conditions) @classmethod async def find_one(cls, *conditions) -> Optional['BaseDoc']: """Find one document matching conditions""" db = await cls._get_database() collection_name = cls._get_collection_name() collection = db[collection_name] # Convert Beanie-style conditions to MongoDB query query = cls._convert_conditions_to_query(conditions) doc = await collection.find_one(query) if doc: # Extract MongoDB _id and convert to string mongo_id = doc.pop('_id', None) # Filter doc to only include fields defined in the model model_fields = set(cls.model_fields.keys()) filtered_doc = {k: v for k, v in doc.items() if k in model_fields} # Add the id field if mongo_id: filtered_doc['id'] = str(mongo_id) return cls(**filtered_doc) return None @classmethod async def get(cls, doc_id: str) -> Optional['BaseDoc']: """Get document by ID""" from bson import ObjectId try: object_id = ObjectId(doc_id) except: return None db = await cls._get_database() collection_name = cls._get_collection_name() collection = db[collection_name] doc = await collection.find_one({"_id": object_id}) if doc: # Extract MongoDB _id and convert to string mongo_id = doc.pop('_id', None) # Filter doc to only include fields defined in the model model_fields = set(cls.model_fields.keys()) filtered_doc = {k: v for k, v in doc.items() if k in model_fields} # Add the id field if mongo_id: filtered_doc['id'] = str(mongo_id) return cls(**filtered_doc) return None @classmethod def _convert_conditions_to_query(cls, conditions) -> Dict[str, Any]: """Convert Beanie-style conditions to MongoDB query""" if not conditions: return {} query = {} for condition in conditions: if isinstance(condition, dict): # Handle QueryExpression results (dictionaries) and direct dictionary queries query.update(condition) elif isinstance(condition, FieldCondition): # Handle legacy FieldCondition objects if condition.operator == "$eq": query[condition.field_name] = condition.value else: query[condition.field_name] = {condition.operator: condition.value} elif hasattr(condition, 'left') and hasattr(condition, 'right'): # Handle field == value conditions field_name = condition.left.name value = condition.right query[field_name] = value elif hasattr(condition, '__dict__'): # Handle complex conditions like FLID.identity == value if hasattr(condition, 'left') and hasattr(condition, 'right'): left = condition.left if hasattr(left, 'name') and hasattr(left, 'left'): # Nested field access like FLID.identity field_name = f"{left.left.name}.{left.name}" value = condition.right query[field_name] = value else: field_name = left.name value = condition.right query[field_name] = value return query def _convert_decimals_to_float(self, obj): """Convert Decimal objects to float for MongoDB compatibility""" from decimal import Decimal if isinstance(obj, Decimal): return float(obj) elif isinstance(obj, dict): return {key: self._convert_decimals_to_float(value) for key, value in obj.items()} elif isinstance(obj, list): return [self._convert_decimals_to_float(item) for item in obj] else: return obj async def create(self) -> 'BaseDoc': """Create this document in the database""" db = await self._get_database() collection_name = self._get_collection_name() collection = db[collection_name] # Convert to dict and insert, excluding field descriptors doc_dict = self.model_dump(exclude={'id'}) # Convert Decimal objects to float for MongoDB compatibility doc_dict = self._convert_decimals_to_float(doc_dict) result = await collection.insert_one(doc_dict) # Set the id field from the inserted document if result.inserted_id: self.id = str(result.inserted_id) # Return the created document return self async def save(self) -> 'BaseDoc': """Save this document to the database (update if exists, create if not)""" db = await self._get_database() collection_name = self._get_collection_name() collection = db[collection_name] # Convert to dict, excluding field descriptors doc_dict = self.model_dump(exclude={'id'}) # Convert Decimal objects to float for MongoDB compatibility doc_dict = self._convert_decimals_to_float(doc_dict) # Try to find existing document by user_id or other unique fields query = {} if hasattr(self, 'user_id'): query['user_id'] = self.user_id elif hasattr(self, 'email'): query['email'] = self.email elif hasattr(self, 'mobile'): query['mobile'] = self.mobile elif hasattr(self, 'auth_code'): query['auth_code'] = self.auth_code if query: # Update existing document result = await collection.update_one(query, {"$set": doc_dict}, upsert=True) # If it was an insert, set the id field if result.upserted_id: self.id = str(result.upserted_id) else: # Insert new document result = await collection.insert_one(doc_dict) if result.inserted_id: self.id = str(result.inserted_id) return self async def delete(self) -> bool: """Delete this document from the database""" db = await self._get_database() collection_name = self._get_collection_name() collection = db[collection_name] # Try to find existing document by user_id or other unique fields query = {} if hasattr(self, 'user_id'): query['user_id'] = self.user_id elif hasattr(self, 'email'): query['email'] = self.email elif hasattr(self, 'mobile'): query['mobile'] = self.mobile elif hasattr(self, 'auth_code'): query['auth_code'] = self.auth_code if query: result = await collection.delete_one(query) return result.deleted_count > 0 return False class QueryBuilder: """Query builder for chaining operations like Beanie's QueryBuilder""" def __init__(self, model_class: Type[BaseDoc], conditions: tuple): self.model_class = model_class self.conditions = conditions self._limit_value = None self._skip_value = None def limit(self, n: int) -> 'QueryBuilder': """Limit number of results""" self._limit_value = n return self def skip(self, n: int) -> 'QueryBuilder': """Skip number of results""" self._skip_value = n return self async def to_list(self) -> List[BaseDoc]: """Convert query to list of documents""" db = await self.model_class._get_database() collection_name = self.model_class._get_collection_name() collection = db[collection_name] # Convert conditions to MongoDB query query = self.model_class._convert_conditions_to_query(self.conditions) # Build cursor cursor = collection.find(query) if self._skip_value: cursor = cursor.skip(self._skip_value) if self._limit_value: cursor = cursor.limit(self._limit_value) # Execute query and convert to model instances docs = await cursor.to_list(length=None) results = [] for doc in docs: # Extract MongoDB _id and convert to string mongo_id = doc.pop('_id', None) # Filter doc to only include fields defined in the model model_fields = set(self.model_class.model_fields.keys()) filtered_doc = {k: v for k, v in doc.items() if k in model_fields} # Add the id field if mongo_id: filtered_doc['id'] = str(mongo_id) results.append(self.model_class(**filtered_doc)) return results async def first_or_none(self) -> Optional[BaseDoc]: """Get first result or None""" results = await self.limit(1).to_list() return results[0] if results else None async def count(self) -> int: """Count number of matching documents""" db = await self.model_class._get_database() collection_name = self.model_class._get_collection_name() collection = db[collection_name] query = self.model_class._convert_conditions_to_query(self.conditions) return await collection.count_documents(query)