Fix Uploaders to work with imports/exports (#749)

* Refactor Uploaders to work better with imports/exports
* Get S3 uploader working properly with imports/exports
* cache pip in travis
selenium-screenshot-testing
Kevin Chung 2018-11-23 06:10:33 -05:00 committed by GitHub
parent 310475d739
commit 49ed27cfd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 52 additions and 32 deletions

View File

@ -1,4 +1,5 @@
language: python
cache: pip
services:
- mysql
- postgresql

View File

@ -6,8 +6,6 @@ from flask import Flask
from werkzeug.contrib.fixers import ProxyFix
from jinja2 import FileSystemLoader
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.url import make_url
from sqlalchemy_utils import database_exists, create_database
from six.moves import input
from CTFd import utils

View File

@ -158,9 +158,8 @@ class Config(object):
'''
UPLOAD_PROVIDER = os.getenv('UPLOAD_PROVIDER') or 'filesystem'
if UPLOAD_PROVIDER == 'filesystem':
UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
elif UPLOAD_PROVIDER == 's3':
UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
if UPLOAD_PROVIDER == 's3':
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
AWS_S3_BUCKET = os.getenv('AWS_S3_BUCKET')

View File

@ -1,6 +1,7 @@
from CTFd.utils import get_app_config, get_config, set_config
from CTFd.utils import get_app_config
from CTFd.utils.migrations import get_current_revision, create_database, drop_database, upgrade, stamp
from CTFd.models import db, get_class_by_tablename
from CTFd.utils.uploads import get_uploader
from CTFd.models import db
from CTFd.cache import cache
from datafreeze.format import SERIALIZERS
from flask import current_app as app
@ -12,7 +13,6 @@ import json
import os
import re
import six
import shutil
import zipfile
@ -85,6 +85,9 @@ def export_ctf():
backup_zip.writestr('db/alembic_version.json', result_file.read())
# Backup uploads
uploader = get_uploader()
uploader.sync()
upload_folder = os.path.join(os.path.normpath(app.root_path), app.config.get('UPLOAD_FOLDER'))
for root, dirs, files in os.walk(upload_folder):
for file in files:
@ -199,7 +202,7 @@ def import_ctf(backup, erase=True):
# Extracting files
files = [f for f in backup.namelist() if f.startswith('uploads/')]
upload_folder = app.config.get('UPLOAD_FOLDER')
uploader = get_uploader()
for f in files:
filename = f.split(os.sep, 1)
@ -207,16 +210,7 @@ def import_ctf(backup, erase=True):
continue
filename = filename[1] # Get the second entry in the list (the actual filename)
full_path = os.path.join(upload_folder, filename)
dirname = os.path.dirname(full_path)
# Create any parent directories for the file
if not os.path.exists(dirname):
os.makedirs(dirname)
source = backup.open(f)
target = open(full_path, "wb")
with source, target:
shutil.copyfileobj(source, target)
uploader.store(fileobj=source, filename=filename)
cache.clear()

View File

@ -13,6 +13,9 @@ class BaseUploader(object):
def __init__(self):
raise NotImplementedError
def store(self, fileobj, filename):
raise NotImplementedError
def upload(self, file_obj, filename):
raise NotImplementedError
@ -22,32 +25,36 @@ class BaseUploader(object):
def delete(self, filename):
raise NotImplementedError
def sync(self):
raise NotImplementedError
class FilesystemUploader(BaseUploader):
def __init__(self, base_path=None):
super(BaseUploader, self).__init__()
self.base_path = base_path or current_app.config.get('UPLOAD_FOLDER')
def store(self, fileobj, filename):
location = os.path.join(self.base_path, filename)
directory = os.path.dirname(location)
if not os.path.exists(directory):
os.makedirs(directory)
with open(location, 'wb') as dst:
copyfileobj(fileobj, dst, 16384)
return filename
def upload(self, file_obj, filename):
if len(filename) == 0:
raise Exception('Empty filenames cannot be used')
filename = secure_filename(filename)
md5hash = hashlib.md5(os.urandom(64)).hexdigest()
file_path = os.path.join(md5hash, filename)
if not os.path.exists(os.path.join(self.base_path, md5hash)):
os.makedirs(os.path.join(self.base_path, md5hash))
location = os.path.join(self.base_path, md5hash, filename)
dst_file = open(location, 'wb')
try:
copyfileobj(file_obj, dst_file, 16384)
finally:
dst_file.close()
key = os.path.join(md5hash, filename)
return key
return self.store(file_obj, file_path)
def download(self, filename):
return send_file(safe_join(self.base_path, filename))
@ -58,6 +65,9 @@ class FilesystemUploader(BaseUploader):
return True
return False
def sync(self):
pass
class S3Uploader(BaseUploader):
def __init__(self):
@ -81,6 +91,10 @@ class S3Uploader(BaseUploader):
if c in string.ascii_letters + string.digits + '-' + '_' + '.':
return True
def store(self, fileobj, filename):
self.s3.upload_fileobj(fileobj, self.bucket, filename)
return filename
def upload(self, file_obj, filename):
filename = filter(self._clean_filename, secure_filename(filename).replace(' ', '_'))
if len(filename) <= 0:
@ -105,3 +119,17 @@ class S3Uploader(BaseUploader):
def delete(self, filename):
self.s3.delete_object(Bucket=self.bucket, Key=filename)
return True
def sync(self):
local_folder = current_app.config.get('UPLOAD_FOLDER')
bucket_list = self.s3.list_objects(Bucket=self.bucket)['Contents']
for s3_key in bucket_list:
s3_object = s3_key['Key']
local_path = os.path.join(local_folder, s3_object)
directory = os.path.dirname(local_path)
if not os.path.exists(directory):
os.makedirs(directory)
self.s3.download_file(self.bucket, s3_object, local_path)