350 lines
13 KiB
Python
350 lines
13 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Dict, Optional, Tuple
|
|
from backend.infra.payment.models import StripeTransactionDoc
|
|
from backend.infra.payment.constants import TransactionStatus
|
|
from common.config.app_settings import app_settings
|
|
import stripe
|
|
from stripe.error import SignatureVerificationError
|
|
from common.log.module_logger import ModuleLogger
|
|
from decimal import Decimal
|
|
import json
|
|
|
|
|
|
stripe.api_key = app_settings.STRIPE_API_KEY
|
|
|
|
|
|
class StripeManager:
|
|
def __init__(self) -> None:
|
|
self.site_url_root = app_settings.SITE_URL_ROOT.rstrip("/")
|
|
self.module_logger = ModuleLogger(sender_id="StripeManager")
|
|
|
|
async def create_stripe_account(self) -> Optional[str]:
|
|
account = stripe.Account.create(type="standard")
|
|
return account.id
|
|
|
|
async def create_account_link(self, account_id: str) -> Optional[str]:
|
|
|
|
account_link = stripe.AccountLink.create(
|
|
account=account_id,
|
|
refresh_url="{}/front-door".format(self.site_url_root),
|
|
return_url="{}/user-profile".format(self.site_url_root),
|
|
type="account_onboarding",
|
|
)
|
|
return account_link.url
|
|
|
|
async def can_account_receive_payments(self, account_id: str) -> bool:
|
|
account = stripe.Account.retrieve(account_id)
|
|
|
|
if account.capabilities and account.capabilities["transfers"] == "active":
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
async def __fetch_transaction_by_id(
|
|
self, transaction_id: str
|
|
) -> Optional[StripeTransactionDoc]:
|
|
transaction = await StripeTransactionDoc.get(transaction_id)
|
|
if transaction:
|
|
return transaction
|
|
|
|
return None
|
|
|
|
async def fetch_transaction_by_id(
|
|
self, transaction_id: str
|
|
) -> Optional[Dict[str, any]]:
|
|
transaction = await StripeTransactionDoc.get(transaction_id)
|
|
if transaction:
|
|
return transaction.model_dump()
|
|
|
|
return None
|
|
|
|
async def __fetch_transaction_by_session_id(
|
|
self, session_id: str
|
|
) -> Optional[StripeTransactionDoc]:
|
|
transactions = await StripeTransactionDoc.find(
|
|
StripeTransactionDoc.stripe_checkout_session_id == session_id
|
|
).to_list()
|
|
|
|
if len(transactions) > 1:
|
|
await self.module_logger.log_error(
|
|
error="More than one transaction found for session_id: {}".format(
|
|
session_id
|
|
),
|
|
properties={"session_id": session_id},
|
|
)
|
|
elif len(transactions) == 0:
|
|
return None
|
|
|
|
return transactions[0]
|
|
|
|
async def fetch_transaction_by_session_id(
|
|
self, session_id: str
|
|
) -> Optional[Dict[str, any]]:
|
|
transaction = await StripeTransactionDoc.find(
|
|
StripeTransactionDoc.stripe_checkout_session_id == session_id
|
|
).to_list()
|
|
|
|
if len(transaction) > 1:
|
|
await self.module_logger.log_error(
|
|
error="More than one transaction found for session_id: {}".format(
|
|
session_id
|
|
),
|
|
properties={"session_id": session_id},
|
|
)
|
|
elif len(transaction) == 0:
|
|
return None
|
|
|
|
return transaction[0].model_dump()
|
|
|
|
async def fetch_stripe_transaction_for_milestone(
|
|
self, project_id: str, milestone_index: int
|
|
) -> Optional[Dict[str, any]]:
|
|
transaction = await StripeTransactionDoc.find(
|
|
StripeTransactionDoc.project_id == project_id,
|
|
StripeTransactionDoc.milestone_index == milestone_index,
|
|
).to_list()
|
|
if len(transaction) > 1:
|
|
await self.module_logger.log_error(
|
|
error="More than one transaction found for project_id: {} and milestone_index: {}".format(
|
|
project_id, milestone_index
|
|
),
|
|
properties={
|
|
"project_id": project_id,
|
|
"milestone_index": milestone_index,
|
|
},
|
|
)
|
|
|
|
elif len(transaction) == 0:
|
|
return None
|
|
|
|
return transaction[0].model_dump()
|
|
|
|
async def create_stripe_transaction_for_milestone(
|
|
self,
|
|
project_id: str,
|
|
milestone_index: int,
|
|
currency: str,
|
|
expected_payment: Decimal,
|
|
from_user: str,
|
|
to_user: str,
|
|
to_stripe_account_id: str,
|
|
) -> Optional[str]:
|
|
transactions = await StripeTransactionDoc.find(
|
|
StripeTransactionDoc.project_id == project_id,
|
|
StripeTransactionDoc.milestone_index == milestone_index,
|
|
).to_list()
|
|
|
|
if len(transactions) == 0:
|
|
transaction_doc = StripeTransactionDoc(
|
|
project_id=project_id,
|
|
milestone_index=milestone_index,
|
|
currency=currency,
|
|
unit_amount=int(expected_payment * 100),
|
|
from_user=from_user,
|
|
to_user=to_user,
|
|
to_stripe_account_id=to_stripe_account_id,
|
|
created_time=datetime.now(timezone.utc),
|
|
updated_time=datetime.now(timezone.utc),
|
|
status=TransactionStatus.PENDING,
|
|
)
|
|
transaction = await transaction_doc.create()
|
|
return transaction.id
|
|
else:
|
|
await self.module_logger.log_error(
|
|
error="Transaction already exists for project_id: {} and milestone_index: {}".format(
|
|
project_id, milestone_index
|
|
),
|
|
properties={
|
|
"project_id": project_id,
|
|
"milestone_index": milestone_index,
|
|
},
|
|
)
|
|
res = transactions[0].id
|
|
return res
|
|
|
|
async def create_payment_link(self, transaction_id: str) -> Optional[str]:
|
|
transaction = await StripeTransactionDoc.get(transaction_id)
|
|
if transaction:
|
|
if transaction.stripe_payment_link:
|
|
return transaction.stripe_payment_link
|
|
|
|
if not transaction.stripe_product_id:
|
|
product = stripe.Product.create(
|
|
name="{}-{}".format(
|
|
transaction.project_id, transaction.milestone_index
|
|
)
|
|
)
|
|
transaction.stripe_product_id = product.id
|
|
await transaction.save()
|
|
|
|
if not transaction.stripe_price_id:
|
|
price = stripe.Price.create(
|
|
unit_amount=transaction.unit_amount,
|
|
currency=transaction.currency,
|
|
product=transaction.stripe_product_id,
|
|
)
|
|
transaction.stripe_price_id = price.id
|
|
await transaction.save()
|
|
|
|
payment_link = stripe.PaymentLink.create(
|
|
line_items=[
|
|
{
|
|
"price": transaction.stripe_price_id,
|
|
"quantity": 1,
|
|
}
|
|
],
|
|
application_fee_amount=transaction.application_fee_amount,
|
|
on_behalf_of=transaction.to_stripe_account_id,
|
|
transfer_data={
|
|
"destination": transaction.to_stripe_account_id,
|
|
},
|
|
)
|
|
|
|
if payment_link:
|
|
transaction.stripe_payment_link = payment_link.url
|
|
transaction.updated_time = datetime.now(timezone.utc)
|
|
await transaction.save()
|
|
|
|
return payment_link.url
|
|
else:
|
|
return None
|
|
|
|
async def create_checkout_session(
|
|
self, transaction_id: str
|
|
) -> Tuple[Optional[str], Optional[str]]:
|
|
transaction = await StripeTransactionDoc.get(transaction_id)
|
|
if transaction:
|
|
if transaction.stripe_checkout_session_id:
|
|
session = stripe.checkout.Session.retrieve(
|
|
transaction.stripe_checkout_session_id
|
|
)
|
|
expires_at_timestamp = session.expires_at
|
|
expires_at_utc = datetime.fromtimestamp(
|
|
expires_at_timestamp, tz=timezone.utc
|
|
)
|
|
|
|
if datetime.now(timezone.utc) < expires_at_utc:
|
|
return (
|
|
transaction.stripe_checkout_session_id,
|
|
transaction.stripe_checkout_session_url,
|
|
)
|
|
|
|
# Check connected account capabilities
|
|
connected_account = stripe.Account.retrieve(
|
|
transaction.to_stripe_account_id
|
|
)
|
|
# if (
|
|
# connected_account.capabilities.get("card_payments") != "active"
|
|
# or connected_account.capabilities.get("transfers") != "active"
|
|
# ):
|
|
# raise Exception(
|
|
# f"Connected account {transaction.to_stripe_account_id} lacks required capabilities."
|
|
# )
|
|
|
|
if not transaction.stripe_product_id:
|
|
product = stripe.Product.create(
|
|
name="{}-{}".format(
|
|
transaction.project_id, transaction.milestone_index
|
|
)
|
|
)
|
|
transaction.stripe_product_id = product.id
|
|
await transaction.save()
|
|
|
|
if not transaction.stripe_price_id:
|
|
price = stripe.Price.create(
|
|
unit_amount=transaction.unit_amount,
|
|
currency=transaction.currency,
|
|
product=transaction.stripe_product_id,
|
|
)
|
|
transaction.stripe_price_id = price.id
|
|
await transaction.save()
|
|
|
|
session = stripe.checkout.Session.create(
|
|
payment_method_types=["card"],
|
|
line_items=[
|
|
{
|
|
"price": transaction.stripe_price_id,
|
|
"quantity": 1,
|
|
}
|
|
],
|
|
payment_intent_data={
|
|
"on_behalf_of": transaction.to_stripe_account_id,
|
|
"application_fee_amount": transaction.application_fee_amount,
|
|
"transfer_data": {
|
|
"destination": transaction.to_stripe_account_id,
|
|
},
|
|
},
|
|
mode="payment",
|
|
success_url="{}/work-space".format(
|
|
self.site_url_root
|
|
), # needs to be set, local: http://localhost/
|
|
cancel_url="{}/work-space".format(self.site_url_root),
|
|
)
|
|
|
|
if session:
|
|
transaction.stripe_checkout_session_id = session.id
|
|
transaction.stripe_checkout_session_url = session.url
|
|
transaction.updated_time = datetime.now(timezone.utc)
|
|
await transaction.save()
|
|
|
|
return session.id, session.url
|
|
else:
|
|
return None, None
|
|
|
|
async def fetch_payment_link(self, transaction_id: str) -> Optional[str]:
|
|
transaction = await StripeTransactionDoc.get(transaction_id)
|
|
if transaction and transaction.stripe_payment_link:
|
|
return transaction.stripe_payment_link
|
|
|
|
return None
|
|
|
|
async def fetch_checkout_session_id(self, transaction_id: str) -> Optional[str]:
|
|
transaction = await StripeTransactionDoc.get(transaction_id)
|
|
if transaction and transaction.stripe_checkout_session_id:
|
|
return transaction.stripe_checkout_session_id
|
|
|
|
return None
|
|
|
|
async def fetch_checkout_session_url(self, transaction_id: str) -> Optional[str]:
|
|
transaction = await StripeTransactionDoc.get(transaction_id)
|
|
if transaction and transaction.stripe_checkout_session_url:
|
|
return transaction.stripe_checkout_session_url
|
|
|
|
return None
|
|
|
|
async def invoke_checkout_session_webhook(
|
|
self, payload: str, stripe_signature: str
|
|
) -> Tuple[bool, Optional[str], Optional[str]]:
|
|
try:
|
|
event = stripe.Webhook.construct_event(
|
|
payload, stripe_signature, app_settings.STRIPE_WEBHOOK_SECRET
|
|
)
|
|
except ValueError as e:
|
|
await self.module_logger.log_exception(exception=e, text="Invalid payload")
|
|
|
|
return False, None, None
|
|
except SignatureVerificationError as e:
|
|
await self.module_logger.log_exception(
|
|
exception=e, text="Invalid signature"
|
|
)
|
|
return False, None, None
|
|
|
|
# Handle the checkout.session.completed event
|
|
if event["type"] == "checkout.session.completed":
|
|
session = event["data"]["object"]
|
|
transaction = await self.__fetch_transaction_by_session_id(session.id)
|
|
if not transaction:
|
|
await self.module_logger.log_error(
|
|
error="Transaction not found for session_id: {}".format(session.id),
|
|
properties={"session_id": session.id},
|
|
)
|
|
return False
|
|
|
|
transaction.status = TransactionStatus.COMPLETED
|
|
transaction.updated_time = datetime.now(timezone.utc)
|
|
await transaction.save()
|
|
|
|
return True, transaction.project_id, transaction.milestone_index
|
|
|
|
return False, None, None
|