update how workspaces are created so tests can utilize functionality
parent
c049b9f3e2
commit
64102b35db
65
cme/cmedb.py
65
cme/cmedb.py
|
@ -3,6 +3,7 @@
|
|||
|
||||
import cmd
|
||||
import logging
|
||||
import shutil
|
||||
import sqlite3
|
||||
import sys
|
||||
import os
|
||||
|
@ -10,7 +11,7 @@ import requests
|
|||
from terminaltables import AsciiTable
|
||||
import configparser
|
||||
from cme.loaders.protocol_loader import protocol_loader
|
||||
from cme.paths import CONFIG_PATH, WS_PATH
|
||||
from cme.paths import CONFIG_PATH, WS_PATH, WORKSPACE_DIR
|
||||
from requests import ConnectionError
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.exc import SAWarning
|
||||
|
@ -343,7 +344,7 @@ class CMEDBMenu(cmd.Cmd):
|
|||
print("[-] Error reading cme.conf: {}".format(e))
|
||||
sys.exit(1)
|
||||
|
||||
self.workspace_dir = os.path.expanduser('~/.cme/workspaces')
|
||||
|
||||
self.conn = None
|
||||
self.p_loader = protocol_loader()
|
||||
self.protocols = self.p_loader.get_protocols()
|
||||
|
@ -363,7 +364,7 @@ class CMEDBMenu(cmd.Cmd):
|
|||
if not proto:
|
||||
return
|
||||
|
||||
proto_db_path = os.path.join(self.workspace_dir, self.workspace, proto + '.db')
|
||||
proto_db_path = os.path.join(WORKSPACE_DIR, self.workspace, proto + '.db')
|
||||
if os.path.exists(proto_db_path):
|
||||
self.conn = create_db_engine(proto_db_path)
|
||||
db_nav_object = self.p_loader.load_protocol(self.protocols[proto]['nvpath'])
|
||||
|
@ -394,38 +395,16 @@ class CMEDBMenu(cmd.Cmd):
|
|||
if subcommand == 'create':
|
||||
new_workspace = line.split()[1].strip()
|
||||
print("[*] Creating workspace '{}'".format(new_workspace))
|
||||
os.mkdir(os.path.join(self.workspace_dir, new_workspace))
|
||||
|
||||
for protocol in self.protocols.keys():
|
||||
try:
|
||||
protocol_object = self.p_loader.load_protocol(self.protocols[protocol]['dbpath'])
|
||||
except KeyError:
|
||||
continue
|
||||
proto_db_path = os.path.join(self.workspace_dir, new_workspace, protocol + '.db')
|
||||
|
||||
if not os.path.exists(proto_db_path):
|
||||
print('[*] Initializing {} protocol database'.format(protocol.upper()))
|
||||
conn = sqlite3.connect(proto_db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
# try to prevent some weird sqlite I/O errors
|
||||
c.execute('PRAGMA journal_mode = OFF')
|
||||
c.execute('PRAGMA foreign_keys = 1')
|
||||
|
||||
getattr(protocol_object, 'database').db_schema(c)
|
||||
|
||||
# commit the changes and close everything off
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self.create_workspace(new_workspace, self.p_loader, self.protocols)
|
||||
self.do_workspace(new_workspace)
|
||||
elif subcommand == 'list':
|
||||
print("[*] Enumerating Workspaces")
|
||||
for workspace in os.listdir(os.path.join(self.workspace_dir)):
|
||||
for workspace in os.listdir(os.path.join(WORKSPACE_DIR)):
|
||||
if workspace == self.workspace:
|
||||
print("==> "+workspace)
|
||||
else:
|
||||
print(workspace)
|
||||
elif os.path.exists(os.path.join(self.workspace_dir, line)):
|
||||
elif os.path.exists(os.path.join(WORKSPACE_DIR, line)):
|
||||
self.config.set('CME', 'workspace', line)
|
||||
self.write_configfile()
|
||||
self.workspace = line
|
||||
|
@ -446,6 +425,36 @@ class CMEDBMenu(cmd.Cmd):
|
|||
"""
|
||||
print_help(help_string)
|
||||
|
||||
def create_workspace(workspace_name, p_loader, protocols):
|
||||
os.mkdir(os.path.join(WORKSPACE_DIR, workspace_name))
|
||||
|
||||
for protocol in protocols.keys():
|
||||
try:
|
||||
protocol_object = p_loader.load_protocol(protocols[protocol]['dbpath'])
|
||||
except KeyError:
|
||||
continue
|
||||
proto_db_path = os.path.join(WORKSPACE_DIR, workspace_name, protocol + '.db')
|
||||
|
||||
if not os.path.exists(proto_db_path):
|
||||
print('[*] Initializing {} protocol database'.format(protocol.upper()))
|
||||
conn = sqlite3.connect(proto_db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
# try to prevent some weird sqlite I/O errors
|
||||
c.execute('PRAGMA journal_mode = OFF')
|
||||
c.execute('PRAGMA foreign_keys = 1')
|
||||
|
||||
getattr(protocol_object, 'database').db_schema(c)
|
||||
|
||||
# commit the changes and close everything off
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def delete_workspace(workspace_name):
|
||||
shutil.rmtree(os.path.join(WORKSPACE_DIR, workspace_name))
|
||||
|
||||
|
||||
def initialize_db(logger):
|
||||
if not os.path.exists(os.path.join(WS_PATH, 'default')):
|
||||
logger.info('Creating default workspace')
|
||||
|
|
|
@ -228,7 +228,7 @@ def main():
|
|||
protocol_object = getattr(p_loader.load_protocol(protocol_path), args.protocol)
|
||||
logging.debug(f"Protocol Object: {protocol_object}")
|
||||
protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), 'database')
|
||||
logging.debug(f"Protocol DB Object: {protocol_object}")
|
||||
logging.debug(f"Protocol DB Object: {protocol_db_object}")
|
||||
|
||||
db_path = os.path.join(CME_PATH, 'workspaces', current_workspace, args.protocol + '.db')
|
||||
logging.debug(f"DB Path: {db_path}")
|
||||
|
|
|
@ -10,3 +10,4 @@ if hasattr(sys, 'getandroidapilevel'):
|
|||
WS_PATH = os.path.join(CME_PATH, 'workspaces')
|
||||
CERT_PATH = os.path.join(CME_PATH, 'cme.pem')
|
||||
CONFIG_PATH = os.path.join(CME_PATH, 'cme.conf')
|
||||
WORKSPACE_DIR = os.path.join(CME_PATH, 'workspaces')
|
||||
|
|
Loading…
Reference in New Issue