refactor: centralize shared path variables and improve cmedb intialization

main
Marshall Hallenbeck 2023-03-02 11:01:29 -05:00 committed by Marshall Hallenbeck
parent 23d8d588e8
commit 50a74951c8
4 changed files with 59 additions and 48 deletions

View File

@ -10,6 +10,7 @@ from time import sleep
from terminaltables import AsciiTable from terminaltables import AsciiTable
import configparser import configparser
from cme.loaders.protocol_loader import protocol_loader from cme.loaders.protocol_loader import protocol_loader
from cme.paths import CONFIG_PATH, WS_PATH
from requests import ConnectionError from requests import ConnectionError
import csv import csv
@ -389,15 +390,48 @@ class CMEDBMenu(cmd.Cmd):
sys.exit(0) sys.exit(0)
def main(): def initialize_db(logger):
config_path = os.path.expanduser('~/.cme/cme.conf') if not os.path.exists(os.path.join(WS_PATH, 'default')):
logger.info('Creating default workspace')
os.mkdir(os.path.join(WS_PATH, 'default'))
if not os.path.exists(config_path): p_loader = protocol_loader()
protocols = p_loader.get_protocols()
for protocol in protocols.keys():
try:
protocol_object = p_loader.load_protocol(protocols[protocol]['dbpath'])
except KeyError:
continue
proto_db_path = os.path.join(WS_PATH, 'default', protocol + '.db')
if not os.path.exists(proto_db_path):
logger.info('Initializing {} protocol database'.format(protocol.upper()))
conn = sqlite3.connect(proto_db_path)
c = conn.cursor()
# try to prevent some of the weird sqlite I/O errors
c.execute('PRAGMA journal_mode = OFF')
c.execute('PRAGMA foreign_keys = 1')
getattr(protocol_object, 'database').db_schema(c)
# commit the changes and close everything off
conn.commit()
conn.close()
def main():
if not os.path.exists(CONFIG_PATH):
print("[-] Unable to find config file") print("[-] Unable to find config file")
sys.exit(1) sys.exit(1)
try: try:
cmedbnav = CMEDBMenu(config_path) cmedbnav = CMEDBMenu(CONFIG_PATH)
cmedbnav.cmdloop() cmedbnav.cmdloop()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
if __name__ == "__main__":
main()

View File

@ -13,9 +13,11 @@ from cme.loaders.module_loader import module_loader
from cme.servers.http import CMEServer from cme.servers.http import CMEServer
from cme.first_run import first_run_setup from cme.first_run import first_run_setup
from cme.context import Context from cme.context import Context
from cme.paths import CME_PATH
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pprint import pformat from pprint import pformat
from decimal import Decimal from decimal import Decimal
import time
import asyncio import asyncio
import aioconsole import aioconsole
import functools import functools
@ -116,6 +118,7 @@ async def start_threadpool(protocol_obj, args, db, targets, jitter):
monitor_task.cancel() monitor_task.cancel()
pool.shutdown(wait=True) pool.shutdown(wait=True)
def main(): def main():
first_run_setup(logger) first_run_setup(logger)
@ -128,10 +131,8 @@ def main():
except: except:
sys.exit(1) sys.exit(1)
cme_path = os.path.expanduser('~/.cme')
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(os.path.join(cme_path, 'cme.conf')) config.read(os.path.join(CME_PATH, 'cme.conf'))
module = None module = None
module_server = None module_server = None
@ -197,7 +198,7 @@ def main():
protocol_object = getattr(p_loader.load_protocol(protocol_path), args.protocol) protocol_object = getattr(p_loader.load_protocol(protocol_path), args.protocol)
protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), 'database') protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), 'database')
db_path = os.path.join(cme_path, 'workspaces', current_workspace, args.protocol + '.db') db_path = os.path.join(CME_PATH, 'workspaces', current_workspace, args.protocol + '.db')
# set the database connection to autocommit w/ isolation level # set the database connection to autocommit w/ isolation level
db_connection = sqlite3.connect(db_path, check_same_thread=False) db_connection = sqlite3.connect(db_path, check_same_thread=False)
db_connection.text_factory = str db_connection.text_factory = str

View File

@ -8,23 +8,13 @@ import shutil
import cme import cme
import configparser import configparser
from configparser import ConfigParser, NoSectionError, NoOptionError from configparser import ConfigParser, NoSectionError, NoOptionError
from cme.loaders.protocol_loader import protocol_loader from cme.paths import CME_PATH, CONFIG_PATH, CERT_PATH, TMP_PATH
from cmedb import initialize_db
from subprocess import check_output, PIPE from subprocess import check_output, PIPE
import sys import sys
CME_PATH = os.path.expanduser('~/.cme')
TMP_PATH = os.path.join('/tmp', 'cme_hosted')
if os.name == 'nt':
TMP_PATH = os.getenv('LOCALAPPDATA') + '\\Temp\\cme_hosted'
if hasattr(sys, 'getandroidapilevel'):
TMP_PATH = os.path.join('/data','data', 'com.termux', 'files', 'usr', 'tmp', 'cme_hosted')
WS_PATH = os.path.join(CME_PATH, 'workspaces')
CERT_PATH = os.path.join(CME_PATH, 'cme.pem')
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')
def first_run_setup(logger): def first_run_setup(logger):
if not os.path.exists(TMP_PATH): if not os.path.exists(TMP_PATH):
os.mkdir(TMP_PATH) os.mkdir(TMP_PATH)
@ -36,36 +26,10 @@ def first_run_setup(logger):
folders = ['logs', 'modules', 'protocols', 'workspaces', 'obfuscated_scripts', 'screenshots'] folders = ['logs', 'modules', 'protocols', 'workspaces', 'obfuscated_scripts', 'screenshots']
for folder in folders: for folder in folders:
if not os.path.exists(os.path.join(CME_PATH, folder)): if not os.path.exists(os.path.join(CME_PATH, folder)):
logger.info("Creating missing folder {}".format(folder))
os.mkdir(os.path.join(CME_PATH, folder)) os.mkdir(os.path.join(CME_PATH, folder))
if not os.path.exists(os.path.join(WS_PATH, 'default')): initialize_db(logger)
logger.info('Creating default workspace')
os.mkdir(os.path.join(WS_PATH, 'default'))
p_loader = protocol_loader()
protocols = p_loader.get_protocols()
for protocol in protocols.keys():
try:
protocol_object = p_loader.load_protocol(protocols[protocol]['dbpath'])
except KeyError:
continue
proto_db_path = os.path.join(WS_PATH, 'default', protocol + '.db')
if not os.path.exists(proto_db_path):
logger.info('Initializing {} protocol database'.format(protocol.upper()))
conn = sqlite3.connect(proto_db_path)
c = conn.cursor()
# try to prevent some of the weird sqlite I/O errors
c.execute('PRAGMA journal_mode = OFF')
c.execute('PRAGMA foreign_keys = 1')
getattr(protocol_object, 'database').db_schema(c)
# commit the changes and close everything off
conn.commit()
conn.close()
if not os.path.exists(CONFIG_PATH): if not os.path.exists(CONFIG_PATH):
logger.info('Copying default configuration file') logger.info('Copying default configuration file')

View File

@ -0,0 +1,12 @@
import os
import sys
CME_PATH = os.path.expanduser('~/.cme')
TMP_PATH = os.path.join('/tmp', 'cme_hosted')
if os.name == 'nt':
TMP_PATH = os.getenv('LOCALAPPDATA') + '\\Temp\\cme_hosted'
if hasattr(sys, 'getandroidapilevel'):
TMP_PATH = os.path.join('/data','data', 'com.termux', 'files', 'usr', 'tmp', 'cme_hosted')
WS_PATH = os.path.join(CME_PATH, 'workspaces')
CERT_PATH = os.path.join(CME_PATH, 'cme.pem')
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')