mirror of https://github.com/JohnHammond/CTFd.git
Fix reset tests and enforce foreign keys on sqlite
parent
6092ed1f31
commit
1d33ed4cb2
|
@ -179,10 +179,9 @@ def create_app(config="CTFd.config.Config"):
|
|||
|
||||
# Alembic sqlite support is lacking so we should just create_all anyway
|
||||
if url.drivername.startswith("sqlite"):
|
||||
db.create_all()
|
||||
stamp_latest_revision()
|
||||
|
||||
# Enable foreign keys for SQLite
|
||||
# Enable foreign keys for SQLite. This must be before the
|
||||
# db.create_all call because tests use the in-memory SQLite
|
||||
# database (each connection, including db creation, is a new db).
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import event
|
||||
|
||||
|
@ -192,6 +191,8 @@ def create_app(config="CTFd.config.Config"):
|
|||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
db.create_all()
|
||||
stamp_latest_revision()
|
||||
else:
|
||||
# This creates tables instead of db.create_all()
|
||||
# Allows migrations to happen properly
|
||||
|
|
|
@ -1,6 +1,21 @@
|
|||
import random
|
||||
|
||||
from CTFd.models import Challenges, Fails, Solves, Teams, Tracking, Users
|
||||
from CTFd.models import (
|
||||
Awards,
|
||||
Challenges,
|
||||
Fails,
|
||||
Flags,
|
||||
Hints,
|
||||
Notifications,
|
||||
Pages,
|
||||
Solves,
|
||||
Submissions,
|
||||
Tags,
|
||||
Teams,
|
||||
Tracking,
|
||||
Unlocks,
|
||||
Users,
|
||||
)
|
||||
from tests.helpers import (
|
||||
create_ctfd,
|
||||
destroy_ctfd,
|
||||
|
@ -8,12 +23,12 @@ from tests.helpers import (
|
|||
gen_challenge,
|
||||
gen_fail,
|
||||
gen_flag,
|
||||
gen_hint,
|
||||
gen_solve,
|
||||
gen_team,
|
||||
gen_tracking,
|
||||
gen_user,
|
||||
login_as_user,
|
||||
register_user,
|
||||
)
|
||||
|
||||
|
||||
|
@ -25,6 +40,7 @@ def test_reset():
|
|||
for x in range(10):
|
||||
chal = gen_challenge(app.db, name="chal_name{}".format(x))
|
||||
gen_flag(app.db, challenge_id=chal.id, content="flag")
|
||||
gen_hint(app.db, challenge_id=chal.id)
|
||||
|
||||
for x in range(10):
|
||||
user = base_user + str(x)
|
||||
|
@ -37,16 +53,62 @@ def test_reset():
|
|||
|
||||
assert Users.query.count() == 11 # 11 because of the first admin user
|
||||
assert Challenges.query.count() == 10
|
||||
assert Flags.query.count() == 10
|
||||
assert Hints.query.count() == 10
|
||||
assert Submissions.query.count() == 20
|
||||
assert Pages.query.count() == 1
|
||||
assert Tracking.query.count() == 10
|
||||
|
||||
register_user(app)
|
||||
client = login_as_user(app, name="admin", password="password")
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce")}
|
||||
client.post("/admin/reset", data=data)
|
||||
|
||||
assert Users.query.count() == 0
|
||||
data = {"nonce": sess.get("nonce"), "submissions": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/admin/statistics")
|
||||
assert Submissions.query.count() == 0
|
||||
assert Solves.query.count() == 0
|
||||
assert Fails.query.count() == 0
|
||||
assert Awards.query.count() == 0
|
||||
assert Unlocks.query.count() == 0
|
||||
assert Users.query.count() == 11
|
||||
assert Challenges.query.count() == 10
|
||||
assert Flags.query.count() == 10
|
||||
assert Tracking.query.count() == 11
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce"), "pages": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/admin/statistics")
|
||||
assert Pages.query.count() == 0
|
||||
assert Users.query.count() == 11
|
||||
assert Challenges.query.count() == 10
|
||||
assert Tracking.query.count() == 11
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce"), "notifications": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/admin/statistics")
|
||||
assert Notifications.query.count() == 0
|
||||
assert Users.query.count() == 11
|
||||
assert Challenges.query.count() == 10
|
||||
assert Tracking.query.count() == 11
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce"), "challenges": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/admin/statistics")
|
||||
assert Challenges.query.count() == 0
|
||||
assert Flags.query.count() == 0
|
||||
assert Hints.query.count() == 0
|
||||
assert Tags.query.count() == 0
|
||||
assert Users.query.count() == 11
|
||||
assert Tracking.query.count() == 11
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce"), "accounts": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/setup")
|
||||
assert Users.query.count() == 0
|
||||
assert Solves.query.count() == 0
|
||||
assert Fails.query.count() == 0
|
||||
assert Tracking.query.count() == 0
|
||||
|
@ -62,6 +124,7 @@ def test_reset_team_mode():
|
|||
for x in range(10):
|
||||
chal = gen_challenge(app.db, name="chal_name{}".format(x))
|
||||
gen_flag(app.db, challenge_id=chal.id, content="flag")
|
||||
gen_hint(app.db, challenge_id=chal.id)
|
||||
|
||||
for x in range(10):
|
||||
user = base_user + str(x)
|
||||
|
@ -79,21 +142,69 @@ def test_reset_team_mode():
|
|||
gen_tracking(app.db, user_id=user_obj.id)
|
||||
|
||||
assert Teams.query.count() == 10
|
||||
assert (
|
||||
Users.query.count() == 51
|
||||
) # 10 random users, 40 users (10 teams * 4), 1 admin user
|
||||
# 10 random users, 40 users (10 teams * 4), 1 admin user
|
||||
assert Users.query.count() == 51
|
||||
assert Challenges.query.count() == 10
|
||||
assert Submissions.query.count() == 20
|
||||
assert Solves.query.count() == 10
|
||||
assert Fails.query.count() == 10
|
||||
assert Tracking.query.count() == 10
|
||||
|
||||
register_user(app)
|
||||
client = login_as_user(app, name="admin", password="password")
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce")}
|
||||
client.post("/admin/reset", data=data)
|
||||
|
||||
assert Teams.query.count() == 0
|
||||
assert Users.query.count() == 0
|
||||
data = {"nonce": sess.get("nonce"), "submissions": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/admin/statistics")
|
||||
assert Submissions.query.count() == 0
|
||||
assert Solves.query.count() == 0
|
||||
assert Fails.query.count() == 0
|
||||
assert Awards.query.count() == 0
|
||||
assert Unlocks.query.count() == 0
|
||||
assert Teams.query.count() == 10
|
||||
assert Users.query.count() == 51
|
||||
assert Challenges.query.count() == 10
|
||||
assert Flags.query.count() == 10
|
||||
assert Tracking.query.count() == 11
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce"), "pages": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/admin/statistics")
|
||||
assert Pages.query.count() == 0
|
||||
assert Teams.query.count() == 10
|
||||
assert Users.query.count() == 51
|
||||
assert Challenges.query.count() == 10
|
||||
assert Tracking.query.count() == 11
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce"), "notifications": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/admin/statistics")
|
||||
assert Notifications.query.count() == 0
|
||||
assert Teams.query.count() == 10
|
||||
assert Users.query.count() == 51
|
||||
assert Challenges.query.count() == 10
|
||||
assert Tracking.query.count() == 11
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce"), "challenges": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/admin/statistics")
|
||||
assert Challenges.query.count() == 0
|
||||
assert Flags.query.count() == 0
|
||||
assert Hints.query.count() == 0
|
||||
assert Tags.query.count() == 0
|
||||
assert Teams.query.count() == 10
|
||||
assert Users.query.count() == 51
|
||||
assert Tracking.query.count() == 11
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
data = {"nonce": sess.get("nonce"), "accounts": "on"}
|
||||
r = client.post("/admin/reset", data=data)
|
||||
assert r.location.endswith("/setup")
|
||||
assert Users.query.count() == 0
|
||||
assert Teams.query.count() == 0
|
||||
assert Solves.query.count() == 0
|
||||
assert Fails.query.count() == 0
|
||||
assert Tracking.query.count() == 0
|
||||
|
|
Loading…
Reference in New Issue