mirror of
https://github.com/TronoSfera/Law.git
synced 2026-05-19 02:23:45 +03:00
87 lines
2.5 KiB
Python
87 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from threading import Lock
|
|
from typing import Protocol
|
|
|
|
import redis
|
|
|
|
from app.core.config import settings
|
|
|
|
_LOG = logging.getLogger("app.rate_limit")
|
|
|
|
|
|
@dataclass
|
|
class RateLimitResult:
|
|
allowed: bool
|
|
retry_after_seconds: int
|
|
current_value: int
|
|
|
|
|
|
class RateLimiter(Protocol):
|
|
def hit(self, key: str, *, limit: int, window_seconds: int) -> RateLimitResult:
|
|
...
|
|
|
|
|
|
class InMemoryRateLimiter:
|
|
def __init__(self):
|
|
self._data: dict[str, tuple[int, datetime]] = {}
|
|
self._lock = Lock()
|
|
|
|
def hit(self, key: str, *, limit: int, window_seconds: int) -> RateLimitResult:
|
|
now = datetime.now(timezone.utc)
|
|
with self._lock:
|
|
count, expires_at = self._data.get(key, (0, now))
|
|
if expires_at <= now:
|
|
count = 0
|
|
expires_at = now + timedelta(seconds=max(int(window_seconds), 1))
|
|
count += 1
|
|
self._data[key] = (count, expires_at)
|
|
retry_after = max(0, int((expires_at - now).total_seconds()))
|
|
return RateLimitResult(allowed=count <= limit, retry_after_seconds=retry_after, current_value=count)
|
|
|
|
|
|
class RedisRateLimiter:
|
|
def __init__(self, client: redis.Redis):
|
|
self.client = client
|
|
|
|
def hit(self, key: str, *, limit: int, window_seconds: int) -> RateLimitResult:
|
|
count = int(self.client.incr(key))
|
|
if count == 1:
|
|
self.client.expire(key, int(max(window_seconds, 1)))
|
|
ttl = int(self.client.ttl(key))
|
|
if ttl < 0:
|
|
ttl = int(max(window_seconds, 1))
|
|
return RateLimitResult(allowed=count <= limit, retry_after_seconds=ttl, current_value=count)
|
|
|
|
|
|
_cached_limiter: RateLimiter | None = None
|
|
|
|
|
|
def _build_limiter() -> RateLimiter:
|
|
try:
|
|
client = redis.Redis.from_url(
|
|
settings.REDIS_URL,
|
|
decode_responses=True,
|
|
socket_timeout=0.4,
|
|
socket_connect_timeout=0.4,
|
|
)
|
|
client.ping()
|
|
return RedisRateLimiter(client)
|
|
except Exception:
|
|
_LOG.warning("Redis limiter unavailable; fallback to in-memory limiter")
|
|
return InMemoryRateLimiter()
|
|
|
|
|
|
def get_rate_limiter() -> RateLimiter:
|
|
global _cached_limiter
|
|
if _cached_limiter is None:
|
|
_cached_limiter = _build_limiter()
|
|
return _cached_limiter
|
|
|
|
|
|
def reset_rate_limiter_for_tests() -> None:
|
|
global _cached_limiter
|
|
_cached_limiter = None
|