mirror of https://github.com/JohnHammond/CTFd.git
Fix Uploaders to work with imports/exports (#749)
* Refactor Uploaders to work better with imports/exports * Get S3 uploader working properly with imports/exports * cache pip in travisselenium-screenshot-testing
parent
310475d739
commit
49ed27cfd6
|
@ -1,4 +1,5 @@
|
|||
language: python
|
||||
cache: pip
|
||||
services:
|
||||
- mysql
|
||||
- postgresql
|
||||
|
|
|
@ -6,8 +6,6 @@ from flask import Flask
|
|||
from werkzeug.contrib.fixers import ProxyFix
|
||||
from jinja2 import FileSystemLoader
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy_utils import database_exists, create_database
|
||||
from six.moves import input
|
||||
|
||||
from CTFd import utils
|
||||
|
|
|
@ -158,9 +158,8 @@ class Config(object):
|
|||
|
||||
'''
|
||||
UPLOAD_PROVIDER = os.getenv('UPLOAD_PROVIDER') or 'filesystem'
|
||||
if UPLOAD_PROVIDER == 'filesystem':
|
||||
UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
|
||||
elif UPLOAD_PROVIDER == 's3':
|
||||
UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
|
||||
if UPLOAD_PROVIDER == 's3':
|
||||
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
||||
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
|
||||
AWS_S3_BUCKET = os.getenv('AWS_S3_BUCKET')
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from CTFd.utils import get_app_config, get_config, set_config
|
||||
from CTFd.utils import get_app_config
|
||||
from CTFd.utils.migrations import get_current_revision, create_database, drop_database, upgrade, stamp
|
||||
from CTFd.models import db, get_class_by_tablename
|
||||
from CTFd.utils.uploads import get_uploader
|
||||
from CTFd.models import db
|
||||
from CTFd.cache import cache
|
||||
from datafreeze.format import SERIALIZERS
|
||||
from flask import current_app as app
|
||||
|
@ -12,7 +13,6 @@ import json
|
|||
import os
|
||||
import re
|
||||
import six
|
||||
import shutil
|
||||
import zipfile
|
||||
|
||||
|
||||
|
@ -85,6 +85,9 @@ def export_ctf():
|
|||
backup_zip.writestr('db/alembic_version.json', result_file.read())
|
||||
|
||||
# Backup uploads
|
||||
uploader = get_uploader()
|
||||
uploader.sync()
|
||||
|
||||
upload_folder = os.path.join(os.path.normpath(app.root_path), app.config.get('UPLOAD_FOLDER'))
|
||||
for root, dirs, files in os.walk(upload_folder):
|
||||
for file in files:
|
||||
|
@ -199,7 +202,7 @@ def import_ctf(backup, erase=True):
|
|||
|
||||
# Extracting files
|
||||
files = [f for f in backup.namelist() if f.startswith('uploads/')]
|
||||
upload_folder = app.config.get('UPLOAD_FOLDER')
|
||||
uploader = get_uploader()
|
||||
for f in files:
|
||||
filename = f.split(os.sep, 1)
|
||||
|
||||
|
@ -207,16 +210,7 @@ def import_ctf(backup, erase=True):
|
|||
continue
|
||||
|
||||
filename = filename[1] # Get the second entry in the list (the actual filename)
|
||||
full_path = os.path.join(upload_folder, filename)
|
||||
dirname = os.path.dirname(full_path)
|
||||
|
||||
# Create any parent directories for the file
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
source = backup.open(f)
|
||||
target = open(full_path, "wb")
|
||||
with source, target:
|
||||
shutil.copyfileobj(source, target)
|
||||
uploader.store(fileobj=source, filename=filename)
|
||||
|
||||
cache.clear()
|
||||
|
|
|
@ -13,6 +13,9 @@ class BaseUploader(object):
|
|||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def store(self, fileobj, filename):
|
||||
raise NotImplementedError
|
||||
|
||||
def upload(self, file_obj, filename):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -22,32 +25,36 @@ class BaseUploader(object):
|
|||
def delete(self, filename):
|
||||
raise NotImplementedError
|
||||
|
||||
def sync(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FilesystemUploader(BaseUploader):
|
||||
def __init__(self, base_path=None):
|
||||
super(BaseUploader, self).__init__()
|
||||
self.base_path = base_path or current_app.config.get('UPLOAD_FOLDER')
|
||||
|
||||
def store(self, fileobj, filename):
|
||||
location = os.path.join(self.base_path, filename)
|
||||
directory = os.path.dirname(location)
|
||||
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
with open(location, 'wb') as dst:
|
||||
copyfileobj(fileobj, dst, 16384)
|
||||
|
||||
return filename
|
||||
|
||||
def upload(self, file_obj, filename):
|
||||
if len(filename) == 0:
|
||||
raise Exception('Empty filenames cannot be used')
|
||||
|
||||
filename = secure_filename(filename)
|
||||
md5hash = hashlib.md5(os.urandom(64)).hexdigest()
|
||||
file_path = os.path.join(md5hash, filename)
|
||||
|
||||
if not os.path.exists(os.path.join(self.base_path, md5hash)):
|
||||
os.makedirs(os.path.join(self.base_path, md5hash))
|
||||
|
||||
location = os.path.join(self.base_path, md5hash, filename)
|
||||
|
||||
dst_file = open(location, 'wb')
|
||||
try:
|
||||
copyfileobj(file_obj, dst_file, 16384)
|
||||
finally:
|
||||
dst_file.close()
|
||||
|
||||
key = os.path.join(md5hash, filename)
|
||||
return key
|
||||
return self.store(file_obj, file_path)
|
||||
|
||||
def download(self, filename):
|
||||
return send_file(safe_join(self.base_path, filename))
|
||||
|
@ -58,6 +65,9 @@ class FilesystemUploader(BaseUploader):
|
|||
return True
|
||||
return False
|
||||
|
||||
def sync(self):
|
||||
pass
|
||||
|
||||
|
||||
class S3Uploader(BaseUploader):
|
||||
def __init__(self):
|
||||
|
@ -81,6 +91,10 @@ class S3Uploader(BaseUploader):
|
|||
if c in string.ascii_letters + string.digits + '-' + '_' + '.':
|
||||
return True
|
||||
|
||||
def store(self, fileobj, filename):
|
||||
self.s3.upload_fileobj(fileobj, self.bucket, filename)
|
||||
return filename
|
||||
|
||||
def upload(self, file_obj, filename):
|
||||
filename = filter(self._clean_filename, secure_filename(filename).replace(' ', '_'))
|
||||
if len(filename) <= 0:
|
||||
|
@ -105,3 +119,17 @@ class S3Uploader(BaseUploader):
|
|||
def delete(self, filename):
|
||||
self.s3.delete_object(Bucket=self.bucket, Key=filename)
|
||||
return True
|
||||
|
||||
def sync(self):
|
||||
local_folder = current_app.config.get('UPLOAD_FOLDER')
|
||||
bucket_list = self.s3.list_objects(Bucket=self.bucket)['Contents']
|
||||
|
||||
for s3_key in bucket_list:
|
||||
s3_object = s3_key['Key']
|
||||
|
||||
local_path = os.path.join(local_folder, s3_object)
|
||||
directory = os.path.dirname(local_path)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
self.s3.download_file(self.bucket, s3_object, local_path)
|
||||
|
|
Loading…
Reference in New Issue