Source code for converge.extensions.rate_limit

from __future__ import annotations

import time
from dataclasses import dataclass
from typing import Any

from converge.core.message import Message
from converge.core.store import Store
from converge.extensions.storage.memory import MemoryStore
from converge.network.transport.hooks import MessageHook
from converge.observability.metrics import MetricsCollector


[docs] @dataclass(frozen=True) class TokenBucketConfig: capacity: float refill_tokens_per_sec: float
[docs] class RateLimiter: def __init__( self, *, store: Store | None = None, global_config: TokenBucketConfig | None = None, sender_config: TokenBucketConfig | None = None, topic_config: TokenBucketConfig | None = None, ) -> None: self.store = store or MemoryStore() self.global_config = global_config self.sender_config = sender_config self.topic_config = topic_config
[docs] def allow_message(self, message: Message, *, direction: str) -> bool: checks: list[tuple[str, TokenBucketConfig]] = [] if self.global_config is not None: checks.append((f"{direction}:global", self.global_config)) if self.sender_config is not None and message.sender: checks.append((f"{direction}:sender:{message.sender}", self.sender_config)) if self.topic_config is not None: for topic in message.topics: checks.append((f"{direction}:topic:{topic.namespace}", self.topic_config)) return all(self._consume(key, cfg) for key, cfg in checks)
def _consume(self, key: str, cfg: TokenBucketConfig) -> bool: now = time.monotonic() state_key = f"rate_limit:{key}" state = self.store.get(state_key) if not isinstance(state, dict): tokens = cfg.capacity last_ts = now else: tokens = float(state.get("tokens", cfg.capacity)) last_ts = float(state.get("last_ts", now)) elapsed = max(0.0, now - last_ts) replenished = min(cfg.capacity, tokens + (elapsed * cfg.refill_tokens_per_sec)) if replenished < 1.0: self.store.put(state_key, {"tokens": replenished, "last_ts": now}) return False self.store.put(state_key, {"tokens": replenished - 1.0, "last_ts": now}) return True
[docs] class RateLimitHook(MessageHook): def __init__( self, rate_limiter: RateLimiter, *, metrics_collector: MetricsCollector | None = None, ) -> None: self.rate_limiter = rate_limiter self.metrics = metrics_collector
[docs] def pre_send(self, message: Message) -> Message | None: if self.rate_limiter.allow_message(message, direction="egress"): return message if self.metrics: self.metrics.inc("rate_limit_egress_dropped_total") return None
[docs] def post_receive(self, message: Message) -> Message | None: if self.rate_limiter.allow_message(message, direction="ingress"): return message if self.metrics: self.metrics.inc("rate_limit_ingress_dropped_total") return None
[docs] def on_error(self, stage: str, error: Exception, context: dict[str, Any]) -> None: _ = (stage, error, context)