mirror of
https://github.com/TronoSfera/Law.git
synced 2026-05-18 10:03:45 +03:00
44 lines
1.5 KiB
Python
44 lines
1.5 KiB
Python
from __future__ import annotations
|
||
|
||
from urllib.parse import urlsplit
|
||
|
||
from fastapi import HTTPException, Request
|
||
|
||
from app.core.config import settings
|
||
|
||
|
||
def _origin_from_header(value: str | None) -> str:
|
||
raw = str(value or "").strip()
|
||
if not raw:
|
||
return ""
|
||
parts = urlsplit(raw)
|
||
if not parts.scheme or not parts.netloc:
|
||
return ""
|
||
return f"{parts.scheme.lower()}://{parts.netloc.lower()}".rstrip("/")
|
||
|
||
|
||
def _sec_fetch_site(value: str | None) -> str:
|
||
return str(value or "").strip().lower()
|
||
|
||
|
||
def enforce_public_origin_or_403(request: Request, *, endpoint: str) -> None:
|
||
if not bool(getattr(settings, "PUBLIC_STRICT_ORIGIN_CHECK", True)):
|
||
return
|
||
if not bool(getattr(settings, "app_env_is_production", False)):
|
||
return
|
||
|
||
fetch_site = _sec_fetch_site(request.headers.get("sec-fetch-site"))
|
||
if fetch_site == "cross-site":
|
||
raise HTTPException(status_code=403, detail=f"Forbidden origin for {endpoint}")
|
||
|
||
allowed = set(settings.public_allowed_web_origins_list)
|
||
if not allowed:
|
||
raise HTTPException(status_code=500, detail="Не настроен список разрешенных public-origin")
|
||
|
||
origin = _origin_from_header(request.headers.get("origin"))
|
||
if not origin:
|
||
origin = _origin_from_header(request.headers.get("referer"))
|
||
if not origin:
|
||
raise HTTPException(status_code=403, detail=f"Forbidden origin for {endpoint}")
|
||
if origin not in allowed:
|
||
raise HTTPException(status_code=403, detail=f"Forbidden origin for {endpoint}")
|