refactor: deduplicate code and simplify initial db setup

main
Marshall Hallenbeck 2023-11-17 21:24:03 -05:00
parent b4f3bacb99
commit 861626d061
5 changed files with 29 additions and 37 deletions

View File

@ -8,7 +8,9 @@ from os.path import exists
from os.path import join as path_join from os.path import join as path_join
from nxc.loaders.protocolloader import ProtocolLoader from nxc.loaders.protocolloader import ProtocolLoader
from nxc.paths import WS_PATH, WORKSPACE_DIR from nxc.paths import WORKSPACE_DIR
from nxc.logger import nxc_logger
def create_db_engine(db_path): def create_db_engine(db_path):
return create_engine(f"sqlite:///{db_path}", isolation_level="AUTOCOMMIT", future=True) return create_engine(f"sqlite:///{db_path}", isolation_level="AUTOCOMMIT", future=True)
@ -37,8 +39,25 @@ def write_configfile(config, config_path):
config.write(configfile) config.write(configfile)
def create_workspace(workspace_name, p_loader, protocols): def create_workspace(workspace_name, p_loader=None):
mkdir(path_join(WORKSPACE_DIR, workspace_name)) """
Create a new workspace with the given name.
Args:
----
workspace_name (str): The name of the workspace.
Returns:
-------
None
"""
if not exists(path_join(WORKSPACE_DIR, workspace_name)):
nxc_logger.debug(f"Creating {workspace_name} workspace")
mkdir(path_join(WORKSPACE_DIR, workspace_name))
if p_loader is None:
p_loader = ProtocolLoader()
protocols = p_loader.get_protocols()
for protocol in protocols: for protocol in protocols:
protocol_object = p_loader.load_protocol(protocols[protocol]["dbpath"]) protocol_object = p_loader.load_protocol(protocols[protocol]["dbpath"])
@ -64,27 +83,5 @@ def delete_workspace(workspace_name):
shutil.rmtree(path_join(WORKSPACE_DIR, workspace_name)) shutil.rmtree(path_join(WORKSPACE_DIR, workspace_name))
def initialize_db(logger): def initialize_db():
if not exists(path_join(WS_PATH, "default")): create_workspace("default")
logger.debug("Creating default workspace")
mkdir(path_join(WS_PATH, "default"))
p_loader = ProtocolLoader()
protocols = p_loader.get_protocols()
for protocol in protocols:
protocol_object = p_loader.load_protocol(protocols[protocol]["dbpath"])
proto_db_path = path_join(WS_PATH, "default", f"{protocol}.db")
if not exists(proto_db_path):
logger.debug(f"Initializing {protocol.upper()} protocol database")
conn = connect(proto_db_path)
c = conn.cursor()
# try to prevent some weird sqlite I/O errors
c.execute("PRAGMA journal_mode = OFF") # could try setting to PERSIST if DB corruption starts occurring
c.execute("PRAGMA foreign_keys = 1")
# set a small timeout (5s) so if another thread is writing to the database, the entire program doesn't crash
c.execute("PRAGMA busy_timeout = 5000")
protocol_object.database.db_schema(c)
# commit the changes and close everything off
conn.commit()
conn.close()

View File

@ -3,7 +3,7 @@ from os.path import exists
from os.path import join as path_join from os.path import join as path_join
import shutil import shutil
from nxc.paths import NXC_PATH, CONFIG_PATH, TMP_PATH, DATA_PATH from nxc.paths import NXC_PATH, CONFIG_PATH, TMP_PATH, DATA_PATH
from nxc.nxcdb import initialize_db from nxc.database import initialize_db
from nxc.logger import nxc_logger from nxc.logger import nxc_logger
@ -29,7 +29,7 @@ def first_run_setup(logger=nxc_logger):
logger.display(f"Creating missing folder {folder}") logger.display(f"Creating missing folder {folder}")
mkdir(path_join(NXC_PATH, folder)) mkdir(path_join(NXC_PATH, folder))
initialize_db(logger) initialize_db()
if not exists(CONFIG_PATH): if not exists(CONFIG_PATH):
logger.display("Copying default configuration file") logger.display("Copying default configuration file")

View File

@ -12,6 +12,7 @@ from nxc.paths import NXC_PATH
from nxc.console import nxc_console from nxc.console import nxc_console
from nxc.logger import nxc_logger from nxc.logger import nxc_logger
from nxc.config import nxc_config, nxc_workspace, config_log, ignore_opsec from nxc.config import nxc_config, nxc_workspace, config_log, ignore_opsec
from nxc.database import create_db_engine
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio import asyncio
from nxc.helpers import powershell from nxc.helpers import powershell
@ -21,7 +22,6 @@ from os.path import exists
from os.path import join as path_join from os.path import join as path_join
from sys import exit from sys import exit
import logging import logging
import sqlalchemy
from rich.progress import Progress from rich.progress import Progress
import platform import platform
@ -38,11 +38,6 @@ if platform.system() != "Windows":
resource.setrlimit(resource.RLIMIT_NOFILE, file_limit) resource.setrlimit(resource.RLIMIT_NOFILE, file_limit)
def create_db_engine(db_path):
return sqlalchemy.create_engine(f"sqlite:///{db_path}", isolation_level="AUTOCOMMIT", future=True)
async def start_run(protocol_obj, args, db, targets): async def start_run(protocol_obj, args, db, targets):
nxc_logger.debug("Creating ThreadPoolExecutor") nxc_logger.debug("Creating ThreadPoolExecutor")
if args.no_progress or len(targets) == 1: if args.no_progress or len(targets) == 1:

View File

@ -484,7 +484,7 @@ class NXCDBMenu(cmd.Cmd):
if subcommand == "create": if subcommand == "create":
new_workspace = line.split()[1].strip() new_workspace = line.split()[1].strip()
print(f"[*] Creating workspace '{new_workspace}'") print(f"[*] Creating workspace '{new_workspace}'")
create_workspace(new_workspace, self.p_loader, self.protocols) create_workspace(new_workspace, self.p_loader)
self.do_workspace(new_workspace) self.do_workspace(new_workspace)
elif subcommand == "list": elif subcommand == "list":
print("[*] Enumerating Workspaces") print("[*] Enumerating Workspaces")

View File

@ -8,7 +8,7 @@ if os.name == "nt":
TMP_PATH = os.getenv("LOCALAPPDATA") + "\\Temp\\nxc_hosted" TMP_PATH = os.getenv("LOCALAPPDATA") + "\\Temp\\nxc_hosted"
if hasattr(sys, "getandroidapilevel"): if hasattr(sys, "getandroidapilevel"):
TMP_PATH = os.path.join("/data", "data", "com.termux", "files", "usr", "tmp", "nxc_hosted") TMP_PATH = os.path.join("/data", "data", "com.termux", "files", "usr", "tmp", "nxc_hosted")
WS_PATH = os.path.join(NXC_PATH, "workspaces")
CERT_PATH = os.path.join(NXC_PATH, "nxc.pem") CERT_PATH = os.path.join(NXC_PATH, "nxc.pem")
CONFIG_PATH = os.path.join(NXC_PATH, "nxc.conf") CONFIG_PATH = os.path.join(NXC_PATH, "nxc.conf")
WORKSPACE_DIR = os.path.join(NXC_PATH, "workspaces") WORKSPACE_DIR = os.path.join(NXC_PATH, "workspaces")