refactor: centralize shared path variables and improve cmedb intialization
parent
23d8d588e8
commit
50a74951c8
42
cme/cmedb.py
42
cme/cmedb.py
|
@ -10,6 +10,7 @@ from time import sleep
|
||||||
from terminaltables import AsciiTable
|
from terminaltables import AsciiTable
|
||||||
import configparser
|
import configparser
|
||||||
from cme.loaders.protocol_loader import protocol_loader
|
from cme.loaders.protocol_loader import protocol_loader
|
||||||
|
from cme.paths import CONFIG_PATH, WS_PATH
|
||||||
from requests import ConnectionError
|
from requests import ConnectionError
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
|
@ -389,15 +390,48 @@ class CMEDBMenu(cmd.Cmd):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def initialize_db(logger):
|
||||||
config_path = os.path.expanduser('~/.cme/cme.conf')
|
if not os.path.exists(os.path.join(WS_PATH, 'default')):
|
||||||
|
logger.info('Creating default workspace')
|
||||||
|
os.mkdir(os.path.join(WS_PATH, 'default'))
|
||||||
|
|
||||||
if not os.path.exists(config_path):
|
p_loader = protocol_loader()
|
||||||
|
protocols = p_loader.get_protocols()
|
||||||
|
for protocol in protocols.keys():
|
||||||
|
try:
|
||||||
|
protocol_object = p_loader.load_protocol(protocols[protocol]['dbpath'])
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
proto_db_path = os.path.join(WS_PATH, 'default', protocol + '.db')
|
||||||
|
|
||||||
|
if not os.path.exists(proto_db_path):
|
||||||
|
logger.info('Initializing {} protocol database'.format(protocol.upper()))
|
||||||
|
conn = sqlite3.connect(proto_db_path)
|
||||||
|
c = conn.cursor()
|
||||||
|
|
||||||
|
# try to prevent some of the weird sqlite I/O errors
|
||||||
|
c.execute('PRAGMA journal_mode = OFF')
|
||||||
|
c.execute('PRAGMA foreign_keys = 1')
|
||||||
|
|
||||||
|
getattr(protocol_object, 'database').db_schema(c)
|
||||||
|
|
||||||
|
# commit the changes and close everything off
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not os.path.exists(CONFIG_PATH):
|
||||||
print("[-] Unable to find config file")
|
print("[-] Unable to find config file")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cmedbnav = CMEDBMenu(config_path)
|
cmedbnav = CMEDBMenu(CONFIG_PATH)
|
||||||
cmedbnav.cmdloop()
|
cmedbnav.cmdloop()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
@ -13,9 +13,11 @@ from cme.loaders.module_loader import module_loader
|
||||||
from cme.servers.http import CMEServer
|
from cme.servers.http import CMEServer
|
||||||
from cme.first_run import first_run_setup
|
from cme.first_run import first_run_setup
|
||||||
from cme.context import Context
|
from cme.context import Context
|
||||||
|
from cme.paths import CME_PATH
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import aioconsole
|
import aioconsole
|
||||||
import functools
|
import functools
|
||||||
|
@ -116,6 +118,7 @@ async def start_threadpool(protocol_obj, args, db, targets, jitter):
|
||||||
monitor_task.cancel()
|
monitor_task.cancel()
|
||||||
pool.shutdown(wait=True)
|
pool.shutdown(wait=True)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
first_run_setup(logger)
|
first_run_setup(logger)
|
||||||
|
|
||||||
|
@ -128,10 +131,8 @@ def main():
|
||||||
except:
|
except:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
cme_path = os.path.expanduser('~/.cme')
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read(os.path.join(cme_path, 'cme.conf'))
|
config.read(os.path.join(CME_PATH, 'cme.conf'))
|
||||||
|
|
||||||
module = None
|
module = None
|
||||||
module_server = None
|
module_server = None
|
||||||
|
@ -197,7 +198,7 @@ def main():
|
||||||
protocol_object = getattr(p_loader.load_protocol(protocol_path), args.protocol)
|
protocol_object = getattr(p_loader.load_protocol(protocol_path), args.protocol)
|
||||||
protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), 'database')
|
protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), 'database')
|
||||||
|
|
||||||
db_path = os.path.join(cme_path, 'workspaces', current_workspace, args.protocol + '.db')
|
db_path = os.path.join(CME_PATH, 'workspaces', current_workspace, args.protocol + '.db')
|
||||||
# set the database connection to autocommit w/ isolation level
|
# set the database connection to autocommit w/ isolation level
|
||||||
db_connection = sqlite3.connect(db_path, check_same_thread=False)
|
db_connection = sqlite3.connect(db_path, check_same_thread=False)
|
||||||
db_connection.text_factory = str
|
db_connection.text_factory = str
|
||||||
|
|
|
@ -8,23 +8,13 @@ import shutil
|
||||||
import cme
|
import cme
|
||||||
import configparser
|
import configparser
|
||||||
from configparser import ConfigParser, NoSectionError, NoOptionError
|
from configparser import ConfigParser, NoSectionError, NoOptionError
|
||||||
from cme.loaders.protocol_loader import protocol_loader
|
from cme.paths import CME_PATH, CONFIG_PATH, CERT_PATH, TMP_PATH
|
||||||
|
from cmedb import initialize_db
|
||||||
from subprocess import check_output, PIPE
|
from subprocess import check_output, PIPE
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
CME_PATH = os.path.expanduser('~/.cme')
|
|
||||||
TMP_PATH = os.path.join('/tmp', 'cme_hosted')
|
|
||||||
if os.name == 'nt':
|
|
||||||
TMP_PATH = os.getenv('LOCALAPPDATA') + '\\Temp\\cme_hosted'
|
|
||||||
if hasattr(sys, 'getandroidapilevel'):
|
|
||||||
TMP_PATH = os.path.join('/data','data', 'com.termux', 'files', 'usr', 'tmp', 'cme_hosted')
|
|
||||||
WS_PATH = os.path.join(CME_PATH, 'workspaces')
|
|
||||||
CERT_PATH = os.path.join(CME_PATH, 'cme.pem')
|
|
||||||
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')
|
|
||||||
|
|
||||||
|
|
||||||
def first_run_setup(logger):
|
def first_run_setup(logger):
|
||||||
|
|
||||||
if not os.path.exists(TMP_PATH):
|
if not os.path.exists(TMP_PATH):
|
||||||
os.mkdir(TMP_PATH)
|
os.mkdir(TMP_PATH)
|
||||||
|
|
||||||
|
@ -36,36 +26,10 @@ def first_run_setup(logger):
|
||||||
folders = ['logs', 'modules', 'protocols', 'workspaces', 'obfuscated_scripts', 'screenshots']
|
folders = ['logs', 'modules', 'protocols', 'workspaces', 'obfuscated_scripts', 'screenshots']
|
||||||
for folder in folders:
|
for folder in folders:
|
||||||
if not os.path.exists(os.path.join(CME_PATH, folder)):
|
if not os.path.exists(os.path.join(CME_PATH, folder)):
|
||||||
|
logger.info("Creating missing folder {}".format(folder))
|
||||||
os.mkdir(os.path.join(CME_PATH, folder))
|
os.mkdir(os.path.join(CME_PATH, folder))
|
||||||
|
|
||||||
if not os.path.exists(os.path.join(WS_PATH, 'default')):
|
initialize_db(logger)
|
||||||
logger.info('Creating default workspace')
|
|
||||||
os.mkdir(os.path.join(WS_PATH, 'default'))
|
|
||||||
|
|
||||||
p_loader = protocol_loader()
|
|
||||||
protocols = p_loader.get_protocols()
|
|
||||||
for protocol in protocols.keys():
|
|
||||||
try:
|
|
||||||
protocol_object = p_loader.load_protocol(protocols[protocol]['dbpath'])
|
|
||||||
except KeyError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
proto_db_path = os.path.join(WS_PATH, 'default', protocol + '.db')
|
|
||||||
|
|
||||||
if not os.path.exists(proto_db_path):
|
|
||||||
logger.info('Initializing {} protocol database'.format(protocol.upper()))
|
|
||||||
conn = sqlite3.connect(proto_db_path)
|
|
||||||
c = conn.cursor()
|
|
||||||
|
|
||||||
# try to prevent some of the weird sqlite I/O errors
|
|
||||||
c.execute('PRAGMA journal_mode = OFF')
|
|
||||||
c.execute('PRAGMA foreign_keys = 1')
|
|
||||||
|
|
||||||
getattr(protocol_object, 'database').db_schema(c)
|
|
||||||
|
|
||||||
# commit the changes and close everything off
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if not os.path.exists(CONFIG_PATH):
|
if not os.path.exists(CONFIG_PATH):
|
||||||
logger.info('Copying default configuration file')
|
logger.info('Copying default configuration file')
|
||||||
|
|
12
cme/paths.py
12
cme/paths.py
|
@ -0,0 +1,12 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
CME_PATH = os.path.expanduser('~/.cme')
|
||||||
|
TMP_PATH = os.path.join('/tmp', 'cme_hosted')
|
||||||
|
if os.name == 'nt':
|
||||||
|
TMP_PATH = os.getenv('LOCALAPPDATA') + '\\Temp\\cme_hosted'
|
||||||
|
if hasattr(sys, 'getandroidapilevel'):
|
||||||
|
TMP_PATH = os.path.join('/data','data', 'com.termux', 'files', 'usr', 'tmp', 'cme_hosted')
|
||||||
|
WS_PATH = os.path.join(CME_PATH, 'workspaces')
|
||||||
|
CERT_PATH = os.path.join(CME_PATH, 'cme.pem')
|
||||||
|
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')
|
Loading…
Reference in New Issue