diff --git a/CTFd/__init__.py b/CTFd/__init__.py index 0005486..dbc0b66 100644 --- a/CTFd/__init__.py +++ b/CTFd/__init__.py @@ -179,10 +179,9 @@ def create_app(config="CTFd.config.Config"): # Alembic sqlite support is lacking so we should just create_all anyway if url.drivername.startswith("sqlite"): - db.create_all() - stamp_latest_revision() - - # Enable foreign keys for SQLite + # Enable foreign keys for SQLite. This must be before the + # db.create_all call because tests use the in-memory SQLite + # database (each connection, including db creation, is a new db). from sqlalchemy.engine import Engine from sqlalchemy import event @@ -192,6 +191,8 @@ def create_app(config="CTFd.config.Config"): cursor.execute("PRAGMA foreign_keys=ON") cursor.close() + db.create_all() + stamp_latest_revision() else: # This creates tables instead of db.create_all() # Allows migrations to happen properly diff --git a/tests/admin/test_config.py b/tests/admin/test_config.py index 5451351..43578d3 100644 --- a/tests/admin/test_config.py +++ b/tests/admin/test_config.py @@ -1,6 +1,21 @@ import random -from CTFd.models import Challenges, Fails, Solves, Teams, Tracking, Users +from CTFd.models import ( + Awards, + Challenges, + Fails, + Flags, + Hints, + Notifications, + Pages, + Solves, + Submissions, + Tags, + Teams, + Tracking, + Unlocks, + Users, +) from tests.helpers import ( create_ctfd, destroy_ctfd, @@ -8,12 +23,12 @@ from tests.helpers import ( gen_challenge, gen_fail, gen_flag, + gen_hint, gen_solve, gen_team, gen_tracking, gen_user, login_as_user, - register_user, ) @@ -25,6 +40,7 @@ def test_reset(): for x in range(10): chal = gen_challenge(app.db, name="chal_name{}".format(x)) gen_flag(app.db, challenge_id=chal.id, content="flag") + gen_hint(app.db, challenge_id=chal.id) for x in range(10): user = base_user + str(x) @@ -37,16 +53,62 @@ def test_reset(): assert Users.query.count() == 11 # 11 because of the first admin user assert Challenges.query.count() == 10 + assert Flags.query.count() == 10 + assert Hints.query.count() == 10 + assert Submissions.query.count() == 20 + assert Pages.query.count() == 1 + assert Tracking.query.count() == 10 - register_user(app) client = login_as_user(app, name="admin", password="password") with client.session_transaction() as sess: - data = {"nonce": sess.get("nonce")} - client.post("/admin/reset", data=data) - - assert Users.query.count() == 0 + data = {"nonce": sess.get("nonce"), "submissions": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/admin/statistics") + assert Submissions.query.count() == 0 + assert Solves.query.count() == 0 + assert Fails.query.count() == 0 + assert Awards.query.count() == 0 + assert Unlocks.query.count() == 0 + assert Users.query.count() == 11 assert Challenges.query.count() == 10 + assert Flags.query.count() == 10 + assert Tracking.query.count() == 11 + + with client.session_transaction() as sess: + data = {"nonce": sess.get("nonce"), "pages": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/admin/statistics") + assert Pages.query.count() == 0 + assert Users.query.count() == 11 + assert Challenges.query.count() == 10 + assert Tracking.query.count() == 11 + + with client.session_transaction() as sess: + data = {"nonce": sess.get("nonce"), "notifications": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/admin/statistics") + assert Notifications.query.count() == 0 + assert Users.query.count() == 11 + assert Challenges.query.count() == 10 + assert Tracking.query.count() == 11 + + with client.session_transaction() as sess: + data = {"nonce": sess.get("nonce"), "challenges": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/admin/statistics") + assert Challenges.query.count() == 0 + assert Flags.query.count() == 0 + assert Hints.query.count() == 0 + assert Tags.query.count() == 0 + assert Users.query.count() == 11 + assert Tracking.query.count() == 11 + + with client.session_transaction() as sess: + data = {"nonce": sess.get("nonce"), "accounts": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/setup") + assert Users.query.count() == 0 assert Solves.query.count() == 0 assert Fails.query.count() == 0 assert Tracking.query.count() == 0 @@ -62,6 +124,7 @@ def test_reset_team_mode(): for x in range(10): chal = gen_challenge(app.db, name="chal_name{}".format(x)) gen_flag(app.db, challenge_id=chal.id, content="flag") + gen_hint(app.db, challenge_id=chal.id) for x in range(10): user = base_user + str(x) @@ -79,21 +142,69 @@ def test_reset_team_mode(): gen_tracking(app.db, user_id=user_obj.id) assert Teams.query.count() == 10 - assert ( - Users.query.count() == 51 - ) # 10 random users, 40 users (10 teams * 4), 1 admin user + # 10 random users, 40 users (10 teams * 4), 1 admin user + assert Users.query.count() == 51 assert Challenges.query.count() == 10 + assert Submissions.query.count() == 20 + assert Solves.query.count() == 10 + assert Fails.query.count() == 10 + assert Tracking.query.count() == 10 - register_user(app) client = login_as_user(app, name="admin", password="password") with client.session_transaction() as sess: - data = {"nonce": sess.get("nonce")} - client.post("/admin/reset", data=data) - - assert Teams.query.count() == 0 - assert Users.query.count() == 0 + data = {"nonce": sess.get("nonce"), "submissions": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/admin/statistics") + assert Submissions.query.count() == 0 + assert Solves.query.count() == 0 + assert Fails.query.count() == 0 + assert Awards.query.count() == 0 + assert Unlocks.query.count() == 0 + assert Teams.query.count() == 10 + assert Users.query.count() == 51 assert Challenges.query.count() == 10 + assert Flags.query.count() == 10 + assert Tracking.query.count() == 11 + + with client.session_transaction() as sess: + data = {"nonce": sess.get("nonce"), "pages": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/admin/statistics") + assert Pages.query.count() == 0 + assert Teams.query.count() == 10 + assert Users.query.count() == 51 + assert Challenges.query.count() == 10 + assert Tracking.query.count() == 11 + + with client.session_transaction() as sess: + data = {"nonce": sess.get("nonce"), "notifications": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/admin/statistics") + assert Notifications.query.count() == 0 + assert Teams.query.count() == 10 + assert Users.query.count() == 51 + assert Challenges.query.count() == 10 + assert Tracking.query.count() == 11 + + with client.session_transaction() as sess: + data = {"nonce": sess.get("nonce"), "challenges": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/admin/statistics") + assert Challenges.query.count() == 0 + assert Flags.query.count() == 0 + assert Hints.query.count() == 0 + assert Tags.query.count() == 0 + assert Teams.query.count() == 10 + assert Users.query.count() == 51 + assert Tracking.query.count() == 11 + + with client.session_transaction() as sess: + data = {"nonce": sess.get("nonce"), "accounts": "on"} + r = client.post("/admin/reset", data=data) + assert r.location.endswith("/setup") + assert Users.query.count() == 0 + assert Teams.query.count() == 0 assert Solves.query.count() == 0 assert Fails.query.count() == 0 assert Tracking.query.count() == 0