refactor(async): update how tasks are created to new threads using proper ThreadPool; update functionality everywhere to match

main
Marshall Hallenbeck 2023-03-26 01:52:37 -04:00
parent 4c76a30a4a
commit bfcc689acc
12 changed files with 286 additions and 373 deletions

View File

@ -39,6 +39,7 @@ def gen_cli_args():
parser.add_argument("-t", type=int, dest="threads", default=100, help="set how many concurrent threads to use (default: 100)")
parser.add_argument("--timeout", default=None, type=int, help='max timeout in seconds of each thread (default: None)')
parser.add_argument("--jitter", metavar='INTERVAL', type=str, help='sets a random delay between each connection (default: None)')
parser.add_argument("--progress", default=True, action='store_false', help='display progress bar during scan')
parser.add_argument("--darrell", action='store_true', help='give Darrell a hand')
parser.add_argument("--verbose", action='store_true', help="enable verbose output")
parser.add_argument("--version", action='store_true', help="Display CME version")

View File

@ -8,12 +8,12 @@ import sqlite3
import sys
import os
import requests
from sqlalchemy import create_engine
from terminaltables import AsciiTable
import configparser
from cme.loaders.protocol_loader import protocol_loader
from cme.paths import CONFIG_PATH, WS_PATH, WORKSPACE_DIR
from requests import ConnectionError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.exc import SAWarning
import asyncio
import csv
@ -33,13 +33,11 @@ class UserExitedProto(Exception):
def create_db_engine(db_path):
db_engine = create_async_engine(
f"sqlite+aiosqlite:///{db_path}",
db_engine = create_engine(
f"sqlite:///{db_path}",
isolation_level="AUTOCOMMIT",
future=True
) # can add echo=True
# db_engine.execution_options(isolation_level="AUTOCOMMIT")
# db_engine.connect().connection.text_factory = str
)
return db_engine
@ -51,6 +49,7 @@ def print_table(data, title=None):
print(table.table)
print("")
def write_csv(filename, headers, entries):
"""
Writes a CSV file with the provided parameters.
@ -106,7 +105,7 @@ class DatabaseNavigator(cmd.Cmd):
self.prompt = 'cmedb ({})({}) > '.format(main_menu.workspace, proto)
def do_exit(self, line):
asyncio.run(self.db.shutdown_db())
self.db.shutdown_db()
sys.exit()
def help_exit(self):

View File

@ -2,11 +2,14 @@
# -*- coding: utf-8 -*-
import logging
import random
import socket
from os.path import isfile
from threading import BoundedSemaphore
from socket import gethostbyname
from functools import wraps
from time import sleep
from cme.logger import CMEAdapter
from cme.context import Context
from cme.helpers.logger import write_log
@ -69,6 +72,11 @@ class connection(object):
logging.debug('Error resolving hostname {}: {}'.format(self.hostname, e))
return
if args.jitter:
value = random.choice(range(args.jitter[0], args.jitter[1]))
logging.debug(f"Doin' the jitterbug for {value} second(s)")
sleep(value)
self.proto_flow()
@staticmethod

View File

@ -1,5 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import concurrent.futures
import sqlalchemy
from cme.logger import setup_logger, setup_debug_logger, CMEAdapter
from cme.helpers.logger import highlight
@ -18,8 +21,6 @@ from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
from decimal import Decimal
import asyncio
import aioconsole
import functools
import configparser
import cme.helpers.powershell as powershell
import cme
@ -31,9 +32,9 @@ import os
import sys
import logging
from sqlalchemy.orm import declarative_base
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.exc import SAWarning
import warnings
from tqdm import tqdm
Base = declarative_base()
@ -50,95 +51,20 @@ warnings.filterwarnings("ignore", category=SAWarning)
def create_db_engine(db_path):
db_engine = create_async_engine(
f"sqlite+aiosqlite:///{db_path}",
db_engine = sqlalchemy.create_engine(
f"sqlite:///{db_path}",
isolation_level="AUTOCOMMIT",
future=True
) # can add echo=True
# db_engine.execution_options(isolation_level="AUTOCOMMIT")
# db_engine.connect().connection.text_factory = str
)
return db_engine
async def monitor_threadpool(pool, targets):
logging.debug('Started thread poller')
while True:
try:
text = await aioconsole.ainput("")
if text == "":
pool_size = pool._work_queue.qsize()
finished_threads = len(targets) - pool_size
percentage = Decimal(finished_threads) / Decimal(len(targets)) * Decimal(100)
logger.info(f"completed: {percentage:.2f}% ({finished_threads}/{len(targets)})")
except asyncio.CancelledError:
logging.debug("Stopped thread poller")
break
async def run_protocol(loop, protocol_obj, args, db, target, jitter):
try:
if jitter:
value = random.choice(range(jitter[0], jitter[1]))
logging.debug(f"Doin' the jitterbug for {value} second(s)")
await asyncio.sleep(value)
thread = loop.run_in_executor(
None,
functools.partial(
protocol_obj,
args,
db,
str(target)
)
)
await asyncio.wait_for(
thread,
timeout=args.timeout
)
except asyncio.TimeoutError:
logging.debug("Thread exceeded timeout")
except asyncio.CancelledError:
logging.debug("Shutting down DB")
thread.cancel()
except sqlite3.OperationalError as e:
logging.debug("Sqlite error - sqlite3.operationalError - {}".format(str(e)))
async def start_threadpool(protocol_obj, args, db, targets, jitter):
pool = ThreadPoolExecutor(max_workers=args.threads + 1)
loop = asyncio.get_running_loop()
loop.set_default_executor(pool)
monitor_task = asyncio.create_task(
monitor_threadpool(pool, targets)
)
jobs = [
run_protocol(
loop,
protocol_obj,
args,
db,
target,
jitter
)
for target in targets
]
try:
logging.debug("Running")
await asyncio.gather(*jobs)
except asyncio.CancelledError:
print('\n')
logger.info("Shutting down, please wait...")
logging.debug("Cancelling scan")
finally:
await asyncio.shield(db.shutdown_db())
monitor_task.cancel()
pool.shutdown(wait=True)
async def start_scan(protocol_obj, args, db, targets):
with tqdm(total=len(targets), disable=args.progress) as pbar:
with ThreadPoolExecutor(max_workers=args.threads + 1) as executor:
futures = [executor.submit(protocol_obj, args, db, target) for target in targets]
for future in concurrent.futures.as_completed(futures):
pbar.update(1)
def main():
@ -164,7 +90,6 @@ def main():
module = None
module_server = None
targets = []
jitter = None
server_port_dict = {'http': 80, 'https': 443, 'smb': 445}
current_workspace = config.get('CME', 'workspace')
if config.get('CME', 'log_mode') != "False":
@ -178,9 +103,9 @@ def main():
if args.jitter:
if '-' in args.jitter:
start, end = args.jitter.split('-')
jitter = (int(start), int(end))
args.jitter = (int(start), int(end))
else:
jitter = (0, int(args.jitter))
args.jitter = (0, int(args.jitter))
if hasattr(args, 'cred_id') and args.cred_id:
for cred_id in args.cred_id:
@ -298,14 +223,14 @@ def main():
try:
asyncio.run(
start_threadpool(protocol_object, args, db, targets, jitter)
start_scan(protocol_object, args, db, targets)
)
except KeyboardInterrupt:
logging.debug("Got keyboard interrupt")
finally:
if module_server:
module_server.shutdown()
asyncio.run(db_engine.dispose())
db_engine.dispose()
if __name__ == '__main__':

View File

@ -16,8 +16,11 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
@ -38,20 +41,9 @@ class database:
"server_banner" text
)''')
async def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -63,7 +55,16 @@ class database:
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -15,8 +15,11 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
@ -37,20 +40,9 @@ class database:
"port" integer
)''')
async def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -62,6 +54,15 @@ class database:
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -5,9 +5,7 @@ from sqlalchemy import MetaData, func, Table, select, insert, update, delete
from sqlalchemy.dialects.sqlite import Insert # used for upsert
from sqlalchemy.exc import IllegalStateChangeError, NoInspectionAvailable
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SAWarning
import asyncio
import warnings
# if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings
@ -22,39 +20,16 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
self.conn = Session()
async def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
print(
"[-] Error reflecting tables - this means there is a DB schema mismatch \n"
"[-] This is probably because a newer version of CME is being ran on an old DB schema\n"
"[-] If you wish to save the old DB data, copy it to a new location (`cp -r ~/.cme/workspaces/ ~/old_cme_workspaces/`)\n"
"[-] Then remove the CME DB folders (`rm -rf ~/.cme/workspaces/`) and rerun CME to initialize the new DB schema"
)
exit()
@staticmethod
def db_schema(db_conn):
db_conn.execute('''CREATE TABLE "hosts" (
@ -84,6 +59,34 @@ class database:
FOREIGN KEY(pillaged_from_hostid) REFERENCES hosts(id)
)''')
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
print(
"[-] Error reflecting tables - this means there is a DB schema mismatch \n"
"[-] This is probably because a newer version of CME is being ran on an old DB schema\n"
"[-] If you wish to save the old DB data, copy it to a new location (`cp -r ~/.cme/workspaces/ ~/old_cme_workspaces/`)\n"
"[-] Then remove the CME DB folders (`rm -rf ~/.cme/workspaces/`) and rerun CME to initialize the new DB schema"
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
self.conn.execute(table.delete())
def add_host(self, ip, hostname, domain, os, instances):
"""
Check if this host has already been added to the database, if not, add it in.
@ -95,7 +98,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.ip == ip
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"mssql add_host() - hosts returned: {results}")
host_data = {
@ -133,11 +136,9 @@ class database:
index_elements=self.HostsTable.primary_key,
set_=update_columns
)
asyncio.run(
self.conn.execute(
q,
hosts
)
self.conn.execute(
q,
hosts
)
def add_credential(self, credtype, domain, username, password, pillaged_from=None):
@ -164,7 +165,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
if not results:
user_data = {
@ -175,13 +176,13 @@ class database:
"pillaged_from_hostid": pillaged_from,
}
q = insert(self.UsersTable).values(user_data) # .returning(self.UsersTable.c.id)
asyncio.run(self.conn.execute(q)) # .first()
self.conn.execute(q) # .first()
else:
for user in results:
# might be able to just remove this if check, but leaving it in for now
if not user[3] and not user[4] and not user[5]:
q = update(self.UsersTable).values(credential_data) # .returning(self.UsersTable.c.id)
results = asyncio.run(self.conn.execute(q)) # .first()
results = self.conn.execute(q) # .first()
# user_rowid = results.id
logging.debug('add_credential(credtype={}, domain={}, username={}, password={}, pillaged_from={})'.format(
@ -203,7 +204,7 @@ class database:
self.UsersTable.c.id == cred_id
)
del_hosts.append(q)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
domain = domain.split('.')[0].upper()
@ -212,7 +213,7 @@ class database:
q = select(self.UsersTable).filter(
self.UsersTable.c.id == user_id
)
users = asyncio.run(self.conn.execute(q)).all()
users = self.conn.execute(q).all()
else:
q = select(self.UsersTable).filter(
self.UsersTable.c.credtype == credtype,
@ -220,31 +221,35 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
self.UsersTable.c.password == password
)
users = asyncio.run(self.conn.execute(q)).all()
users = self.conn.execute(q).all()
logging.debug(f"Users: {users}")
like_term = func.lower(f"%{host}%")
q = q.filter(
self.HostsTable.c.ip.like(like_term)
)
hosts = asyncio.run(self.conn.execute(q)).all()
hosts = self.conn.execute(q).all()
logging.debug(f"Hosts: {hosts}")
if users is not None and hosts is not None:
for user, host in zip(users, hosts):
user_id = user[0]
host_id = host[0]
link = {
"userid": user_id,
"hostid": host_id
}
q = select(self.AdminRelationsTable).filter(
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id
)
links = asyncio.run(self.conn.execute(q)).all()
links = self.conn.execute(q).all()
if not links:
asyncio.run(self.conn.execute(
self.conn.execute(
insert(self.AdminRelationsTable).values(link)
))
)
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@ -258,7 +263,7 @@ class database:
else:
q = select(self.AdminRelationsTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_admin_relation(self, user_ids=None, host_ids=None):
@ -273,7 +278,7 @@ class database:
q = q.filter(
self.AdminRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def is_credential_valid(self, credential_id):
"""
@ -283,7 +288,7 @@ class database:
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@ -309,7 +314,7 @@ class database:
else:
q = select(self.UsersTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def is_host_valid(self, host_id):
@ -319,7 +324,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.id == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None, domain=None):
@ -333,7 +338,7 @@ class database:
q = q.filter(
self.HostsTable.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@ -353,9 +358,6 @@ class database:
func.lower(self.HostsTable.c.hostname).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))

View File

@ -15,8 +15,11 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
@ -39,20 +42,9 @@ class database:
"server_banner" text
)''')
async def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -64,6 +56,15 @@ class database:
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -30,13 +30,10 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
# we don't use async_sessionmaker or async_scoped_session because when `database` is initialized,
# there is no running async loop
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True,
class_=AsyncSession
expire_on_commit=True
)
Session = scoped_session(session_factory)
@ -134,11 +131,9 @@ class database:
# FOREIGN KEY(hostid) REFERENCES hosts(id)
# )''')
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
self.GroupsTable = Table("groups", self.metadata, autoload_with=self.db_engine)
@ -157,9 +152,9 @@ class database:
)
exit()
async def shutdown_db(self):
def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
@ -168,7 +163,7 @@ class database:
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())
# pull/545
def add_host(self, ip, hostname, domain, os, smbv1, signing, spooler=None, zerologon=None, petitpotam=None,
@ -183,7 +178,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.ip == ip
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
# create new host
if not results:
@ -238,12 +233,11 @@ class database:
index_elements=self.HostsTable.primary_key,
set_=update_columns
)
asyncio.run(
self.conn.execute(
q,
hosts
)
) # .scalar()
self.conn.execute(
q,
hosts
) # .scalar()
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
if updated_ids:
logging.debug(f"add_host() - Host IDs Updated: {updated_ids}")
@ -267,7 +261,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
# add new credential
if not results:
@ -314,20 +308,18 @@ class database:
set_=update_columns_users
)
logging.debug(f"Adding credentials: {credentials}")
asyncio.run(
self.conn.execute(
q_users,
credentials
)
) # .scalar()
self.conn.execute(
q_users,
credentials
) # .scalar()
if groups:
q_groups = Insert(self.GroupRelationsTable)
asyncio.run(
self.conn.execute(
q_groups,
groups
)
self.conn.execute(
q_groups,
groups
)
# return user_ids
@ -341,7 +333,7 @@ class database:
self.UsersTable.c.id == cred_id
)
del_hosts.append(q)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
domain = domain.split('.')[0]
@ -359,7 +351,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
self.UsersTable.c.password == password
)
users = asyncio.run(self.conn.execute(creds_q))
users = self.conn.execute(creds_q)
hosts = self.get_hosts(host)
if users and hosts:
@ -374,7 +366,7 @@ class database:
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id
)
links = asyncio.run(self.conn.execute(admin_relations_select)).all()
links = self.conn.execute(admin_relations_select).all()
if not links:
add_links.append(link)
@ -382,10 +374,10 @@ class database:
admin_relations_insert = Insert(self.AdminRelationsTable)
if add_links:
asyncio.run(self.conn.execute(
self.conn.execute(
admin_relations_insert,
add_links
))
)
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@ -399,7 +391,7 @@ class database:
else:
q = select(self.AdminRelationsTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_admin_relation(self, user_ids=None, host_ids=None):
@ -414,7 +406,7 @@ class database:
q = q.filter(
self.AdminRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def is_credential_valid(self, credential_id):
"""
@ -424,7 +416,7 @@ class database:
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@ -450,7 +442,7 @@ class database:
else:
q = select(self.UsersTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_credential(self, cred_type, domain, username, password):
@ -462,20 +454,20 @@ class database:
self.UsersTable.c.password == password,
self.UsersTable.c.credtype == cred_type
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
return results.id
def is_credential_local(self, credential_id):
q = select(self.UsersTable.c.domain).filter(
self.UsersTable.c.id == credential_id
)
user_domain = asyncio.run(self.conn.execute(q)).all()
user_domain = self.conn.execute(q).all()
if user_domain:
q = select(self.HostsTable).filter(
func.lower(self.HostsTable.c.id) == func.lower(user_domain)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
@ -486,7 +478,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.id == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None, domain=None):
@ -500,7 +492,7 @@ class database:
q = q.filter(
self.HostsTable.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@ -542,7 +534,7 @@ class database:
self.HostsTable.c.ip.like(like_term) |
func.lower(self.HostsTable.c.hostname).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"smb hosts() - results: {results}")
return results
@ -553,7 +545,7 @@ class database:
q = select(self.GroupsTable).filter(
self.GroupsTable.c.id == group_id
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
valid = True if results else False
logging.debug(f"is_group_valid(groupID={group_id}) => {valid}")
@ -585,11 +577,10 @@ class database:
# insert the group and get the returned id right away, this can be refactored when we can use RETURNING
q = Insert(self.GroupsTable)
asyncio.run(
self.conn.execute(
q,
groups
)
self.conn.execute(
q,
groups
)
new_group_data = self.get_groups(group_name=group_data["name"], group_domain=group_data["domain"])
returned_id = [new_group_data[0].id]
@ -623,11 +614,10 @@ class database:
index_elements=self.GroupsTable.primary_key,
set_=update_columns
)
asyncio.run(
self.conn.execute(
q,
groups
)
self.conn.execute(
q,
groups
)
# TODO: always return a list and fix code references to not expect a single integer
# inserted_result = res_inserted_result.first()
@ -650,7 +640,7 @@ class database:
q = select(self.GroupsTable).filter(
self.GroupsTable.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif group_name and group_domain:
@ -666,7 +656,7 @@ class database:
else:
q = select(self.GroupsTable).filter()
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(
f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}")
@ -687,7 +677,7 @@ class database:
self.GroupRelationsTable.c.groupid == group_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_group_relations(self, user_id=None, group_id=None):
@ -700,7 +690,7 @@ class database:
q = q.filter(
self.GroupRelationsTable.c.groupid == group_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def is_user_valid(self, user_id):
"""
@ -709,7 +699,7 @@ class database:
q = select(self.UsersTable).filter(
self.UsersTable.c.id == user_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_users(self, filter_term=None):
@ -725,7 +715,7 @@ class database:
q = q.filter(
func.lower(self.UsersTable.c.username).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_user(self, domain, username):
@ -733,7 +723,7 @@ class database:
func.lower(self.UsersTable.c.domain) == func.lower(domain),
func.lower(self.UsersTable.c.username) == func.lower(username)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_domain_controllers(self, domain=None):
@ -746,7 +736,7 @@ class database:
q = select(self.SharesTable).filter(
self.SharesTable.c.id == share_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"is_share_valid(shareID={share_id}) => {len(results) > 0}")
return len(results) > 0
@ -760,10 +750,10 @@ class database:
"read": read,
"write": write,
}
share_id = asyncio.run(self.conn.execute(
share_id = self.conn.execute(
Insert(self.SharesTable).on_conflict_do_nothing(), # .returning(self.SharesTable.c.id),
share_data
)) # .scalar_one()
) # .scalar_one()
# return share_id
def get_shares(self, filter_term=None):
@ -778,7 +768,7 @@ class database:
)
else:
q = select(self.SharesTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_shares_by_access(self, permissions, share_id=None):
@ -790,7 +780,7 @@ class database:
q = q.filter(self.SharesTable.c.read == 1)
if "w" in permissions:
q = q.filter(self.SharesTable.c.write == 1)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_users_with_share_access(self, host_id, share_name, permissions):
@ -803,7 +793,7 @@ class database:
q = q.filter(self.SharesTable.c.read == 1)
if "w" in permissions:
q = q.filter(self.SharesTable.c.write == 1)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
@ -816,7 +806,7 @@ class database:
q = select(self.DpapiBackupkey).filter(
func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
if not len(results):
pvk_encoded = base64.b64encode(pvk)
@ -827,11 +817,10 @@ class database:
try:
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.DpapiBackupkey) # .returning(self.DpapiBackupkey.c.id)
asyncio.run(
self.conn.execute(
q,
[backup_key]
)
self.conn.execute(
q,
[backup_key]
) # .scalar()
logging.debug(f"add_domain_backupkey(domain={domain}, pvk={pvk_encoded})")
# return inserted_id
@ -848,7 +837,7 @@ class database:
q = q.filter(
func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"get_domain_backupkey(domain={domain}) => {results}")
@ -864,7 +853,7 @@ class database:
q = select(self.DpapiSecrets).filter(
func.lower(self.DpapiSecrets.c.id) == dpapi_secret_id
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
valid = True if results is not None else False
logging.debug(f"is_dpapi_secret_valid(groupID={dpapi_secret_id}) => {valid}")
return valid
@ -883,11 +872,10 @@ class database:
"url": url
}
q = Insert(self.DpapiSecrets).on_conflict_do_nothing() # .returning(self.DpapiSecrets.c.id)
asyncio.run(
self.conn.execute(
q,
[secret]
)
self.conn.execute(
q,
[secret]
) # .scalar()
# inserted_result = res_inserted_result.first()
@ -907,14 +895,14 @@ class database:
q = q.filter(
self.DpapiSecrets.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif host:
q = q.filter(
self.DpapiSecrets.c.host == host
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif dpapi_type:
@ -935,7 +923,7 @@ class database:
q = q.filter(
func.lower(self.DpapiSecrets.c.url) == func.lower(url)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(
f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}")
@ -946,7 +934,7 @@ class database:
self.LoggedinRelationsTable.c.userid == user_id,
self.LoggedinRelationsTable.c.hostid == host_id
)
results = asyncio.run(self.conn.execute(relation_query)).all()
results = self.conn.execute(relation_query).all()
# only add one if one doesn't already exist
if not results:
@ -958,11 +946,10 @@ class database:
logging.debug(f"Inserting loggedin_relations: {relation}")
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
asyncio.run(
self.conn.execute(
q,
[relation]
)
self.conn.execute(
q,
[relation]
) # .scalar()
inserted_id_results = self.get_loggedin_relations(user_id, host_id)
logging.debug(f"Checking if relation was added: {inserted_id_results}")
@ -980,7 +967,7 @@ class database:
q = q.filter(
self.LoggedinRelationsTable.c.hostid == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_loggedin_relations(self, user_id=None, host_id=None):
@ -993,4 +980,4 @@ class database:
q = q.filter(
self.LoggedinRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)

View File

@ -15,8 +15,11 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
@ -37,20 +40,9 @@ class database:
"server_banner" text
)''')
async def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -62,6 +54,15 @@ class database:
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -22,13 +22,10 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
# we don't use async_sessionmaker or async_scoped_session because when `database` is initialized,
# there is no running async loop
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True,
class_=AsyncSession
expire_on_commit=True
)
Session = scoped_session(session_factory)
@ -52,11 +49,9 @@ class database:
"server_banner" text
)''')
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -68,9 +63,9 @@ class database:
)
exit()
async def shutdown_db(self):
def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
@ -79,4 +74,4 @@ class database:
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -19,11 +19,10 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True,
class_=AsyncSession
expire_on_commit=True
)
Session = scoped_session(session_factory)
@ -64,11 +63,9 @@ class database:
FOREIGN KEY(hostid) REFERENCES hosts(id)
)''')
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
@ -82,9 +79,9 @@ class database:
)
exit()
async def shutdown_db(self):
def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
@ -93,7 +90,7 @@ class database:
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())
def add_host(self, ip, port, hostname, domain, os=None):
"""
@ -106,7 +103,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.ip == ip
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"smb add_host() - hosts returned: {results}")
# create new host
@ -146,11 +143,9 @@ class database:
index_elements=self.HostsTable.primary_key,
set_=update_columns
)
asyncio.run(
self.conn.execute(
q,
hosts
)
self.conn.execute(
q,
hosts
)
def add_credential(self, credtype, domain, username, password, pillaged_from=None):
@ -177,7 +172,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
# add new credential
if not results:
@ -216,12 +211,10 @@ class database:
index_elements=self.UsersTable.primary_key,
set_=update_columns_users
)
asyncio.run(
self.conn.execute(
q_users,
credentials
)
) # .scalar()
self.conn.execute(
q_users,
credentials
) # .scalar()
# return user_ids
def remove_credentials(self, creds_id):
@ -234,7 +227,7 @@ class database:
self.UsersTable.c.id == cred_id
)
del_hosts.append(q)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
domain = domain.split('.')[0]
@ -252,7 +245,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
self.UsersTable.c.password == password
)
users = asyncio.run(self.conn.execute(creds_q))
users = self.conn.execute(creds_q)
hosts = self.get_hosts(host)
if users and hosts:
@ -267,17 +260,17 @@ class database:
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id
)
links = asyncio.run(self.conn.execute(admin_relations_select)).all()
links = self.conn.execute(admin_relations_select).all()
if not links:
add_links.append(link)
admin_relations_insert = Insert(self.AdminRelationsTable)
asyncio.run(self.conn.execute(
self.conn.execute(
admin_relations_insert,
add_links
))
)
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@ -291,7 +284,7 @@ class database:
else:
q = select(self.AdminRelationsTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_admin_relation(self, user_ids=None, host_ids=None):
@ -306,7 +299,7 @@ class database:
q = q.filter(
self.AdminRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def is_credential_valid(self, credential_id):
"""
@ -316,7 +309,7 @@ class database:
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@ -342,20 +335,20 @@ class database:
else:
q = select(self.UsersTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def is_credential_local(self, credential_id):
q = select(self.UsersTable.c.domain).filter(
self.UsersTable.c.id == credential_id
)
user_domain = asyncio.run(self.conn.execute(q)).all()
user_domain = self.conn.execute(q).all()
if user_domain:
q = select(self.HostsTable).filter(
func.lower(self.HostsTable.c.id) == func.lower(user_domain)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
@ -366,7 +359,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.id == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None):
@ -380,7 +373,7 @@ class database:
q = q.filter(
self.HostsTable.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@ -397,7 +390,7 @@ class database:
self.HostsTable.c.ip.like(like_term) |
func.lower(self.HostsTable.c.hostname).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"winrm get_hosts() - results: {results}")
return results
@ -408,7 +401,7 @@ class database:
q = select(self.UsersTable).filter(
self.UsersTable.c.id == user_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_users(self, filter_term=None):
@ -424,7 +417,7 @@ class database:
q = q.filter(
func.lower(self.UsersTable.c.username).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_user(self, domain, username):
@ -432,7 +425,7 @@ class database:
func.lower(self.UsersTable.c.domain) == func.lower(domain),
func.lower(self.UsersTable.c.username) == func.lower(username)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def add_loggedin_relation(self, user_id, host_id):
@ -440,7 +433,7 @@ class database:
self.LoggedinRelationsTable.c.userid == user_id,
self.LoggedinRelationsTable.c.hostid == host_id
)
results = asyncio.run(self.conn.execute(relation_query)).all()
results = self.conn.execute(relation_query).all()
# only add one if one doesn't already exist
if not results:
@ -451,11 +444,10 @@ class database:
try:
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
asyncio.run(
self.conn.execute(
q,
[relation]
)
self.conn.execute(
q,
[relation]
) # .scalar()
# return inserted_ids
except Exception as e:
@ -471,7 +463,7 @@ class database:
q = q.filter(
self.LoggedinRelationsTable.c.hostid == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_loggedin_relations(self, user_id=None, host_id=None):
@ -484,4 +476,4 @@ class database:
q = q.filter(
self.LoggedinRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)