from datetime import datetime, timezone from typing import Dict, Optional, Tuple from backend.infra.payment.models import StripeTransactionDoc from backend.infra.payment.constants import TransactionStatus from common.config.app_settings import app_settings import stripe from stripe.error import SignatureVerificationError from common.log.module_logger import ModuleLogger from decimal import Decimal import json stripe.api_key = app_settings.STRIPE_API_KEY class StripeManager: def __init__(self) -> None: self.site_url_root = app_settings.SITE_URL_ROOT.rstrip("/") self.module_logger = ModuleLogger(sender_id="StripeManager") async def create_stripe_account(self) -> Optional[str]: account = stripe.Account.create(type="standard") return account.id async def create_account_link(self, account_id: str, link_type: str = "account_onboarding") -> Optional[str]: account = stripe.Account.retrieve(account_id) # For account_update, try to show dashboard if TOS is accepted self.module_logger.log_info("create_account_link urls", { "redirect_url": "{}/work".format(self.site_url_root), "refresh_url": "{}/front-door".format(self.site_url_root), "return_url": "{}/work".format(self.site_url_root) } ) if link_type == "account_update" and account.tos_acceptance.date: login_link = stripe.Account.create_login_link( account_id, redirect_url="{}/work".format(self.site_url_root) ) return login_link.url # Otherwise show onboarding account_link = stripe.AccountLink.create( account=account_id, refresh_url="{}/front-door".format(self.site_url_root), return_url="{}/work".format(self.site_url_root), type="account_onboarding", ) return account_link.url async def can_account_receive_payments(self, account_id: str) -> bool: account = stripe.Account.retrieve(account_id) if account.capabilities and account.capabilities["transfers"] == "active": return True else: return False async def __fetch_transaction_by_id( self, transaction_id: str ) -> Optional[StripeTransactionDoc]: transaction = await StripeTransactionDoc.get(transaction_id) if transaction: return transaction return None async def fetch_transaction_by_id( self, transaction_id: str ) -> Optional[Dict[str, any]]: transaction = await StripeTransactionDoc.get(transaction_id) if transaction: return transaction.model_dump() return None async def __fetch_transaction_by_session_id( self, session_id: str ) -> Optional[StripeTransactionDoc]: transactions = await StripeTransactionDoc.find( StripeTransactionDoc.stripe_checkout_session_id == session_id ).to_list() if len(transactions) > 1: await self.module_logger.log_error( error="More than one transaction found for session_id: {}".format( session_id ), properties={"session_id": session_id}, ) elif len(transactions) == 0: return None return transactions[0] async def fetch_transaction_by_session_id( self, session_id: str ) -> Optional[Dict[str, any]]: transaction = await StripeTransactionDoc.find( StripeTransactionDoc.stripe_checkout_session_id == session_id ).to_list() if len(transaction) > 1: await self.module_logger.log_error( error="More than one transaction found for session_id: {}".format( session_id ), properties={"session_id": session_id}, ) elif len(transaction) == 0: return None return transaction[0].model_dump() async def fetch_stripe_transaction_for_milestone( self, project_id: str, milestone_index: int ) -> Optional[Dict[str, any]]: transaction = await StripeTransactionDoc.find( StripeTransactionDoc.project_id == project_id, StripeTransactionDoc.milestone_index == milestone_index, ).to_list() if len(transaction) > 1: await self.module_logger.log_error( error="More than one transaction found for project_id: {} and milestone_index: {}".format( project_id, milestone_index ), properties={ "project_id": project_id, "milestone_index": milestone_index, }, ) elif len(transaction) == 0: return None return transaction[0].model_dump() async def create_stripe_transaction_for_milestone( self, project_id: str, milestone_index: int, currency: str, expected_payment: Decimal, from_user: str, to_user: str, to_stripe_account_id: str, ) -> Optional[str]: transactions = await StripeTransactionDoc.find( StripeTransactionDoc.project_id == project_id, StripeTransactionDoc.milestone_index == milestone_index, ).to_list() if len(transactions) == 0: transaction_doc = StripeTransactionDoc( project_id=project_id, milestone_index=milestone_index, currency=currency, unit_amount=int(expected_payment * 100), from_user=from_user, to_user=to_user, to_stripe_account_id=to_stripe_account_id, created_time=datetime.now(timezone.utc), updated_time=datetime.now(timezone.utc), status=TransactionStatus.PENDING, ) transaction = await transaction_doc.create() return transaction.id else: await self.module_logger.log_error( error="Transaction already exists for project_id: {} and milestone_index: {}".format( project_id, milestone_index ), properties={ "project_id": project_id, "milestone_index": milestone_index, }, ) res = transactions[0].id return res async def create_payment_link(self, transaction_id: str) -> Optional[str]: transaction = await StripeTransactionDoc.get(transaction_id) if transaction: if transaction.stripe_payment_link: return transaction.stripe_payment_link if not transaction.stripe_product_id: product = stripe.Product.create( name="{}-{}".format( transaction.project_id, transaction.milestone_index ) ) transaction.stripe_product_id = product.id await transaction.save() if not transaction.stripe_price_id: price = stripe.Price.create( unit_amount=transaction.unit_amount, currency=transaction.currency, product=transaction.stripe_product_id, ) transaction.stripe_price_id = price.id await transaction.save() payment_link = stripe.PaymentLink.create( line_items=[ { "price": transaction.stripe_price_id, "quantity": 1, } ], application_fee_amount=transaction.application_fee_amount, on_behalf_of=transaction.to_stripe_account_id, transfer_data={ "destination": transaction.to_stripe_account_id, }, ) if payment_link: transaction.stripe_payment_link = payment_link.url transaction.updated_time = datetime.now(timezone.utc) await transaction.save() return payment_link.url else: return None async def create_checkout_session( self, transaction_id: str ) -> Tuple[Optional[str], Optional[str]]: transaction = await StripeTransactionDoc.get(transaction_id) if transaction: if transaction.stripe_checkout_session_id: session = stripe.checkout.Session.retrieve( transaction.stripe_checkout_session_id ) expires_at_timestamp = session.expires_at expires_at_utc = datetime.fromtimestamp( expires_at_timestamp, tz=timezone.utc ) if datetime.now(timezone.utc) < expires_at_utc: return ( transaction.stripe_checkout_session_id, transaction.stripe_checkout_session_url, ) # Check connected account capabilities connected_account = stripe.Account.retrieve( transaction.to_stripe_account_id ) # if ( # connected_account.capabilities.get("card_payments") != "active" # or connected_account.capabilities.get("transfers") != "active" # ): # raise Exception( # f"Connected account {transaction.to_stripe_account_id} lacks required capabilities." # ) if not transaction.stripe_product_id: product = stripe.Product.create( name="{}-{}".format( transaction.project_id, transaction.milestone_index ) ) transaction.stripe_product_id = product.id await transaction.save() if not transaction.stripe_price_id: price = stripe.Price.create( unit_amount=transaction.unit_amount, currency=transaction.currency, product=transaction.stripe_product_id, ) transaction.stripe_price_id = price.id await transaction.save() session = stripe.checkout.Session.create( payment_method_types=["card"], line_items=[ { "price": transaction.stripe_price_id, "quantity": 1, } ], payment_intent_data={ "on_behalf_of": transaction.to_stripe_account_id, "application_fee_amount": transaction.application_fee_amount, "transfer_data": { "destination": transaction.to_stripe_account_id, }, }, mode="payment", success_url="{}/projects".format( self.site_url_root ), # needs to be set, local: http://localhost/ cancel_url="{}/projects".format(self.site_url_root), ) if session: transaction.stripe_checkout_session_id = session.id transaction.stripe_checkout_session_url = session.url transaction.updated_time = datetime.now(timezone.utc) await transaction.save() return session.id, session.url else: return None, None async def fetch_payment_link(self, transaction_id: str) -> Optional[str]: transaction = await StripeTransactionDoc.get(transaction_id) if transaction and transaction.stripe_payment_link: return transaction.stripe_payment_link return None async def fetch_checkout_session_id(self, transaction_id: str) -> Optional[str]: transaction = await StripeTransactionDoc.get(transaction_id) if transaction and transaction.stripe_checkout_session_id: return transaction.stripe_checkout_session_id return None async def fetch_checkout_session_url(self, transaction_id: str) -> Optional[str]: transaction = await StripeTransactionDoc.get(transaction_id) if transaction and transaction.stripe_checkout_session_url: return transaction.stripe_checkout_session_url return None async def invoke_checkout_session_webhook( self, event: dict ) -> Tuple[bool, Optional[str], Optional[str]]: # Handle the checkout.session.completed event if event["type"] == "checkout.session.completed": session = event["data"]["object"] transaction = await self.__fetch_transaction_by_session_id(session["id"]) if not transaction: await self.module_logger.log_error( error="Transaction not found for session_id: {}".format(session["id"]), properties={"session_id": session["id"]}, ) return False transaction.status = TransactionStatus.COMPLETED transaction.updated_time = datetime.now(timezone.utc) await transaction.save() return True, transaction.project_id, transaction.milestone_index return False, None, None