Law/tests/test_migrations.py
2026-02-22 10:57:49 +03:00

101 lines
3.4 KiB
Python

import os
import subprocess
import unittest
from pathlib import Path
import psycopg
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.engine import make_url
class MigrationTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
db_url_raw = os.getenv("DATABASE_URL", "")
if not db_url_raw.startswith("postgresql"):
raise unittest.SkipTest("Migration test requires PostgreSQL DATABASE_URL")
cls.project_root = Path(__file__).resolve().parents[1]
cls.base_url = make_url(db_url_raw)
cls.test_db_name = f"{cls.base_url.database}_migration_test"
cls.test_url = cls.base_url.set(database=cls.test_db_name)
cls.admin_url = cls.base_url.set(database="postgres")
cls._drop_create_database()
cls._run_alembic_upgrade()
cls.engine = create_engine(cls.test_url)
cls.inspector = inspect(cls.engine)
@classmethod
def tearDownClass(cls):
if hasattr(cls, "engine"):
cls.engine.dispose()
if hasattr(cls, "admin_url") and hasattr(cls, "test_db_name"):
cls._drop_database()
@classmethod
def _to_psycopg_dsn(cls, url):
return url.render_as_string(hide_password=False).replace("+psycopg", "")
@classmethod
def _drop_create_database(cls):
dsn = cls._to_psycopg_dsn(cls.admin_url)
with psycopg.connect(dsn, autocommit=True) as conn:
conn.execute(
"SELECT pg_terminate_backend(pid) "
"FROM pg_stat_activity "
"WHERE datname = %s AND pid <> pg_backend_pid()",
(cls.test_db_name,),
)
conn.execute(f'DROP DATABASE IF EXISTS "{cls.test_db_name}"')
conn.execute(f'CREATE DATABASE "{cls.test_db_name}"')
@classmethod
def _drop_database(cls):
dsn = cls._to_psycopg_dsn(cls.admin_url)
with psycopg.connect(dsn, autocommit=True) as conn:
conn.execute(
"SELECT pg_terminate_backend(pid) "
"FROM pg_stat_activity "
"WHERE datname = %s AND pid <> pg_backend_pid()",
(cls.test_db_name,),
)
conn.execute(f'DROP DATABASE IF EXISTS "{cls.test_db_name}"')
@classmethod
def _run_alembic_upgrade(cls):
env = os.environ.copy()
env["DATABASE_URL"] = cls.test_url.render_as_string(hide_password=False)
env["PYTHONPATH"] = str(cls.project_root)
subprocess.run(
["alembic", "upgrade", "head"],
cwd=cls.project_root,
env=env,
check=True,
capture_output=True,
text=True,
)
def test_upgrade_head_creates_expected_tables(self):
expected = {
"admin_users",
"topics",
"statuses",
"form_fields",
"requests",
"messages",
"attachments",
"status_history",
"audit_log",
"otp_sessions",
"quotes",
"alembic_version",
}
tables = set(self.inspector.get_table_names())
self.assertTrue(expected.issubset(tables), f"Missing tables: {expected - tables}")
def test_alembic_version_is_set(self):
with self.engine.connect() as conn:
version = conn.execute(text("SELECT version_num FROM alembic_version")).scalar_one()
self.assertEqual(version, "0001_init")