update how workspaces are created so tests can utilize functionality

main
Marshall Hallenbeck 2023-03-20 21:14:07 -04:00 committed by Marshall Hallenbeck
parent c049b9f3e2
commit 64102b35db
3 changed files with 40 additions and 30 deletions

View File

@ -3,6 +3,7 @@
import cmd
import logging
import shutil
import sqlite3
import sys
import os
@ -10,7 +11,7 @@ import requests
from terminaltables import AsciiTable
import configparser
from cme.loaders.protocol_loader import protocol_loader
from cme.paths import CONFIG_PATH, WS_PATH
from cme.paths import CONFIG_PATH, WS_PATH, WORKSPACE_DIR
from requests import ConnectionError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.exc import SAWarning
@ -343,7 +344,7 @@ class CMEDBMenu(cmd.Cmd):
print("[-] Error reading cme.conf: {}".format(e))
sys.exit(1)
self.workspace_dir = os.path.expanduser('~/.cme/workspaces')
self.conn = None
self.p_loader = protocol_loader()
self.protocols = self.p_loader.get_protocols()
@ -363,7 +364,7 @@ class CMEDBMenu(cmd.Cmd):
if not proto:
return
proto_db_path = os.path.join(self.workspace_dir, self.workspace, proto + '.db')
proto_db_path = os.path.join(WORKSPACE_DIR, self.workspace, proto + '.db')
if os.path.exists(proto_db_path):
self.conn = create_db_engine(proto_db_path)
db_nav_object = self.p_loader.load_protocol(self.protocols[proto]['nvpath'])
@ -394,38 +395,16 @@ class CMEDBMenu(cmd.Cmd):
if subcommand == 'create':
new_workspace = line.split()[1].strip()
print("[*] Creating workspace '{}'".format(new_workspace))
os.mkdir(os.path.join(self.workspace_dir, new_workspace))
for protocol in self.protocols.keys():
try:
protocol_object = self.p_loader.load_protocol(self.protocols[protocol]['dbpath'])
except KeyError:
continue
proto_db_path = os.path.join(self.workspace_dir, new_workspace, protocol + '.db')
if not os.path.exists(proto_db_path):
print('[*] Initializing {} protocol database'.format(protocol.upper()))
conn = sqlite3.connect(proto_db_path)
c = conn.cursor()
# try to prevent some weird sqlite I/O errors
c.execute('PRAGMA journal_mode = OFF')
c.execute('PRAGMA foreign_keys = 1')
getattr(protocol_object, 'database').db_schema(c)
# commit the changes and close everything off
conn.commit()
conn.close()
self.create_workspace(new_workspace, self.p_loader, self.protocols)
self.do_workspace(new_workspace)
elif subcommand == 'list':
print("[*] Enumerating Workspaces")
for workspace in os.listdir(os.path.join(self.workspace_dir)):
for workspace in os.listdir(os.path.join(WORKSPACE_DIR)):
if workspace == self.workspace:
print("==> "+workspace)
else:
print(workspace)
elif os.path.exists(os.path.join(self.workspace_dir, line)):
elif os.path.exists(os.path.join(WORKSPACE_DIR, line)):
self.config.set('CME', 'workspace', line)
self.write_configfile()
self.workspace = line
@ -446,6 +425,36 @@ class CMEDBMenu(cmd.Cmd):
"""
print_help(help_string)
def create_workspace(workspace_name, p_loader, protocols):
os.mkdir(os.path.join(WORKSPACE_DIR, workspace_name))
for protocol in protocols.keys():
try:
protocol_object = p_loader.load_protocol(protocols[protocol]['dbpath'])
except KeyError:
continue
proto_db_path = os.path.join(WORKSPACE_DIR, workspace_name, protocol + '.db')
if not os.path.exists(proto_db_path):
print('[*] Initializing {} protocol database'.format(protocol.upper()))
conn = sqlite3.connect(proto_db_path)
c = conn.cursor()
# try to prevent some weird sqlite I/O errors
c.execute('PRAGMA journal_mode = OFF')
c.execute('PRAGMA foreign_keys = 1')
getattr(protocol_object, 'database').db_schema(c)
# commit the changes and close everything off
conn.commit()
conn.close()
def delete_workspace(workspace_name):
shutil.rmtree(os.path.join(WORKSPACE_DIR, workspace_name))
def initialize_db(logger):
if not os.path.exists(os.path.join(WS_PATH, 'default')):
logger.info('Creating default workspace')

View File

@ -228,7 +228,7 @@ def main():
protocol_object = getattr(p_loader.load_protocol(protocol_path), args.protocol)
logging.debug(f"Protocol Object: {protocol_object}")
protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), 'database')
logging.debug(f"Protocol DB Object: {protocol_object}")
logging.debug(f"Protocol DB Object: {protocol_db_object}")
db_path = os.path.join(CME_PATH, 'workspaces', current_workspace, args.protocol + '.db')
logging.debug(f"DB Path: {db_path}")

View File

@ -9,4 +9,5 @@ if hasattr(sys, 'getandroidapilevel'):
TMP_PATH = os.path.join('/data','data', 'com.termux', 'files', 'usr', 'tmp', 'cme_hosted')
WS_PATH = os.path.join(CME_PATH, 'workspaces')
CERT_PATH = os.path.join(CME_PATH, 'cme.pem')
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')
WORKSPACE_DIR = os.path.join(CME_PATH, 'workspaces')