Law/app/services/rate_limit.py
2026-02-23 17:54:19 +03:00

87 lines
2.5 KiB
Python

from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from threading import Lock
from typing import Protocol
import redis
from app.core.config import settings
_LOG = logging.getLogger("app.rate_limit")
@dataclass
class RateLimitResult:
allowed: bool
retry_after_seconds: int
current_value: int
class RateLimiter(Protocol):
def hit(self, key: str, *, limit: int, window_seconds: int) -> RateLimitResult:
...
class InMemoryRateLimiter:
def __init__(self):
self._data: dict[str, tuple[int, datetime]] = {}
self._lock = Lock()
def hit(self, key: str, *, limit: int, window_seconds: int) -> RateLimitResult:
now = datetime.now(timezone.utc)
with self._lock:
count, expires_at = self._data.get(key, (0, now))
if expires_at <= now:
count = 0
expires_at = now + timedelta(seconds=max(int(window_seconds), 1))
count += 1
self._data[key] = (count, expires_at)
retry_after = max(0, int((expires_at - now).total_seconds()))
return RateLimitResult(allowed=count <= limit, retry_after_seconds=retry_after, current_value=count)
class RedisRateLimiter:
def __init__(self, client: redis.Redis):
self.client = client
def hit(self, key: str, *, limit: int, window_seconds: int) -> RateLimitResult:
count = int(self.client.incr(key))
if count == 1:
self.client.expire(key, int(max(window_seconds, 1)))
ttl = int(self.client.ttl(key))
if ttl < 0:
ttl = int(max(window_seconds, 1))
return RateLimitResult(allowed=count <= limit, retry_after_seconds=ttl, current_value=count)
_cached_limiter: RateLimiter | None = None
def _build_limiter() -> RateLimiter:
try:
client = redis.Redis.from_url(
settings.REDIS_URL,
decode_responses=True,
socket_timeout=0.4,
socket_connect_timeout=0.4,
)
client.ping()
return RedisRateLimiter(client)
except Exception:
_LOG.warning("Redis limiter unavailable; fallback to in-memory limiter")
return InMemoryRateLimiter()
def get_rate_limiter() -> RateLimiter:
global _cached_limiter
if _cached_limiter is None:
_cached_limiter = _build_limiter()
return _cached_limiter
def reset_rate_limiter_for_tests() -> None:
global _cached_limiter
_cached_limiter = None