mirror of https://github.com/JohnHammond/CTFd.git
Disable foreign keys during import (#1031)
* Temporarily disable foreign keys in MySQL, MariaDB, and Postgres during `import_ctf()` * Likely also disables SQLite but SQLite is permissive about foreign keys to begin with so irrelevant.selenium-screenshot-testing
parent
ff0f2c2a0b
commit
447f71d41b
|
@ -143,6 +143,14 @@ def import_ctf(backup, erase=True):
|
||||||
if info.file_size > max_content_length:
|
if info.file_size > max_content_length:
|
||||||
raise zipfile.LargeZipFile
|
raise zipfile.LargeZipFile
|
||||||
|
|
||||||
|
try:
|
||||||
|
if postgres:
|
||||||
|
side_db.query("SET session_replication_role=replica;")
|
||||||
|
else:
|
||||||
|
side_db.query("SET FOREIGN_KEY_CHECKS=0;")
|
||||||
|
except Exception:
|
||||||
|
print("Failed to disable foreign key checks. Continuing.")
|
||||||
|
|
||||||
first = [
|
first = [
|
||||||
"db/teams.json",
|
"db/teams.json",
|
||||||
"db/users.json",
|
"db/users.json",
|
||||||
|
@ -280,6 +288,14 @@ def import_ctf(backup, erase=True):
|
||||||
app.db.create_all()
|
app.db.create_all()
|
||||||
stamp_latest_revision()
|
stamp_latest_revision()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if postgres:
|
||||||
|
side_db.query("SET session_replication_role=DEFAULT;")
|
||||||
|
else:
|
||||||
|
side_db.query("SET FOREIGN_KEY_CHECKS=1;")
|
||||||
|
except Exception:
|
||||||
|
print("Failed to enable foreign key checks. Continuing.")
|
||||||
|
|
||||||
# Invalidate all cached data
|
# Invalidate all cached data
|
||||||
cache.clear()
|
cache.clear()
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,10 @@ from tests.helpers import (
|
||||||
gen_challenge,
|
gen_challenge,
|
||||||
gen_flag,
|
gen_flag,
|
||||||
gen_user,
|
gen_user,
|
||||||
|
gen_team,
|
||||||
gen_hint,
|
gen_hint,
|
||||||
)
|
)
|
||||||
from CTFd.models import Challenges, Flags, Users
|
from CTFd.models import Challenges, Flags, Users, Teams
|
||||||
from CTFd.utils import text_type
|
from CTFd.utils import text_type
|
||||||
from CTFd.utils.exports import import_ctf, export_ctf
|
from CTFd.utils.exports import import_ctf, export_ctf
|
||||||
import json
|
import json
|
||||||
|
@ -61,6 +62,12 @@ def test_import_ctf():
|
||||||
user_email = user + "@ctfd.io"
|
user_email = user + "@ctfd.io"
|
||||||
gen_user(app.db, name=user, email=user_email)
|
gen_user(app.db, name=user, email=user_email)
|
||||||
|
|
||||||
|
base_team = "team"
|
||||||
|
for x in range(5):
|
||||||
|
team = base_team + str(x)
|
||||||
|
team_email = team + "@ctfd.io"
|
||||||
|
gen_team(app.db, name=team, email=team_email)
|
||||||
|
|
||||||
for x in range(9):
|
for x in range(9):
|
||||||
chal = gen_challenge(app.db, name="chal_name{}".format(x))
|
chal = gen_challenge(app.db, name="chal_name{}".format(x))
|
||||||
gen_flag(app.db, challenge_id=chal.id, content="flag")
|
gen_flag(app.db, challenge_id=chal.id, content="flag")
|
||||||
|
@ -86,7 +93,8 @@ def test_import_ctf():
|
||||||
|
|
||||||
if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("postgres"):
|
if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("postgres"):
|
||||||
# TODO: Dig deeper into why Postgres fails here
|
# TODO: Dig deeper into why Postgres fails here
|
||||||
assert Users.query.count() == 11
|
assert Users.query.count() == 31
|
||||||
|
assert Teams.query.count() == 5
|
||||||
assert Challenges.query.count() == 10
|
assert Challenges.query.count() == 10
|
||||||
assert Flags.query.count() == 10
|
assert Flags.query.count() == 10
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue