From 861626d06127976a4dc25ce170ede3e01286614c Mon Sep 17 00:00:00 2001 From: Marshall Hallenbeck Date: Fri, 17 Nov 2023 21:24:03 -0500 Subject: [PATCH] refactor: deduplicate code and simplify initial db setup --- nxc/database.py | 51 +++++++++++++++++++++++------------------------- nxc/first_run.py | 4 ++-- nxc/netexec.py | 7 +------ nxc/nxcdb.py | 2 +- nxc/paths.py | 2 +- 5 files changed, 29 insertions(+), 37 deletions(-) diff --git a/nxc/database.py b/nxc/database.py index 556e192a..45ddc241 100644 --- a/nxc/database.py +++ b/nxc/database.py @@ -8,7 +8,9 @@ from os.path import exists from os.path import join as path_join from nxc.loaders.protocolloader import ProtocolLoader -from nxc.paths import WS_PATH, WORKSPACE_DIR +from nxc.paths import WORKSPACE_DIR +from nxc.logger import nxc_logger + def create_db_engine(db_path): return create_engine(f"sqlite:///{db_path}", isolation_level="AUTOCOMMIT", future=True) @@ -37,8 +39,25 @@ def write_configfile(config, config_path): config.write(configfile) -def create_workspace(workspace_name, p_loader, protocols): - mkdir(path_join(WORKSPACE_DIR, workspace_name)) +def create_workspace(workspace_name, p_loader=None): + """ + Create a new workspace with the given name. + + Args: + ---- + workspace_name (str): The name of the workspace. + + Returns: + ------- + None + """ + if not exists(path_join(WORKSPACE_DIR, workspace_name)): + nxc_logger.debug(f"Creating {workspace_name} workspace") + mkdir(path_join(WORKSPACE_DIR, workspace_name)) + + if p_loader is None: + p_loader = ProtocolLoader() + protocols = p_loader.get_protocols() for protocol in protocols: protocol_object = p_loader.load_protocol(protocols[protocol]["dbpath"]) @@ -64,27 +83,5 @@ def delete_workspace(workspace_name): shutil.rmtree(path_join(WORKSPACE_DIR, workspace_name)) -def initialize_db(logger): - if not exists(path_join(WS_PATH, "default")): - logger.debug("Creating default workspace") - mkdir(path_join(WS_PATH, "default")) - - p_loader = ProtocolLoader() - protocols = p_loader.get_protocols() - for protocol in protocols: - protocol_object = p_loader.load_protocol(protocols[protocol]["dbpath"]) - proto_db_path = path_join(WS_PATH, "default", f"{protocol}.db") - - if not exists(proto_db_path): - logger.debug(f"Initializing {protocol.upper()} protocol database") - conn = connect(proto_db_path) - c = conn.cursor() - # try to prevent some weird sqlite I/O errors - c.execute("PRAGMA journal_mode = OFF") # could try setting to PERSIST if DB corruption starts occurring - c.execute("PRAGMA foreign_keys = 1") - # set a small timeout (5s) so if another thread is writing to the database, the entire program doesn't crash - c.execute("PRAGMA busy_timeout = 5000") - protocol_object.database.db_schema(c) - # commit the changes and close everything off - conn.commit() - conn.close() \ No newline at end of file +def initialize_db(): + create_workspace("default") \ No newline at end of file diff --git a/nxc/first_run.py b/nxc/first_run.py index e60979bc..c3b55f14 100755 --- a/nxc/first_run.py +++ b/nxc/first_run.py @@ -3,7 +3,7 @@ from os.path import exists from os.path import join as path_join import shutil from nxc.paths import NXC_PATH, CONFIG_PATH, TMP_PATH, DATA_PATH -from nxc.nxcdb import initialize_db +from nxc.database import initialize_db from nxc.logger import nxc_logger @@ -29,7 +29,7 @@ def first_run_setup(logger=nxc_logger): logger.display(f"Creating missing folder {folder}") mkdir(path_join(NXC_PATH, folder)) - initialize_db(logger) + initialize_db() if not exists(CONFIG_PATH): logger.display("Copying default configuration file") diff --git a/nxc/netexec.py b/nxc/netexec.py index e09e249e..d41dbc77 100755 --- a/nxc/netexec.py +++ b/nxc/netexec.py @@ -12,6 +12,7 @@ from nxc.paths import NXC_PATH from nxc.console import nxc_console from nxc.logger import nxc_logger from nxc.config import nxc_config, nxc_workspace, config_log, ignore_opsec +from nxc.database import create_db_engine from concurrent.futures import ThreadPoolExecutor, as_completed import asyncio from nxc.helpers import powershell @@ -21,7 +22,6 @@ from os.path import exists from os.path import join as path_join from sys import exit import logging -import sqlalchemy from rich.progress import Progress import platform @@ -38,11 +38,6 @@ if platform.system() != "Windows": resource.setrlimit(resource.RLIMIT_NOFILE, file_limit) - -def create_db_engine(db_path): - return sqlalchemy.create_engine(f"sqlite:///{db_path}", isolation_level="AUTOCOMMIT", future=True) - - async def start_run(protocol_obj, args, db, targets): nxc_logger.debug("Creating ThreadPoolExecutor") if args.no_progress or len(targets) == 1: diff --git a/nxc/nxcdb.py b/nxc/nxcdb.py index 27ce0c7f..db89cd58 100644 --- a/nxc/nxcdb.py +++ b/nxc/nxcdb.py @@ -484,7 +484,7 @@ class NXCDBMenu(cmd.Cmd): if subcommand == "create": new_workspace = line.split()[1].strip() print(f"[*] Creating workspace '{new_workspace}'") - create_workspace(new_workspace, self.p_loader, self.protocols) + create_workspace(new_workspace, self.p_loader) self.do_workspace(new_workspace) elif subcommand == "list": print("[*] Enumerating Workspaces") diff --git a/nxc/paths.py b/nxc/paths.py index 5b16c191..5ebed0e8 100644 --- a/nxc/paths.py +++ b/nxc/paths.py @@ -8,7 +8,7 @@ if os.name == "nt": TMP_PATH = os.getenv("LOCALAPPDATA") + "\\Temp\\nxc_hosted" if hasattr(sys, "getandroidapilevel"): TMP_PATH = os.path.join("/data", "data", "com.termux", "files", "usr", "tmp", "nxc_hosted") -WS_PATH = os.path.join(NXC_PATH, "workspaces") + CERT_PATH = os.path.join(NXC_PATH, "nxc.pem") CONFIG_PATH = os.path.join(NXC_PATH, "nxc.conf") WORKSPACE_DIR = os.path.join(NXC_PATH, "workspaces")