import time from functools import update_wrapper from flask import request, g, jsonify, session, flash, redirect class RateLimit(object): expiration_window = 10 def __init__(self, key_prefix, limit, per, send_x_headers): self.reset = (int(time.time()) // per) * per + per self.key = key_prefix + str(self.reset) self.limit = limit self.per = per self.send_x_headers = send_x_headers p = g.redis.pipeline() p.incr(self.key) p.expireat(self.key, self.reset + self.expiration_window) self.current = min(p.execute()[0], limit) remaining = property(lambda x: x.limit - x.current) over_limit = property(lambda x: x.current >= x.limit) def get_view_rate_limit(): return getattr(g, '_view_rate_limit', None) def on_over_limit(limit): flash("You are doing that too fast!") return redirect(request.path) def on_over_api_limit(limit): return jsonify(dict(code=1000, message="You are doing that too fast!")) def scope_func(): id = str(request.remote_addr) if g.logged_in: id += "/%s" % (session["user_id"]) return id def ratelimit(limit, per=300, send_x_headers=True, methods=["POST"], over_limit=on_over_limit, scope_func=scope_func, key_func=lambda: request.endpoint): def decorator(f): def rate_limited(*args, **kwargs): if request.method in methods: key = 'rate-limit/%s/%s/' % (key_func(), scope_func()) rlimit = RateLimit(key, limit, per, send_x_headers) g._view_rate_limit = rlimit if over_limit is not None and rlimit.over_limit: return over_limit(rlimit) return f(*args, **kwargs) return update_wrapper(rate_limited, f) return decorator