From 447f71d41b46912e383b4045e9395063421fd939 Mon Sep 17 00:00:00 2001 From: Kevin Chung Date: Fri, 21 Jun 2019 12:35:55 -0400 Subject: [PATCH] Disable foreign keys during import (#1031) * Temporarily disable foreign keys in MySQL, MariaDB, and Postgres during `import_ctf()` * Likely also disables SQLite but SQLite is permissive about foreign keys to begin with so irrelevant. --- CTFd/utils/exports/__init__.py | 16 ++++++++++++++++ tests/utils/test_exports.py | 12 ++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/CTFd/utils/exports/__init__.py b/CTFd/utils/exports/__init__.py index d994a69..14cf014 100644 --- a/CTFd/utils/exports/__init__.py +++ b/CTFd/utils/exports/__init__.py @@ -143,6 +143,14 @@ def import_ctf(backup, erase=True): if info.file_size > max_content_length: raise zipfile.LargeZipFile + try: + if postgres: + side_db.query("SET session_replication_role=replica;") + else: + side_db.query("SET FOREIGN_KEY_CHECKS=0;") + except Exception: + print("Failed to disable foreign key checks. Continuing.") + first = [ "db/teams.json", "db/users.json", @@ -280,6 +288,14 @@ def import_ctf(backup, erase=True): app.db.create_all() stamp_latest_revision() + try: + if postgres: + side_db.query("SET session_replication_role=DEFAULT;") + else: + side_db.query("SET FOREIGN_KEY_CHECKS=1;") + except Exception: + print("Failed to enable foreign key checks. Continuing.") + # Invalidate all cached data cache.clear() diff --git a/tests/utils/test_exports.py b/tests/utils/test_exports.py index 0859e85..98a592f 100644 --- a/tests/utils/test_exports.py +++ b/tests/utils/test_exports.py @@ -7,9 +7,10 @@ from tests.helpers import ( gen_challenge, gen_flag, gen_user, + gen_team, gen_hint, ) -from CTFd.models import Challenges, Flags, Users +from CTFd.models import Challenges, Flags, Users, Teams from CTFd.utils import text_type from CTFd.utils.exports import import_ctf, export_ctf import json @@ -61,6 +62,12 @@ def test_import_ctf(): user_email = user + "@ctfd.io" gen_user(app.db, name=user, email=user_email) + base_team = "team" + for x in range(5): + team = base_team + str(x) + team_email = team + "@ctfd.io" + gen_team(app.db, name=team, email=team_email) + for x in range(9): chal = gen_challenge(app.db, name="chal_name{}".format(x)) gen_flag(app.db, challenge_id=chal.id, content="flag") @@ -86,7 +93,8 @@ def test_import_ctf(): if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("postgres"): # TODO: Dig deeper into why Postgres fails here - assert Users.query.count() == 11 + assert Users.query.count() == 31 + assert Teams.query.count() == 5 assert Challenges.query.count() == 10 assert Flags.query.count() == 10