mirror of
https://github.com/TronoSfera/Law.git
synced 2026-05-18 18:13:46 +03:00
196 lines
7.4 KiB
Python
196 lines
7.4 KiB
Python
import os
|
|
import subprocess
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import psycopg
|
|
from sqlalchemy import create_engine, inspect, text
|
|
from sqlalchemy.engine import make_url
|
|
|
|
|
|
class MigrationTests(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
db_url_raw = os.getenv("DATABASE_URL", "")
|
|
if not db_url_raw.startswith("postgresql"):
|
|
raise unittest.SkipTest("Migration test requires PostgreSQL DATABASE_URL")
|
|
|
|
cls.project_root = Path(__file__).resolve().parents[1]
|
|
cls.base_url = make_url(db_url_raw)
|
|
cls.test_db_name = f"{cls.base_url.database}_migration_test"
|
|
cls.test_url = cls.base_url.set(database=cls.test_db_name)
|
|
cls.admin_url = cls.base_url.set(database="postgres")
|
|
|
|
cls._drop_create_database()
|
|
cls._run_alembic_upgrade()
|
|
|
|
cls.engine = create_engine(cls.test_url)
|
|
cls.inspector = inspect(cls.engine)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
if hasattr(cls, "engine"):
|
|
cls.engine.dispose()
|
|
if hasattr(cls, "admin_url") and hasattr(cls, "test_db_name"):
|
|
cls._drop_database()
|
|
|
|
@classmethod
|
|
def _to_psycopg_dsn(cls, url):
|
|
return url.render_as_string(hide_password=False).replace("+psycopg", "")
|
|
|
|
@classmethod
|
|
def _drop_create_database(cls):
|
|
dsn = cls._to_psycopg_dsn(cls.admin_url)
|
|
with psycopg.connect(dsn, autocommit=True) as conn:
|
|
conn.execute(
|
|
"SELECT pg_terminate_backend(pid) "
|
|
"FROM pg_stat_activity "
|
|
"WHERE datname = %s AND pid <> pg_backend_pid()",
|
|
(cls.test_db_name,),
|
|
)
|
|
conn.execute(f'DROP DATABASE IF EXISTS "{cls.test_db_name}"')
|
|
conn.execute(f'CREATE DATABASE "{cls.test_db_name}"')
|
|
|
|
@classmethod
|
|
def _drop_database(cls):
|
|
dsn = cls._to_psycopg_dsn(cls.admin_url)
|
|
with psycopg.connect(dsn, autocommit=True) as conn:
|
|
conn.execute(
|
|
"SELECT pg_terminate_backend(pid) "
|
|
"FROM pg_stat_activity "
|
|
"WHERE datname = %s AND pid <> pg_backend_pid()",
|
|
(cls.test_db_name,),
|
|
)
|
|
conn.execute(f'DROP DATABASE IF EXISTS "{cls.test_db_name}"')
|
|
|
|
@classmethod
|
|
def _run_alembic_upgrade(cls):
|
|
env = os.environ.copy()
|
|
env["DATABASE_URL"] = cls.test_url.render_as_string(hide_password=False)
|
|
env["PYTHONPATH"] = str(cls.project_root)
|
|
subprocess.run(
|
|
["alembic", "upgrade", "head"],
|
|
cwd=cls.project_root,
|
|
env=env,
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
|
|
def test_upgrade_head_creates_expected_tables(self):
|
|
expected = {
|
|
"admin_users",
|
|
"topics",
|
|
"statuses",
|
|
"form_fields",
|
|
"topic_required_fields",
|
|
"topic_data_templates",
|
|
"request_data_requirements",
|
|
"requests",
|
|
"messages",
|
|
"attachments",
|
|
"status_history",
|
|
"audit_log",
|
|
"otp_sessions",
|
|
"quotes",
|
|
"admin_user_topics",
|
|
"topic_status_transitions",
|
|
"notifications",
|
|
"invoices",
|
|
"security_audit_log",
|
|
"alembic_version",
|
|
}
|
|
tables = set(self.inspector.get_table_names())
|
|
self.assertTrue(expected.issubset(tables), f"Missing tables: {expected - tables}")
|
|
|
|
def test_alembic_version_is_set(self):
|
|
with self.engine.connect() as conn:
|
|
version = conn.execute(text("SELECT version_num FROM alembic_version")).scalar_one()
|
|
self.assertEqual(version, "0014_security_audit_log")
|
|
|
|
def test_responsible_column_exists_in_all_domain_tables(self):
|
|
tables = {
|
|
"admin_users",
|
|
"topics",
|
|
"statuses",
|
|
"form_fields",
|
|
"topic_required_fields",
|
|
"topic_data_templates",
|
|
"request_data_requirements",
|
|
"requests",
|
|
"messages",
|
|
"attachments",
|
|
"status_history",
|
|
"audit_log",
|
|
"otp_sessions",
|
|
"quotes",
|
|
"admin_user_topics",
|
|
"topic_status_transitions",
|
|
"notifications",
|
|
"invoices",
|
|
"security_audit_log",
|
|
}
|
|
for table in tables:
|
|
columns = {column["name"] for column in self.inspector.get_columns(table)}
|
|
self.assertIn("id", columns)
|
|
self.assertIn("created_at", columns)
|
|
self.assertIn("responsible", columns)
|
|
|
|
def test_admin_users_contains_primary_topic_profile_column(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("admin_users")}
|
|
self.assertIn("primary_topic_code", columns)
|
|
|
|
def test_admin_users_contains_avatar_column(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("admin_users")}
|
|
self.assertIn("avatar_url", columns)
|
|
|
|
def test_requests_contains_read_marker_columns(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("requests")}
|
|
self.assertIn("client_has_unread_updates", columns)
|
|
self.assertIn("client_unread_event_type", columns)
|
|
self.assertIn("lawyer_has_unread_updates", columns)
|
|
self.assertIn("lawyer_unread_event_type", columns)
|
|
|
|
def test_status_transitions_contains_sla_hours_column(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("topic_status_transitions")}
|
|
self.assertIn("sla_hours", columns)
|
|
|
|
def test_notifications_has_recipient_and_read_columns(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("notifications")}
|
|
self.assertIn("recipient_type", columns)
|
|
self.assertIn("recipient_admin_user_id", columns)
|
|
self.assertIn("recipient_track_number", columns)
|
|
self.assertIn("event_type", columns)
|
|
self.assertIn("is_read", columns)
|
|
self.assertIn("read_at", columns)
|
|
|
|
def test_admin_users_contains_rate_columns(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("admin_users")}
|
|
self.assertIn("default_rate", columns)
|
|
self.assertIn("salary_percent", columns)
|
|
|
|
def test_requests_contains_financial_columns(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("requests")}
|
|
self.assertIn("effective_rate", columns)
|
|
self.assertIn("invoice_amount", columns)
|
|
self.assertIn("paid_at", columns)
|
|
self.assertIn("paid_by_admin_id", columns)
|
|
|
|
def test_invoices_contains_core_columns(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("invoices")}
|
|
self.assertIn("request_id", columns)
|
|
self.assertIn("invoice_number", columns)
|
|
self.assertIn("status", columns)
|
|
self.assertIn("amount", columns)
|
|
self.assertIn("currency", columns)
|
|
self.assertIn("payer_display_name", columns)
|
|
self.assertIn("payer_details_encrypted", columns)
|
|
self.assertIn("issued_by_admin_user_id", columns)
|
|
self.assertIn("issued_by_role", columns)
|
|
self.assertIn("issued_at", columns)
|
|
self.assertIn("paid_at", columns)
|
|
|
|
def test_statuses_contains_billing_columns(self):
|
|
columns = {column["name"] for column in self.inspector.get_columns("statuses")}
|
|
self.assertIn("kind", columns)
|
|
self.assertIn("invoice_template", columns)
|