refactor(async): update how tasks are created to new threads using proper ThreadPool; update functionality everywhere to match
parent
4c76a30a4a
commit
bfcc689acc
|
@ -39,6 +39,7 @@ def gen_cli_args():
|
|||
parser.add_argument("-t", type=int, dest="threads", default=100, help="set how many concurrent threads to use (default: 100)")
|
||||
parser.add_argument("--timeout", default=None, type=int, help='max timeout in seconds of each thread (default: None)')
|
||||
parser.add_argument("--jitter", metavar='INTERVAL', type=str, help='sets a random delay between each connection (default: None)')
|
||||
parser.add_argument("--progress", default=True, action='store_false', help='display progress bar during scan')
|
||||
parser.add_argument("--darrell", action='store_true', help='give Darrell a hand')
|
||||
parser.add_argument("--verbose", action='store_true', help="enable verbose output")
|
||||
parser.add_argument("--version", action='store_true', help="Display CME version")
|
||||
|
|
13
cme/cmedb.py
13
cme/cmedb.py
|
@ -8,12 +8,12 @@ import sqlite3
|
|||
import sys
|
||||
import os
|
||||
import requests
|
||||
from sqlalchemy import create_engine
|
||||
from terminaltables import AsciiTable
|
||||
import configparser
|
||||
from cme.loaders.protocol_loader import protocol_loader
|
||||
from cme.paths import CONFIG_PATH, WS_PATH, WORKSPACE_DIR
|
||||
from requests import ConnectionError
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.exc import SAWarning
|
||||
import asyncio
|
||||
import csv
|
||||
|
@ -33,13 +33,11 @@ class UserExitedProto(Exception):
|
|||
|
||||
|
||||
def create_db_engine(db_path):
|
||||
db_engine = create_async_engine(
|
||||
f"sqlite+aiosqlite:///{db_path}",
|
||||
db_engine = create_engine(
|
||||
f"sqlite:///{db_path}",
|
||||
isolation_level="AUTOCOMMIT",
|
||||
future=True
|
||||
) # can add echo=True
|
||||
# db_engine.execution_options(isolation_level="AUTOCOMMIT")
|
||||
# db_engine.connect().connection.text_factory = str
|
||||
)
|
||||
return db_engine
|
||||
|
||||
|
||||
|
@ -51,6 +49,7 @@ def print_table(data, title=None):
|
|||
print(table.table)
|
||||
print("")
|
||||
|
||||
|
||||
def write_csv(filename, headers, entries):
|
||||
"""
|
||||
Writes a CSV file with the provided parameters.
|
||||
|
@ -106,7 +105,7 @@ class DatabaseNavigator(cmd.Cmd):
|
|||
self.prompt = 'cmedb ({})({}) > '.format(main_menu.workspace, proto)
|
||||
|
||||
def do_exit(self, line):
|
||||
asyncio.run(self.db.shutdown_db())
|
||||
self.db.shutdown_db()
|
||||
sys.exit()
|
||||
|
||||
def help_exit(self):
|
||||
|
|
|
@ -2,11 +2,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
from os.path import isfile
|
||||
from threading import BoundedSemaphore
|
||||
from socket import gethostbyname
|
||||
from functools import wraps
|
||||
from time import sleep
|
||||
|
||||
from cme.logger import CMEAdapter
|
||||
from cme.context import Context
|
||||
from cme.helpers.logger import write_log
|
||||
|
@ -69,6 +72,11 @@ class connection(object):
|
|||
logging.debug('Error resolving hostname {}: {}'.format(self.hostname, e))
|
||||
return
|
||||
|
||||
if args.jitter:
|
||||
value = random.choice(range(args.jitter[0], args.jitter[1]))
|
||||
logging.debug(f"Doin' the jitterbug for {value} second(s)")
|
||||
sleep(value)
|
||||
|
||||
self.proto_flow()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import concurrent.futures
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from cme.logger import setup_logger, setup_debug_logger, CMEAdapter
|
||||
from cme.helpers.logger import highlight
|
||||
|
@ -18,8 +21,6 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from pprint import pformat
|
||||
from decimal import Decimal
|
||||
import asyncio
|
||||
import aioconsole
|
||||
import functools
|
||||
import configparser
|
||||
import cme.helpers.powershell as powershell
|
||||
import cme
|
||||
|
@ -31,9 +32,9 @@ import os
|
|||
import sys
|
||||
import logging
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.exc import SAWarning
|
||||
import warnings
|
||||
from tqdm import tqdm
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
@ -50,95 +51,20 @@ warnings.filterwarnings("ignore", category=SAWarning)
|
|||
|
||||
|
||||
def create_db_engine(db_path):
|
||||
db_engine = create_async_engine(
|
||||
f"sqlite+aiosqlite:///{db_path}",
|
||||
db_engine = sqlalchemy.create_engine(
|
||||
f"sqlite:///{db_path}",
|
||||
isolation_level="AUTOCOMMIT",
|
||||
future=True
|
||||
) # can add echo=True
|
||||
# db_engine.execution_options(isolation_level="AUTOCOMMIT")
|
||||
# db_engine.connect().connection.text_factory = str
|
||||
)
|
||||
return db_engine
|
||||
|
||||
|
||||
async def monitor_threadpool(pool, targets):
|
||||
logging.debug('Started thread poller')
|
||||
|
||||
while True:
|
||||
try:
|
||||
text = await aioconsole.ainput("")
|
||||
if text == "":
|
||||
pool_size = pool._work_queue.qsize()
|
||||
finished_threads = len(targets) - pool_size
|
||||
percentage = Decimal(finished_threads) / Decimal(len(targets)) * Decimal(100)
|
||||
logger.info(f"completed: {percentage:.2f}% ({finished_threads}/{len(targets)})")
|
||||
except asyncio.CancelledError:
|
||||
logging.debug("Stopped thread poller")
|
||||
break
|
||||
|
||||
|
||||
async def run_protocol(loop, protocol_obj, args, db, target, jitter):
|
||||
try:
|
||||
if jitter:
|
||||
value = random.choice(range(jitter[0], jitter[1]))
|
||||
logging.debug(f"Doin' the jitterbug for {value} second(s)")
|
||||
await asyncio.sleep(value)
|
||||
|
||||
thread = loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
protocol_obj,
|
||||
args,
|
||||
db,
|
||||
str(target)
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.wait_for(
|
||||
thread,
|
||||
timeout=args.timeout
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logging.debug("Thread exceeded timeout")
|
||||
except asyncio.CancelledError:
|
||||
logging.debug("Shutting down DB")
|
||||
thread.cancel()
|
||||
except sqlite3.OperationalError as e:
|
||||
logging.debug("Sqlite error - sqlite3.operationalError - {}".format(str(e)))
|
||||
|
||||
|
||||
async def start_threadpool(protocol_obj, args, db, targets, jitter):
|
||||
pool = ThreadPoolExecutor(max_workers=args.threads + 1)
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.set_default_executor(pool)
|
||||
|
||||
monitor_task = asyncio.create_task(
|
||||
monitor_threadpool(pool, targets)
|
||||
)
|
||||
|
||||
jobs = [
|
||||
run_protocol(
|
||||
loop,
|
||||
protocol_obj,
|
||||
args,
|
||||
db,
|
||||
target,
|
||||
jitter
|
||||
)
|
||||
for target in targets
|
||||
]
|
||||
|
||||
try:
|
||||
logging.debug("Running")
|
||||
await asyncio.gather(*jobs)
|
||||
except asyncio.CancelledError:
|
||||
print('\n')
|
||||
logger.info("Shutting down, please wait...")
|
||||
logging.debug("Cancelling scan")
|
||||
finally:
|
||||
await asyncio.shield(db.shutdown_db())
|
||||
monitor_task.cancel()
|
||||
pool.shutdown(wait=True)
|
||||
async def start_scan(protocol_obj, args, db, targets):
|
||||
with tqdm(total=len(targets), disable=args.progress) as pbar:
|
||||
with ThreadPoolExecutor(max_workers=args.threads + 1) as executor:
|
||||
futures = [executor.submit(protocol_obj, args, db, target) for target in targets]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -164,7 +90,6 @@ def main():
|
|||
module = None
|
||||
module_server = None
|
||||
targets = []
|
||||
jitter = None
|
||||
server_port_dict = {'http': 80, 'https': 443, 'smb': 445}
|
||||
current_workspace = config.get('CME', 'workspace')
|
||||
if config.get('CME', 'log_mode') != "False":
|
||||
|
@ -178,9 +103,9 @@ def main():
|
|||
if args.jitter:
|
||||
if '-' in args.jitter:
|
||||
start, end = args.jitter.split('-')
|
||||
jitter = (int(start), int(end))
|
||||
args.jitter = (int(start), int(end))
|
||||
else:
|
||||
jitter = (0, int(args.jitter))
|
||||
args.jitter = (0, int(args.jitter))
|
||||
|
||||
if hasattr(args, 'cred_id') and args.cred_id:
|
||||
for cred_id in args.cred_id:
|
||||
|
@ -298,14 +223,14 @@ def main():
|
|||
|
||||
try:
|
||||
asyncio.run(
|
||||
start_threadpool(protocol_object, args, db, targets, jitter)
|
||||
start_scan(protocol_object, args, db, targets)
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logging.debug("Got keyboard interrupt")
|
||||
finally:
|
||||
if module_server:
|
||||
module_server.shutdown()
|
||||
asyncio.run(db_engine.dispose())
|
||||
db_engine.dispose()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -16,8 +16,11 @@ class database:
|
|||
|
||||
self.db_engine = db_engine
|
||||
self.metadata = MetaData()
|
||||
asyncio.run(self.reflect_tables())
|
||||
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
|
||||
self.reflect_tables()
|
||||
session_factory = sessionmaker(
|
||||
bind=self.db_engine,
|
||||
expire_on_commit=True
|
||||
)
|
||||
|
||||
Session = scoped_session(session_factory)
|
||||
# this is still named "conn" when it is the session object; TODO: rename
|
||||
|
@ -38,20 +41,9 @@ class database:
|
|||
"server_banner" text
|
||||
)''')
|
||||
|
||||
async def shutdown_db(self):
|
||||
try:
|
||||
await asyncio.shield(self.conn.close())
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
async def reflect_tables(self):
|
||||
async with self.db_engine.connect() as conn:
|
||||
def reflect_tables(self):
|
||||
with self.db_engine.connect() as conn:
|
||||
try:
|
||||
await conn.run_sync(self.metadata.reflect)
|
||||
|
||||
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
except NoInspectionAvailable:
|
||||
|
@ -63,7 +55,16 @@ class database:
|
|||
)
|
||||
exit()
|
||||
|
||||
def shutdown_db(self):
|
||||
try:
|
||||
self.conn.close()
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
asyncio.run(self.conn.execute(table.delete()))
|
||||
self.conn.execute(table.delete())
|
||||
|
||||
|
|
|
@ -15,8 +15,11 @@ class database:
|
|||
|
||||
self.db_engine = db_engine
|
||||
self.metadata = MetaData()
|
||||
asyncio.run(self.reflect_tables())
|
||||
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
|
||||
self.reflect_tables()
|
||||
session_factory = sessionmaker(
|
||||
bind=self.db_engine,
|
||||
expire_on_commit=True
|
||||
)
|
||||
|
||||
Session = scoped_session(session_factory)
|
||||
# this is still named "conn" when it is the session object; TODO: rename
|
||||
|
@ -37,20 +40,9 @@ class database:
|
|||
"port" integer
|
||||
)''')
|
||||
|
||||
async def shutdown_db(self):
|
||||
try:
|
||||
await asyncio.shield(self.conn.close())
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
async def reflect_tables(self):
|
||||
async with self.db_engine.connect() as conn:
|
||||
def reflect_tables(self):
|
||||
with self.db_engine.connect() as conn:
|
||||
try:
|
||||
await conn.run_sync(self.metadata.reflect)
|
||||
|
||||
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
except NoInspectionAvailable:
|
||||
|
@ -62,6 +54,15 @@ class database:
|
|||
)
|
||||
exit()
|
||||
|
||||
def shutdown_db(self):
|
||||
try:
|
||||
self.conn.close()
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
asyncio.run(self.conn.execute(table.delete()))
|
||||
self.conn.execute(table.delete())
|
||||
|
|
|
@ -5,9 +5,7 @@ from sqlalchemy import MetaData, func, Table, select, insert, update, delete
|
|||
from sqlalchemy.dialects.sqlite import Insert # used for upsert
|
||||
from sqlalchemy.exc import IllegalStateChangeError, NoInspectionAvailable
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SAWarning
|
||||
import asyncio
|
||||
import warnings
|
||||
|
||||
# if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings
|
||||
|
@ -22,39 +20,16 @@ class database:
|
|||
|
||||
self.db_engine = db_engine
|
||||
self.metadata = MetaData()
|
||||
asyncio.run(self.reflect_tables())
|
||||
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
|
||||
self.reflect_tables()
|
||||
session_factory = sessionmaker(
|
||||
bind=self.db_engine,
|
||||
expire_on_commit=True
|
||||
)
|
||||
|
||||
Session = scoped_session(session_factory)
|
||||
# this is still named "conn" when it is the session object; TODO: rename
|
||||
self.conn = Session()
|
||||
|
||||
async def shutdown_db(self):
|
||||
try:
|
||||
await asyncio.shield(self.conn.close())
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
async def reflect_tables(self):
|
||||
async with self.db_engine.connect() as conn:
|
||||
try:
|
||||
await conn.run_sync(self.metadata.reflect)
|
||||
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
|
||||
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
|
||||
except NoInspectionAvailable:
|
||||
print(
|
||||
"[-] Error reflecting tables - this means there is a DB schema mismatch \n"
|
||||
"[-] This is probably because a newer version of CME is being ran on an old DB schema\n"
|
||||
"[-] If you wish to save the old DB data, copy it to a new location (`cp -r ~/.cme/workspaces/ ~/old_cme_workspaces/`)\n"
|
||||
"[-] Then remove the CME DB folders (`rm -rf ~/.cme/workspaces/`) and rerun CME to initialize the new DB schema"
|
||||
)
|
||||
exit()
|
||||
|
||||
@staticmethod
|
||||
def db_schema(db_conn):
|
||||
db_conn.execute('''CREATE TABLE "hosts" (
|
||||
|
@ -84,6 +59,34 @@ class database:
|
|||
FOREIGN KEY(pillaged_from_hostid) REFERENCES hosts(id)
|
||||
)''')
|
||||
|
||||
def reflect_tables(self):
|
||||
with self.db_engine.connect() as conn:
|
||||
try:
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
|
||||
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
|
||||
except NoInspectionAvailable:
|
||||
print(
|
||||
"[-] Error reflecting tables - this means there is a DB schema mismatch \n"
|
||||
"[-] This is probably because a newer version of CME is being ran on an old DB schema\n"
|
||||
"[-] If you wish to save the old DB data, copy it to a new location (`cp -r ~/.cme/workspaces/ ~/old_cme_workspaces/`)\n"
|
||||
"[-] Then remove the CME DB folders (`rm -rf ~/.cme/workspaces/`) and rerun CME to initialize the new DB schema"
|
||||
)
|
||||
exit()
|
||||
|
||||
def shutdown_db(self):
|
||||
try:
|
||||
self.conn.close()
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
self.conn.execute(table.delete())
|
||||
|
||||
def add_host(self, ip, hostname, domain, os, instances):
|
||||
"""
|
||||
Check if this host has already been added to the database, if not, add it in.
|
||||
|
@ -95,7 +98,7 @@ class database:
|
|||
q = select(self.HostsTable).filter(
|
||||
self.HostsTable.c.ip == ip
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
logging.debug(f"mssql add_host() - hosts returned: {results}")
|
||||
|
||||
host_data = {
|
||||
|
@ -133,11 +136,9 @@ class database:
|
|||
index_elements=self.HostsTable.primary_key,
|
||||
set_=update_columns
|
||||
)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
hosts
|
||||
)
|
||||
self.conn.execute(
|
||||
q,
|
||||
hosts
|
||||
)
|
||||
|
||||
def add_credential(self, credtype, domain, username, password, pillaged_from=None):
|
||||
|
@ -164,7 +165,7 @@ class database:
|
|||
func.lower(self.UsersTable.c.username) == func.lower(username),
|
||||
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
if not results:
|
||||
user_data = {
|
||||
|
@ -175,13 +176,13 @@ class database:
|
|||
"pillaged_from_hostid": pillaged_from,
|
||||
}
|
||||
q = insert(self.UsersTable).values(user_data) # .returning(self.UsersTable.c.id)
|
||||
asyncio.run(self.conn.execute(q)) # .first()
|
||||
self.conn.execute(q) # .first()
|
||||
else:
|
||||
for user in results:
|
||||
# might be able to just remove this if check, but leaving it in for now
|
||||
if not user[3] and not user[4] and not user[5]:
|
||||
q = update(self.UsersTable).values(credential_data) # .returning(self.UsersTable.c.id)
|
||||
results = asyncio.run(self.conn.execute(q)) # .first()
|
||||
results = self.conn.execute(q) # .first()
|
||||
# user_rowid = results.id
|
||||
|
||||
logging.debug('add_credential(credtype={}, domain={}, username={}, password={}, pillaged_from={})'.format(
|
||||
|
@ -203,7 +204,7 @@ class database:
|
|||
self.UsersTable.c.id == cred_id
|
||||
)
|
||||
del_hosts.append(q)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
||||
|
||||
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
|
||||
domain = domain.split('.')[0].upper()
|
||||
|
@ -212,7 +213,7 @@ class database:
|
|||
q = select(self.UsersTable).filter(
|
||||
self.UsersTable.c.id == user_id
|
||||
)
|
||||
users = asyncio.run(self.conn.execute(q)).all()
|
||||
users = self.conn.execute(q).all()
|
||||
else:
|
||||
q = select(self.UsersTable).filter(
|
||||
self.UsersTable.c.credtype == credtype,
|
||||
|
@ -220,31 +221,35 @@ class database:
|
|||
func.lower(self.UsersTable.c.username) == func.lower(username),
|
||||
self.UsersTable.c.password == password
|
||||
)
|
||||
users = asyncio.run(self.conn.execute(q)).all()
|
||||
users = self.conn.execute(q).all()
|
||||
logging.debug(f"Users: {users}")
|
||||
|
||||
like_term = func.lower(f"%{host}%")
|
||||
q = q.filter(
|
||||
self.HostsTable.c.ip.like(like_term)
|
||||
)
|
||||
hosts = asyncio.run(self.conn.execute(q)).all()
|
||||
hosts = self.conn.execute(q).all()
|
||||
logging.debug(f"Hosts: {hosts}")
|
||||
|
||||
if users is not None and hosts is not None:
|
||||
for user, host in zip(users, hosts):
|
||||
user_id = user[0]
|
||||
host_id = host[0]
|
||||
link = {
|
||||
"userid": user_id,
|
||||
"hostid": host_id
|
||||
}
|
||||
|
||||
q = select(self.AdminRelationsTable).filter(
|
||||
self.AdminRelationsTable.c.userid == user_id,
|
||||
self.AdminRelationsTable.c.hostid == host_id
|
||||
)
|
||||
links = asyncio.run(self.conn.execute(q)).all()
|
||||
links = self.conn.execute(q).all()
|
||||
|
||||
if not links:
|
||||
asyncio.run(self.conn.execute(
|
||||
self.conn.execute(
|
||||
insert(self.AdminRelationsTable).values(link)
|
||||
))
|
||||
)
|
||||
|
||||
def get_admin_relations(self, user_id=None, host_id=None):
|
||||
if user_id:
|
||||
|
@ -258,7 +263,7 @@ class database:
|
|||
else:
|
||||
q = select(self.AdminRelationsTable)
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def remove_admin_relation(self, user_ids=None, host_ids=None):
|
||||
|
@ -273,7 +278,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.AdminRelationsTable.c.hostid == host_id
|
||||
)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
||||
|
||||
def is_credential_valid(self, credential_id):
|
||||
"""
|
||||
|
@ -283,7 +288,7 @@ class database:
|
|||
self.UsersTable.c.id == credential_id,
|
||||
self.UsersTable.c.password is not None
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return len(results) > 0
|
||||
|
||||
def get_credentials(self, filter_term=None, cred_type=None):
|
||||
|
@ -309,7 +314,7 @@ class database:
|
|||
else:
|
||||
q = select(self.UsersTable)
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def is_host_valid(self, host_id):
|
||||
|
@ -319,7 +324,7 @@ class database:
|
|||
q = select(self.HostsTable).filter(
|
||||
self.HostsTable.c.id == host_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return len(results) > 0
|
||||
|
||||
def get_hosts(self, filter_term=None, domain=None):
|
||||
|
@ -333,7 +338,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.HostsTable.c.id == filter_term
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
# all() returns a list, so we keep the return format the same so consumers don't have to guess
|
||||
return [results]
|
||||
# if we're filtering by domain controllers
|
||||
|
@ -353,9 +358,6 @@ class database:
|
|||
func.lower(self.HostsTable.c.hostname).like(like_term)
|
||||
)
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
asyncio.run(self.conn.execute(table.delete()))
|
||||
|
|
|
@ -15,8 +15,11 @@ class database:
|
|||
|
||||
self.db_engine = db_engine
|
||||
self.metadata = MetaData()
|
||||
asyncio.run(self.reflect_tables())
|
||||
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
|
||||
self.reflect_tables()
|
||||
session_factory = sessionmaker(
|
||||
bind=self.db_engine,
|
||||
expire_on_commit=True
|
||||
)
|
||||
|
||||
Session = scoped_session(session_factory)
|
||||
# this is still named "conn" when it is the session object; TODO: rename
|
||||
|
@ -39,20 +42,9 @@ class database:
|
|||
"server_banner" text
|
||||
)''')
|
||||
|
||||
async def shutdown_db(self):
|
||||
try:
|
||||
await asyncio.shield(self.conn.close())
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
async def reflect_tables(self):
|
||||
async with self.db_engine.connect() as conn:
|
||||
def reflect_tables(self):
|
||||
with self.db_engine.connect() as conn:
|
||||
try:
|
||||
await conn.run_sync(self.metadata.reflect)
|
||||
|
||||
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
except NoInspectionAvailable:
|
||||
|
@ -64,6 +56,15 @@ class database:
|
|||
)
|
||||
exit()
|
||||
|
||||
def shutdown_db(self):
|
||||
try:
|
||||
self.conn.close()
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
asyncio.run(self.conn.execute(table.delete()))
|
||||
self.conn.execute(table.delete())
|
||||
|
|
|
@ -30,13 +30,10 @@ class database:
|
|||
|
||||
self.db_engine = db_engine
|
||||
self.metadata = MetaData()
|
||||
asyncio.run(self.reflect_tables())
|
||||
# we don't use async_sessionmaker or async_scoped_session because when `database` is initialized,
|
||||
# there is no running async loop
|
||||
self.reflect_tables()
|
||||
session_factory = sessionmaker(
|
||||
bind=self.db_engine,
|
||||
expire_on_commit=True,
|
||||
class_=AsyncSession
|
||||
expire_on_commit=True
|
||||
)
|
||||
|
||||
Session = scoped_session(session_factory)
|
||||
|
@ -134,11 +131,9 @@ class database:
|
|||
# FOREIGN KEY(hostid) REFERENCES hosts(id)
|
||||
# )''')
|
||||
|
||||
async def reflect_tables(self):
|
||||
async with self.db_engine.connect() as conn:
|
||||
def reflect_tables(self):
|
||||
with self.db_engine.connect() as conn:
|
||||
try:
|
||||
await conn.run_sync(self.metadata.reflect)
|
||||
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
|
||||
self.GroupsTable = Table("groups", self.metadata, autoload_with=self.db_engine)
|
||||
|
@ -157,9 +152,9 @@ class database:
|
|||
)
|
||||
exit()
|
||||
|
||||
async def shutdown_db(self):
|
||||
def shutdown_db(self):
|
||||
try:
|
||||
await asyncio.shield(self.conn.close())
|
||||
self.conn.close()
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
|
@ -168,7 +163,7 @@ class database:
|
|||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
asyncio.run(self.conn.execute(table.delete()))
|
||||
self.conn.execute(table.delete())
|
||||
|
||||
# pull/545
|
||||
def add_host(self, ip, hostname, domain, os, smbv1, signing, spooler=None, zerologon=None, petitpotam=None,
|
||||
|
@ -183,7 +178,7 @@ class database:
|
|||
q = select(self.HostsTable).filter(
|
||||
self.HostsTable.c.ip == ip
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
# create new host
|
||||
if not results:
|
||||
|
@ -238,12 +233,11 @@ class database:
|
|||
index_elements=self.HostsTable.primary_key,
|
||||
set_=update_columns
|
||||
)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
hosts
|
||||
)
|
||||
) # .scalar()
|
||||
|
||||
self.conn.execute(
|
||||
q,
|
||||
hosts
|
||||
) # .scalar()
|
||||
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
|
||||
if updated_ids:
|
||||
logging.debug(f"add_host() - Host IDs Updated: {updated_ids}")
|
||||
|
@ -267,7 +261,7 @@ class database:
|
|||
func.lower(self.UsersTable.c.username) == func.lower(username),
|
||||
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
# add new credential
|
||||
if not results:
|
||||
|
@ -314,20 +308,18 @@ class database:
|
|||
set_=update_columns_users
|
||||
)
|
||||
logging.debug(f"Adding credentials: {credentials}")
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q_users,
|
||||
credentials
|
||||
)
|
||||
) # .scalar()
|
||||
|
||||
self.conn.execute(
|
||||
q_users,
|
||||
credentials
|
||||
) # .scalar()
|
||||
|
||||
if groups:
|
||||
q_groups = Insert(self.GroupRelationsTable)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q_groups,
|
||||
groups
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
q_groups,
|
||||
groups
|
||||
)
|
||||
# return user_ids
|
||||
|
||||
|
@ -341,7 +333,7 @@ class database:
|
|||
self.UsersTable.c.id == cred_id
|
||||
)
|
||||
del_hosts.append(q)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
||||
|
||||
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
|
||||
domain = domain.split('.')[0]
|
||||
|
@ -359,7 +351,7 @@ class database:
|
|||
func.lower(self.UsersTable.c.username) == func.lower(username),
|
||||
self.UsersTable.c.password == password
|
||||
)
|
||||
users = asyncio.run(self.conn.execute(creds_q))
|
||||
users = self.conn.execute(creds_q)
|
||||
hosts = self.get_hosts(host)
|
||||
|
||||
if users and hosts:
|
||||
|
@ -374,7 +366,7 @@ class database:
|
|||
self.AdminRelationsTable.c.userid == user_id,
|
||||
self.AdminRelationsTable.c.hostid == host_id
|
||||
)
|
||||
links = asyncio.run(self.conn.execute(admin_relations_select)).all()
|
||||
links = self.conn.execute(admin_relations_select).all()
|
||||
|
||||
if not links:
|
||||
add_links.append(link)
|
||||
|
@ -382,10 +374,10 @@ class database:
|
|||
admin_relations_insert = Insert(self.AdminRelationsTable)
|
||||
|
||||
if add_links:
|
||||
asyncio.run(self.conn.execute(
|
||||
self.conn.execute(
|
||||
admin_relations_insert,
|
||||
add_links
|
||||
))
|
||||
)
|
||||
|
||||
def get_admin_relations(self, user_id=None, host_id=None):
|
||||
if user_id:
|
||||
|
@ -399,7 +391,7 @@ class database:
|
|||
else:
|
||||
q = select(self.AdminRelationsTable)
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def remove_admin_relation(self, user_ids=None, host_ids=None):
|
||||
|
@ -414,7 +406,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.AdminRelationsTable.c.hostid == host_id
|
||||
)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
||||
|
||||
def is_credential_valid(self, credential_id):
|
||||
"""
|
||||
|
@ -424,7 +416,7 @@ class database:
|
|||
self.UsersTable.c.id == credential_id,
|
||||
self.UsersTable.c.password is not None
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return len(results) > 0
|
||||
|
||||
def get_credentials(self, filter_term=None, cred_type=None):
|
||||
|
@ -450,7 +442,7 @@ class database:
|
|||
else:
|
||||
q = select(self.UsersTable)
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def get_credential(self, cred_type, domain, username, password):
|
||||
|
@ -462,20 +454,20 @@ class database:
|
|||
self.UsersTable.c.password == password,
|
||||
self.UsersTable.c.credtype == cred_type
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
return results.id
|
||||
|
||||
def is_credential_local(self, credential_id):
|
||||
q = select(self.UsersTable.c.domain).filter(
|
||||
self.UsersTable.c.id == credential_id
|
||||
)
|
||||
user_domain = asyncio.run(self.conn.execute(q)).all()
|
||||
user_domain = self.conn.execute(q).all()
|
||||
|
||||
if user_domain:
|
||||
q = select(self.HostsTable).filter(
|
||||
func.lower(self.HostsTable.c.id) == func.lower(user_domain)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
return len(results) > 0
|
||||
|
||||
|
@ -486,7 +478,7 @@ class database:
|
|||
q = select(self.HostsTable).filter(
|
||||
self.HostsTable.c.id == host_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return len(results) > 0
|
||||
|
||||
def get_hosts(self, filter_term=None, domain=None):
|
||||
|
@ -500,7 +492,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.HostsTable.c.id == filter_term
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
# all() returns a list, so we keep the return format the same so consumers don't have to guess
|
||||
return [results]
|
||||
# if we're filtering by domain controllers
|
||||
|
@ -542,7 +534,7 @@ class database:
|
|||
self.HostsTable.c.ip.like(like_term) |
|
||||
func.lower(self.HostsTable.c.hostname).like(like_term)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
logging.debug(f"smb hosts() - results: {results}")
|
||||
return results
|
||||
|
||||
|
@ -553,7 +545,7 @@ class database:
|
|||
q = select(self.GroupsTable).filter(
|
||||
self.GroupsTable.c.id == group_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
|
||||
valid = True if results else False
|
||||
logging.debug(f"is_group_valid(groupID={group_id}) => {valid}")
|
||||
|
@ -585,11 +577,10 @@ class database:
|
|||
|
||||
# insert the group and get the returned id right away, this can be refactored when we can use RETURNING
|
||||
q = Insert(self.GroupsTable)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
groups
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
q,
|
||||
groups
|
||||
)
|
||||
new_group_data = self.get_groups(group_name=group_data["name"], group_domain=group_data["domain"])
|
||||
returned_id = [new_group_data[0].id]
|
||||
|
@ -623,11 +614,10 @@ class database:
|
|||
index_elements=self.GroupsTable.primary_key,
|
||||
set_=update_columns
|
||||
)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
groups
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
q,
|
||||
groups
|
||||
)
|
||||
# TODO: always return a list and fix code references to not expect a single integer
|
||||
# inserted_result = res_inserted_result.first()
|
||||
|
@ -650,7 +640,7 @@ class database:
|
|||
q = select(self.GroupsTable).filter(
|
||||
self.GroupsTable.c.id == filter_term
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
# all() returns a list, so we keep the return format the same so consumers don't have to guess
|
||||
return [results]
|
||||
elif group_name and group_domain:
|
||||
|
@ -666,7 +656,7 @@ class database:
|
|||
else:
|
||||
q = select(self.GroupsTable).filter()
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
logging.debug(
|
||||
f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}")
|
||||
|
@ -687,7 +677,7 @@ class database:
|
|||
self.GroupRelationsTable.c.groupid == group_id
|
||||
)
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def remove_group_relations(self, user_id=None, group_id=None):
|
||||
|
@ -700,7 +690,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.GroupRelationsTable.c.groupid == group_id
|
||||
)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
||||
|
||||
def is_user_valid(self, user_id):
|
||||
"""
|
||||
|
@ -709,7 +699,7 @@ class database:
|
|||
q = select(self.UsersTable).filter(
|
||||
self.UsersTable.c.id == user_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return len(results) > 0
|
||||
|
||||
def get_users(self, filter_term=None):
|
||||
|
@ -725,7 +715,7 @@ class database:
|
|||
q = q.filter(
|
||||
func.lower(self.UsersTable.c.username).like(like_term)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def get_user(self, domain, username):
|
||||
|
@ -733,7 +723,7 @@ class database:
|
|||
func.lower(self.UsersTable.c.domain) == func.lower(domain),
|
||||
func.lower(self.UsersTable.c.username) == func.lower(username)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def get_domain_controllers(self, domain=None):
|
||||
|
@ -746,7 +736,7 @@ class database:
|
|||
q = select(self.SharesTable).filter(
|
||||
self.SharesTable.c.id == share_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
logging.debug(f"is_share_valid(shareID={share_id}) => {len(results) > 0}")
|
||||
return len(results) > 0
|
||||
|
@ -760,10 +750,10 @@ class database:
|
|||
"read": read,
|
||||
"write": write,
|
||||
}
|
||||
share_id = asyncio.run(self.conn.execute(
|
||||
share_id = self.conn.execute(
|
||||
Insert(self.SharesTable).on_conflict_do_nothing(), # .returning(self.SharesTable.c.id),
|
||||
share_data
|
||||
)) # .scalar_one()
|
||||
) # .scalar_one()
|
||||
# return share_id
|
||||
|
||||
def get_shares(self, filter_term=None):
|
||||
|
@ -778,7 +768,7 @@ class database:
|
|||
)
|
||||
else:
|
||||
q = select(self.SharesTable)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def get_shares_by_access(self, permissions, share_id=None):
|
||||
|
@ -790,7 +780,7 @@ class database:
|
|||
q = q.filter(self.SharesTable.c.read == 1)
|
||||
if "w" in permissions:
|
||||
q = q.filter(self.SharesTable.c.write == 1)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def get_users_with_share_access(self, host_id, share_name, permissions):
|
||||
|
@ -803,7 +793,7 @@ class database:
|
|||
q = q.filter(self.SharesTable.c.read == 1)
|
||||
if "w" in permissions:
|
||||
q = q.filter(self.SharesTable.c.write == 1)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
return results
|
||||
|
||||
|
@ -816,7 +806,7 @@ class database:
|
|||
q = select(self.DpapiBackupkey).filter(
|
||||
func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
if not len(results):
|
||||
pvk_encoded = base64.b64encode(pvk)
|
||||
|
@ -827,11 +817,10 @@ class database:
|
|||
try:
|
||||
# TODO: find a way to abstract this away to a single Upsert call
|
||||
q = Insert(self.DpapiBackupkey) # .returning(self.DpapiBackupkey.c.id)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
[backup_key]
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
q,
|
||||
[backup_key]
|
||||
) # .scalar()
|
||||
logging.debug(f"add_domain_backupkey(domain={domain}, pvk={pvk_encoded})")
|
||||
# return inserted_id
|
||||
|
@ -848,7 +837,7 @@ class database:
|
|||
q = q.filter(
|
||||
func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
logging.debug(f"get_domain_backupkey(domain={domain}) => {results}")
|
||||
|
||||
|
@ -864,7 +853,7 @@ class database:
|
|||
q = select(self.DpapiSecrets).filter(
|
||||
func.lower(self.DpapiSecrets.c.id) == dpapi_secret_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
valid = True if results is not None else False
|
||||
logging.debug(f"is_dpapi_secret_valid(groupID={dpapi_secret_id}) => {valid}")
|
||||
return valid
|
||||
|
@ -883,11 +872,10 @@ class database:
|
|||
"url": url
|
||||
}
|
||||
q = Insert(self.DpapiSecrets).on_conflict_do_nothing() # .returning(self.DpapiSecrets.c.id)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
[secret]
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
q,
|
||||
[secret]
|
||||
) # .scalar()
|
||||
|
||||
# inserted_result = res_inserted_result.first()
|
||||
|
@ -907,14 +895,14 @@ class database:
|
|||
q = q.filter(
|
||||
self.DpapiSecrets.c.id == filter_term
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
# all() returns a list, so we keep the return format the same so consumers don't have to guess
|
||||
return [results]
|
||||
elif host:
|
||||
q = q.filter(
|
||||
self.DpapiSecrets.c.host == host
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
# all() returns a list, so we keep the return format the same so consumers don't have to guess
|
||||
return [results]
|
||||
elif dpapi_type:
|
||||
|
@ -935,7 +923,7 @@ class database:
|
|||
q = q.filter(
|
||||
func.lower(self.DpapiSecrets.c.url) == func.lower(url)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
logging.debug(
|
||||
f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}")
|
||||
|
@ -946,7 +934,7 @@ class database:
|
|||
self.LoggedinRelationsTable.c.userid == user_id,
|
||||
self.LoggedinRelationsTable.c.hostid == host_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(relation_query)).all()
|
||||
results = self.conn.execute(relation_query).all()
|
||||
|
||||
# only add one if one doesn't already exist
|
||||
if not results:
|
||||
|
@ -958,11 +946,10 @@ class database:
|
|||
logging.debug(f"Inserting loggedin_relations: {relation}")
|
||||
# TODO: find a way to abstract this away to a single Upsert call
|
||||
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
[relation]
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
q,
|
||||
[relation]
|
||||
) # .scalar()
|
||||
inserted_id_results = self.get_loggedin_relations(user_id, host_id)
|
||||
logging.debug(f"Checking if relation was added: {inserted_id_results}")
|
||||
|
@ -980,7 +967,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.LoggedinRelationsTable.c.hostid == host_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def remove_loggedin_relations(self, user_id=None, host_id=None):
|
||||
|
@ -993,4 +980,4 @@ class database:
|
|||
q = q.filter(
|
||||
self.LoggedinRelationsTable.c.hostid == host_id
|
||||
)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
||||
|
|
|
@ -15,8 +15,11 @@ class database:
|
|||
|
||||
self.db_engine = db_engine
|
||||
self.metadata = MetaData()
|
||||
asyncio.run(self.reflect_tables())
|
||||
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
|
||||
self.reflect_tables()
|
||||
session_factory = sessionmaker(
|
||||
bind=self.db_engine,
|
||||
expire_on_commit=True
|
||||
)
|
||||
|
||||
Session = scoped_session(session_factory)
|
||||
# this is still named "conn" when it is the session object; TODO: rename
|
||||
|
@ -37,20 +40,9 @@ class database:
|
|||
"server_banner" text
|
||||
)''')
|
||||
|
||||
async def shutdown_db(self):
|
||||
try:
|
||||
await asyncio.shield(self.conn.close())
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
async def reflect_tables(self):
|
||||
async with self.db_engine.connect() as conn:
|
||||
def reflect_tables(self):
|
||||
with self.db_engine.connect() as conn:
|
||||
try:
|
||||
await conn.run_sync(self.metadata.reflect)
|
||||
|
||||
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
except NoInspectionAvailable:
|
||||
|
@ -62,6 +54,15 @@ class database:
|
|||
)
|
||||
exit()
|
||||
|
||||
def shutdown_db(self):
|
||||
try:
|
||||
self.conn.close()
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
except IllegalStateChangeError as e:
|
||||
logging.debug(f"Error while closing session db object: {e}")
|
||||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
asyncio.run(self.conn.execute(table.delete()))
|
||||
self.conn.execute(table.delete())
|
||||
|
|
|
@ -22,13 +22,10 @@ class database:
|
|||
|
||||
self.db_engine = db_engine
|
||||
self.metadata = MetaData()
|
||||
asyncio.run(self.reflect_tables())
|
||||
# we don't use async_sessionmaker or async_scoped_session because when `database` is initialized,
|
||||
# there is no running async loop
|
||||
self.reflect_tables()
|
||||
session_factory = sessionmaker(
|
||||
bind=self.db_engine,
|
||||
expire_on_commit=True,
|
||||
class_=AsyncSession
|
||||
expire_on_commit=True
|
||||
)
|
||||
|
||||
Session = scoped_session(session_factory)
|
||||
|
@ -52,11 +49,9 @@ class database:
|
|||
"server_banner" text
|
||||
)''')
|
||||
|
||||
async def reflect_tables(self):
|
||||
async with self.db_engine.connect() as conn:
|
||||
def reflect_tables(self):
|
||||
with self.db_engine.connect() as conn:
|
||||
try:
|
||||
await conn.run_sync(self.metadata.reflect)
|
||||
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
|
||||
except NoInspectionAvailable:
|
||||
|
@ -68,9 +63,9 @@ class database:
|
|||
)
|
||||
exit()
|
||||
|
||||
async def shutdown_db(self):
|
||||
def shutdown_db(self):
|
||||
try:
|
||||
await asyncio.shield(self.conn.close())
|
||||
self.conn.close()
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
|
@ -79,4 +74,4 @@ class database:
|
|||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
asyncio.run(self.conn.execute(table.delete()))
|
||||
self.conn.execute(table.delete())
|
||||
|
|
|
@ -19,11 +19,10 @@ class database:
|
|||
|
||||
self.db_engine = db_engine
|
||||
self.metadata = MetaData()
|
||||
asyncio.run(self.reflect_tables())
|
||||
self.reflect_tables()
|
||||
session_factory = sessionmaker(
|
||||
bind=self.db_engine,
|
||||
expire_on_commit=True,
|
||||
class_=AsyncSession
|
||||
expire_on_commit=True
|
||||
)
|
||||
|
||||
Session = scoped_session(session_factory)
|
||||
|
@ -64,11 +63,9 @@ class database:
|
|||
FOREIGN KEY(hostid) REFERENCES hosts(id)
|
||||
)''')
|
||||
|
||||
async def reflect_tables(self):
|
||||
async with self.db_engine.connect() as conn:
|
||||
def reflect_tables(self):
|
||||
with self.db_engine.connect() as conn:
|
||||
try:
|
||||
await conn.run_sync(self.metadata.reflect)
|
||||
|
||||
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
|
||||
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
|
||||
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
|
||||
|
@ -82,9 +79,9 @@ class database:
|
|||
)
|
||||
exit()
|
||||
|
||||
async def shutdown_db(self):
|
||||
def shutdown_db(self):
|
||||
try:
|
||||
await asyncio.shield(self.conn.close())
|
||||
self.conn.close()
|
||||
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
|
||||
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
|
||||
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
|
||||
|
@ -93,7 +90,7 @@ class database:
|
|||
|
||||
def clear_database(self):
|
||||
for table in self.metadata.sorted_tables:
|
||||
asyncio.run(self.conn.execute(table.delete()))
|
||||
self.conn.execute(table.delete())
|
||||
|
||||
def add_host(self, ip, port, hostname, domain, os=None):
|
||||
"""
|
||||
|
@ -106,7 +103,7 @@ class database:
|
|||
q = select(self.HostsTable).filter(
|
||||
self.HostsTable.c.ip == ip
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
logging.debug(f"smb add_host() - hosts returned: {results}")
|
||||
|
||||
# create new host
|
||||
|
@ -146,11 +143,9 @@ class database:
|
|||
index_elements=self.HostsTable.primary_key,
|
||||
set_=update_columns
|
||||
)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
hosts
|
||||
)
|
||||
self.conn.execute(
|
||||
q,
|
||||
hosts
|
||||
)
|
||||
|
||||
def add_credential(self, credtype, domain, username, password, pillaged_from=None):
|
||||
|
@ -177,7 +172,7 @@ class database:
|
|||
func.lower(self.UsersTable.c.username) == func.lower(username),
|
||||
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
# add new credential
|
||||
if not results:
|
||||
|
@ -216,12 +211,10 @@ class database:
|
|||
index_elements=self.UsersTable.primary_key,
|
||||
set_=update_columns_users
|
||||
)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q_users,
|
||||
credentials
|
||||
)
|
||||
) # .scalar()
|
||||
self.conn.execute(
|
||||
q_users,
|
||||
credentials
|
||||
) # .scalar()
|
||||
# return user_ids
|
||||
|
||||
def remove_credentials(self, creds_id):
|
||||
|
@ -234,7 +227,7 @@ class database:
|
|||
self.UsersTable.c.id == cred_id
|
||||
)
|
||||
del_hosts.append(q)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
||||
|
||||
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
|
||||
domain = domain.split('.')[0]
|
||||
|
@ -252,7 +245,7 @@ class database:
|
|||
func.lower(self.UsersTable.c.username) == func.lower(username),
|
||||
self.UsersTable.c.password == password
|
||||
)
|
||||
users = asyncio.run(self.conn.execute(creds_q))
|
||||
users = self.conn.execute(creds_q)
|
||||
hosts = self.get_hosts(host)
|
||||
|
||||
if users and hosts:
|
||||
|
@ -267,17 +260,17 @@ class database:
|
|||
self.AdminRelationsTable.c.userid == user_id,
|
||||
self.AdminRelationsTable.c.hostid == host_id
|
||||
)
|
||||
links = asyncio.run(self.conn.execute(admin_relations_select)).all()
|
||||
links = self.conn.execute(admin_relations_select).all()
|
||||
|
||||
if not links:
|
||||
add_links.append(link)
|
||||
|
||||
admin_relations_insert = Insert(self.AdminRelationsTable)
|
||||
|
||||
asyncio.run(self.conn.execute(
|
||||
self.conn.execute(
|
||||
admin_relations_insert,
|
||||
add_links
|
||||
))
|
||||
)
|
||||
|
||||
def get_admin_relations(self, user_id=None, host_id=None):
|
||||
if user_id:
|
||||
|
@ -291,7 +284,7 @@ class database:
|
|||
else:
|
||||
q = select(self.AdminRelationsTable)
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def remove_admin_relation(self, user_ids=None, host_ids=None):
|
||||
|
@ -306,7 +299,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.AdminRelationsTable.c.hostid == host_id
|
||||
)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
||||
|
||||
def is_credential_valid(self, credential_id):
|
||||
"""
|
||||
|
@ -316,7 +309,7 @@ class database:
|
|||
self.UsersTable.c.id == credential_id,
|
||||
self.UsersTable.c.password is not None
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return len(results) > 0
|
||||
|
||||
def get_credentials(self, filter_term=None, cred_type=None):
|
||||
|
@ -342,20 +335,20 @@ class database:
|
|||
else:
|
||||
q = select(self.UsersTable)
|
||||
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def is_credential_local(self, credential_id):
|
||||
q = select(self.UsersTable.c.domain).filter(
|
||||
self.UsersTable.c.id == credential_id
|
||||
)
|
||||
user_domain = asyncio.run(self.conn.execute(q)).all()
|
||||
user_domain = self.conn.execute(q).all()
|
||||
|
||||
if user_domain:
|
||||
q = select(self.HostsTable).filter(
|
||||
func.lower(self.HostsTable.c.id) == func.lower(user_domain)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
|
||||
return len(results) > 0
|
||||
|
||||
|
@ -366,7 +359,7 @@ class database:
|
|||
q = select(self.HostsTable).filter(
|
||||
self.HostsTable.c.id == host_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return len(results) > 0
|
||||
|
||||
def get_hosts(self, filter_term=None):
|
||||
|
@ -380,7 +373,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.HostsTable.c.id == filter_term
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).first()
|
||||
results = self.conn.execute(q).first()
|
||||
# all() returns a list, so we keep the return format the same so consumers don't have to guess
|
||||
return [results]
|
||||
# if we're filtering by domain controllers
|
||||
|
@ -397,7 +390,7 @@ class database:
|
|||
self.HostsTable.c.ip.like(like_term) |
|
||||
func.lower(self.HostsTable.c.hostname).like(like_term)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
logging.debug(f"winrm get_hosts() - results: {results}")
|
||||
return results
|
||||
|
||||
|
@ -408,7 +401,7 @@ class database:
|
|||
q = select(self.UsersTable).filter(
|
||||
self.UsersTable.c.id == user_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return len(results) > 0
|
||||
|
||||
def get_users(self, filter_term=None):
|
||||
|
@ -424,7 +417,7 @@ class database:
|
|||
q = q.filter(
|
||||
func.lower(self.UsersTable.c.username).like(like_term)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def get_user(self, domain, username):
|
||||
|
@ -432,7 +425,7 @@ class database:
|
|||
func.lower(self.UsersTable.c.domain) == func.lower(domain),
|
||||
func.lower(self.UsersTable.c.username) == func.lower(username)
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def add_loggedin_relation(self, user_id, host_id):
|
||||
|
@ -440,7 +433,7 @@ class database:
|
|||
self.LoggedinRelationsTable.c.userid == user_id,
|
||||
self.LoggedinRelationsTable.c.hostid == host_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(relation_query)).all()
|
||||
results = self.conn.execute(relation_query).all()
|
||||
|
||||
# only add one if one doesn't already exist
|
||||
if not results:
|
||||
|
@ -451,11 +444,10 @@ class database:
|
|||
try:
|
||||
# TODO: find a way to abstract this away to a single Upsert call
|
||||
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
|
||||
asyncio.run(
|
||||
self.conn.execute(
|
||||
q,
|
||||
[relation]
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
q,
|
||||
[relation]
|
||||
) # .scalar()
|
||||
# return inserted_ids
|
||||
except Exception as e:
|
||||
|
@ -471,7 +463,7 @@ class database:
|
|||
q = q.filter(
|
||||
self.LoggedinRelationsTable.c.hostid == host_id
|
||||
)
|
||||
results = asyncio.run(self.conn.execute(q)).all()
|
||||
results = self.conn.execute(q).all()
|
||||
return results
|
||||
|
||||
def remove_loggedin_relations(self, user_id=None, host_id=None):
|
||||
|
@ -484,4 +476,4 @@ class database:
|
|||
q = q.filter(
|
||||
self.LoggedinRelationsTable.c.hostid == host_id
|
||||
)
|
||||
asyncio.run(self.conn.execute(q))
|
||||
self.conn.execute(q)
|
Loading…
Reference in New Issue