From bfcc689accbeef95fc14cedd9606b2e95478e9d8 Mon Sep 17 00:00:00 2001 From: Marshall Hallenbeck Date: Sun, 26 Mar 2023 01:52:37 -0400 Subject: [PATCH] refactor(async): update how tasks are created to new threads using proper ThreadPool; update functionality everywhere to match --- cme/cli.py | 1 + cme/cmedb.py | 13 ++- cme/connection.py | 8 ++ cme/crackmapexec.py | 109 ++++---------------- cme/protocols/ftp/database.py | 33 +++--- cme/protocols/ldap/database.py | 33 +++--- cme/protocols/mssql/database.py | 114 +++++++++++---------- cme/protocols/rdp/database.py | 33 +++--- cme/protocols/smb/database.py | 175 +++++++++++++++----------------- cme/protocols/ssh/database.py | 33 +++--- cme/protocols/vnc/database.py | 19 ++-- cme/protocols/winrm/database.py | 88 ++++++++-------- 12 files changed, 286 insertions(+), 373 deletions(-) diff --git a/cme/cli.py b/cme/cli.py index 8067f09e..14a28614 100755 --- a/cme/cli.py +++ b/cme/cli.py @@ -39,6 +39,7 @@ def gen_cli_args(): parser.add_argument("-t", type=int, dest="threads", default=100, help="set how many concurrent threads to use (default: 100)") parser.add_argument("--timeout", default=None, type=int, help='max timeout in seconds of each thread (default: None)') parser.add_argument("--jitter", metavar='INTERVAL', type=str, help='sets a random delay between each connection (default: None)') + parser.add_argument("--progress", default=True, action='store_false', help='display progress bar during scan') parser.add_argument("--darrell", action='store_true', help='give Darrell a hand') parser.add_argument("--verbose", action='store_true', help="enable verbose output") parser.add_argument("--version", action='store_true', help="Display CME version") diff --git a/cme/cmedb.py b/cme/cmedb.py index f713920e..d0a86381 100644 --- a/cme/cmedb.py +++ b/cme/cmedb.py @@ -8,12 +8,12 @@ import sqlite3 import sys import os import requests +from sqlalchemy import create_engine from terminaltables import AsciiTable import configparser from cme.loaders.protocol_loader import protocol_loader from cme.paths import CONFIG_PATH, WS_PATH, WORKSPACE_DIR from requests import ConnectionError -from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.exc import SAWarning import asyncio import csv @@ -33,13 +33,11 @@ class UserExitedProto(Exception): def create_db_engine(db_path): - db_engine = create_async_engine( - f"sqlite+aiosqlite:///{db_path}", + db_engine = create_engine( + f"sqlite:///{db_path}", isolation_level="AUTOCOMMIT", future=True - ) # can add echo=True - # db_engine.execution_options(isolation_level="AUTOCOMMIT") - # db_engine.connect().connection.text_factory = str + ) return db_engine @@ -51,6 +49,7 @@ def print_table(data, title=None): print(table.table) print("") + def write_csv(filename, headers, entries): """ Writes a CSV file with the provided parameters. @@ -106,7 +105,7 @@ class DatabaseNavigator(cmd.Cmd): self.prompt = 'cmedb ({})({}) > '.format(main_menu.workspace, proto) def do_exit(self, line): - asyncio.run(self.db.shutdown_db()) + self.db.shutdown_db() sys.exit() def help_exit(self): diff --git a/cme/connection.py b/cme/connection.py index a0190b12..a42a2943 100755 --- a/cme/connection.py +++ b/cme/connection.py @@ -2,11 +2,14 @@ # -*- coding: utf-8 -*- import logging +import random import socket from os.path import isfile from threading import BoundedSemaphore from socket import gethostbyname from functools import wraps +from time import sleep + from cme.logger import CMEAdapter from cme.context import Context from cme.helpers.logger import write_log @@ -69,6 +72,11 @@ class connection(object): logging.debug('Error resolving hostname {}: {}'.format(self.hostname, e)) return + if args.jitter: + value = random.choice(range(args.jitter[0], args.jitter[1])) + logging.debug(f"Doin' the jitterbug for {value} second(s)") + sleep(value) + self.proto_flow() @staticmethod diff --git a/cme/crackmapexec.py b/cme/crackmapexec.py index 429386d2..9b88dd5d 100755 --- a/cme/crackmapexec.py +++ b/cme/crackmapexec.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import concurrent.futures + +import sqlalchemy from cme.logger import setup_logger, setup_debug_logger, CMEAdapter from cme.helpers.logger import highlight @@ -18,8 +21,6 @@ from concurrent.futures import ThreadPoolExecutor from pprint import pformat from decimal import Decimal import asyncio -import aioconsole -import functools import configparser import cme.helpers.powershell as powershell import cme @@ -31,9 +32,9 @@ import os import sys import logging from sqlalchemy.orm import declarative_base -from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.exc import SAWarning import warnings +from tqdm import tqdm Base = declarative_base() @@ -50,95 +51,20 @@ warnings.filterwarnings("ignore", category=SAWarning) def create_db_engine(db_path): - db_engine = create_async_engine( - f"sqlite+aiosqlite:///{db_path}", + db_engine = sqlalchemy.create_engine( + f"sqlite:///{db_path}", isolation_level="AUTOCOMMIT", future=True - ) # can add echo=True - # db_engine.execution_options(isolation_level="AUTOCOMMIT") - # db_engine.connect().connection.text_factory = str + ) return db_engine -async def monitor_threadpool(pool, targets): - logging.debug('Started thread poller') - - while True: - try: - text = await aioconsole.ainput("") - if text == "": - pool_size = pool._work_queue.qsize() - finished_threads = len(targets) - pool_size - percentage = Decimal(finished_threads) / Decimal(len(targets)) * Decimal(100) - logger.info(f"completed: {percentage:.2f}% ({finished_threads}/{len(targets)})") - except asyncio.CancelledError: - logging.debug("Stopped thread poller") - break - - -async def run_protocol(loop, protocol_obj, args, db, target, jitter): - try: - if jitter: - value = random.choice(range(jitter[0], jitter[1])) - logging.debug(f"Doin' the jitterbug for {value} second(s)") - await asyncio.sleep(value) - - thread = loop.run_in_executor( - None, - functools.partial( - protocol_obj, - args, - db, - str(target) - ) - ) - - await asyncio.wait_for( - thread, - timeout=args.timeout - ) - - except asyncio.TimeoutError: - logging.debug("Thread exceeded timeout") - except asyncio.CancelledError: - logging.debug("Shutting down DB") - thread.cancel() - except sqlite3.OperationalError as e: - logging.debug("Sqlite error - sqlite3.operationalError - {}".format(str(e))) - - -async def start_threadpool(protocol_obj, args, db, targets, jitter): - pool = ThreadPoolExecutor(max_workers=args.threads + 1) - loop = asyncio.get_running_loop() - loop.set_default_executor(pool) - - monitor_task = asyncio.create_task( - monitor_threadpool(pool, targets) - ) - - jobs = [ - run_protocol( - loop, - protocol_obj, - args, - db, - target, - jitter - ) - for target in targets - ] - - try: - logging.debug("Running") - await asyncio.gather(*jobs) - except asyncio.CancelledError: - print('\n') - logger.info("Shutting down, please wait...") - logging.debug("Cancelling scan") - finally: - await asyncio.shield(db.shutdown_db()) - monitor_task.cancel() - pool.shutdown(wait=True) +async def start_scan(protocol_obj, args, db, targets): + with tqdm(total=len(targets), disable=args.progress) as pbar: + with ThreadPoolExecutor(max_workers=args.threads + 1) as executor: + futures = [executor.submit(protocol_obj, args, db, target) for target in targets] + for future in concurrent.futures.as_completed(futures): + pbar.update(1) def main(): @@ -164,7 +90,6 @@ def main(): module = None module_server = None targets = [] - jitter = None server_port_dict = {'http': 80, 'https': 443, 'smb': 445} current_workspace = config.get('CME', 'workspace') if config.get('CME', 'log_mode') != "False": @@ -178,9 +103,9 @@ def main(): if args.jitter: if '-' in args.jitter: start, end = args.jitter.split('-') - jitter = (int(start), int(end)) + args.jitter = (int(start), int(end)) else: - jitter = (0, int(args.jitter)) + args.jitter = (0, int(args.jitter)) if hasattr(args, 'cred_id') and args.cred_id: for cred_id in args.cred_id: @@ -298,14 +223,14 @@ def main(): try: asyncio.run( - start_threadpool(protocol_object, args, db, targets, jitter) + start_scan(protocol_object, args, db, targets) ) except KeyboardInterrupt: logging.debug("Got keyboard interrupt") finally: if module_server: module_server.shutdown() - asyncio.run(db_engine.dispose()) + db_engine.dispose() if __name__ == '__main__': diff --git a/cme/protocols/ftp/database.py b/cme/protocols/ftp/database.py index 762385d5..190d3192 100644 --- a/cme/protocols/ftp/database.py +++ b/cme/protocols/ftp/database.py @@ -16,8 +16,11 @@ class database: self.db_engine = db_engine self.metadata = MetaData() - asyncio.run(self.reflect_tables()) - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession) + self.reflect_tables() + session_factory = sessionmaker( + bind=self.db_engine, + expire_on_commit=True + ) Session = scoped_session(session_factory) # this is still named "conn" when it is the session object; TODO: rename @@ -38,20 +41,9 @@ class database: "server_banner" text )''') - async def shutdown_db(self): - try: - await asyncio.shield(self.conn.close()) - # due to the async nature of CME, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - logging.debug(f"Error while closing session db object: {e}") - - async def reflect_tables(self): - async with self.db_engine.connect() as conn: + def reflect_tables(self): + with self.db_engine.connect() as conn: try: - await conn.run_sync(self.metadata.reflect) - self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine) self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) except NoInspectionAvailable: @@ -63,7 +55,16 @@ class database: ) exit() + def shutdown_db(self): + try: + self.conn.close() + # due to the async nature of CME, sometimes session state is a bit messy and this will throw: + # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and + # this would cause an unexpected state change to + except IllegalStateChangeError as e: + logging.debug(f"Error while closing session db object: {e}") + def clear_database(self): for table in self.metadata.sorted_tables: - asyncio.run(self.conn.execute(table.delete())) + self.conn.execute(table.delete()) diff --git a/cme/protocols/ldap/database.py b/cme/protocols/ldap/database.py index eb19c13c..19874772 100644 --- a/cme/protocols/ldap/database.py +++ b/cme/protocols/ldap/database.py @@ -15,8 +15,11 @@ class database: self.db_engine = db_engine self.metadata = MetaData() - asyncio.run(self.reflect_tables()) - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession) + self.reflect_tables() + session_factory = sessionmaker( + bind=self.db_engine, + expire_on_commit=True + ) Session = scoped_session(session_factory) # this is still named "conn" when it is the session object; TODO: rename @@ -37,20 +40,9 @@ class database: "port" integer )''') - async def shutdown_db(self): - try: - await asyncio.shield(self.conn.close()) - # due to the async nature of CME, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - logging.debug(f"Error while closing session db object: {e}") - - async def reflect_tables(self): - async with self.db_engine.connect() as conn: + def reflect_tables(self): + with self.db_engine.connect() as conn: try: - await conn.run_sync(self.metadata.reflect) - self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine) self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) except NoInspectionAvailable: @@ -62,6 +54,15 @@ class database: ) exit() + def shutdown_db(self): + try: + self.conn.close() + # due to the async nature of CME, sometimes session state is a bit messy and this will throw: + # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and + # this would cause an unexpected state change to + except IllegalStateChangeError as e: + logging.debug(f"Error while closing session db object: {e}") + def clear_database(self): for table in self.metadata.sorted_tables: - asyncio.run(self.conn.execute(table.delete())) + self.conn.execute(table.delete()) diff --git a/cme/protocols/mssql/database.py b/cme/protocols/mssql/database.py index 357349a3..e6810bbe 100755 --- a/cme/protocols/mssql/database.py +++ b/cme/protocols/mssql/database.py @@ -5,9 +5,7 @@ from sqlalchemy import MetaData, func, Table, select, insert, update, delete from sqlalchemy.dialects.sqlite import Insert # used for upsert from sqlalchemy.exc import IllegalStateChangeError, NoInspectionAvailable from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.exc import SAWarning -import asyncio import warnings # if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings @@ -22,39 +20,16 @@ class database: self.db_engine = db_engine self.metadata = MetaData() - asyncio.run(self.reflect_tables()) - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession) + self.reflect_tables() + session_factory = sessionmaker( + bind=self.db_engine, + expire_on_commit=True + ) Session = scoped_session(session_factory) # this is still named "conn" when it is the session object; TODO: rename self.conn = Session() - async def shutdown_db(self): - try: - await asyncio.shield(self.conn.close()) - # due to the async nature of CME, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - logging.debug(f"Error while closing session db object: {e}") - - async def reflect_tables(self): - async with self.db_engine.connect() as conn: - try: - await conn.run_sync(self.metadata.reflect) - - self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) - self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine) - self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine) - except NoInspectionAvailable: - print( - "[-] Error reflecting tables - this means there is a DB schema mismatch \n" - "[-] This is probably because a newer version of CME is being ran on an old DB schema\n" - "[-] If you wish to save the old DB data, copy it to a new location (`cp -r ~/.cme/workspaces/ ~/old_cme_workspaces/`)\n" - "[-] Then remove the CME DB folders (`rm -rf ~/.cme/workspaces/`) and rerun CME to initialize the new DB schema" - ) - exit() - @staticmethod def db_schema(db_conn): db_conn.execute('''CREATE TABLE "hosts" ( @@ -84,6 +59,34 @@ class database: FOREIGN KEY(pillaged_from_hostid) REFERENCES hosts(id) )''') + def reflect_tables(self): + with self.db_engine.connect() as conn: + try: + self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) + self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine) + self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine) + except NoInspectionAvailable: + print( + "[-] Error reflecting tables - this means there is a DB schema mismatch \n" + "[-] This is probably because a newer version of CME is being ran on an old DB schema\n" + "[-] If you wish to save the old DB data, copy it to a new location (`cp -r ~/.cme/workspaces/ ~/old_cme_workspaces/`)\n" + "[-] Then remove the CME DB folders (`rm -rf ~/.cme/workspaces/`) and rerun CME to initialize the new DB schema" + ) + exit() + + def shutdown_db(self): + try: + self.conn.close() + # due to the async nature of CME, sometimes session state is a bit messy and this will throw: + # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and + # this would cause an unexpected state change to + except IllegalStateChangeError as e: + logging.debug(f"Error while closing session db object: {e}") + + def clear_database(self): + for table in self.metadata.sorted_tables: + self.conn.execute(table.delete()) + def add_host(self, ip, hostname, domain, os, instances): """ Check if this host has already been added to the database, if not, add it in. @@ -95,7 +98,7 @@ class database: q = select(self.HostsTable).filter( self.HostsTable.c.ip == ip ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() logging.debug(f"mssql add_host() - hosts returned: {results}") host_data = { @@ -133,11 +136,9 @@ class database: index_elements=self.HostsTable.primary_key, set_=update_columns ) - asyncio.run( - self.conn.execute( - q, - hosts - ) + self.conn.execute( + q, + hosts ) def add_credential(self, credtype, domain, username, password, pillaged_from=None): @@ -164,7 +165,7 @@ class database: func.lower(self.UsersTable.c.username) == func.lower(username), func.lower(self.UsersTable.c.credtype) == func.lower(credtype) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() if not results: user_data = { @@ -175,13 +176,13 @@ class database: "pillaged_from_hostid": pillaged_from, } q = insert(self.UsersTable).values(user_data) # .returning(self.UsersTable.c.id) - asyncio.run(self.conn.execute(q)) # .first() + self.conn.execute(q) # .first() else: for user in results: # might be able to just remove this if check, but leaving it in for now if not user[3] and not user[4] and not user[5]: q = update(self.UsersTable).values(credential_data) # .returning(self.UsersTable.c.id) - results = asyncio.run(self.conn.execute(q)) # .first() + results = self.conn.execute(q) # .first() # user_rowid = results.id logging.debug('add_credential(credtype={}, domain={}, username={}, password={}, pillaged_from={})'.format( @@ -203,7 +204,7 @@ class database: self.UsersTable.c.id == cred_id ) del_hosts.append(q) - asyncio.run(self.conn.execute(q)) + self.conn.execute(q) def add_admin_user(self, credtype, domain, username, password, host, user_id=None): domain = domain.split('.')[0].upper() @@ -212,7 +213,7 @@ class database: q = select(self.UsersTable).filter( self.UsersTable.c.id == user_id ) - users = asyncio.run(self.conn.execute(q)).all() + users = self.conn.execute(q).all() else: q = select(self.UsersTable).filter( self.UsersTable.c.credtype == credtype, @@ -220,31 +221,35 @@ class database: func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password ) - users = asyncio.run(self.conn.execute(q)).all() + users = self.conn.execute(q).all() logging.debug(f"Users: {users}") like_term = func.lower(f"%{host}%") q = q.filter( self.HostsTable.c.ip.like(like_term) ) - hosts = asyncio.run(self.conn.execute(q)).all() + hosts = self.conn.execute(q).all() logging.debug(f"Hosts: {hosts}") if users is not None and hosts is not None: for user, host in zip(users, hosts): user_id = user[0] host_id = host[0] + link = { + "userid": user_id, + "hostid": host_id + } q = select(self.AdminRelationsTable).filter( self.AdminRelationsTable.c.userid == user_id, self.AdminRelationsTable.c.hostid == host_id ) - links = asyncio.run(self.conn.execute(q)).all() + links = self.conn.execute(q).all() if not links: - asyncio.run(self.conn.execute( + self.conn.execute( insert(self.AdminRelationsTable).values(link) - )) + ) def get_admin_relations(self, user_id=None, host_id=None): if user_id: @@ -258,7 +263,7 @@ class database: else: q = select(self.AdminRelationsTable) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def remove_admin_relation(self, user_ids=None, host_ids=None): @@ -273,7 +278,7 @@ class database: q = q.filter( self.AdminRelationsTable.c.hostid == host_id ) - asyncio.run(self.conn.execute(q)) + self.conn.execute(q) def is_credential_valid(self, credential_id): """ @@ -283,7 +288,7 @@ class database: self.UsersTable.c.id == credential_id, self.UsersTable.c.password is not None ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 def get_credentials(self, filter_term=None, cred_type=None): @@ -309,7 +314,7 @@ class database: else: q = select(self.UsersTable) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def is_host_valid(self, host_id): @@ -319,7 +324,7 @@ class database: q = select(self.HostsTable).filter( self.HostsTable.c.id == host_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 def get_hosts(self, filter_term=None, domain=None): @@ -333,7 +338,7 @@ class database: q = q.filter( self.HostsTable.c.id == filter_term ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] # if we're filtering by domain controllers @@ -353,9 +358,6 @@ class database: func.lower(self.HostsTable.c.hostname).like(like_term) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results - def clear_database(self): - for table in self.metadata.sorted_tables: - asyncio.run(self.conn.execute(table.delete())) diff --git a/cme/protocols/rdp/database.py b/cme/protocols/rdp/database.py index 8f3263e4..05a7c16c 100644 --- a/cme/protocols/rdp/database.py +++ b/cme/protocols/rdp/database.py @@ -15,8 +15,11 @@ class database: self.db_engine = db_engine self.metadata = MetaData() - asyncio.run(self.reflect_tables()) - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession) + self.reflect_tables() + session_factory = sessionmaker( + bind=self.db_engine, + expire_on_commit=True + ) Session = scoped_session(session_factory) # this is still named "conn" when it is the session object; TODO: rename @@ -39,20 +42,9 @@ class database: "server_banner" text )''') - async def shutdown_db(self): - try: - await asyncio.shield(self.conn.close()) - # due to the async nature of CME, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - logging.debug(f"Error while closing session db object: {e}") - - async def reflect_tables(self): - async with self.db_engine.connect() as conn: + def reflect_tables(self): + with self.db_engine.connect() as conn: try: - await conn.run_sync(self.metadata.reflect) - self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine) self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) except NoInspectionAvailable: @@ -64,6 +56,15 @@ class database: ) exit() + def shutdown_db(self): + try: + self.conn.close() + # due to the async nature of CME, sometimes session state is a bit messy and this will throw: + # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and + # this would cause an unexpected state change to + except IllegalStateChangeError as e: + logging.debug(f"Error while closing session db object: {e}") + def clear_database(self): for table in self.metadata.sorted_tables: - asyncio.run(self.conn.execute(table.delete())) + self.conn.execute(table.delete()) diff --git a/cme/protocols/smb/database.py b/cme/protocols/smb/database.py index d7070678..be75b06a 100755 --- a/cme/protocols/smb/database.py +++ b/cme/protocols/smb/database.py @@ -30,13 +30,10 @@ class database: self.db_engine = db_engine self.metadata = MetaData() - asyncio.run(self.reflect_tables()) - # we don't use async_sessionmaker or async_scoped_session because when `database` is initialized, - # there is no running async loop + self.reflect_tables() session_factory = sessionmaker( bind=self.db_engine, - expire_on_commit=True, - class_=AsyncSession + expire_on_commit=True ) Session = scoped_session(session_factory) @@ -134,11 +131,9 @@ class database: # FOREIGN KEY(hostid) REFERENCES hosts(id) # )''') - async def reflect_tables(self): - async with self.db_engine.connect() as conn: + def reflect_tables(self): + with self.db_engine.connect() as conn: try: - await conn.run_sync(self.metadata.reflect) - self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine) self.GroupsTable = Table("groups", self.metadata, autoload_with=self.db_engine) @@ -157,9 +152,9 @@ class database: ) exit() - async def shutdown_db(self): + def shutdown_db(self): try: - await asyncio.shield(self.conn.close()) + self.conn.close() # due to the async nature of CME, sometimes session state is a bit messy and this will throw: # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and # this would cause an unexpected state change to @@ -168,7 +163,7 @@ class database: def clear_database(self): for table in self.metadata.sorted_tables: - asyncio.run(self.conn.execute(table.delete())) + self.conn.execute(table.delete()) # pull/545 def add_host(self, ip, hostname, domain, os, smbv1, signing, spooler=None, zerologon=None, petitpotam=None, @@ -183,7 +178,7 @@ class database: q = select(self.HostsTable).filter( self.HostsTable.c.ip == ip ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() # create new host if not results: @@ -238,12 +233,11 @@ class database: index_elements=self.HostsTable.primary_key, set_=update_columns ) - asyncio.run( - self.conn.execute( - q, - hosts - ) - ) # .scalar() + + self.conn.execute( + q, + hosts + ) # .scalar() # we only return updated IDs for now - when RETURNING clause is allowed we can return inserted if updated_ids: logging.debug(f"add_host() - Host IDs Updated: {updated_ids}") @@ -267,7 +261,7 @@ class database: func.lower(self.UsersTable.c.username) == func.lower(username), func.lower(self.UsersTable.c.credtype) == func.lower(credtype) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() # add new credential if not results: @@ -314,20 +308,18 @@ class database: set_=update_columns_users ) logging.debug(f"Adding credentials: {credentials}") - asyncio.run( - self.conn.execute( - q_users, - credentials - ) - ) # .scalar() + + self.conn.execute( + q_users, + credentials + ) # .scalar() if groups: q_groups = Insert(self.GroupRelationsTable) - asyncio.run( - self.conn.execute( - q_groups, - groups - ) + + self.conn.execute( + q_groups, + groups ) # return user_ids @@ -341,7 +333,7 @@ class database: self.UsersTable.c.id == cred_id ) del_hosts.append(q) - asyncio.run(self.conn.execute(q)) + self.conn.execute(q) def add_admin_user(self, credtype, domain, username, password, host, user_id=None): domain = domain.split('.')[0] @@ -359,7 +351,7 @@ class database: func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password ) - users = asyncio.run(self.conn.execute(creds_q)) + users = self.conn.execute(creds_q) hosts = self.get_hosts(host) if users and hosts: @@ -374,7 +366,7 @@ class database: self.AdminRelationsTable.c.userid == user_id, self.AdminRelationsTable.c.hostid == host_id ) - links = asyncio.run(self.conn.execute(admin_relations_select)).all() + links = self.conn.execute(admin_relations_select).all() if not links: add_links.append(link) @@ -382,10 +374,10 @@ class database: admin_relations_insert = Insert(self.AdminRelationsTable) if add_links: - asyncio.run(self.conn.execute( + self.conn.execute( admin_relations_insert, add_links - )) + ) def get_admin_relations(self, user_id=None, host_id=None): if user_id: @@ -399,7 +391,7 @@ class database: else: q = select(self.AdminRelationsTable) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def remove_admin_relation(self, user_ids=None, host_ids=None): @@ -414,7 +406,7 @@ class database: q = q.filter( self.AdminRelationsTable.c.hostid == host_id ) - asyncio.run(self.conn.execute(q)) + self.conn.execute(q) def is_credential_valid(self, credential_id): """ @@ -424,7 +416,7 @@ class database: self.UsersTable.c.id == credential_id, self.UsersTable.c.password is not None ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 def get_credentials(self, filter_term=None, cred_type=None): @@ -450,7 +442,7 @@ class database: else: q = select(self.UsersTable) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def get_credential(self, cred_type, domain, username, password): @@ -462,20 +454,20 @@ class database: self.UsersTable.c.password == password, self.UsersTable.c.credtype == cred_type ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() return results.id def is_credential_local(self, credential_id): q = select(self.UsersTable.c.domain).filter( self.UsersTable.c.id == credential_id ) - user_domain = asyncio.run(self.conn.execute(q)).all() + user_domain = self.conn.execute(q).all() if user_domain: q = select(self.HostsTable).filter( func.lower(self.HostsTable.c.id) == func.lower(user_domain) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 @@ -486,7 +478,7 @@ class database: q = select(self.HostsTable).filter( self.HostsTable.c.id == host_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 def get_hosts(self, filter_term=None, domain=None): @@ -500,7 +492,7 @@ class database: q = q.filter( self.HostsTable.c.id == filter_term ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] # if we're filtering by domain controllers @@ -542,7 +534,7 @@ class database: self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() logging.debug(f"smb hosts() - results: {results}") return results @@ -553,7 +545,7 @@ class database: q = select(self.GroupsTable).filter( self.GroupsTable.c.id == group_id ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() valid = True if results else False logging.debug(f"is_group_valid(groupID={group_id}) => {valid}") @@ -585,11 +577,10 @@ class database: # insert the group and get the returned id right away, this can be refactored when we can use RETURNING q = Insert(self.GroupsTable) - asyncio.run( - self.conn.execute( - q, - groups - ) + + self.conn.execute( + q, + groups ) new_group_data = self.get_groups(group_name=group_data["name"], group_domain=group_data["domain"]) returned_id = [new_group_data[0].id] @@ -623,11 +614,10 @@ class database: index_elements=self.GroupsTable.primary_key, set_=update_columns ) - asyncio.run( - self.conn.execute( - q, - groups - ) + + self.conn.execute( + q, + groups ) # TODO: always return a list and fix code references to not expect a single integer # inserted_result = res_inserted_result.first() @@ -650,7 +640,7 @@ class database: q = select(self.GroupsTable).filter( self.GroupsTable.c.id == filter_term ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] elif group_name and group_domain: @@ -666,7 +656,7 @@ class database: else: q = select(self.GroupsTable).filter() - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() logging.debug( f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}") @@ -687,7 +677,7 @@ class database: self.GroupRelationsTable.c.groupid == group_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def remove_group_relations(self, user_id=None, group_id=None): @@ -700,7 +690,7 @@ class database: q = q.filter( self.GroupRelationsTable.c.groupid == group_id ) - asyncio.run(self.conn.execute(q)) + self.conn.execute(q) def is_user_valid(self, user_id): """ @@ -709,7 +699,7 @@ class database: q = select(self.UsersTable).filter( self.UsersTable.c.id == user_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 def get_users(self, filter_term=None): @@ -725,7 +715,7 @@ class database: q = q.filter( func.lower(self.UsersTable.c.username).like(like_term) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def get_user(self, domain, username): @@ -733,7 +723,7 @@ class database: func.lower(self.UsersTable.c.domain) == func.lower(domain), func.lower(self.UsersTable.c.username) == func.lower(username) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def get_domain_controllers(self, domain=None): @@ -746,7 +736,7 @@ class database: q = select(self.SharesTable).filter( self.SharesTable.c.id == share_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() logging.debug(f"is_share_valid(shareID={share_id}) => {len(results) > 0}") return len(results) > 0 @@ -760,10 +750,10 @@ class database: "read": read, "write": write, } - share_id = asyncio.run(self.conn.execute( + share_id = self.conn.execute( Insert(self.SharesTable).on_conflict_do_nothing(), # .returning(self.SharesTable.c.id), share_data - )) # .scalar_one() + ) # .scalar_one() # return share_id def get_shares(self, filter_term=None): @@ -778,7 +768,7 @@ class database: ) else: q = select(self.SharesTable) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def get_shares_by_access(self, permissions, share_id=None): @@ -790,7 +780,7 @@ class database: q = q.filter(self.SharesTable.c.read == 1) if "w" in permissions: q = q.filter(self.SharesTable.c.write == 1) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def get_users_with_share_access(self, host_id, share_name, permissions): @@ -803,7 +793,7 @@ class database: q = q.filter(self.SharesTable.c.read == 1) if "w" in permissions: q = q.filter(self.SharesTable.c.write == 1) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results @@ -816,7 +806,7 @@ class database: q = select(self.DpapiBackupkey).filter( func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() if not len(results): pvk_encoded = base64.b64encode(pvk) @@ -827,11 +817,10 @@ class database: try: # TODO: find a way to abstract this away to a single Upsert call q = Insert(self.DpapiBackupkey) # .returning(self.DpapiBackupkey.c.id) - asyncio.run( - self.conn.execute( - q, - [backup_key] - ) + + self.conn.execute( + q, + [backup_key] ) # .scalar() logging.debug(f"add_domain_backupkey(domain={domain}, pvk={pvk_encoded})") # return inserted_id @@ -848,7 +837,7 @@ class database: q = q.filter( func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() logging.debug(f"get_domain_backupkey(domain={domain}) => {results}") @@ -864,7 +853,7 @@ class database: q = select(self.DpapiSecrets).filter( func.lower(self.DpapiSecrets.c.id) == dpapi_secret_id ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() valid = True if results is not None else False logging.debug(f"is_dpapi_secret_valid(groupID={dpapi_secret_id}) => {valid}") return valid @@ -883,11 +872,10 @@ class database: "url": url } q = Insert(self.DpapiSecrets).on_conflict_do_nothing() # .returning(self.DpapiSecrets.c.id) - asyncio.run( - self.conn.execute( - q, - [secret] - ) + + self.conn.execute( + q, + [secret] ) # .scalar() # inserted_result = res_inserted_result.first() @@ -907,14 +895,14 @@ class database: q = q.filter( self.DpapiSecrets.c.id == filter_term ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] elif host: q = q.filter( self.DpapiSecrets.c.host == host ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] elif dpapi_type: @@ -935,7 +923,7 @@ class database: q = q.filter( func.lower(self.DpapiSecrets.c.url) == func.lower(url) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() logging.debug( f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}") @@ -946,7 +934,7 @@ class database: self.LoggedinRelationsTable.c.userid == user_id, self.LoggedinRelationsTable.c.hostid == host_id ) - results = asyncio.run(self.conn.execute(relation_query)).all() + results = self.conn.execute(relation_query).all() # only add one if one doesn't already exist if not results: @@ -958,11 +946,10 @@ class database: logging.debug(f"Inserting loggedin_relations: {relation}") # TODO: find a way to abstract this away to a single Upsert call q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id) - asyncio.run( - self.conn.execute( - q, - [relation] - ) + + self.conn.execute( + q, + [relation] ) # .scalar() inserted_id_results = self.get_loggedin_relations(user_id, host_id) logging.debug(f"Checking if relation was added: {inserted_id_results}") @@ -980,7 +967,7 @@ class database: q = q.filter( self.LoggedinRelationsTable.c.hostid == host_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def remove_loggedin_relations(self, user_id=None, host_id=None): @@ -993,4 +980,4 @@ class database: q = q.filter( self.LoggedinRelationsTable.c.hostid == host_id ) - asyncio.run(self.conn.execute(q)) + self.conn.execute(q) diff --git a/cme/protocols/ssh/database.py b/cme/protocols/ssh/database.py index 5542eadb..69721ccb 100644 --- a/cme/protocols/ssh/database.py +++ b/cme/protocols/ssh/database.py @@ -15,8 +15,11 @@ class database: self.db_engine = db_engine self.metadata = MetaData() - asyncio.run(self.reflect_tables()) - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession) + self.reflect_tables() + session_factory = sessionmaker( + bind=self.db_engine, + expire_on_commit=True + ) Session = scoped_session(session_factory) # this is still named "conn" when it is the session object; TODO: rename @@ -37,20 +40,9 @@ class database: "server_banner" text )''') - async def shutdown_db(self): - try: - await asyncio.shield(self.conn.close()) - # due to the async nature of CME, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - logging.debug(f"Error while closing session db object: {e}") - - async def reflect_tables(self): - async with self.db_engine.connect() as conn: + def reflect_tables(self): + with self.db_engine.connect() as conn: try: - await conn.run_sync(self.metadata.reflect) - self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine) self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) except NoInspectionAvailable: @@ -62,6 +54,15 @@ class database: ) exit() + def shutdown_db(self): + try: + self.conn.close() + # due to the async nature of CME, sometimes session state is a bit messy and this will throw: + # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and + # this would cause an unexpected state change to + except IllegalStateChangeError as e: + logging.debug(f"Error while closing session db object: {e}") + def clear_database(self): for table in self.metadata.sorted_tables: - asyncio.run(self.conn.execute(table.delete())) + self.conn.execute(table.delete()) diff --git a/cme/protocols/vnc/database.py b/cme/protocols/vnc/database.py index 62c8dd55..2b5aeac5 100644 --- a/cme/protocols/vnc/database.py +++ b/cme/protocols/vnc/database.py @@ -22,13 +22,10 @@ class database: self.db_engine = db_engine self.metadata = MetaData() - asyncio.run(self.reflect_tables()) - # we don't use async_sessionmaker or async_scoped_session because when `database` is initialized, - # there is no running async loop + self.reflect_tables() session_factory = sessionmaker( bind=self.db_engine, - expire_on_commit=True, - class_=AsyncSession + expire_on_commit=True ) Session = scoped_session(session_factory) @@ -52,11 +49,9 @@ class database: "server_banner" text )''') - async def reflect_tables(self): - async with self.db_engine.connect() as conn: + def reflect_tables(self): + with self.db_engine.connect() as conn: try: - await conn.run_sync(self.metadata.reflect) - self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine) except NoInspectionAvailable: @@ -68,9 +63,9 @@ class database: ) exit() - async def shutdown_db(self): + def shutdown_db(self): try: - await asyncio.shield(self.conn.close()) + self.conn.close() # due to the async nature of CME, sometimes session state is a bit messy and this will throw: # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and # this would cause an unexpected state change to @@ -79,4 +74,4 @@ class database: def clear_database(self): for table in self.metadata.sorted_tables: - asyncio.run(self.conn.execute(table.delete())) \ No newline at end of file + self.conn.execute(table.delete()) diff --git a/cme/protocols/winrm/database.py b/cme/protocols/winrm/database.py index 7044cd84..6af7dfdb 100644 --- a/cme/protocols/winrm/database.py +++ b/cme/protocols/winrm/database.py @@ -19,11 +19,10 @@ class database: self.db_engine = db_engine self.metadata = MetaData() - asyncio.run(self.reflect_tables()) + self.reflect_tables() session_factory = sessionmaker( bind=self.db_engine, - expire_on_commit=True, - class_=AsyncSession + expire_on_commit=True ) Session = scoped_session(session_factory) @@ -64,11 +63,9 @@ class database: FOREIGN KEY(hostid) REFERENCES hosts(id) )''') - async def reflect_tables(self): - async with self.db_engine.connect() as conn: + def reflect_tables(self): + with self.db_engine.connect() as conn: try: - await conn.run_sync(self.metadata.reflect) - self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine) self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine) self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine) @@ -82,9 +79,9 @@ class database: ) exit() - async def shutdown_db(self): + def shutdown_db(self): try: - await asyncio.shield(self.conn.close()) + self.conn.close() # due to the async nature of CME, sometimes session state is a bit messy and this will throw: # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and # this would cause an unexpected state change to @@ -93,7 +90,7 @@ class database: def clear_database(self): for table in self.metadata.sorted_tables: - asyncio.run(self.conn.execute(table.delete())) + self.conn.execute(table.delete()) def add_host(self, ip, port, hostname, domain, os=None): """ @@ -106,7 +103,7 @@ class database: q = select(self.HostsTable).filter( self.HostsTable.c.ip == ip ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() logging.debug(f"smb add_host() - hosts returned: {results}") # create new host @@ -146,11 +143,9 @@ class database: index_elements=self.HostsTable.primary_key, set_=update_columns ) - asyncio.run( - self.conn.execute( - q, - hosts - ) + self.conn.execute( + q, + hosts ) def add_credential(self, credtype, domain, username, password, pillaged_from=None): @@ -177,7 +172,7 @@ class database: func.lower(self.UsersTable.c.username) == func.lower(username), func.lower(self.UsersTable.c.credtype) == func.lower(credtype) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() # add new credential if not results: @@ -216,12 +211,10 @@ class database: index_elements=self.UsersTable.primary_key, set_=update_columns_users ) - asyncio.run( - self.conn.execute( - q_users, - credentials - ) - ) # .scalar() + self.conn.execute( + q_users, + credentials + ) # .scalar() # return user_ids def remove_credentials(self, creds_id): @@ -234,7 +227,7 @@ class database: self.UsersTable.c.id == cred_id ) del_hosts.append(q) - asyncio.run(self.conn.execute(q)) + self.conn.execute(q) def add_admin_user(self, credtype, domain, username, password, host, user_id=None): domain = domain.split('.')[0] @@ -252,7 +245,7 @@ class database: func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password ) - users = asyncio.run(self.conn.execute(creds_q)) + users = self.conn.execute(creds_q) hosts = self.get_hosts(host) if users and hosts: @@ -267,17 +260,17 @@ class database: self.AdminRelationsTable.c.userid == user_id, self.AdminRelationsTable.c.hostid == host_id ) - links = asyncio.run(self.conn.execute(admin_relations_select)).all() + links = self.conn.execute(admin_relations_select).all() if not links: add_links.append(link) admin_relations_insert = Insert(self.AdminRelationsTable) - asyncio.run(self.conn.execute( + self.conn.execute( admin_relations_insert, add_links - )) + ) def get_admin_relations(self, user_id=None, host_id=None): if user_id: @@ -291,7 +284,7 @@ class database: else: q = select(self.AdminRelationsTable) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def remove_admin_relation(self, user_ids=None, host_ids=None): @@ -306,7 +299,7 @@ class database: q = q.filter( self.AdminRelationsTable.c.hostid == host_id ) - asyncio.run(self.conn.execute(q)) + self.conn.execute(q) def is_credential_valid(self, credential_id): """ @@ -316,7 +309,7 @@ class database: self.UsersTable.c.id == credential_id, self.UsersTable.c.password is not None ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 def get_credentials(self, filter_term=None, cred_type=None): @@ -342,20 +335,20 @@ class database: else: q = select(self.UsersTable) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def is_credential_local(self, credential_id): q = select(self.UsersTable.c.domain).filter( self.UsersTable.c.id == credential_id ) - user_domain = asyncio.run(self.conn.execute(q)).all() + user_domain = self.conn.execute(q).all() if user_domain: q = select(self.HostsTable).filter( func.lower(self.HostsTable.c.id) == func.lower(user_domain) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 @@ -366,7 +359,7 @@ class database: q = select(self.HostsTable).filter( self.HostsTable.c.id == host_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 def get_hosts(self, filter_term=None): @@ -380,7 +373,7 @@ class database: q = q.filter( self.HostsTable.c.id == filter_term ) - results = asyncio.run(self.conn.execute(q)).first() + results = self.conn.execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] # if we're filtering by domain controllers @@ -397,7 +390,7 @@ class database: self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() logging.debug(f"winrm get_hosts() - results: {results}") return results @@ -408,7 +401,7 @@ class database: q = select(self.UsersTable).filter( self.UsersTable.c.id == user_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return len(results) > 0 def get_users(self, filter_term=None): @@ -424,7 +417,7 @@ class database: q = q.filter( func.lower(self.UsersTable.c.username).like(like_term) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def get_user(self, domain, username): @@ -432,7 +425,7 @@ class database: func.lower(self.UsersTable.c.domain) == func.lower(domain), func.lower(self.UsersTable.c.username) == func.lower(username) ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def add_loggedin_relation(self, user_id, host_id): @@ -440,7 +433,7 @@ class database: self.LoggedinRelationsTable.c.userid == user_id, self.LoggedinRelationsTable.c.hostid == host_id ) - results = asyncio.run(self.conn.execute(relation_query)).all() + results = self.conn.execute(relation_query).all() # only add one if one doesn't already exist if not results: @@ -451,11 +444,10 @@ class database: try: # TODO: find a way to abstract this away to a single Upsert call q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id) - asyncio.run( - self.conn.execute( - q, - [relation] - ) + + self.conn.execute( + q, + [relation] ) # .scalar() # return inserted_ids except Exception as e: @@ -471,7 +463,7 @@ class database: q = q.filter( self.LoggedinRelationsTable.c.hostid == host_id ) - results = asyncio.run(self.conn.execute(q)).all() + results = self.conn.execute(q).all() return results def remove_loggedin_relations(self, user_id=None, host_id=None): @@ -484,4 +476,4 @@ class database: q = q.filter( self.LoggedinRelationsTable.c.hostid == host_id ) - asyncio.run(self.conn.execute(q)) \ No newline at end of file + self.conn.execute(q) \ No newline at end of file