mirror of
https://github.com/TronoSfera/Law.git
synced 2026-05-18 18:13:46 +03:00
440 lines
17 KiB
Python
440 lines
17 KiB
Python
import os
|
||
import unittest
|
||
from datetime import timedelta
|
||
from uuid import UUID, uuid4
|
||
from unittest.mock import patch
|
||
|
||
from fastapi.testclient import TestClient
|
||
from sqlalchemy import create_engine, delete
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
os.environ.setdefault("DATABASE_URL", "sqlite+pysqlite:///:memory:")
|
||
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/0")
|
||
os.environ.setdefault("S3_ENDPOINT", "http://localhost:9000")
|
||
os.environ.setdefault("S3_ACCESS_KEY", "test")
|
||
os.environ.setdefault("S3_SECRET_KEY", "test")
|
||
os.environ.setdefault("S3_BUCKET", "test")
|
||
|
||
from app.core.config import settings
|
||
from app.core.security import create_jwt
|
||
from app.db.session import get_db
|
||
from app.main import app
|
||
from app.models.admin_user import AdminUser
|
||
from app.models.attachment import Attachment
|
||
from app.models.invoice import Invoice
|
||
from app.models.message import Message
|
||
from app.models.notification import Notification
|
||
from app.models.request import Request
|
||
from app.models.status import Status
|
||
from app.models.status_history import StatusHistory
|
||
from app.models.topic_status_transition import TopicStatusTransition
|
||
from app.services.invoice_crypto import decrypt_requisites
|
||
|
||
|
||
class _FakeS3Storage:
|
||
def __init__(self):
|
||
self.objects = {}
|
||
|
||
|
||
class BillingFlowTests(unittest.TestCase):
|
||
@classmethod
|
||
def setUpClass(cls):
|
||
cls.engine = create_engine(
|
||
"sqlite+pysqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
cls.SessionLocal = sessionmaker(bind=cls.engine, autocommit=False, autoflush=False)
|
||
|
||
AdminUser.__table__.create(bind=cls.engine)
|
||
Status.__table__.create(bind=cls.engine)
|
||
Request.__table__.create(bind=cls.engine)
|
||
Message.__table__.create(bind=cls.engine)
|
||
Attachment.__table__.create(bind=cls.engine)
|
||
StatusHistory.__table__.create(bind=cls.engine)
|
||
Notification.__table__.create(bind=cls.engine)
|
||
Invoice.__table__.create(bind=cls.engine)
|
||
TopicStatusTransition.__table__.create(bind=cls.engine)
|
||
|
||
@classmethod
|
||
def tearDownClass(cls):
|
||
Invoice.__table__.drop(bind=cls.engine)
|
||
TopicStatusTransition.__table__.drop(bind=cls.engine)
|
||
Notification.__table__.drop(bind=cls.engine)
|
||
StatusHistory.__table__.drop(bind=cls.engine)
|
||
Attachment.__table__.drop(bind=cls.engine)
|
||
Message.__table__.drop(bind=cls.engine)
|
||
Request.__table__.drop(bind=cls.engine)
|
||
Status.__table__.drop(bind=cls.engine)
|
||
AdminUser.__table__.drop(bind=cls.engine)
|
||
cls.engine.dispose()
|
||
|
||
def setUp(self):
|
||
with self.SessionLocal() as db:
|
||
db.execute(delete(Invoice))
|
||
db.execute(delete(Notification))
|
||
db.execute(delete(StatusHistory))
|
||
db.execute(delete(TopicStatusTransition))
|
||
db.execute(delete(Attachment))
|
||
db.execute(delete(Message))
|
||
db.execute(delete(Request))
|
||
db.execute(delete(Status))
|
||
db.execute(delete(AdminUser))
|
||
db.commit()
|
||
|
||
def override_get_db():
|
||
db = self.SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
self.client = TestClient(app)
|
||
self.fake_s3 = _FakeS3Storage()
|
||
self.s3_patch = patch("app.services.invoice_chat.get_s3_storage", return_value=self.fake_s3)
|
||
self.s3_patch.start()
|
||
|
||
def tearDown(self):
|
||
self.client.close()
|
||
self.s3_patch.stop()
|
||
app.dependency_overrides.clear()
|
||
|
||
@staticmethod
|
||
def _auth_headers(role: str, email: str, sub: str | None = None) -> dict[str, str]:
|
||
token = create_jwt(
|
||
{"sub": str(sub or uuid4()), "email": email, "role": role},
|
||
settings.ADMIN_JWT_SECRET,
|
||
timedelta(minutes=30),
|
||
)
|
||
return {"Authorization": f"Bearer {token}"}
|
||
|
||
def _seed_statuses(self):
|
||
with self.SessionLocal() as db:
|
||
db.add_all(
|
||
[
|
||
Status(code="NEW", name="Новая", enabled=True, sort_order=0, is_terminal=False, kind="DEFAULT"),
|
||
Status(
|
||
code="BILLING",
|
||
name="Выставление счета",
|
||
enabled=True,
|
||
sort_order=1,
|
||
is_terminal=False,
|
||
kind="INVOICE",
|
||
invoice_template="Счет по заявке {track_number}; клиент {client_name}; сумма {amount}",
|
||
),
|
||
Status(code="IN_PROGRESS", name="В работе", enabled=True, sort_order=2, is_terminal=False, kind="DEFAULT"),
|
||
Status(code="PAID", name="Оплачено", enabled=True, sort_order=3, is_terminal=False, kind="PAID"),
|
||
]
|
||
)
|
||
db.commit()
|
||
|
||
def test_entering_billing_status_creates_waiting_invoice_from_template(self):
|
||
self._seed_statuses()
|
||
with self.SessionLocal() as db:
|
||
req = Request(
|
||
track_number="TRK-BILL-1",
|
||
client_name="ООО Клиент",
|
||
client_phone="+79990000021",
|
||
status_code="NEW",
|
||
topic_code=None,
|
||
description="billing",
|
||
extra_fields={},
|
||
effective_rate=4300,
|
||
)
|
||
db.add(req)
|
||
db.commit()
|
||
request_id = str(req.id)
|
||
|
||
admin_headers = self._auth_headers("ADMIN", "root@example.com")
|
||
changed = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "BILLING"},
|
||
)
|
||
self.assertEqual(changed.status_code, 200)
|
||
|
||
with self.SessionLocal() as db:
|
||
req = db.get(Request, UUID(request_id))
|
||
self.assertIsNotNone(req)
|
||
self.assertEqual(req.status_code, "BILLING")
|
||
self.assertAlmostEqual(float(req.invoice_amount or 0), 4300.0, places=2)
|
||
|
||
rows = db.query(Invoice).filter(Invoice.request_id == req.id).all()
|
||
self.assertEqual(len(rows), 1)
|
||
invoice = rows[0]
|
||
self.assertEqual(invoice.status, "WAITING_PAYMENT")
|
||
self.assertEqual(invoice.payer_display_name, "ООО Клиент")
|
||
self.assertAlmostEqual(float(invoice.amount or 0), 4300.0, places=2)
|
||
details = decrypt_requisites(invoice.payer_details_encrypted)
|
||
rendered = str((details or {}).get("template_rendered") or "")
|
||
self.assertIn("TRK-BILL-1", rendered)
|
||
self.assertIn("ООО Клиент", rendered)
|
||
|
||
def test_workflow_billing_invoice_contains_autofilled_requisites(self):
|
||
self._seed_statuses()
|
||
with self.SessionLocal() as db:
|
||
req = Request(
|
||
track_number="TRK-BILL-AUTO",
|
||
client_name="ООО Авто",
|
||
client_phone="+79990000111",
|
||
status_code="NEW",
|
||
topic_code="consulting",
|
||
description="auto requisites",
|
||
extra_fields={},
|
||
invoice_amount=12500.5,
|
||
)
|
||
db.add(req)
|
||
db.commit()
|
||
request_id = str(req.id)
|
||
|
||
admin_headers = self._auth_headers("ADMIN", "root@example.com")
|
||
changed = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "BILLING"},
|
||
)
|
||
self.assertEqual(changed.status_code, 200)
|
||
|
||
with self.SessionLocal() as db:
|
||
req = db.get(Request, UUID(request_id))
|
||
self.assertIsNotNone(req)
|
||
invoice = (
|
||
db.query(Invoice)
|
||
.filter(Invoice.request_id == req.id)
|
||
.order_by(Invoice.issued_at.desc(), Invoice.created_at.desc(), Invoice.id.desc())
|
||
.first()
|
||
)
|
||
self.assertIsNotNone(invoice)
|
||
details = decrypt_requisites(invoice.payer_details_encrypted)
|
||
self.assertEqual(details.get("request_track_number"), "TRK-BILL-AUTO")
|
||
self.assertEqual(details.get("topic_code"), "consulting")
|
||
rendered = str(details.get("template_rendered") or "")
|
||
self.assertTrue(rendered)
|
||
self.assertIn("TRK-BILL-AUTO", rendered)
|
||
self.assertIn("ООО Авто", rendered)
|
||
message = None
|
||
message_rows = (
|
||
db.query(Message)
|
||
.filter(Message.request_id == req.id)
|
||
.order_by(Message.created_at.desc(), Message.id.desc())
|
||
.all()
|
||
)
|
||
for item in message_rows:
|
||
if str(item.body or "").strip() == "Счет на оплату":
|
||
message = item
|
||
break
|
||
self.assertIsNotNone(message)
|
||
attachment = (
|
||
db.query(Attachment)
|
||
.filter(Attachment.request_id == req.id, Attachment.message_id == message.id)
|
||
.order_by(Attachment.created_at.desc(), Attachment.id.desc())
|
||
.first()
|
||
)
|
||
self.assertIsNotNone(attachment)
|
||
self.assertEqual(attachment.mime_type, "application/pdf")
|
||
self.assertIn(str(invoice.invoice_number), str(attachment.file_name))
|
||
self.assertGreater(int(req.total_attachments_bytes or 0), 0)
|
||
|
||
stored = self.fake_s3.objects.get(str(attachment.s3_key))
|
||
self.assertIsNotNone(stored)
|
||
self.assertEqual(stored.get("mime"), "application/pdf")
|
||
self.assertTrue(bytes(stored.get("content") or b"").startswith(b"%PDF"))
|
||
|
||
def test_paid_status_requires_admin_and_marks_waiting_invoice_paid(self):
|
||
self._seed_statuses()
|
||
with self.SessionLocal() as db:
|
||
lawyer = AdminUser(
|
||
role="LAWYER",
|
||
name="Юрист",
|
||
email="lawyer-paid@example.com",
|
||
password_hash="hash",
|
||
is_active=True,
|
||
)
|
||
req = Request(
|
||
track_number="TRK-BILL-2",
|
||
client_name="Клиент",
|
||
client_phone="+79990000022",
|
||
status_code="BILLING",
|
||
topic_code=None,
|
||
description="billing",
|
||
extra_fields={},
|
||
)
|
||
db.add_all([lawyer, req])
|
||
db.flush()
|
||
invoice = Invoice(
|
||
request_id=req.id,
|
||
invoice_number="INV-MANUAL-1",
|
||
status="WAITING_PAYMENT",
|
||
amount=7500,
|
||
currency="RUB",
|
||
payer_display_name=req.client_name,
|
||
payer_details_encrypted=None,
|
||
issued_by_admin_user_id=None,
|
||
issued_by_role="ADMIN",
|
||
issued_at=req.created_at,
|
||
paid_at=None,
|
||
responsible="root@example.com",
|
||
)
|
||
db.add(invoice)
|
||
db.commit()
|
||
request_id = str(req.id)
|
||
lawyer_id = str(lawyer.id)
|
||
invoice_id = str(invoice.id)
|
||
|
||
lawyer_headers = self._auth_headers("LAWYER", "lawyer-paid@example.com", sub=lawyer_id)
|
||
blocked = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=lawyer_headers,
|
||
json={"status_code": "PAID"},
|
||
)
|
||
self.assertEqual(blocked.status_code, 403)
|
||
|
||
admin_headers = self._auth_headers("ADMIN", "root@example.com")
|
||
paid = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "PAID"},
|
||
)
|
||
self.assertEqual(paid.status_code, 200)
|
||
|
||
with self.SessionLocal() as db:
|
||
req = db.get(Request, UUID(request_id))
|
||
inv = db.get(Invoice, UUID(invoice_id))
|
||
self.assertIsNotNone(req)
|
||
self.assertIsNotNone(inv)
|
||
self.assertEqual(inv.status, "PAID")
|
||
self.assertIsNotNone(inv.paid_at)
|
||
self.assertEqual(req.status_code, "PAID")
|
||
self.assertIsNotNone(req.paid_at)
|
||
self.assertEqual(str(req.paid_at), str(inv.paid_at))
|
||
self.assertIsNotNone(req.paid_by_admin_id)
|
||
self.assertAlmostEqual(float(req.invoice_amount or 0), 7500.0, places=2)
|
||
|
||
def test_paid_status_without_waiting_invoice_returns_400(self):
|
||
self._seed_statuses()
|
||
with self.SessionLocal() as db:
|
||
req = Request(
|
||
track_number="TRK-BILL-3",
|
||
client_name="Клиент",
|
||
client_phone="+79990000023",
|
||
status_code="IN_PROGRESS",
|
||
topic_code=None,
|
||
description="billing",
|
||
extra_fields={},
|
||
)
|
||
db.add(req)
|
||
db.commit()
|
||
request_id = str(req.id)
|
||
|
||
admin_headers = self._auth_headers("ADMIN", "root@example.com")
|
||
blocked = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "PAID"},
|
||
)
|
||
self.assertEqual(blocked.status_code, 400)
|
||
self.assertIn("Ожидает оплату", blocked.json().get("detail", ""))
|
||
|
||
def test_multiple_billing_cycles_are_supported(self):
|
||
self._seed_statuses()
|
||
with self.SessionLocal() as db:
|
||
req = Request(
|
||
track_number="TRK-BILL-4",
|
||
client_name="Клиент",
|
||
client_phone="+79990000024",
|
||
status_code="NEW",
|
||
topic_code=None,
|
||
description="billing",
|
||
extra_fields={},
|
||
effective_rate=1000,
|
||
)
|
||
db.add(req)
|
||
db.commit()
|
||
request_id = str(req.id)
|
||
|
||
admin_headers = self._auth_headers("ADMIN", "root@example.com")
|
||
|
||
first_billing = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "BILLING"},
|
||
)
|
||
self.assertEqual(first_billing.status_code, 200)
|
||
|
||
with self.SessionLocal() as db:
|
||
req = db.get(Request, UUID(request_id))
|
||
first_invoice = (
|
||
db.query(Invoice)
|
||
.filter(Invoice.request_id == req.id)
|
||
.order_by(Invoice.issued_at.desc(), Invoice.created_at.desc(), Invoice.id.desc())
|
||
.first()
|
||
)
|
||
self.assertIsNotNone(first_invoice)
|
||
first_invoice_id = str(first_invoice.id)
|
||
|
||
tune_first_amount = self.client.patch(
|
||
f"/api/admin/invoices/{first_invoice_id}",
|
||
headers=admin_headers,
|
||
json={"amount": 1100},
|
||
)
|
||
self.assertEqual(tune_first_amount.status_code, 200)
|
||
|
||
first_paid = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "PAID"},
|
||
)
|
||
self.assertEqual(first_paid.status_code, 200)
|
||
|
||
back_to_work = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "IN_PROGRESS"},
|
||
)
|
||
self.assertEqual(back_to_work.status_code, 200)
|
||
|
||
set_second_amount = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"invoice_amount": 2500},
|
||
)
|
||
self.assertEqual(set_second_amount.status_code, 200)
|
||
|
||
second_billing = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "BILLING"},
|
||
)
|
||
self.assertEqual(second_billing.status_code, 200)
|
||
|
||
second_paid = self.client.patch(
|
||
f"/api/admin/requests/{request_id}",
|
||
headers=admin_headers,
|
||
json={"status_code": "PAID"},
|
||
)
|
||
self.assertEqual(second_paid.status_code, 200)
|
||
|
||
with self.SessionLocal() as db:
|
||
req = db.get(Request, UUID(request_id))
|
||
self.assertIsNotNone(req)
|
||
invoices = (
|
||
db.query(Invoice)
|
||
.filter(Invoice.request_id == req.id)
|
||
.order_by(Invoice.issued_at.asc(), Invoice.created_at.asc(), Invoice.id.asc())
|
||
.all()
|
||
)
|
||
self.assertEqual(len(invoices), 2)
|
||
self.assertEqual(invoices[0].status, "PAID")
|
||
self.assertEqual(invoices[1].status, "PAID")
|
||
self.assertIsNotNone(invoices[0].paid_at)
|
||
self.assertIsNotNone(invoices[1].paid_at)
|
||
self.assertAlmostEqual(float(invoices[0].amount or 0), 1100.0, places=2)
|
||
self.assertAlmostEqual(float(invoices[1].amount or 0), 2500.0, places=2)
|
||
self.assertAlmostEqual(float(req.invoice_amount or 0), 2500.0, places=2)
|
||
self.assertEqual(str(req.paid_at), str(invoices[1].paid_at))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|