From 49ed27cfd649ac505404e48e7d014faf7e760b7a Mon Sep 17 00:00:00 2001 From: Kevin Chung Date: Fri, 23 Nov 2018 06:10:33 -0500 Subject: [PATCH] Fix Uploaders to work with imports/exports (#749) * Refactor Uploaders to work better with imports/exports * Get S3 uploader working properly with imports/exports * cache pip in travis --- .travis.yml | 1 + CTFd/__init__.py | 2 -- CTFd/config.py | 5 ++- CTFd/utils/exports/__init__.py | 22 +++++--------- CTFd/utils/uploads/uploaders.py | 54 +++++++++++++++++++++++++-------- 5 files changed, 52 insertions(+), 32 deletions(-) diff --git a/.travis.yml b/.travis.yml index e2aef71..9e28644 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,5 @@ language: python +cache: pip services: - mysql - postgresql diff --git a/CTFd/__init__.py b/CTFd/__init__.py index 0039385..c2212de 100644 --- a/CTFd/__init__.py +++ b/CTFd/__init__.py @@ -6,8 +6,6 @@ from flask import Flask from werkzeug.contrib.fixers import ProxyFix from jinja2 import FileSystemLoader from jinja2.sandbox import SandboxedEnvironment -from sqlalchemy.engine.url import make_url -from sqlalchemy_utils import database_exists, create_database from six.moves import input from CTFd import utils diff --git a/CTFd/config.py b/CTFd/config.py index 10f3fc7..8b3ef62 100644 --- a/CTFd/config.py +++ b/CTFd/config.py @@ -158,9 +158,8 @@ class Config(object): ''' UPLOAD_PROVIDER = os.getenv('UPLOAD_PROVIDER') or 'filesystem' - if UPLOAD_PROVIDER == 'filesystem': - UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') - elif UPLOAD_PROVIDER == 's3': + UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') + if UPLOAD_PROVIDER == 's3': AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID') AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY') AWS_S3_BUCKET = os.getenv('AWS_S3_BUCKET') diff --git a/CTFd/utils/exports/__init__.py b/CTFd/utils/exports/__init__.py index a1ef5e5..7caa28e 100644 --- a/CTFd/utils/exports/__init__.py +++ b/CTFd/utils/exports/__init__.py @@ -1,6 +1,7 @@ -from CTFd.utils import get_app_config, get_config, set_config +from CTFd.utils import get_app_config from CTFd.utils.migrations import get_current_revision, create_database, drop_database, upgrade, stamp -from CTFd.models import db, get_class_by_tablename +from CTFd.utils.uploads import get_uploader +from CTFd.models import db from CTFd.cache import cache from datafreeze.format import SERIALIZERS from flask import current_app as app @@ -12,7 +13,6 @@ import json import os import re import six -import shutil import zipfile @@ -85,6 +85,9 @@ def export_ctf(): backup_zip.writestr('db/alembic_version.json', result_file.read()) # Backup uploads + uploader = get_uploader() + uploader.sync() + upload_folder = os.path.join(os.path.normpath(app.root_path), app.config.get('UPLOAD_FOLDER')) for root, dirs, files in os.walk(upload_folder): for file in files: @@ -199,7 +202,7 @@ def import_ctf(backup, erase=True): # Extracting files files = [f for f in backup.namelist() if f.startswith('uploads/')] - upload_folder = app.config.get('UPLOAD_FOLDER') + uploader = get_uploader() for f in files: filename = f.split(os.sep, 1) @@ -207,16 +210,7 @@ def import_ctf(backup, erase=True): continue filename = filename[1] # Get the second entry in the list (the actual filename) - full_path = os.path.join(upload_folder, filename) - dirname = os.path.dirname(full_path) - - # Create any parent directories for the file - if not os.path.exists(dirname): - os.makedirs(dirname) - source = backup.open(f) - target = open(full_path, "wb") - with source, target: - shutil.copyfileobj(source, target) + uploader.store(fileobj=source, filename=filename) cache.clear() diff --git a/CTFd/utils/uploads/uploaders.py b/CTFd/utils/uploads/uploaders.py index 06a5d1c..31eca7f 100644 --- a/CTFd/utils/uploads/uploaders.py +++ b/CTFd/utils/uploads/uploaders.py @@ -13,6 +13,9 @@ class BaseUploader(object): def __init__(self): raise NotImplementedError + def store(self, fileobj, filename): + raise NotImplementedError + def upload(self, file_obj, filename): raise NotImplementedError @@ -22,32 +25,36 @@ class BaseUploader(object): def delete(self, filename): raise NotImplementedError + def sync(self): + raise NotImplementedError + class FilesystemUploader(BaseUploader): def __init__(self, base_path=None): super(BaseUploader, self).__init__() self.base_path = base_path or current_app.config.get('UPLOAD_FOLDER') + def store(self, fileobj, filename): + location = os.path.join(self.base_path, filename) + directory = os.path.dirname(location) + + if not os.path.exists(directory): + os.makedirs(directory) + + with open(location, 'wb') as dst: + copyfileobj(fileobj, dst, 16384) + + return filename + def upload(self, file_obj, filename): if len(filename) == 0: raise Exception('Empty filenames cannot be used') filename = secure_filename(filename) md5hash = hashlib.md5(os.urandom(64)).hexdigest() + file_path = os.path.join(md5hash, filename) - if not os.path.exists(os.path.join(self.base_path, md5hash)): - os.makedirs(os.path.join(self.base_path, md5hash)) - - location = os.path.join(self.base_path, md5hash, filename) - - dst_file = open(location, 'wb') - try: - copyfileobj(file_obj, dst_file, 16384) - finally: - dst_file.close() - - key = os.path.join(md5hash, filename) - return key + return self.store(file_obj, file_path) def download(self, filename): return send_file(safe_join(self.base_path, filename)) @@ -58,6 +65,9 @@ class FilesystemUploader(BaseUploader): return True return False + def sync(self): + pass + class S3Uploader(BaseUploader): def __init__(self): @@ -81,6 +91,10 @@ class S3Uploader(BaseUploader): if c in string.ascii_letters + string.digits + '-' + '_' + '.': return True + def store(self, fileobj, filename): + self.s3.upload_fileobj(fileobj, self.bucket, filename) + return filename + def upload(self, file_obj, filename): filename = filter(self._clean_filename, secure_filename(filename).replace(' ', '_')) if len(filename) <= 0: @@ -105,3 +119,17 @@ class S3Uploader(BaseUploader): def delete(self, filename): self.s3.delete_object(Bucket=self.bucket, Key=filename) return True + + def sync(self): + local_folder = current_app.config.get('UPLOAD_FOLDER') + bucket_list = self.s3.list_objects(Bucket=self.bucket)['Contents'] + + for s3_key in bucket_list: + s3_object = s3_key['Key'] + + local_path = os.path.join(local_folder, s3_object) + directory = os.path.dirname(local_path) + if not os.path.exists(directory): + os.makedirs(directory) + + self.s3.download_file(self.bucket, s3_object, local_path)