Building Custom Strategies¶
The built-in strategies cover the vast majority of rate limiting needs. But Traffik is built for extension, and writing your own strategy is surprisingly straightforward.
Maybe you need a strategy that accounts for user tier. Maybe you want a "quota with rollover" that carries unused tokens to the next window. Maybe you've discovered a new algorithm in a paper and want to try it. Whatever the reason, here's how.
The Strategy Protocol¶
A strategy is any async callable that matches this signature:
async def __call__(
self,
key: Stringable, # The namespaced throttle key for this client
rate: Rate, # The rate limit definition
backend: ThrottleBackend, # The backend to read/write from
cost: int = 1, # How many units this request costs
) -> WaitPeriod: # Milliseconds to wait (0 = allow, >0 = throttle)
That's the entire contract. Return 0.0 to allow the request. Return a positive number of milliseconds to throttle the client.
Add get_stat(...) for observability
If you also implement get_stat(key, rate, backend) -> StrategyStat, your strategy will support throttle.stat() calls and all the observability features that come with it. It's optional but highly recommended.
Full Example: Sliding Quota with Priority¶
Here's a complete custom strategy that gives "priority" clients a higher limit than regular clients by reading a flag from the context.
Actually, a simpler and more useful example: a ResetOnFirstHit strategy that tracks when a client first used the API in a window, and gives them the full window from that point (rather than aligning to clock boundaries):
from dataclasses import dataclass, field
from typing import Optional
from traffik.backends.base import ThrottleBackend
from traffik.rates import Rate
from traffik.types import LockConfig, StrategyStat, Stringable, WaitPeriod
from traffik.utils import time
@dataclass(frozen=True)
class RollingWindowStrategy:
"""
A rolling window strategy that starts the window on the client's first request.
Unlike FixedWindow (which aligns to clock boundaries), this gives each client
a full 'rate.expire' milliseconds from their first request.
"""
lock_config: LockConfig = field(default_factory=dict)
async def __call__(
self,
key: Stringable,
rate: Rate,
backend: ThrottleBackend,
cost: int = 1,
) -> WaitPeriod:
# Always check for unlimited rate first — zero overhead fast path
if rate.unlimited:
return 0.0
now_ms = time() * 1000
full_key = backend.get_key(str(key))
start_key = f"{full_key}:rolling:start"
count_key = f"{full_key}:rolling:count"
ttl_seconds = max(int(rate.expire // 1000), 1)
# Multi-step: need a lock to prevent races
async with await backend.lock(f"lock:{full_key}:rolling", **self.lock_config):
window_start = await backend.get(start_key)
if window_start is None:
# First request from this client — start their window now
await backend.multi_set(
{
start_key: str(now_ms),
count_key: str(cost),
},
expire=ttl_seconds,
)
return 0.0
window_start_ms = float(window_start)
window_end_ms = window_start_ms + rate.expire
if now_ms >= window_end_ms:
# Window expired — start a fresh one
await backend.multi_set(
{
start_key: str(now_ms),
count_key: str(cost),
},
expire=ttl_seconds,
)
return 0.0
# Inside the window — increment and check
counter = await backend.increment_with_ttl(count_key, amount=cost, ttl=ttl_seconds)
if counter > rate.limit:
# Over the limit — tell client when their window resets
wait_ms = window_end_ms - now_ms
return max(wait_ms, 0.0)
return 0.0
async def get_stat(
self,
key: Stringable,
rate: Rate,
backend: ThrottleBackend,
) -> StrategyStat:
now_ms = time() * 1000
full_key = backend.get_key(str(key))
start_key = f"{full_key}:rolling:start"
count_key = f"{full_key}:rolling:count"
window_start_raw, count_raw = await backend.multi_get(start_key, count_key)
if window_start_raw is None:
return StrategyStat(
key=key,
rate=rate,
hits_remaining=float(rate.limit),
wait_ms=0.0,
)
window_start_ms = float(window_start_raw)
window_end_ms = window_start_ms + rate.expire
counter = int(count_raw) if count_raw else 0
if now_ms >= window_end_ms:
# Window expired — fresh slate
return StrategyStat(
key=key,
rate=rate,
hits_remaining=float(rate.limit),
wait_ms=0.0,
)
hits_remaining = max(rate.limit - counter, 0)
wait_ms = max(window_end_ms - now_ms, 0.0) if counter > rate.limit else 0.0
return StrategyStat(
key=key,
rate=rate,
hits_remaining=hits_remaining,
wait_ms=wait_ms,
)
Use it like any built-in strategy:
from traffik import HTTPThrottle
throttle = HTTPThrottle(
"api:rolling",
rate="100/min",
strategy=RollingWindowStrategy(),
)
Best Practices¶
1. Handle rate.unlimited First¶
This is the zero-overhead fast path. Never skip it:
async def __call__(self, key, rate, backend, cost=1):
if rate.unlimited:
return 0.0
# ... your logic
2. Use @dataclass(frozen=True) for Configuration¶
Frozen dataclasses prevent accidental mutation and make your strategy safe to share across threads and requests:
from dataclasses import dataclass, field
from traffik.types import LockConfig
@dataclass(frozen=True)
class MyStrategy:
burst_multiplier: float = 1.5
lock_config: LockConfig = field(default_factory=dict)
3. Use backend.lock() for Multi-Step Operations¶
Any strategy that reads and then writes needs a lock to prevent race conditions under concurrency. The lock key should be derived from the throttle key:
async with await backend.lock(f"lock:{full_key}:mystrategy", **self.lock_config):
old_value = await backend.get(some_key)
new_value = compute(old_value)
await backend.set(some_key, new_value, expire=ttl)
4. Use increment_with_ttl() When Possible¶
This is an atomic increment-and-set-TTL operation — much more efficient than get() + increment() + expire() with a lock:
# Good: single atomic operation
counter = await backend.increment_with_ttl(counter_key, amount=cost, ttl=ttl_seconds)
# Less efficient: three operations under a lock
async with await backend.lock(...):
counter = await backend.increment(counter_key, cost)
await backend.expire(counter_key, ttl_seconds)
5. Always Set TTLs¶
Backend keys that never expire are a memory leak. Set TTLs on everything:
ttl_seconds = max(int(rate.expire // 1000), 1) # At least 1 second
await backend.set(key, value, expire=ttl_seconds)
6. Return Milliseconds, Not Seconds¶
WaitPeriod is in milliseconds. The rate limit window (rate.expire) is also in milliseconds. Don't mix units:
# Correct: milliseconds
wait_ms = window_end_ms - now_ms
return max(wait_ms, 0.0)
# Wrong: accidentally returning seconds
return wait_ms / 1000 # This would be almost always 0
7. Return 0.0 (not None) to Allow¶
Returning 0 or 0.0 means "allow". Returning None is not valid — always return a float.
Summary Checklist¶
Before shipping your custom strategy:
- Handle
rate.unlimitedat the top withreturn 0.0 - Use
@dataclass(frozen=True)for the class - Include
lock_config: LockConfigfor configurable locking - Use
backend.lock()for any multi-step read/write sequence - Prefer
backend.increment_with_ttl()over separate increment + expire - Set TTLs on all backend keys
- Return milliseconds (not seconds)
- Implement
get_stat()for observability - Test under concurrency to verify correctness