mirror of
https://github.com/TronoSfera/Law.git
synced 2026-05-18 18:13:46 +03:00
300 lines
11 KiB
Python
300 lines
11 KiB
Python
from __future__ import annotations
|
||
|
||
import base64
|
||
import hashlib
|
||
import hmac
|
||
import secrets
|
||
from typing import Any
|
||
|
||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||
|
||
from app.services.crypto_keyring import get_chat_secrets, key_digest, ordered_unique_key_digests
|
||
|
||
_VERSION_LEGACY = b"v1"
|
||
_PREFIX_LEGACY = "chatenc:v1:"
|
||
_PREFIX_V2 = "chatenc:v2:"
|
||
_PREFIX_V3 = "chatenc:v3:"
|
||
_CHAT_CRYPTO_EXTRA_FIELDS_KEY = "chat_crypto"
|
||
|
||
|
||
def _xor_bytes(a: bytes, b: bytes) -> bytes:
|
||
return bytes(x ^ y for x, y in zip(a, b))
|
||
|
||
|
||
def _urlsafe_b64encode(value: bytes) -> str:
|
||
return base64.urlsafe_b64encode(value).decode("ascii")
|
||
|
||
|
||
def _urlsafe_b64decode(value: str) -> bytes:
|
||
return base64.urlsafe_b64decode(str(value or "").encode("ascii"))
|
||
|
||
|
||
def _aad_v2(kid: str) -> bytes:
|
||
return b"v2|" + str(kid).encode("utf-8") + b"|"
|
||
|
||
|
||
def _aad_v3_message(kid: str) -> bytes:
|
||
return b"v3|message|" + str(kid).encode("utf-8") + b"|"
|
||
|
||
|
||
def _aad_v3_wrapped_key(kid: str) -> bytes:
|
||
return b"v3|chat-key|" + str(kid).encode("utf-8") + b"|"
|
||
|
||
|
||
def active_chat_kid() -> str:
|
||
active_kid, _ = get_chat_secrets()
|
||
return active_kid
|
||
|
||
|
||
def _active_chat_secret() -> tuple[str, str]:
|
||
active_kid, key_map = get_chat_secrets()
|
||
active_secret = key_map.get(active_kid)
|
||
if not active_secret and key_map:
|
||
active_secret = next(iter(key_map.values()))
|
||
if not active_secret:
|
||
raise ValueError("Не найден активный ключ шифрования чата")
|
||
return active_kid, active_secret
|
||
|
||
|
||
def _chat_payload_or_none(extra_fields: dict[str, Any] | None) -> dict[str, Any] | None:
|
||
payload = (extra_fields or {}).get(_CHAT_CRYPTO_EXTRA_FIELDS_KEY)
|
||
return payload if isinstance(payload, dict) else None
|
||
|
||
|
||
def _wrap_chat_key(chat_key: bytes, *, kid: str, secret: str) -> dict[str, Any]:
|
||
nonce = secrets.token_bytes(12)
|
||
payload = AESGCM(key_digest(secret)).encrypt(nonce, chat_key, _aad_v3_wrapped_key(kid))
|
||
return {
|
||
"version": 1,
|
||
"kek_kid": str(kid),
|
||
"nonce": _urlsafe_b64encode(nonce),
|
||
"wrapped_key": _urlsafe_b64encode(payload),
|
||
}
|
||
|
||
|
||
def _unwrap_chat_key(payload: dict[str, Any], *, key_map: dict[str, str]) -> tuple[bytes, str]:
|
||
if int(payload.get("version") or 0) != 1:
|
||
raise ValueError("Неподдерживаемая версия ключа чата")
|
||
kid = str(payload.get("kek_kid") or "").strip()
|
||
nonce = _urlsafe_b64decode(str(payload.get("nonce") or ""))
|
||
wrapped_key = _urlsafe_b64decode(str(payload.get("wrapped_key") or ""))
|
||
if len(nonce) != 12 or not wrapped_key:
|
||
raise ValueError("Некорректный формат ключа чата")
|
||
|
||
candidate_secrets: list[tuple[str, str]] = []
|
||
if kid and kid in key_map:
|
||
candidate_secrets.append((kid, key_map[kid]))
|
||
for fallback_kid, secret in key_map.items():
|
||
if kid and fallback_kid == kid:
|
||
continue
|
||
candidate_secrets.append((fallback_kid, secret))
|
||
|
||
for candidate_kid, secret in candidate_secrets:
|
||
try:
|
||
plaintext = AESGCM(key_digest(secret)).decrypt(nonce, wrapped_key, _aad_v3_wrapped_key(kid or candidate_kid))
|
||
except Exception:
|
||
continue
|
||
if len(plaintext) not in {16, 24, 32}:
|
||
raise ValueError("Некорректная длина ключа чата")
|
||
return plaintext, (kid or candidate_kid)
|
||
raise ValueError("Не удалось расшифровать ключ чата")
|
||
|
||
|
||
def extract_message_kid(value: str | None) -> str | None:
|
||
token = str(value or "").strip()
|
||
if not token:
|
||
return None
|
||
if token.startswith(_PREFIX_V2) or token.startswith(_PREFIX_V3):
|
||
parts = token.split(":", 3)
|
||
if len(parts) != 4:
|
||
return None
|
||
kid = str(parts[2] or "").strip()
|
||
return kid or None
|
||
return None
|
||
|
||
|
||
def is_encrypted_message(value: str | None) -> bool:
|
||
token = str(value or "").strip()
|
||
return token.startswith(_PREFIX_LEGACY) or token.startswith(_PREFIX_V2) or token.startswith(_PREFIX_V3)
|
||
|
||
|
||
def prepare_request_chat_crypto(extra_fields: dict[str, Any] | None) -> tuple[dict[str, Any], bytes, bool]:
|
||
active_kid, key_map = get_chat_secrets()
|
||
updated = dict(extra_fields or {})
|
||
payload = _chat_payload_or_none(updated)
|
||
chat_key: bytes | None = None
|
||
payload_kid = active_kid
|
||
changed = False
|
||
|
||
if payload:
|
||
try:
|
||
chat_key, payload_kid = _unwrap_chat_key(payload, key_map=key_map)
|
||
except Exception:
|
||
chat_key = None
|
||
|
||
if chat_key is None:
|
||
chat_key = secrets.token_bytes(32)
|
||
changed = True
|
||
|
||
if changed or payload_kid != active_kid or payload != _chat_payload_or_none(updated):
|
||
active_secret = key_map.get(active_kid)
|
||
if not active_secret:
|
||
raise ValueError("Не найден активный ключ шифрования чата")
|
||
updated[_CHAT_CRYPTO_EXTRA_FIELDS_KEY] = _wrap_chat_key(chat_key, kid=active_kid, secret=active_secret)
|
||
changed = True
|
||
|
||
return updated, chat_key, changed
|
||
|
||
|
||
def _request_chat_key(extra_fields: dict[str, Any] | None) -> tuple[bytes, str]:
|
||
payload = _chat_payload_or_none(extra_fields)
|
||
if not payload:
|
||
raise ValueError("Не найден ключ шифрования чата для заявки")
|
||
key_map = get_chat_secrets()[1]
|
||
chat_key, payload_kid = _unwrap_chat_key(payload, key_map=key_map)
|
||
return chat_key, payload_kid
|
||
|
||
|
||
def encrypt_message_body(value: str | None) -> str | None:
|
||
if value is None:
|
||
return None
|
||
text = str(value)
|
||
if not text or is_encrypted_message(text):
|
||
return text
|
||
|
||
active_kid, active_secret = _active_chat_secret()
|
||
key = key_digest(active_secret)
|
||
|
||
raw = text.encode("utf-8")
|
||
nonce = secrets.token_bytes(16)
|
||
stream = hashlib.pbkdf2_hmac("sha256", key, nonce, 120_000, dklen=len(raw))
|
||
cipher = _xor_bytes(raw, stream)
|
||
tag = hmac.new(key, _aad_v2(active_kid) + nonce + cipher, hashlib.sha256).digest()
|
||
blob = nonce + tag + cipher
|
||
return f"{_PREFIX_V2}{active_kid}:" + _urlsafe_b64encode(blob)
|
||
|
||
|
||
def encrypt_message_body_for_request(
|
||
value: str | None,
|
||
*,
|
||
request_extra_fields: dict[str, Any] | None,
|
||
) -> tuple[str | None, dict[str, Any], bool]:
|
||
if value is None:
|
||
return None, dict(request_extra_fields or {}), False
|
||
text = str(value)
|
||
if not text or is_encrypted_message(text):
|
||
return text, dict(request_extra_fields or {}), False
|
||
|
||
updated_extra_fields, chat_key, changed = prepare_request_chat_crypto(request_extra_fields)
|
||
kid = str(extract_request_chat_kek_kid(updated_extra_fields) or active_chat_kid())
|
||
nonce = secrets.token_bytes(12)
|
||
cipher = AESGCM(chat_key).encrypt(nonce, text.encode("utf-8"), _aad_v3_message(kid))
|
||
return f"{_PREFIX_V3}{kid}:" + _urlsafe_b64encode(nonce + cipher), updated_extra_fields, changed
|
||
|
||
|
||
def extract_request_chat_kek_kid(extra_fields: dict[str, Any] | None) -> str | None:
|
||
payload = _chat_payload_or_none(extra_fields)
|
||
if not payload:
|
||
return None
|
||
kid = str(payload.get("kek_kid") or "").strip()
|
||
return kid or None
|
||
|
||
|
||
def _decrypt_v2(encoded: str, *, kid: str, key: bytes) -> str:
|
||
blob = _urlsafe_b64decode(encoded)
|
||
if len(blob) < 16 + 32:
|
||
raise ValueError("Некорректный зашифрованный формат сообщения")
|
||
nonce = blob[:16]
|
||
tag = blob[16:48]
|
||
cipher = blob[48:]
|
||
expected = hmac.new(key, _aad_v2(kid) + nonce + cipher, hashlib.sha256).digest()
|
||
if not hmac.compare_digest(tag, expected):
|
||
raise ValueError("Поврежденные данные сообщения")
|
||
stream = hashlib.pbkdf2_hmac("sha256", key, nonce, 120_000, dklen=len(cipher))
|
||
raw = _xor_bytes(cipher, stream)
|
||
return raw.decode("utf-8")
|
||
|
||
|
||
def _decrypt_v3(encoded: str, *, kid: str, request_extra_fields: dict[str, Any] | None) -> str:
|
||
chat_key, _ = _request_chat_key(request_extra_fields)
|
||
blob = _urlsafe_b64decode(encoded)
|
||
if len(blob) <= 12:
|
||
raise ValueError("Некорректный зашифрованный формат сообщения")
|
||
nonce = blob[:12]
|
||
cipher = blob[12:]
|
||
raw = AESGCM(chat_key).decrypt(nonce, cipher, _aad_v3_message(kid))
|
||
return raw.decode("utf-8")
|
||
|
||
|
||
def _decrypt_legacy(encoded: str, keys: list[bytes]) -> str:
|
||
blob = _urlsafe_b64decode(encoded)
|
||
if len(blob) < 2 + 16 + 32:
|
||
raise ValueError("Некорректный зашифрованный формат сообщения")
|
||
version = blob[:2]
|
||
nonce = blob[2:18]
|
||
tag = blob[18:50]
|
||
cipher = blob[50:]
|
||
if version != _VERSION_LEGACY:
|
||
raise ValueError("Неподдерживаемая версия шифрования чата")
|
||
|
||
for key in keys:
|
||
expected = hmac.new(key, version + nonce + cipher, hashlib.sha256).digest()
|
||
if not hmac.compare_digest(tag, expected):
|
||
continue
|
||
stream = hashlib.pbkdf2_hmac("sha256", key, nonce, 120_000, dklen=len(cipher))
|
||
raw = _xor_bytes(cipher, stream)
|
||
return raw.decode("utf-8")
|
||
|
||
raise ValueError("Поврежденные данные сообщения")
|
||
|
||
|
||
def decrypt_message_body(value: str | None) -> str | None:
|
||
if value is None:
|
||
return None
|
||
text = str(value)
|
||
if not text:
|
||
return text
|
||
if not is_encrypted_message(text):
|
||
return text
|
||
|
||
active_kid, key_map = get_chat_secrets()
|
||
_ = active_kid
|
||
if text.startswith(_PREFIX_V3):
|
||
raise ValueError("Для сообщений v3 требуется контекст заявки")
|
||
if text.startswith(_PREFIX_V2):
|
||
encoded = text[len(_PREFIX_V2) :]
|
||
parts = encoded.split(":", 1)
|
||
if len(parts) != 2:
|
||
raise ValueError("Некорректный зашифрованный формат сообщения")
|
||
kid, payload = str(parts[0] or "").strip(), parts[1]
|
||
if kid in key_map:
|
||
return _decrypt_v2(payload, kid=kid, key=key_digest(key_map[kid]))
|
||
for fallback_key in ordered_unique_key_digests(key_map.values()):
|
||
try:
|
||
return _decrypt_v2(payload, kid=kid, key=fallback_key)
|
||
except Exception:
|
||
continue
|
||
raise ValueError("Неподдерживаемый идентификатор ключа шифрования")
|
||
|
||
encoded = text[len(_PREFIX_LEGACY) :]
|
||
return _decrypt_legacy(encoded, ordered_unique_key_digests(key_map.values()))
|
||
|
||
|
||
def decrypt_message_body_for_request(
|
||
value: str | None,
|
||
*,
|
||
request_extra_fields: dict[str, Any] | None,
|
||
) -> str | None:
|
||
if value is None:
|
||
return None
|
||
text = str(value)
|
||
if not text or not is_encrypted_message(text):
|
||
return text
|
||
if text.startswith(_PREFIX_V3):
|
||
encoded = text[len(_PREFIX_V3) :]
|
||
parts = encoded.split(":", 1)
|
||
if len(parts) != 2:
|
||
raise ValueError("Некорректный зашифрованный формат сообщения")
|
||
kid, payload = str(parts[0] or "").strip(), parts[1]
|
||
return _decrypt_v3(payload, kid=kid, request_extra_fields=request_extra_fields)
|
||
return decrypt_message_body(text)
|