Disable foreign keys during import (#1031)

* Temporarily disable foreign keys in MySQL, MariaDB, and Postgres during `import_ctf()`
    * Likely also disables SQLite but SQLite is permissive about foreign keys to begin with so irrelevant.
selenium-screenshot-testing
Kevin Chung 2019-06-21 12:35:55 -04:00 committed by GitHub
parent ff0f2c2a0b
commit 447f71d41b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 2 deletions

View File

@ -143,6 +143,14 @@ def import_ctf(backup, erase=True):
if info.file_size > max_content_length: if info.file_size > max_content_length:
raise zipfile.LargeZipFile raise zipfile.LargeZipFile
try:
if postgres:
side_db.query("SET session_replication_role=replica;")
else:
side_db.query("SET FOREIGN_KEY_CHECKS=0;")
except Exception:
print("Failed to disable foreign key checks. Continuing.")
first = [ first = [
"db/teams.json", "db/teams.json",
"db/users.json", "db/users.json",
@ -280,6 +288,14 @@ def import_ctf(backup, erase=True):
app.db.create_all() app.db.create_all()
stamp_latest_revision() stamp_latest_revision()
try:
if postgres:
side_db.query("SET session_replication_role=DEFAULT;")
else:
side_db.query("SET FOREIGN_KEY_CHECKS=1;")
except Exception:
print("Failed to enable foreign key checks. Continuing.")
# Invalidate all cached data # Invalidate all cached data
cache.clear() cache.clear()

View File

@ -7,9 +7,10 @@ from tests.helpers import (
gen_challenge, gen_challenge,
gen_flag, gen_flag,
gen_user, gen_user,
gen_team,
gen_hint, gen_hint,
) )
from CTFd.models import Challenges, Flags, Users from CTFd.models import Challenges, Flags, Users, Teams
from CTFd.utils import text_type from CTFd.utils import text_type
from CTFd.utils.exports import import_ctf, export_ctf from CTFd.utils.exports import import_ctf, export_ctf
import json import json
@ -61,6 +62,12 @@ def test_import_ctf():
user_email = user + "@ctfd.io" user_email = user + "@ctfd.io"
gen_user(app.db, name=user, email=user_email) gen_user(app.db, name=user, email=user_email)
base_team = "team"
for x in range(5):
team = base_team + str(x)
team_email = team + "@ctfd.io"
gen_team(app.db, name=team, email=team_email)
for x in range(9): for x in range(9):
chal = gen_challenge(app.db, name="chal_name{}".format(x)) chal = gen_challenge(app.db, name="chal_name{}".format(x))
gen_flag(app.db, challenge_id=chal.id, content="flag") gen_flag(app.db, challenge_id=chal.id, content="flag")
@ -86,7 +93,8 @@ def test_import_ctf():
if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("postgres"): if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("postgres"):
# TODO: Dig deeper into why Postgres fails here # TODO: Dig deeper into why Postgres fails here
assert Users.query.count() == 11 assert Users.query.count() == 31
assert Teams.query.count() == 5
assert Challenges.query.count() == 10 assert Challenges.query.count() == 10
assert Flags.query.count() == 10 assert Flags.query.count() == 10