refactor: centralize shared path variables and improve cmedb intialization
parent
23d8d588e8
commit
50a74951c8
42
cme/cmedb.py
42
cme/cmedb.py
|
@ -10,6 +10,7 @@ from time import sleep
|
|||
from terminaltables import AsciiTable
|
||||
import configparser
|
||||
from cme.loaders.protocol_loader import protocol_loader
|
||||
from cme.paths import CONFIG_PATH, WS_PATH
|
||||
from requests import ConnectionError
|
||||
import csv
|
||||
|
||||
|
@ -389,15 +390,48 @@ class CMEDBMenu(cmd.Cmd):
|
|||
sys.exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
config_path = os.path.expanduser('~/.cme/cme.conf')
|
||||
def initialize_db(logger):
|
||||
if not os.path.exists(os.path.join(WS_PATH, 'default')):
|
||||
logger.info('Creating default workspace')
|
||||
os.mkdir(os.path.join(WS_PATH, 'default'))
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
p_loader = protocol_loader()
|
||||
protocols = p_loader.get_protocols()
|
||||
for protocol in protocols.keys():
|
||||
try:
|
||||
protocol_object = p_loader.load_protocol(protocols[protocol]['dbpath'])
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
proto_db_path = os.path.join(WS_PATH, 'default', protocol + '.db')
|
||||
|
||||
if not os.path.exists(proto_db_path):
|
||||
logger.info('Initializing {} protocol database'.format(protocol.upper()))
|
||||
conn = sqlite3.connect(proto_db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
# try to prevent some of the weird sqlite I/O errors
|
||||
c.execute('PRAGMA journal_mode = OFF')
|
||||
c.execute('PRAGMA foreign_keys = 1')
|
||||
|
||||
getattr(protocol_object, 'database').db_schema(c)
|
||||
|
||||
# commit the changes and close everything off
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def main():
|
||||
if not os.path.exists(CONFIG_PATH):
|
||||
print("[-] Unable to find config file")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
cmedbnav = CMEDBMenu(config_path)
|
||||
cmedbnav = CMEDBMenu(CONFIG_PATH)
|
||||
cmedbnav.cmdloop()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -13,9 +13,11 @@ from cme.loaders.module_loader import module_loader
|
|||
from cme.servers.http import CMEServer
|
||||
from cme.first_run import first_run_setup
|
||||
from cme.context import Context
|
||||
from cme.paths import CME_PATH
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pprint import pformat
|
||||
from decimal import Decimal
|
||||
import time
|
||||
import asyncio
|
||||
import aioconsole
|
||||
import functools
|
||||
|
@ -116,6 +118,7 @@ async def start_threadpool(protocol_obj, args, db, targets, jitter):
|
|||
monitor_task.cancel()
|
||||
pool.shutdown(wait=True)
|
||||
|
||||
|
||||
def main():
|
||||
first_run_setup(logger)
|
||||
|
||||
|
@ -128,10 +131,8 @@ def main():
|
|||
except:
|
||||
sys.exit(1)
|
||||
|
||||
cme_path = os.path.expanduser('~/.cme')
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(os.path.join(cme_path, 'cme.conf'))
|
||||
config.read(os.path.join(CME_PATH, 'cme.conf'))
|
||||
|
||||
module = None
|
||||
module_server = None
|
||||
|
@ -197,7 +198,7 @@ def main():
|
|||
protocol_object = getattr(p_loader.load_protocol(protocol_path), args.protocol)
|
||||
protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), 'database')
|
||||
|
||||
db_path = os.path.join(cme_path, 'workspaces', current_workspace, args.protocol + '.db')
|
||||
db_path = os.path.join(CME_PATH, 'workspaces', current_workspace, args.protocol + '.db')
|
||||
# set the database connection to autocommit w/ isolation level
|
||||
db_connection = sqlite3.connect(db_path, check_same_thread=False)
|
||||
db_connection.text_factory = str
|
||||
|
|
|
@ -8,23 +8,13 @@ import shutil
|
|||
import cme
|
||||
import configparser
|
||||
from configparser import ConfigParser, NoSectionError, NoOptionError
|
||||
from cme.loaders.protocol_loader import protocol_loader
|
||||
from cme.paths import CME_PATH, CONFIG_PATH, CERT_PATH, TMP_PATH
|
||||
from cmedb import initialize_db
|
||||
from subprocess import check_output, PIPE
|
||||
import sys
|
||||
|
||||
CME_PATH = os.path.expanduser('~/.cme')
|
||||
TMP_PATH = os.path.join('/tmp', 'cme_hosted')
|
||||
if os.name == 'nt':
|
||||
TMP_PATH = os.getenv('LOCALAPPDATA') + '\\Temp\\cme_hosted'
|
||||
if hasattr(sys, 'getandroidapilevel'):
|
||||
TMP_PATH = os.path.join('/data','data', 'com.termux', 'files', 'usr', 'tmp', 'cme_hosted')
|
||||
WS_PATH = os.path.join(CME_PATH, 'workspaces')
|
||||
CERT_PATH = os.path.join(CME_PATH, 'cme.pem')
|
||||
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')
|
||||
|
||||
|
||||
def first_run_setup(logger):
|
||||
|
||||
if not os.path.exists(TMP_PATH):
|
||||
os.mkdir(TMP_PATH)
|
||||
|
||||
|
@ -36,36 +26,10 @@ def first_run_setup(logger):
|
|||
folders = ['logs', 'modules', 'protocols', 'workspaces', 'obfuscated_scripts', 'screenshots']
|
||||
for folder in folders:
|
||||
if not os.path.exists(os.path.join(CME_PATH, folder)):
|
||||
logger.info("Creating missing folder {}".format(folder))
|
||||
os.mkdir(os.path.join(CME_PATH, folder))
|
||||
|
||||
if not os.path.exists(os.path.join(WS_PATH, 'default')):
|
||||
logger.info('Creating default workspace')
|
||||
os.mkdir(os.path.join(WS_PATH, 'default'))
|
||||
|
||||
p_loader = protocol_loader()
|
||||
protocols = p_loader.get_protocols()
|
||||
for protocol in protocols.keys():
|
||||
try:
|
||||
protocol_object = p_loader.load_protocol(protocols[protocol]['dbpath'])
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
proto_db_path = os.path.join(WS_PATH, 'default', protocol + '.db')
|
||||
|
||||
if not os.path.exists(proto_db_path):
|
||||
logger.info('Initializing {} protocol database'.format(protocol.upper()))
|
||||
conn = sqlite3.connect(proto_db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
# try to prevent some of the weird sqlite I/O errors
|
||||
c.execute('PRAGMA journal_mode = OFF')
|
||||
c.execute('PRAGMA foreign_keys = 1')
|
||||
|
||||
getattr(protocol_object, 'database').db_schema(c)
|
||||
|
||||
# commit the changes and close everything off
|
||||
conn.commit()
|
||||
conn.close()
|
||||
initialize_db(logger)
|
||||
|
||||
if not os.path.exists(CONFIG_PATH):
|
||||
logger.info('Copying default configuration file')
|
||||
|
|
12
cme/paths.py
12
cme/paths.py
|
@ -0,0 +1,12 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
CME_PATH = os.path.expanduser('~/.cme')
|
||||
TMP_PATH = os.path.join('/tmp', 'cme_hosted')
|
||||
if os.name == 'nt':
|
||||
TMP_PATH = os.getenv('LOCALAPPDATA') + '\\Temp\\cme_hosted'
|
||||
if hasattr(sys, 'getandroidapilevel'):
|
||||
TMP_PATH = os.path.join('/data','data', 'com.termux', 'files', 'usr', 'tmp', 'cme_hosted')
|
||||
WS_PATH = os.path.join(CME_PATH, 'workspaces')
|
||||
CERT_PATH = os.path.join(CME_PATH, 'cme.pem')
|
||||
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')
|
Loading…
Reference in New Issue