Merge pull request 'fix: change the authentication service model layer to use direct MongoDB things, not the Benie ODM' (#89) from fix/auth-model into dev
Reviewed-on: freeleaps/freeleaps-service-hub#89
This commit is contained in:
commit
cad00ce490
@ -3,7 +3,7 @@ from typing import Optional, List, Tuple
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
from backend.models.permission.models import PermissionDoc, RoleDoc
|
||||
from beanie import PydanticObjectId
|
||||
from bson import ObjectId
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@ -28,10 +28,10 @@ class PermissionHandler:
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
await doc.insert()
|
||||
await doc.create()
|
||||
return doc
|
||||
|
||||
async def update_permission(self, permission_id: PydanticObjectId, permission_key: Optional[str] = None,
|
||||
async def update_permission(self, permission_id: str, permission_key: Optional[str] = None,
|
||||
permission_name: Optional[str] = None, description: Optional[str] = None) -> Optional[
|
||||
PermissionDoc]:
|
||||
"""Update an existing permission document by id, ensuring permission_key is unique"""
|
||||
@ -67,7 +67,7 @@ class PermissionHandler:
|
||||
# Input validation
|
||||
if not permission_key or not permission_name:
|
||||
raise RequestValidationError("permission_key and permission_name are required.")
|
||||
|
||||
|
||||
def create_new_doc():
|
||||
return PermissionDoc(
|
||||
permission_key=permission_key,
|
||||
@ -76,14 +76,14 @@ class PermissionHandler:
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
|
||||
|
||||
def update_doc_fields(doc):
|
||||
doc.permission_key = permission_key
|
||||
doc.permission_name = permission_name
|
||||
doc.description = description
|
||||
doc.updated_at = datetime.now()
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
# Check if permission with this key already exists
|
||||
existing_doc = await PermissionDoc.find_one(
|
||||
{str(PermissionDoc.permission_key): permission_key}
|
||||
@ -98,10 +98,10 @@ class PermissionHandler:
|
||||
id_conflict = await PermissionDoc.get(custom_permission_id)
|
||||
if id_conflict:
|
||||
raise RequestValidationError("Permission with the provided ID already exists.")
|
||||
|
||||
|
||||
new_doc = create_new_doc()
|
||||
new_doc.id = PydanticObjectId(custom_permission_id)
|
||||
await new_doc.insert()
|
||||
new_doc.id = custom_permission_id
|
||||
await new_doc.create()
|
||||
await existing_doc.delete()
|
||||
return new_doc
|
||||
else:
|
||||
@ -112,16 +112,16 @@ class PermissionHandler:
|
||||
else:
|
||||
# If no existing document with this key, create new document
|
||||
new_doc = create_new_doc()
|
||||
|
||||
|
||||
if custom_permission_id:
|
||||
id_conflict = await PermissionDoc.get(custom_permission_id)
|
||||
if id_conflict:
|
||||
raise RequestValidationError("Permission with the provided ID already exists.")
|
||||
new_doc.id = PydanticObjectId(custom_permission_id)
|
||||
|
||||
await new_doc.insert()
|
||||
new_doc.id = custom_permission_id
|
||||
|
||||
await new_doc.create()
|
||||
return new_doc
|
||||
|
||||
|
||||
async def query_permissions(
|
||||
self,
|
||||
permission_key: Optional[str] = None,
|
||||
@ -141,16 +141,16 @@ class PermissionHandler:
|
||||
return docs, total
|
||||
|
||||
async def query_permissions_no_pagination(
|
||||
self,
|
||||
permission_id: Optional[str] = None,
|
||||
permission_key: Optional[str] = None,
|
||||
self,
|
||||
permission_id: Optional[str] = None,
|
||||
permission_key: Optional[str] = None,
|
||||
permission_name: Optional[str] = None
|
||||
) -> Tuple[List[PermissionDoc], int]:
|
||||
"""Query permissions fuzzy search"""
|
||||
query = {}
|
||||
if permission_id:
|
||||
try:
|
||||
query[str(PermissionDoc.id)] = PydanticObjectId(permission_id)
|
||||
query[str(PermissionDoc.id)] = permission_id
|
||||
except Exception:
|
||||
raise RequestValidationError("Invalid permission_id format. Must be a valid ObjectId.")
|
||||
if permission_key:
|
||||
@ -162,7 +162,7 @@ class PermissionHandler:
|
||||
docs = await cursor.to_list()
|
||||
return docs, total
|
||||
|
||||
async def delete_permission(self, permission_id: PydanticObjectId) -> None:
|
||||
async def delete_permission(self, permission_id: str) -> None:
|
||||
"""Delete a permission document after checking if it is referenced by any role and is not default"""
|
||||
if not permission_id:
|
||||
raise RequestValidationError("permission_id is required.")
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Optional, List, Tuple
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
from backend.models.permission.models import RoleDoc, PermissionDoc, UserRoleDoc
|
||||
from beanie import PydanticObjectId
|
||||
from bson import ObjectId
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@ -27,10 +27,10 @@ class RoleHandler:
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
await doc.insert()
|
||||
await doc.create()
|
||||
return doc
|
||||
|
||||
async def update_role(self, role_id: PydanticObjectId, role_key: str, role_name: str,
|
||||
async def update_role(self, role_id: str, role_key: str, role_name: str,
|
||||
role_description: Optional[str], role_level: int) -> Optional[
|
||||
RoleDoc]:
|
||||
"""Update an existing role, ensuring role_key and role_name are unique and not empty"""
|
||||
@ -83,7 +83,7 @@ class RoleHandler:
|
||||
doc.role_description = role_description
|
||||
doc.role_level = role_level
|
||||
doc.updated_at = datetime.now()
|
||||
|
||||
|
||||
# Check if role with this key already exists
|
||||
existing_doc = await RoleDoc.find_one(
|
||||
{str(RoleDoc.role_key): role_key}
|
||||
@ -96,13 +96,13 @@ class RoleHandler:
|
||||
id_conflict = await RoleDoc.get(custom_role_id)
|
||||
if id_conflict:
|
||||
raise RequestValidationError("Role with the provided ID already exists.")
|
||||
|
||||
|
||||
new_doc = create_new_doc()
|
||||
new_doc.id = PydanticObjectId(custom_role_id)
|
||||
await new_doc.insert()
|
||||
new_doc.id = custom_role_id
|
||||
await new_doc.create()
|
||||
await existing_doc.delete()
|
||||
return new_doc
|
||||
|
||||
|
||||
else:
|
||||
# Same ID or no ID provided - update existing document
|
||||
update_doc_fields(existing_doc)
|
||||
@ -111,13 +111,13 @@ class RoleHandler:
|
||||
else:
|
||||
# If no existing document with this key, create new document
|
||||
new_doc = create_new_doc()
|
||||
|
||||
|
||||
if custom_role_id:
|
||||
id_conflict = await RoleDoc.get(custom_role_id)
|
||||
if id_conflict:
|
||||
raise RequestValidationError("Role with the provided ID already exists.")
|
||||
new_doc.id = PydanticObjectId(custom_role_id)
|
||||
|
||||
new_doc.id = custom_role_id
|
||||
|
||||
await new_doc.insert()
|
||||
return new_doc
|
||||
|
||||
@ -135,16 +135,16 @@ class RoleHandler:
|
||||
return docs, total
|
||||
|
||||
async def query_roles_no_pagination(
|
||||
self,
|
||||
role_id: Optional[str] = None,
|
||||
role_key: Optional[str] = None,
|
||||
self,
|
||||
role_id: Optional[str] = None,
|
||||
role_key: Optional[str] = None,
|
||||
role_name: Optional[str] = None
|
||||
) -> Tuple[List[RoleDoc], int]:
|
||||
"""Query roles fuzzy search without pagination"""
|
||||
query = {}
|
||||
if role_id:
|
||||
try:
|
||||
query[str(RoleDoc.id)] = PydanticObjectId(role_id)
|
||||
query[str(RoleDoc.id)] = role_id
|
||||
except Exception:
|
||||
raise RequestValidationError("Invalid role_id format. Must be a valid ObjectId.")
|
||||
if role_key:
|
||||
@ -156,29 +156,29 @@ class RoleHandler:
|
||||
docs = await cursor.to_list()
|
||||
return docs, total
|
||||
|
||||
async def assign_permissions_to_role(self, role_id: PydanticObjectId, permission_ids: List[str]) -> Optional[RoleDoc]:
|
||||
async def assign_permissions_to_role(self, role_id: str, permission_ids: List[str]) -> Optional[RoleDoc]:
|
||||
"""Assign permissions to a role by updating the permission_ids field"""
|
||||
if not role_id or not permission_ids:
|
||||
raise RequestValidationError("role_id and permission_ids are required.")
|
||||
doc = await RoleDoc.get(role_id)
|
||||
if not doc:
|
||||
raise RequestValidationError("Role not found.")
|
||||
|
||||
|
||||
# Validate that all permission_ids exist in the permission collection
|
||||
for permission_id in permission_ids:
|
||||
permission_doc = await PermissionDoc.get(PydanticObjectId(permission_id))
|
||||
permission_doc = await PermissionDoc.get(permission_id)
|
||||
if not permission_doc:
|
||||
raise RequestValidationError(f"Permission with id {permission_id} not found.")
|
||||
|
||||
|
||||
# Remove duplicates from permission_ids
|
||||
unique_permission_ids = list(dict.fromkeys(permission_ids))
|
||||
|
||||
|
||||
doc.permission_ids = unique_permission_ids
|
||||
doc.updated_at = datetime.now()
|
||||
await doc.save()
|
||||
return doc
|
||||
|
||||
async def delete_role(self, role_id: PydanticObjectId) -> None:
|
||||
async def delete_role(self, role_id: str) -> None:
|
||||
"""Delete a role document after checking if it is referenced by any user and is not default"""
|
||||
if not role_id:
|
||||
raise RequestValidationError("role_id is required.")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Optional, List
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from backend.models.permission.models import RoleDoc, UserRoleDoc, PermissionDoc
|
||||
from beanie import PydanticObjectId
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
class UserRoleHandler:
|
||||
@ -15,7 +15,7 @@ class UserRoleHandler:
|
||||
|
||||
# Validate that all role_ids exist in the role collection
|
||||
for role_id in role_ids:
|
||||
role_doc = await RoleDoc.get(PydanticObjectId(role_id))
|
||||
role_doc = await RoleDoc.get(role_id)
|
||||
if not role_doc:
|
||||
raise RequestValidationError(f"Role with id {role_id} not found.")
|
||||
|
||||
@ -47,7 +47,7 @@ class UserRoleHandler:
|
||||
# No roles assigned
|
||||
return [], []
|
||||
# Query all roles by role_ids
|
||||
roles = await RoleDoc.find({"_id": {"$in": [PydanticObjectId(rid) for rid in user_role_doc.role_ids]}}).to_list()
|
||||
roles = await RoleDoc.find({"_id": {"$in": user_role_doc.role_ids}}).to_list()
|
||||
role_names = [role.role_name for role in roles]
|
||||
# Collect all permission_ids from all roles
|
||||
all_permission_ids = []
|
||||
@ -58,7 +58,7 @@ class UserRoleHandler:
|
||||
unique_permission_ids = list(dict.fromkeys(all_permission_ids))
|
||||
# Query all permissions by permission_ids
|
||||
if unique_permission_ids:
|
||||
permissions = await PermissionDoc.find({"_id": {"$in": [PydanticObjectId(pid) for pid in unique_permission_ids]}}).to_list()
|
||||
permissions = await PermissionDoc.find({"_id": {"$in": unique_permission_ids}}).to_list()
|
||||
permission_keys = [perm.permission_key for perm in permissions]
|
||||
else:
|
||||
permission_keys = []
|
||||
|
||||
405
apps/authentication/backend/models/base_doc.py
Normal file
405
apps/authentication/backend/models/base_doc.py
Normal file
@ -0,0 +1,405 @@
|
||||
"""
|
||||
BaseDoc - A custom document class that provides Beanie-like interface using direct MongoDB operations
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List, Dict, Any, Type, Union
|
||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||
from pydantic import BaseModel
|
||||
from common.config.app_settings import app_settings
|
||||
|
||||
|
||||
class QueryExpression:
|
||||
"""Query expression for field comparisons"""
|
||||
def __init__(self, field_name: str):
|
||||
self.field_name = field_name
|
||||
|
||||
def __eq__(self, other) -> dict:
|
||||
"""Handle field == value comparisons"""
|
||||
return {self.field_name: other}
|
||||
|
||||
def __ne__(self, other) -> dict:
|
||||
"""Handle field != value comparisons"""
|
||||
return {self.field_name: {"$ne": other}}
|
||||
|
||||
def __gt__(self, other) -> dict:
|
||||
"""Handle field > value comparisons"""
|
||||
return {self.field_name: {"$gt": other}}
|
||||
|
||||
def __lt__(self, other) -> dict:
|
||||
"""Handle field < value comparisons"""
|
||||
return {self.field_name: {"$lt": other}}
|
||||
|
||||
def __ge__(self, other) -> dict:
|
||||
"""Handle field >= value comparisons"""
|
||||
return {self.field_name: {"$gte": other}}
|
||||
|
||||
def __le__(self, other) -> dict:
|
||||
"""Handle field <= value comparisons"""
|
||||
return {self.field_name: {"$lte": other}}
|
||||
|
||||
|
||||
class FieldDescriptor:
|
||||
"""Descriptor for field access like Beanie's field == value pattern"""
|
||||
def __init__(self, field_name: str, field_type: type):
|
||||
self.field_name = field_name
|
||||
self.field_type = field_type
|
||||
|
||||
def __get__(self, instance: Any, owner: type) -> Any:
|
||||
"""
|
||||
- Class access (instance is None): return QueryExpression for building queries
|
||||
- Instance access (instance is not None): return the actual field value
|
||||
"""
|
||||
if instance is None:
|
||||
return QueryExpression(self.field_name)
|
||||
return instance.__dict__.get(self.field_name)
|
||||
|
||||
def __set__(self, instance: Any, value: Any) -> None:
|
||||
"""Set instance field value with type validation (compatible with Pydantic validation)"""
|
||||
if not isinstance(value, self.field_type):
|
||||
raise TypeError(f"Field {self.field_name} must be {self.field_type}")
|
||||
instance.__dict__[self.field_name] = value
|
||||
|
||||
|
||||
class FieldCondition:
|
||||
"""Represents a field condition for MongoDB queries"""
|
||||
def __init__(self, field_name: str, value: Any, operator: str = "$eq"):
|
||||
self.field_name = field_name
|
||||
self.value = value
|
||||
self.operator = operator
|
||||
self.left = self # For compatibility with existing condition parsing
|
||||
self.right = value
|
||||
|
||||
|
||||
# Module-level variables for database connection
|
||||
_db: Optional[AsyncIOMotorDatabase] = None
|
||||
_client: Optional[AsyncIOMotorClient] = None
|
||||
|
||||
# Context variable for tenant database
|
||||
import contextvars
|
||||
_tenant_db_context: contextvars.ContextVar[Optional[AsyncIOMotorDatabase]] = contextvars.ContextVar('tenant_db', default=None)
|
||||
|
||||
|
||||
class QueryModelMeta(type(BaseModel)):
|
||||
"""Metaclass: automatically create FieldDescriptor for model fields"""
|
||||
def __new__(cls, name: str, bases: tuple, namespace: dict):
|
||||
# Get model field annotations (like name: str -> "name" and str)
|
||||
annotations = namespace.get("__annotations__", {})
|
||||
|
||||
# Create the class first
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
|
||||
# After Pydantic processes the fields, add the descriptors as class attributes
|
||||
for field_name, field_type in annotations.items():
|
||||
if field_name != 'id': # Skip the id field as it's handled specially
|
||||
# Add the descriptor as a class attribute
|
||||
setattr(new_class, field_name, FieldDescriptor(field_name, field_type))
|
||||
|
||||
return new_class
|
||||
|
||||
class BaseDoc(BaseModel, metaclass=QueryModelMeta):
|
||||
"""
|
||||
Base document class that provides Beanie-like interface using direct MongoDB operations.
|
||||
All model classes should inherit from this instead of Beanie's Document.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None # MongoDB _id field
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Override model_dump to exclude field descriptors"""
|
||||
# Get the default model_dump result
|
||||
result = super().model_dump(**kwargs)
|
||||
|
||||
# Remove any field descriptors that might have been included
|
||||
filtered_result = {}
|
||||
for key, value in result.items():
|
||||
if not isinstance(value, FieldDescriptor):
|
||||
filtered_result[key] = value
|
||||
|
||||
return filtered_result
|
||||
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
async def _get_database(cls) -> AsyncIOMotorDatabase:
|
||||
"""Get database connection using pure AsyncIOMotorClient"""
|
||||
# Try to get tenant database from context first
|
||||
tenant_db = _tenant_db_context.get()
|
||||
if tenant_db is not None:
|
||||
return tenant_db
|
||||
|
||||
# Fallback to global database connection
|
||||
global _db, _client
|
||||
if _db is None:
|
||||
_client = AsyncIOMotorClient(app_settings.MONGODB_URI)
|
||||
_db = _client[app_settings.MONGODB_NAME]
|
||||
return _db
|
||||
|
||||
@classmethod
|
||||
def set_tenant_database(cls, db: AsyncIOMotorDatabase):
|
||||
"""Set the tenant database for this context"""
|
||||
_tenant_db_context.set(db)
|
||||
|
||||
@classmethod
|
||||
def _get_collection_name(cls) -> str:
|
||||
"""Get collection name from Settings or class name"""
|
||||
if hasattr(cls, 'Settings') and hasattr(cls.Settings, 'name'):
|
||||
return cls.Settings.name
|
||||
else:
|
||||
# Convert class name to snake_case for collection name
|
||||
import re
|
||||
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', cls.__name__)
|
||||
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
|
||||
|
||||
@classmethod
|
||||
def find(cls, *conditions) -> 'QueryBuilder':
|
||||
"""Find documents matching conditions - returns QueryBuilder for chaining"""
|
||||
return QueryBuilder(cls, conditions)
|
||||
|
||||
@classmethod
|
||||
async def find_one(cls, *conditions) -> Optional['BaseDoc']:
|
||||
"""Find one document matching conditions"""
|
||||
db = await cls._get_database()
|
||||
collection_name = cls._get_collection_name()
|
||||
collection = db[collection_name]
|
||||
|
||||
# Convert Beanie-style conditions to MongoDB query
|
||||
query = cls._convert_conditions_to_query(conditions)
|
||||
|
||||
doc = await collection.find_one(query)
|
||||
if doc:
|
||||
# Extract MongoDB _id and convert to string
|
||||
mongo_id = doc.pop('_id', None)
|
||||
# Filter doc to only include fields defined in the model
|
||||
model_fields = set(cls.model_fields.keys())
|
||||
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
|
||||
# Add the id field
|
||||
if mongo_id:
|
||||
filtered_doc['id'] = str(mongo_id)
|
||||
return cls(**filtered_doc)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def get(cls, doc_id: str) -> Optional['BaseDoc']:
|
||||
"""Get document by ID"""
|
||||
from bson import ObjectId
|
||||
try:
|
||||
object_id = ObjectId(doc_id)
|
||||
except:
|
||||
return None
|
||||
|
||||
db = await cls._get_database()
|
||||
collection_name = cls._get_collection_name()
|
||||
collection = db[collection_name]
|
||||
|
||||
doc = await collection.find_one({"_id": object_id})
|
||||
if doc:
|
||||
# Extract MongoDB _id and convert to string
|
||||
mongo_id = doc.pop('_id', None)
|
||||
# Filter doc to only include fields defined in the model
|
||||
model_fields = set(cls.model_fields.keys())
|
||||
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
|
||||
# Add the id field
|
||||
if mongo_id:
|
||||
filtered_doc['id'] = str(mongo_id)
|
||||
return cls(**filtered_doc)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _convert_conditions_to_query(cls, conditions) -> Dict[str, Any]:
|
||||
"""Convert Beanie-style conditions to MongoDB query"""
|
||||
if not conditions:
|
||||
return {}
|
||||
|
||||
query = {}
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict):
|
||||
# Handle QueryExpression results (dictionaries) and direct dictionary queries
|
||||
query.update(condition)
|
||||
elif isinstance(condition, FieldCondition):
|
||||
# Handle legacy FieldCondition objects
|
||||
if condition.operator == "$eq":
|
||||
query[condition.field_name] = condition.value
|
||||
else:
|
||||
query[condition.field_name] = {condition.operator: condition.value}
|
||||
elif hasattr(condition, 'left') and hasattr(condition, 'right'):
|
||||
# Handle field == value conditions
|
||||
field_name = condition.left.name
|
||||
value = condition.right
|
||||
query[field_name] = value
|
||||
elif hasattr(condition, '__dict__'):
|
||||
# Handle complex conditions like FLID.identity == value
|
||||
if hasattr(condition, 'left') and hasattr(condition, 'right'):
|
||||
left = condition.left
|
||||
if hasattr(left, 'name') and hasattr(left, 'left'):
|
||||
# Nested field access like FLID.identity
|
||||
field_name = f"{left.left.name}.{left.name}"
|
||||
value = condition.right
|
||||
query[field_name] = value
|
||||
else:
|
||||
field_name = left.name
|
||||
value = condition.right
|
||||
query[field_name] = value
|
||||
|
||||
return query
|
||||
|
||||
def _convert_decimals_to_float(self, obj):
|
||||
"""Convert Decimal objects to float for MongoDB compatibility"""
|
||||
from decimal import Decimal
|
||||
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {key: self._convert_decimals_to_float(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._convert_decimals_to_float(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
async def create(self) -> 'BaseDoc':
|
||||
"""Create this document in the database"""
|
||||
db = await self._get_database()
|
||||
collection_name = self._get_collection_name()
|
||||
collection = db[collection_name]
|
||||
|
||||
# Convert to dict and insert, excluding field descriptors
|
||||
doc_dict = self.model_dump(exclude={'id'})
|
||||
|
||||
# Convert Decimal objects to float for MongoDB compatibility
|
||||
doc_dict = self._convert_decimals_to_float(doc_dict)
|
||||
|
||||
result = await collection.insert_one(doc_dict)
|
||||
|
||||
# Set the id field from the inserted document
|
||||
if result.inserted_id:
|
||||
self.id = str(result.inserted_id)
|
||||
|
||||
# Return the created document
|
||||
return self
|
||||
|
||||
async def save(self) -> 'BaseDoc':
|
||||
"""Save this document to the database (update if exists, create if not)"""
|
||||
db = await self._get_database()
|
||||
collection_name = self._get_collection_name()
|
||||
collection = db[collection_name]
|
||||
|
||||
# Convert to dict, excluding field descriptors
|
||||
doc_dict = self.model_dump(exclude={'id'})
|
||||
|
||||
# Convert Decimal objects to float for MongoDB compatibility
|
||||
doc_dict = self._convert_decimals_to_float(doc_dict)
|
||||
|
||||
# Try to find existing document by user_id or other unique fields
|
||||
query = {}
|
||||
if hasattr(self, 'user_id'):
|
||||
query['user_id'] = self.user_id
|
||||
elif hasattr(self, 'email'):
|
||||
query['email'] = self.email
|
||||
elif hasattr(self, 'mobile'):
|
||||
query['mobile'] = self.mobile
|
||||
elif hasattr(self, 'auth_code'):
|
||||
query['auth_code'] = self.auth_code
|
||||
|
||||
if query:
|
||||
# Update existing document
|
||||
result = await collection.update_one(query, {"$set": doc_dict}, upsert=True)
|
||||
# If it was an insert, set the id field
|
||||
if result.upserted_id:
|
||||
self.id = str(result.upserted_id)
|
||||
else:
|
||||
# Insert new document
|
||||
result = await collection.insert_one(doc_dict)
|
||||
if result.inserted_id:
|
||||
self.id = str(result.inserted_id)
|
||||
|
||||
return self
|
||||
|
||||
async def delete(self) -> bool:
|
||||
"""Delete this document from the database"""
|
||||
db = await self._get_database()
|
||||
collection_name = self._get_collection_name()
|
||||
collection = db[collection_name]
|
||||
|
||||
# Try to find existing document by user_id or other unique fields
|
||||
query = {}
|
||||
if hasattr(self, 'user_id'):
|
||||
query['user_id'] = self.user_id
|
||||
elif hasattr(self, 'email'):
|
||||
query['email'] = self.email
|
||||
elif hasattr(self, 'mobile'):
|
||||
query['mobile'] = self.mobile
|
||||
elif hasattr(self, 'auth_code'):
|
||||
query['auth_code'] = self.auth_code
|
||||
|
||||
if query:
|
||||
result = await collection.delete_one(query)
|
||||
return result.deleted_count > 0
|
||||
return False
|
||||
|
||||
|
||||
class QueryBuilder:
|
||||
"""Query builder for chaining operations like Beanie's QueryBuilder"""
|
||||
|
||||
def __init__(self, model_class: Type[BaseDoc], conditions: tuple):
|
||||
self.model_class = model_class
|
||||
self.conditions = conditions
|
||||
self._limit_value = None
|
||||
self._skip_value = None
|
||||
|
||||
def limit(self, n: int) -> 'QueryBuilder':
|
||||
"""Limit number of results"""
|
||||
self._limit_value = n
|
||||
return self
|
||||
|
||||
def skip(self, n: int) -> 'QueryBuilder':
|
||||
"""Skip number of results"""
|
||||
self._skip_value = n
|
||||
return self
|
||||
|
||||
async def to_list(self) -> List[BaseDoc]:
|
||||
"""Convert query to list of documents"""
|
||||
db = await self.model_class._get_database()
|
||||
collection_name = self.model_class._get_collection_name()
|
||||
collection = db[collection_name]
|
||||
|
||||
# Convert conditions to MongoDB query
|
||||
query = self.model_class._convert_conditions_to_query(self.conditions)
|
||||
|
||||
# Build cursor
|
||||
cursor = collection.find(query)
|
||||
|
||||
if self._skip_value:
|
||||
cursor = cursor.skip(self._skip_value)
|
||||
if self._limit_value:
|
||||
cursor = cursor.limit(self._limit_value)
|
||||
|
||||
# Execute query and convert to model instances
|
||||
docs = await cursor.to_list(length=None)
|
||||
results = []
|
||||
for doc in docs:
|
||||
# Extract MongoDB _id and convert to string
|
||||
mongo_id = doc.pop('_id', None)
|
||||
# Filter doc to only include fields defined in the model
|
||||
model_fields = set(self.model_class.model_fields.keys())
|
||||
filtered_doc = {k: v for k, v in doc.items() if k in model_fields}
|
||||
# Add the id field
|
||||
if mongo_id:
|
||||
filtered_doc['id'] = str(mongo_id)
|
||||
results.append(self.model_class(**filtered_doc))
|
||||
|
||||
return results
|
||||
|
||||
async def first_or_none(self) -> Optional[BaseDoc]:
|
||||
"""Get first result or None"""
|
||||
results = await self.limit(1).to_list()
|
||||
return results[0] if results else None
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Count number of matching documents"""
|
||||
db = await self.model_class._get_database()
|
||||
collection_name = self.model_class._get_collection_name()
|
||||
collection = db[collection_name]
|
||||
|
||||
query = self.model_class._convert_conditions_to_query(self.conditions)
|
||||
return await collection.count_documents(query)
|
||||
@ -1,18 +1,18 @@
|
||||
from beanie import Document
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from ..base_doc import BaseDoc
|
||||
|
||||
|
||||
class PermissionDoc(Document):
|
||||
class PermissionDoc(BaseDoc):
|
||||
permission_name: str
|
||||
permission_key: str
|
||||
description: Optional[str] = None # Description of the permission, optional
|
||||
created_at: datetime = datetime.now() # Creation timestamp, auto-generated
|
||||
updated_at: datetime = datetime.now() # Last update timestamp, auto-updated
|
||||
is_default: bool = False
|
||||
|
||||
|
||||
class Settings:
|
||||
# Default collections created by Freeleaps for tenant databases use '_' prefix
|
||||
# Default collections created by Freeleaps for tenant databases use '_' prefix
|
||||
# to prevent naming conflicts with tenant-created collections
|
||||
name = "_permission"
|
||||
indexes = [
|
||||
@ -20,7 +20,7 @@ class PermissionDoc(Document):
|
||||
]
|
||||
|
||||
|
||||
class RoleDoc(Document):
|
||||
class RoleDoc(BaseDoc):
|
||||
role_key: str
|
||||
role_name: str
|
||||
role_description: Optional[str] = None
|
||||
@ -32,14 +32,14 @@ class RoleDoc(Document):
|
||||
is_default: bool = False
|
||||
|
||||
class Settings:
|
||||
# Default collections created by Freeleaps for tenant databases use '_' prefix
|
||||
# Default collections created by Freeleaps for tenant databases use '_' prefix
|
||||
# to prevent naming conflicts with tenant-created collections
|
||||
name = "_role"
|
||||
indexes = [
|
||||
"role_level"
|
||||
]
|
||||
|
||||
class UserRoleDoc(Document):
|
||||
class UserRoleDoc(BaseDoc):
|
||||
"""User role doc"""
|
||||
user_id: str
|
||||
role_ids: Optional[List[str]]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from beanie import Document
|
||||
from ..base_doc import BaseDoc
|
||||
|
||||
from .constants import UserAccountProperty
|
||||
from backend.models.permission.constants import (
|
||||
@ -12,7 +12,7 @@ from common.constants.region import UserRegion
|
||||
from .constants import AuthType
|
||||
|
||||
|
||||
class UserAccountDoc(Document):
|
||||
class UserAccountDoc(BaseDoc):
|
||||
profile_id: Optional[str]
|
||||
account_id: Optional[str]
|
||||
service_plan_id: Optional[str]
|
||||
@ -25,7 +25,7 @@ class UserAccountDoc(Document):
|
||||
name = "user_account"
|
||||
|
||||
|
||||
class UserPasswordDoc(Document):
|
||||
class UserPasswordDoc(BaseDoc):
|
||||
user_id: str
|
||||
password: str
|
||||
|
||||
@ -33,7 +33,7 @@ class UserPasswordDoc(Document):
|
||||
name = "user_password"
|
||||
|
||||
|
||||
class UserEmailDoc(Document):
|
||||
class UserEmailDoc(BaseDoc):
|
||||
user_id: str
|
||||
email: str
|
||||
|
||||
@ -41,7 +41,7 @@ class UserEmailDoc(Document):
|
||||
name = "user_email"
|
||||
|
||||
|
||||
class UserMobileDoc(Document):
|
||||
class UserMobileDoc(BaseDoc):
|
||||
user_id: str
|
||||
mobile: str
|
||||
|
||||
@ -49,16 +49,17 @@ class UserMobileDoc(Document):
|
||||
name = "user_mobile"
|
||||
|
||||
|
||||
class AuthCodeDoc(Document):
|
||||
class AuthCodeDoc(BaseDoc):
|
||||
auth_code: str
|
||||
method: str
|
||||
method_type: AuthType
|
||||
expiry: datetime
|
||||
used: bool = False
|
||||
|
||||
class Settings:
|
||||
name = "user_auth_code"
|
||||
|
||||
class UsageLogDoc(Document):
|
||||
class UsageLogDoc(BaseDoc):
|
||||
timestamp: datetime = datetime.utcnow() # timestamp
|
||||
tenant_id: str # tenant id
|
||||
operation: str # operation type
|
||||
@ -69,7 +70,7 @@ class UsageLogDoc(Document):
|
||||
bytes_out: int # output bytes
|
||||
key_id: Optional[str] = None # API Key ID
|
||||
extra: dict = {} # extra information
|
||||
|
||||
|
||||
class Settings:
|
||||
name = "usage_log_doc"
|
||||
indexes = [
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from beanie import Document, Indexed
|
||||
from pydantic import BaseModel, EmailStr
|
||||
import re
|
||||
|
||||
from decimal import Decimal
|
||||
from common.constants.region import UserRegion
|
||||
from ..base_doc import BaseDoc
|
||||
|
||||
|
||||
class Tags(BaseModel):
|
||||
@ -47,10 +47,10 @@ class Password(BaseModel):
|
||||
expiry: datetime
|
||||
|
||||
|
||||
class BasicProfileDoc(Document):
|
||||
class BasicProfileDoc(BaseDoc):
|
||||
user_id: str
|
||||
first_name: Indexed(str) = "" # type: ignore
|
||||
last_name: Indexed(str) = "" # type: ignore Index for faster search
|
||||
first_name: str = ""
|
||||
last_name: str = ""
|
||||
spoken_language: List[str] = []
|
||||
self_intro: SelfIntro
|
||||
photo: Photo
|
||||
@ -94,7 +94,7 @@ class ExpectedSalary(BaseModel):
|
||||
hourly: Decimal = 0.0
|
||||
|
||||
|
||||
class ProviderProfileDoc(Document):
|
||||
class ProviderProfileDoc(BaseDoc):
|
||||
user_id: str
|
||||
expected_salary: ExpectedSalary
|
||||
accepting_request: bool = False
|
||||
|
||||
@ -4,7 +4,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
|
||||
from backend.infra.permission.permission_handler import PermissionHandler
|
||||
from backend.models.permission.models import PermissionDoc
|
||||
from beanie import PydanticObjectId
|
||||
from bson import ObjectId
|
||||
|
||||
class PermissionService:
|
||||
def __init__(self):
|
||||
@ -16,12 +16,12 @@ class PermissionService:
|
||||
|
||||
async def update_permission(self, permission_id: str, permission_key: Optional[str] = None, permission_name: Optional[str] = None, description: Optional[str] = None) -> PermissionDoc:
|
||||
"""Update an existing permission document by id"""
|
||||
return await self.permission_handler.update_permission(PydanticObjectId(permission_id), permission_key, permission_name, description)
|
||||
return await self.permission_handler.update_permission(permission_id, permission_key, permission_name, description)
|
||||
|
||||
async def create_or_update_permission(self, permission_key: str, permission_name: str, custom_permission_id: Optional[str], description: Optional[str] = None) -> PermissionDoc:
|
||||
"""Create or update a permission document"""
|
||||
return await self.permission_handler.create_or_update_permission(permission_key, permission_name, custom_permission_id, description)
|
||||
|
||||
|
||||
async def query_permissions(self, permission_key: Optional[str] = None, permission_name: Optional[str] = None, page: int = 1, page_size: int = 10) -> Dict[str, Any]:
|
||||
"""Query permissions with pagination and fuzzy search"""
|
||||
if page < 1 or page_size < 1:
|
||||
@ -29,7 +29,7 @@ class PermissionService:
|
||||
skip = (page - 1) * page_size
|
||||
docs, total = await self.permission_handler.query_permissions(permission_key, permission_name, skip, page_size)
|
||||
return {
|
||||
"items": [doc.dict() for doc in docs],
|
||||
"items": [doc.model_dump() for doc in docs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size
|
||||
@ -38,10 +38,10 @@ class PermissionService:
|
||||
"""Query permissions fuzzy search"""
|
||||
docs, total = await self.permission_handler.query_permissions_no_pagination(permission_id, permission_key, permission_name)
|
||||
return {
|
||||
"items": [doc.dict() for doc in docs],
|
||||
"items": [doc.model_dump() for doc in docs],
|
||||
"total": total
|
||||
}
|
||||
|
||||
async def delete_permission(self, permission_id: str) -> None:
|
||||
"""Delete a permission document after checking if it is referenced by any role"""
|
||||
return await self.permission_handler.delete_permission(PydanticObjectId(permission_id))
|
||||
return await self.permission_handler.delete_permission(permission_id)
|
||||
@ -4,7 +4,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
|
||||
from backend.infra.permission.role_handler import RoleHandler
|
||||
from backend.models.permission.models import RoleDoc
|
||||
from beanie import PydanticObjectId
|
||||
from bson import ObjectId
|
||||
|
||||
class RoleService:
|
||||
def __init__(self):
|
||||
@ -19,7 +19,7 @@ class RoleService:
|
||||
async def update_role(self, role_id: str, role_key: str, role_name: str, role_description: Optional[str], role_level: int) -> RoleDoc:
|
||||
"""Update an existing role, ensuring role_key and role_name are unique and not empty"""
|
||||
|
||||
doc = await self.role_handler.update_role(PydanticObjectId(role_id), role_key, role_name, role_description, role_level)
|
||||
doc = await self.role_handler.update_role(role_id, role_key, role_name, role_description, role_level)
|
||||
return doc
|
||||
|
||||
async def create_or_update_role(self, role_key: str, role_name: str, role_level: int, custom_role_id: Optional[str], role_description: Optional[str] = None) -> RoleDoc:
|
||||
@ -33,7 +33,7 @@ class RoleService:
|
||||
skip = (page - 1) * page_size
|
||||
docs, total = await self.role_handler.query_roles(role_key, role_name, skip, page_size)
|
||||
return {
|
||||
"items": [doc.dict() for doc in docs],
|
||||
"items": [doc.model_dump() for doc in docs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size
|
||||
@ -43,14 +43,14 @@ class RoleService:
|
||||
"""Query roles fuzzy search without pagination"""
|
||||
docs, total = await self.role_handler.query_roles_no_pagination(role_id, role_key, role_name)
|
||||
return {
|
||||
"items": [doc.dict() for doc in docs],
|
||||
"items": [doc.model_dump() for doc in docs],
|
||||
"total": total
|
||||
}
|
||||
|
||||
|
||||
async def assign_permissions_to_role(self, role_id: str, permission_ids: List[str]) -> RoleDoc:
|
||||
"""Assign permissions to a role by updating the permission_ids field"""
|
||||
return await self.role_handler.assign_permissions_to_role(PydanticObjectId(role_id), permission_ids)
|
||||
return await self.role_handler.assign_permissions_to_role(role_id, permission_ids)
|
||||
|
||||
async def delete_role(self, role_id: str) -> None:
|
||||
"""Delete a role document after checking if it is referenced by any user"""
|
||||
return await self.role_handler.delete_role(PydanticObjectId(role_id))
|
||||
return await self.role_handler.delete_role(role_id)
|
||||
@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from backend.infra.permission.permission_handler import PermissionHandler
|
||||
from backend.models.permission.models import PermissionDoc, RoleDoc
|
||||
from beanie import PydanticObjectId
|
||||
from bson import ObjectId
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db():
|
||||
@ -51,7 +51,7 @@ class TestPermissionHandler:
|
||||
self.MockPermissionDoc.get = AsyncMock(return_value=mock_doc)
|
||||
self.MockPermissionDoc.find_one = AsyncMock(return_value=None)
|
||||
mock_doc.save = AsyncMock()
|
||||
result = await self.handler.update_permission(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', 'name', 'desc')
|
||||
result = await self.handler.update_permission('507f1f77bcf86cd799439011', 'key', 'name', 'desc')
|
||||
assert result == mock_doc
|
||||
mock_doc.save.assert_awaited_once()
|
||||
|
||||
@ -60,15 +60,15 @@ class TestPermissionHandler:
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_permission(None, 'key', 'name', 'desc')
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_permission(PydanticObjectId('507f1f77bcf86cd799439011'), '', 'name', 'desc')
|
||||
await self.handler.update_permission('507f1f77bcf86cd799439011', '', 'name', 'desc')
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_permission(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', '', 'desc')
|
||||
await self.handler.update_permission('507f1f77bcf86cd799439011', 'key', '', 'desc')
|
||||
|
||||
async def test_update_permission_not_found(self):
|
||||
# Test updating a non-existent permission raises validation error
|
||||
self.MockPermissionDoc.get = AsyncMock(return_value=None)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_permission(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', 'name', 'desc')
|
||||
await self.handler.update_permission('507f1f77bcf86cd799439011', 'key', 'name', 'desc')
|
||||
|
||||
async def test_update_permission_is_default(self):
|
||||
# Test updating a default permission raises validation error
|
||||
@ -76,7 +76,7 @@ class TestPermissionHandler:
|
||||
mock_doc.is_default = True
|
||||
self.MockPermissionDoc.get = AsyncMock(return_value=mock_doc)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_permission(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', 'name', 'desc')
|
||||
await self.handler.update_permission('507f1f77bcf86cd799439011', 'key', 'name', 'desc')
|
||||
|
||||
async def test_update_permission_conflict(self):
|
||||
# Test updating a permission with duplicate key or name raises validation error
|
||||
@ -85,7 +85,7 @@ class TestPermissionHandler:
|
||||
self.MockPermissionDoc.get = AsyncMock(return_value=mock_doc)
|
||||
self.MockPermissionDoc.find_one = AsyncMock(return_value=MagicMock())
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_permission(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', 'name', 'desc')
|
||||
await self.handler.update_permission('507f1f77bcf86cd799439011', 'key', 'name', 'desc')
|
||||
|
||||
async def test_query_permissions_success(self):
|
||||
# Test querying permissions returns docs and total
|
||||
@ -106,7 +106,7 @@ class TestPermissionHandler:
|
||||
mock_doc.is_default = False
|
||||
self.MockPermissionDoc.get = AsyncMock(return_value=mock_doc)
|
||||
mock_doc.delete = AsyncMock()
|
||||
await self.handler.delete_permission(PydanticObjectId('507f1f77bcf86cd799439011'))
|
||||
await self.handler.delete_permission('507f1f77bcf86cd799439011')
|
||||
mock_doc.delete.assert_awaited_once()
|
||||
|
||||
async def test_delete_permission_missing_id(self):
|
||||
@ -118,14 +118,14 @@ class TestPermissionHandler:
|
||||
# Test deleting a permission referenced by a role raises validation error
|
||||
self.MockRoleDoc.find_one = AsyncMock(return_value=MagicMock())
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.delete_permission(PydanticObjectId('507f1f77bcf86cd799439011'))
|
||||
await self.handler.delete_permission('507f1f77bcf86cd799439011')
|
||||
|
||||
async def test_delete_permission_not_found(self):
|
||||
# Test deleting a non-existent permission raises validation error
|
||||
self.MockRoleDoc.find_one = AsyncMock(return_value=None)
|
||||
self.MockPermissionDoc.get = AsyncMock(return_value=None)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.delete_permission(PydanticObjectId('507f1f77bcf86cd799439011'))
|
||||
await self.handler.delete_permission('507f1f77bcf86cd799439011')
|
||||
|
||||
async def test_delete_permission_is_default(self):
|
||||
# Test deleting a default permission raises validation error
|
||||
@ -134,4 +134,4 @@ class TestPermissionHandler:
|
||||
mock_doc.is_default = True
|
||||
self.MockPermissionDoc.get = AsyncMock(return_value=mock_doc)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.delete_permission(PydanticObjectId('507f1f77bcf86cd799439011'))
|
||||
await self.handler.delete_permission('507f1f77bcf86cd799439011')
|
||||
@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from backend.infra.permission.role_handler import RoleHandler
|
||||
from backend.models.permission.models import RoleDoc, PermissionDoc, UserRoleDoc
|
||||
from beanie import PydanticObjectId
|
||||
from bson import ObjectId
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db():
|
||||
@ -52,7 +52,7 @@ class TestRoleHandler:
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=mock_doc)
|
||||
self.MockRoleDoc.find_one = AsyncMock(return_value=None)
|
||||
mock_doc.save = AsyncMock()
|
||||
result = await self.handler.update_role(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', 'name', 'desc', 1)
|
||||
result = await self.handler.update_role('507f1f77bcf86cd799439011', 'key', 'name', 'desc', 1)
|
||||
assert result == mock_doc
|
||||
mock_doc.save.assert_awaited_once()
|
||||
|
||||
@ -61,15 +61,15 @@ class TestRoleHandler:
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_role(None, 'key', 'name', 'desc', 1)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_role(PydanticObjectId('507f1f77bcf86cd799439011'), '', 'name', 'desc', 1)
|
||||
await self.handler.update_role('507f1f77bcf86cd799439011', '', 'name', 'desc', 1)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_role(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', '', 'desc', 1)
|
||||
await self.handler.update_role('507f1f77bcf86cd799439011', 'key', '', 'desc', 1)
|
||||
|
||||
async def test_update_role_not_found(self):
|
||||
# Test updating a non-existent role raises validation error
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=None)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_role(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', 'name', 'desc', 1)
|
||||
await self.handler.update_role('507f1f77bcf86cd799439011', 'key', 'name', 'desc', 1)
|
||||
|
||||
async def test_update_role_is_default(self):
|
||||
# Test updating a default role raises validation error
|
||||
@ -77,7 +77,7 @@ class TestRoleHandler:
|
||||
mock_doc.is_default = True
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=mock_doc)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_role(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', 'name', 'desc', 1)
|
||||
await self.handler.update_role('507f1f77bcf86cd799439011', 'key', 'name', 'desc', 1)
|
||||
|
||||
async def test_update_role_conflict(self):
|
||||
# Test updating a role with duplicate key or name raises validation error
|
||||
@ -86,7 +86,7 @@ class TestRoleHandler:
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=mock_doc)
|
||||
self.MockRoleDoc.find_one = AsyncMock(return_value=MagicMock())
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.update_role(PydanticObjectId('507f1f77bcf86cd799439011'), 'key', 'name', 'desc', 1)
|
||||
await self.handler.update_role('507f1f77bcf86cd799439011', 'key', 'name', 'desc', 1)
|
||||
|
||||
async def test_query_roles_success(self):
|
||||
# Test querying roles returns docs and total
|
||||
@ -106,7 +106,7 @@ class TestRoleHandler:
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=mock_doc)
|
||||
self.MockPermissionDoc.get = AsyncMock(return_value=MagicMock())
|
||||
mock_doc.save = AsyncMock()
|
||||
result = await self.handler.assign_permissions_to_role(PydanticObjectId('507f1f77bcf86cd799439011'), ['507f1f77bcf86cd799439011', '507f1f77bcf86cd799439011'])
|
||||
result = await self.handler.assign_permissions_to_role('507f1f77bcf86cd799439011', ['507f1f77bcf86cd799439011', '507f1f77bcf86cd799439011'])
|
||||
assert result == mock_doc
|
||||
mock_doc.save.assert_awaited_once()
|
||||
|
||||
@ -115,13 +115,13 @@ class TestRoleHandler:
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.assign_permissions_to_role(None, ['pid1'])
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.assign_permissions_to_role(PydanticObjectId('507f1f77bcf86cd799439011'), None)
|
||||
await self.handler.assign_permissions_to_role('507f1f77bcf86cd799439011', None)
|
||||
|
||||
async def test_assign_permissions_to_role_role_not_found(self):
|
||||
# Test assigning permissions to a non-existent role raises validation error
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=None)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.assign_permissions_to_role(PydanticObjectId('507f1f77bcf86cd799439011'), ['507f1f77bcf86cd799439011'])
|
||||
await self.handler.assign_permissions_to_role('507f1f77bcf86cd799439011', ['507f1f77bcf86cd799439011'])
|
||||
|
||||
async def test_assign_permissions_to_role_permission_not_found(self):
|
||||
# Test assigning a non-existent permission raises validation error
|
||||
@ -129,7 +129,7 @@ class TestRoleHandler:
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=mock_doc)
|
||||
self.MockPermissionDoc.get = AsyncMock(side_effect=[None, MagicMock()])
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.assign_permissions_to_role(PydanticObjectId('507f1f77bcf86cd799439011'), ['507f1f77bcf86cd799439011', '507f1f77bcf86cd799439011'])
|
||||
await self.handler.assign_permissions_to_role('507f1f77bcf86cd799439011', ['507f1f77bcf86cd799439011', '507f1f77bcf86cd799439011'])
|
||||
|
||||
async def test_delete_role_success(self):
|
||||
# Test deleting a role successfully
|
||||
@ -138,7 +138,7 @@ class TestRoleHandler:
|
||||
mock_doc.is_default = False
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=mock_doc)
|
||||
mock_doc.delete = AsyncMock()
|
||||
await self.handler.delete_role(PydanticObjectId('507f1f77bcf86cd799439011'))
|
||||
await self.handler.delete_role('507f1f77bcf86cd799439011')
|
||||
mock_doc.delete.assert_awaited_once()
|
||||
|
||||
async def test_delete_role_missing_id(self):
|
||||
@ -150,14 +150,14 @@ class TestRoleHandler:
|
||||
# Test deleting a role referenced by a user raises validation error
|
||||
self.MockUserRoleDoc.find_one = AsyncMock(return_value=MagicMock())
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.delete_role(PydanticObjectId('507f1f77bcf86cd799439011'))
|
||||
await self.handler.delete_role('507f1f77bcf86cd799439011')
|
||||
|
||||
async def test_delete_role_not_found(self):
|
||||
# Test deleting a non-existent role raises validation error
|
||||
self.MockUserRoleDoc.find_one = AsyncMock(return_value=None)
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=None)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.delete_role(PydanticObjectId('507f1f77bcf86cd799439011'))
|
||||
await self.handler.delete_role('507f1f77bcf86cd799439011')
|
||||
|
||||
async def test_delete_role_is_default(self):
|
||||
# Test deleting a default role raises validation error
|
||||
@ -166,4 +166,4 @@ class TestRoleHandler:
|
||||
mock_doc.is_default = True
|
||||
self.MockRoleDoc.get = AsyncMock(return_value=mock_doc)
|
||||
with pytest.raises(RequestValidationError):
|
||||
await self.handler.delete_role(PydanticObjectId('507f1f77bcf86cd799439011'))
|
||||
await self.handler.delete_role('507f1f77bcf86cd799439011')
|
||||
@ -2,6 +2,7 @@ from fastapi import Request, status, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from webapi.middleware.freeleaps_auth_middleware import request_context_var
|
||||
from common.log.module_logger import ModuleLogger
|
||||
from backend.models.base_doc import BaseDoc
|
||||
|
||||
|
||||
class DatabaseMiddleware:
|
||||
@ -39,41 +40,50 @@ class DatabaseMiddleware:
|
||||
return await response(scope, receive, send)
|
||||
|
||||
if not product_id:
|
||||
# Compatibility / public routes: use main database with tenant models initialized
|
||||
# Compatibility / public routes: use main database
|
||||
await self.module_logger.log_info(f"No product_id - using main database for path: {request.url.path}")
|
||||
|
||||
# Get main database with Beanie initialized for tenant models
|
||||
|
||||
# Get main database (no Beanie initialization needed with BaseDoc)
|
||||
main_db_initialized = await tenant_cache.get_main_db_initialized()
|
||||
|
||||
|
||||
request.state.db = main_db_initialized
|
||||
request.state.product_id = None
|
||||
await self.module_logger.log_info(f"Successfully initialized main database with tenant models")
|
||||
|
||||
# Set the database context for BaseDoc models
|
||||
BaseDoc.set_tenant_database(main_db_initialized)
|
||||
|
||||
await self.module_logger.log_info(f"Successfully initialized main database")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
try:
|
||||
# Get tenant-specific database with Beanie already initialized (cached)
|
||||
# Get tenant-specific database (no Beanie initialization needed with BaseDoc)
|
||||
await self.module_logger.log_info(f"Attempting to get tenant database for product_id: {product_id}")
|
||||
tenant_db = await tenant_cache.get_initialized_db(product_id)
|
||||
|
||||
|
||||
request.state.db = tenant_db
|
||||
request.state.product_id = product_id
|
||||
await self.module_logger.log_info(f"Successfully retrieved cached tenant database with Beanie for product_id: {product_id}")
|
||||
|
||||
except HTTPException as e:
|
||||
# Handle tenant not found or inactive (HTTPException from TenantDBCache)
|
||||
await self.module_logger.log_error(f"Tenant error for {product_id}: [{e.status_code}] {e.detail}")
|
||||
# Set the database context for BaseDoc models
|
||||
BaseDoc.set_tenant_database(tenant_db)
|
||||
|
||||
await self.module_logger.log_info(f"Successfully retrieved tenant database for product_id: {product_id}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
except ValueError as e:
|
||||
# Handle tenant not found or inactive (ValueError from TenantDBCache)
|
||||
await self.module_logger.log_error(f"Tenant error for {product_id}: {str(e)}")
|
||||
response = JSONResponse(
|
||||
status_code=e.status_code,
|
||||
content={"detail": e.detail}
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={"detail": str(e)}
|
||||
)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await self.module_logger.log_error(f"Database error for tenant {product_id}: {str(e)}")
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": "Database connection error"}
|
||||
)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
@ -1,5 +1,4 @@
|
||||
from webapi.config.site_settings import site_settings
|
||||
from beanie import init_beanie
|
||||
from fastapi import HTTPException
|
||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||
from backend.models.user.models import (
|
||||
@ -70,7 +69,7 @@ class TenantDBCache:
|
||||
self.module_logger = ModuleLogger(sender_id="TenantDBCache")
|
||||
|
||||
async def get_initialized_db(self, product_id: str) -> AsyncIOMotorDatabase:
|
||||
"""Get tenant database with Beanie already initialized"""
|
||||
"""Get tenant database (no Beanie initialization needed with BaseDoc)"""
|
||||
|
||||
# fast-path: check if client is cached
|
||||
cached_client = self._cache.get(product_id)
|
||||
@ -81,9 +80,7 @@ class TenantDBCache:
|
||||
# Get fresh database instance from cached client
|
||||
db = cached_client.get_default_database()
|
||||
if db is not None:
|
||||
# Initialize Beanie for this fresh database instance
|
||||
await init_beanie(database=db, document_models=tenant_document_models)
|
||||
await self.module_logger.log_info(f"Beanie initialization completed for {product_id} using cached client")
|
||||
await self.module_logger.log_info(f"Using cached client for {product_id}")
|
||||
return db
|
||||
else:
|
||||
await self.module_logger.log_error(f"No default database found for cached client {product_id}")
|
||||
@ -101,9 +98,7 @@ class TenantDBCache:
|
||||
# Get fresh database instance from cached client
|
||||
db = cached_client.get_default_database()
|
||||
if db is not None:
|
||||
# Initialize Beanie for this fresh database instance
|
||||
await init_beanie(database=db, document_models=tenant_document_models)
|
||||
await self.module_logger.log_info(f"Beanie initialization completed for {product_id} using cached client (double-check)")
|
||||
await self.module_logger.log_info(f"Using cached client for {product_id} (double-check)")
|
||||
return db
|
||||
else:
|
||||
await self.module_logger.log_error(f"No default database found for cached client {product_id}")
|
||||
@ -152,20 +147,14 @@ class TenantDBCache:
|
||||
headers={"X-Error-Message": f"No default database found for tenant {product_id}"}
|
||||
)
|
||||
|
||||
# Initialize Beanie for this tenant database
|
||||
await init_beanie(database=db, document_models=tenant_document_models)
|
||||
await self.module_logger.log_info(f"Beanie initialization completed for new tenant database {product_id}")
|
||||
|
||||
# Cache only the client
|
||||
await self._lru_put(product_id, client)
|
||||
await self.module_logger.log_info(f"Tenant client {product_id} cached successfully")
|
||||
return db
|
||||
|
||||
async def get_main_db_initialized(self) -> AsyncIOMotorDatabase:
|
||||
"""Get main database with Beanie initialized for tenant models"""
|
||||
# Re-initialize Beanie for main database with business models
|
||||
await init_beanie(database=self.main_db, document_models=document_models)
|
||||
await self.module_logger.log_info("Beanie initialization completed for main database")
|
||||
"""Get main database (no Beanie initialization needed with BaseDoc)"""
|
||||
await self.module_logger.log_info("Main database ready (using BaseDoc)")
|
||||
return self.main_db
|
||||
|
||||
async def _lru_put(self, key: str, client: AsyncIOMotorClient):
|
||||
@ -226,14 +215,11 @@ async def initiate_database(app):
|
||||
MAIN_CLIENT = AsyncIOMotorClient(app_settings.MONGODB_URI)
|
||||
main_db = MAIN_CLIENT[app_settings.MONGODB_NAME]
|
||||
|
||||
# 2) Initialize Beanie for main DB with business document models
|
||||
await init_beanie(database=main_db, document_models=document_models)
|
||||
|
||||
# 3) Create tenant cache that uses main_db lookups to resolve product_id -> tenant db
|
||||
# 2) Create tenant cache that uses main_db lookups to resolve product_id -> tenant db
|
||||
max_cache_size = getattr(app_settings, 'TENANT_CACHE_MAX', 64)
|
||||
TENANT_CACHE = TenantDBCache(main_db, max_size=max_cache_size)
|
||||
|
||||
# 4) Store on app state for middleware to access
|
||||
# 3) Store on app state for middleware to access
|
||||
app.state.main_db = main_db
|
||||
app.state.tenant_cache = TENANT_CACHE
|
||||
|
||||
|
||||
@ -42,4 +42,4 @@ async def create_or_update_permission(
|
||||
) -> PermissionResponse:
|
||||
doc = await permission_service.create_or_update_permission(req.permission_key, req.permission_name, req.custom_permission_id,
|
||||
req.description)
|
||||
return PermissionResponse(**doc.dict())
|
||||
return PermissionResponse(**doc.model_dump())
|
||||
|
||||
@ -40,4 +40,4 @@ async def create_permission(
|
||||
) -> PermissionResponse:
|
||||
doc = await permission_service.create_permission(req.permission_key, req.permission_name, req.description)
|
||||
|
||||
return PermissionResponse(**doc.dict())
|
||||
return PermissionResponse(**doc.model_dump())
|
||||
|
||||
@ -42,4 +42,4 @@ async def update_permission(
|
||||
) -> PermissionResponse:
|
||||
doc = await permission_service.update_permission(req.permission_id, req.permission_key, req.permission_name,
|
||||
req.description)
|
||||
return PermissionResponse(**doc.dict())
|
||||
return PermissionResponse(**doc.model_dump())
|
||||
|
||||
@ -37,4 +37,4 @@ async def assign_permissions_to_role(
|
||||
#_: bool = Depends(token_manager.has_all_permissions([DefaultPermissionEnum.CHANGE_ROLES.value.permission_key]))
|
||||
) -> RoleResponse:
|
||||
doc = await role_service.assign_permissions_to_role(req.role_id, req.permission_ids)
|
||||
return RoleResponse(**doc.dict())
|
||||
return RoleResponse(**doc.model_dump())
|
||||
@ -19,7 +19,7 @@ class CreateOrUpdateRoleRequest(BaseModel):
|
||||
role_level: int
|
||||
custom_role_id: Optional[str] = None
|
||||
role_description: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
|
||||
class RoleResponse(BaseModel):
|
||||
@ -45,5 +45,5 @@ async def create_or_update_permission(
|
||||
#_: bool = Depends(token_manager.has_all_permissions([DefaultPermissionEnum.CHANGE_PERMISSIONS.value.permission_key]))
|
||||
) -> RoleResponse:
|
||||
doc = await role_service.create_or_update_role(req.role_key, req.role_name, req.role_level, req.custom_role_id, req.role_description)
|
||||
|
||||
return RoleResponse(**doc.dict())
|
||||
|
||||
return RoleResponse(**doc.model_dump())
|
||||
|
||||
@ -42,4 +42,4 @@ async def create_role(
|
||||
# _: bool = Depends(token_manager.has_all_permissions([DefaultPermissionEnum.CHANGE_ROLES.value.permission_key]))
|
||||
) -> RoleResponse:
|
||||
doc = await role_service.create_role(req.role_key, req.role_name, req.role_description, req.role_level)
|
||||
return RoleResponse(**doc.dict())
|
||||
return RoleResponse(**doc.model_dump())
|
||||
|
||||
@ -43,4 +43,4 @@ async def update_role(
|
||||
#_: bool = Depends(token_manager.has_all_permissions([DefaultPermissionEnum.CHANGE_ROLES.value.permission_key]))
|
||||
) -> RoleResponse:
|
||||
doc = await role_service.update_role(req.role_id, req.role_key, req.role_name, req.role_description, req.role_level)
|
||||
return RoleResponse(**doc.dict())
|
||||
return RoleResponse(**doc.model_dump())
|
||||
|
||||
@ -33,4 +33,4 @@ async def assign_roles_to_user(
|
||||
#_: bool = Depends(token_manager.has_all_permissions([DefaultPermissionEnum.INVITE_COLLABORATOR.value.permission_key])),
|
||||
) -> UserRoleResponse:
|
||||
doc = await user_management_service.assign_roles_to_user(req.user_id, req.role_ids)
|
||||
return UserRoleResponse(**doc.dict())
|
||||
return UserRoleResponse(**doc.model_dump())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user