Start to refactor tracker to cache user IPs

cache-user-ips-for-tracker
Kevin Chung 2020-04-29 18:30:17 -04:00
parent f7e7c3c337
commit 817b67d1b0
2 changed files with 27 additions and 5 deletions

View File

@ -38,7 +38,14 @@ from CTFd.utils.plugins import (
) )
from CTFd.utils.security.auth import login_user, logout_user, lookup_user_token from CTFd.utils.security.auth import login_user, logout_user, lookup_user_token
from CTFd.utils.security.csrf import generate_nonce from CTFd.utils.security.csrf import generate_nonce
from CTFd.utils.user import authed, get_current_team, get_current_user, get_ip, is_admin from CTFd.utils.user import (
authed,
get_current_team,
get_current_user,
get_ip,
get_user_ips,
is_admin,
)
def init_template_filters(app): def init_template_filters(app):
@ -170,12 +177,17 @@ def init_request_processors(app):
return return
if authed(): if authed():
track = Tracking.query.filter_by(ip=get_ip(), user_id=session["id"]).first() user_ips = get_user_ips(user_id=session["id"])
if not track: ip = get_ip()
if ip not in user_ips:
visit = Tracking(ip=get_ip(), user_id=session["id"]) visit = Tracking(ip=get_ip(), user_id=session["id"])
db.session.add(visit) db.session.add(visit)
else: else:
track.date = datetime.datetime.utcnow() if request.method != "GET":
track = Tracking.query.filter_by(
ip=get_ip(), user_id=session["id"]
).first()
track.date = datetime.datetime.utcnow()
try: try:
db.session.commit() db.session.commit()

View File

@ -4,7 +4,8 @@ import re
from flask import current_app as app from flask import current_app as app
from flask import request, session from flask import request, session
from CTFd.models import Fails, Users, db from CTFd.cache import cache
from CTFd.models import Fails, Users, db, Tracking
from CTFd.utils import get_config from CTFd.utils import get_config
@ -80,6 +81,15 @@ def get_ip(req=None):
return remote_addr return remote_addr
def get_user_ips(user_id):
addrs = (
Tracking.query.with_entities(Tracking.ip.distinct())
.filter_by(user_id=user_id)
.all()
)
return [ip for ip, in addrs]
def get_wrong_submissions_per_minute(account_id): def get_wrong_submissions_per_minute(account_id):
""" """
Get incorrect submissions per minute. Get incorrect submissions per minute.