Fix reset tests and enforce foreign keys on sqlite

table-granular-admin-reset
Kevin Chung 2020-04-28 02:08:05 -04:00
parent 6092ed1f31
commit 1d33ed4cb2
2 changed files with 132 additions and 20 deletions

View File

@ -179,10 +179,9 @@ def create_app(config="CTFd.config.Config"):
# Alembic sqlite support is lacking so we should just create_all anyway # Alembic sqlite support is lacking so we should just create_all anyway
if url.drivername.startswith("sqlite"): if url.drivername.startswith("sqlite"):
db.create_all() # Enable foreign keys for SQLite. This must be before the
stamp_latest_revision() # db.create_all call because tests use the in-memory SQLite
# database (each connection, including db creation, is a new db).
# Enable foreign keys for SQLite
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy import event from sqlalchemy import event
@ -192,6 +191,8 @@ def create_app(config="CTFd.config.Config"):
cursor.execute("PRAGMA foreign_keys=ON") cursor.execute("PRAGMA foreign_keys=ON")
cursor.close() cursor.close()
db.create_all()
stamp_latest_revision()
else: else:
# This creates tables instead of db.create_all() # This creates tables instead of db.create_all()
# Allows migrations to happen properly # Allows migrations to happen properly

View File

@ -1,6 +1,21 @@
import random import random
from CTFd.models import Challenges, Fails, Solves, Teams, Tracking, Users from CTFd.models import (
Awards,
Challenges,
Fails,
Flags,
Hints,
Notifications,
Pages,
Solves,
Submissions,
Tags,
Teams,
Tracking,
Unlocks,
Users,
)
from tests.helpers import ( from tests.helpers import (
create_ctfd, create_ctfd,
destroy_ctfd, destroy_ctfd,
@ -8,12 +23,12 @@ from tests.helpers import (
gen_challenge, gen_challenge,
gen_fail, gen_fail,
gen_flag, gen_flag,
gen_hint,
gen_solve, gen_solve,
gen_team, gen_team,
gen_tracking, gen_tracking,
gen_user, gen_user,
login_as_user, login_as_user,
register_user,
) )
@ -25,6 +40,7 @@ def test_reset():
for x in range(10): for x in range(10):
chal = gen_challenge(app.db, name="chal_name{}".format(x)) chal = gen_challenge(app.db, name="chal_name{}".format(x))
gen_flag(app.db, challenge_id=chal.id, content="flag") gen_flag(app.db, challenge_id=chal.id, content="flag")
gen_hint(app.db, challenge_id=chal.id)
for x in range(10): for x in range(10):
user = base_user + str(x) user = base_user + str(x)
@ -37,16 +53,62 @@ def test_reset():
assert Users.query.count() == 11 # 11 because of the first admin user assert Users.query.count() == 11 # 11 because of the first admin user
assert Challenges.query.count() == 10 assert Challenges.query.count() == 10
assert Flags.query.count() == 10
assert Hints.query.count() == 10
assert Submissions.query.count() == 20
assert Pages.query.count() == 1
assert Tracking.query.count() == 10
register_user(app)
client = login_as_user(app, name="admin", password="password") client = login_as_user(app, name="admin", password="password")
with client.session_transaction() as sess: with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce")} data = {"nonce": sess.get("nonce"), "submissions": "on"}
client.post("/admin/reset", data=data) r = client.post("/admin/reset", data=data)
assert r.location.endswith("/admin/statistics")
assert Users.query.count() == 0 assert Submissions.query.count() == 0
assert Solves.query.count() == 0
assert Fails.query.count() == 0
assert Awards.query.count() == 0
assert Unlocks.query.count() == 0
assert Users.query.count() == 11
assert Challenges.query.count() == 10 assert Challenges.query.count() == 10
assert Flags.query.count() == 10
assert Tracking.query.count() == 11
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "pages": "on"}
r = client.post("/admin/reset", data=data)
assert r.location.endswith("/admin/statistics")
assert Pages.query.count() == 0
assert Users.query.count() == 11
assert Challenges.query.count() == 10
assert Tracking.query.count() == 11
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "notifications": "on"}
r = client.post("/admin/reset", data=data)
assert r.location.endswith("/admin/statistics")
assert Notifications.query.count() == 0
assert Users.query.count() == 11
assert Challenges.query.count() == 10
assert Tracking.query.count() == 11
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "challenges": "on"}
r = client.post("/admin/reset", data=data)
assert r.location.endswith("/admin/statistics")
assert Challenges.query.count() == 0
assert Flags.query.count() == 0
assert Hints.query.count() == 0
assert Tags.query.count() == 0
assert Users.query.count() == 11
assert Tracking.query.count() == 11
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "accounts": "on"}
r = client.post("/admin/reset", data=data)
assert r.location.endswith("/setup")
assert Users.query.count() == 0
assert Solves.query.count() == 0 assert Solves.query.count() == 0
assert Fails.query.count() == 0 assert Fails.query.count() == 0
assert Tracking.query.count() == 0 assert Tracking.query.count() == 0
@ -62,6 +124,7 @@ def test_reset_team_mode():
for x in range(10): for x in range(10):
chal = gen_challenge(app.db, name="chal_name{}".format(x)) chal = gen_challenge(app.db, name="chal_name{}".format(x))
gen_flag(app.db, challenge_id=chal.id, content="flag") gen_flag(app.db, challenge_id=chal.id, content="flag")
gen_hint(app.db, challenge_id=chal.id)
for x in range(10): for x in range(10):
user = base_user + str(x) user = base_user + str(x)
@ -79,21 +142,69 @@ def test_reset_team_mode():
gen_tracking(app.db, user_id=user_obj.id) gen_tracking(app.db, user_id=user_obj.id)
assert Teams.query.count() == 10 assert Teams.query.count() == 10
assert ( # 10 random users, 40 users (10 teams * 4), 1 admin user
Users.query.count() == 51 assert Users.query.count() == 51
) # 10 random users, 40 users (10 teams * 4), 1 admin user
assert Challenges.query.count() == 10 assert Challenges.query.count() == 10
assert Submissions.query.count() == 20
assert Solves.query.count() == 10
assert Fails.query.count() == 10
assert Tracking.query.count() == 10
register_user(app)
client = login_as_user(app, name="admin", password="password") client = login_as_user(app, name="admin", password="password")
with client.session_transaction() as sess: with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce")} data = {"nonce": sess.get("nonce"), "submissions": "on"}
client.post("/admin/reset", data=data) r = client.post("/admin/reset", data=data)
assert r.location.endswith("/admin/statistics")
assert Teams.query.count() == 0 assert Submissions.query.count() == 0
assert Users.query.count() == 0 assert Solves.query.count() == 0
assert Fails.query.count() == 0
assert Awards.query.count() == 0
assert Unlocks.query.count() == 0
assert Teams.query.count() == 10
assert Users.query.count() == 51
assert Challenges.query.count() == 10 assert Challenges.query.count() == 10
assert Flags.query.count() == 10
assert Tracking.query.count() == 11
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "pages": "on"}
r = client.post("/admin/reset", data=data)
assert r.location.endswith("/admin/statistics")
assert Pages.query.count() == 0
assert Teams.query.count() == 10
assert Users.query.count() == 51
assert Challenges.query.count() == 10
assert Tracking.query.count() == 11
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "notifications": "on"}
r = client.post("/admin/reset", data=data)
assert r.location.endswith("/admin/statistics")
assert Notifications.query.count() == 0
assert Teams.query.count() == 10
assert Users.query.count() == 51
assert Challenges.query.count() == 10
assert Tracking.query.count() == 11
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "challenges": "on"}
r = client.post("/admin/reset", data=data)
assert r.location.endswith("/admin/statistics")
assert Challenges.query.count() == 0
assert Flags.query.count() == 0
assert Hints.query.count() == 0
assert Tags.query.count() == 0
assert Teams.query.count() == 10
assert Users.query.count() == 51
assert Tracking.query.count() == 11
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "accounts": "on"}
r = client.post("/admin/reset", data=data)
assert r.location.endswith("/setup")
assert Users.query.count() == 0
assert Teams.query.count() == 0
assert Solves.query.count() == 0 assert Solves.query.count() == 0
assert Fails.query.count() == 0 assert Fails.query.count() == 0
assert Tracking.query.count() == 0 assert Tracking.query.count() == 0