Merge branch 'tests_marshall' into modules_marshall

main
Marshall Hallenbeck 2023-03-26 02:05:43 -04:00
commit f187453525
15 changed files with 380 additions and 399 deletions

View File

@ -39,6 +39,7 @@ def gen_cli_args():
parser.add_argument("-t", type=int, dest="threads", default=100, help="set how many concurrent threads to use (default: 100)")
parser.add_argument("--timeout", default=None, type=int, help='max timeout in seconds of each thread (default: None)')
parser.add_argument("--jitter", metavar='INTERVAL', type=str, help='sets a random delay between each connection (default: None)')
parser.add_argument("--progress", default=True, action='store_false', help='display progress bar during scan')
parser.add_argument("--darrell", action='store_true', help='give Darrell a hand')
parser.add_argument("--verbose", action='store_true', help="enable verbose output")
parser.add_argument("--version", action='store_true', help="Display CME version")

View File

@ -8,12 +8,12 @@ import sqlite3
import sys
import os
import requests
from sqlalchemy import create_engine
from terminaltables import AsciiTable
import configparser
from cme.loaders.protocol_loader import protocol_loader
from cme.paths import CONFIG_PATH, WS_PATH, WORKSPACE_DIR
from requests import ConnectionError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.exc import SAWarning
import asyncio
import csv
@ -33,13 +33,11 @@ class UserExitedProto(Exception):
def create_db_engine(db_path):
db_engine = create_async_engine(
f"sqlite+aiosqlite:///{db_path}",
db_engine = create_engine(
f"sqlite:///{db_path}",
isolation_level="AUTOCOMMIT",
future=True
) # can add echo=True
# db_engine.execution_options(isolation_level="AUTOCOMMIT")
# db_engine.connect().connection.text_factory = str
)
return db_engine
@ -52,10 +50,6 @@ def print_table(data, title=None):
print("")
# def do_exit():
# sys.exit(0)
def write_csv(filename, headers, entries):
"""
Writes a CSV file with the provided parameters.
@ -111,7 +105,7 @@ class DatabaseNavigator(cmd.Cmd):
self.prompt = 'cmedb ({})({}) > '.format(main_menu.workspace, proto)
def do_exit(self, line):
asyncio.run(self.db.shutdown_db())
self.db.shutdown_db()
sys.exit()
def help_exit(self):

View File

@ -2,11 +2,14 @@
# -*- coding: utf-8 -*-
import logging
import random
import socket
from os.path import isfile
from threading import BoundedSemaphore
from socket import gethostbyname
from functools import wraps
from time import sleep
from cme.logger import CMEAdapter
from cme.context import Context
from cme.helpers.logger import write_log
@ -69,6 +72,11 @@ class connection(object):
logging.debug('Error resolving hostname {}: {}'.format(self.hostname, e))
return
if args.jitter:
value = random.choice(range(args.jitter[0], args.jitter[1]))
logging.debug(f"Doin' the jitterbug for {value} second(s)")
sleep(value)
self.proto_flow()
@staticmethod

View File

@ -1,5 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import concurrent.futures
import sqlalchemy
from cme.logger import setup_logger, setup_debug_logger, CMEAdapter
from cme.helpers.logger import highlight
@ -18,8 +21,6 @@ from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
from decimal import Decimal
import asyncio
import aioconsole
import functools
import configparser
import cme.helpers.powershell as powershell
import cme
@ -31,9 +32,9 @@ import os
import sys
import logging
from sqlalchemy.orm import declarative_base
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.exc import SAWarning
import warnings
from tqdm import tqdm
Base = declarative_base()
@ -50,95 +51,20 @@ warnings.filterwarnings("ignore", category=SAWarning)
def create_db_engine(db_path):
db_engine = create_async_engine(
f"sqlite+aiosqlite:///{db_path}",
db_engine = sqlalchemy.create_engine(
f"sqlite:///{db_path}",
isolation_level="AUTOCOMMIT",
future=True
) # can add echo=True
# db_engine.execution_options(isolation_level="AUTOCOMMIT")
# db_engine.connect().connection.text_factory = str
)
return db_engine
async def monitor_threadpool(pool, targets):
logging.debug('Started thread poller')
while True:
try:
text = await aioconsole.ainput("")
if text == "":
pool_size = pool._work_queue.qsize()
finished_threads = len(targets) - pool_size
percentage = Decimal(finished_threads) / Decimal(len(targets)) * Decimal(100)
logger.info(f"completed: {percentage:.2f}% ({finished_threads}/{len(targets)})")
except asyncio.CancelledError:
logging.debug("Stopped thread poller")
break
async def run_protocol(loop, protocol_obj, args, db, target, jitter):
try:
if jitter:
value = random.choice(range(jitter[0], jitter[1]))
logging.debug(f"Doin' the jitterbug for {value} second(s)")
await asyncio.sleep(value)
thread = loop.run_in_executor(
None,
functools.partial(
protocol_obj,
args,
db,
str(target)
)
)
await asyncio.wait_for(
thread,
timeout=args.timeout
)
except asyncio.TimeoutError:
logging.debug("Thread exceeded timeout")
except asyncio.CancelledError:
logging.debug("Shutting down DB")
thread.cancel()
except sqlite3.OperationalError as e:
logging.debug("Sqlite error - sqlite3.operationalError - {}".format(str(e)))
async def start_threadpool(protocol_obj, args, db, targets, jitter):
pool = ThreadPoolExecutor(max_workers=args.threads + 1)
loop = asyncio.get_running_loop()
loop.set_default_executor(pool)
monitor_task = asyncio.create_task(
monitor_threadpool(pool, targets)
)
jobs = [
run_protocol(
loop,
protocol_obj,
args,
db,
target,
jitter
)
for target in targets
]
try:
logging.debug("Running")
await asyncio.gather(*jobs)
except asyncio.CancelledError:
print('\n')
logger.info("Shutting down, please wait...")
logging.debug("Cancelling scan")
finally:
await asyncio.shield(db.shutdown_db())
monitor_task.cancel()
pool.shutdown(wait=True)
async def start_scan(protocol_obj, args, db, targets):
with tqdm(total=len(targets), disable=args.progress) as pbar:
with ThreadPoolExecutor(max_workers=args.threads + 1) as executor:
futures = [executor.submit(protocol_obj, args, db, target) for target in targets]
for future in concurrent.futures.as_completed(futures):
pbar.update(1)
def main():
@ -164,7 +90,6 @@ def main():
module = None
module_server = None
targets = []
jitter = None
server_port_dict = {'http': 80, 'https': 443, 'smb': 445}
current_workspace = config.get('CME', 'workspace')
if config.get('CME', 'log_mode') != "False":
@ -178,9 +103,9 @@ def main():
if args.jitter:
if '-' in args.jitter:
start, end = args.jitter.split('-')
jitter = (int(start), int(end))
args.jitter = (int(start), int(end))
else:
jitter = (0, int(args.jitter))
args.jitter = (0, int(args.jitter))
if hasattr(args, 'cred_id') and args.cred_id:
for cred_id in args.cred_id:
@ -304,14 +229,14 @@ def main():
try:
asyncio.run(
start_threadpool(protocol_object, args, db, targets, jitter)
start_scan(protocol_object, args, db, targets)
)
except KeyboardInterrupt:
logging.debug("Got keyboard interrupt")
finally:
if module_server:
module_server.shutdown()
asyncio.run(db_engine.dispose())
db_engine.dispose()
if __name__ == '__main__':

View File

@ -102,11 +102,11 @@ class CMEModule:
if entries:
num = len(entries)
if 1 == num:
logging.info('Received one endpoint.')
logging.debug(f"[Spooler] Received one endpoint")
else:
logging.info('Received %d endpoints.' % num)
logging.debug(f"[Spooler] Received {num} endpoints")
else:
logging.info('No endpoints found.')
logging.debug(f"[Spooler] No endpoints found")
def __fetchList(self, rpctransport):
dce = rpctransport.get_dce_rpc()

View File

@ -16,8 +16,11 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
@ -38,20 +41,9 @@ class database:
"server_banner" text
)''')
async def shutdown_db(self):
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -63,7 +55,16 @@ class database:
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -15,8 +15,11 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
@ -37,20 +40,9 @@ class database:
"port" integer
)''')
async def shutdown_db(self):
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -62,6 +54,15 @@ class database:
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -5,9 +5,7 @@ from sqlalchemy import MetaData, func, Table, select, insert, update, delete
from sqlalchemy.dialects.sqlite import Insert # used for upsert
from sqlalchemy.exc import IllegalStateChangeError, NoInspectionAvailable
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SAWarning
import asyncio
import warnings
# if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings
@ -22,39 +20,16 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
self.conn = Session()
async def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
print(
"[-] Error reflecting tables - this means there is a DB schema mismatch \n"
"[-] This is probably because a newer version of CME is being ran on an old DB schema\n"
"[-] If you wish to save the old DB data, copy it to a new location (`cp -r ~/.cme/workspaces/ ~/old_cme_workspaces/`)\n"
"[-] Then remove the CME DB folders (`rm -rf ~/.cme/workspaces/`) and rerun CME to initialize the new DB schema"
)
exit()
@staticmethod
def db_schema(db_conn):
db_conn.execute('''CREATE TABLE "hosts" (
@ -84,6 +59,34 @@ class database:
FOREIGN KEY(pillaged_from_hostid) REFERENCES hosts(id)
)''')
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
print(
"[-] Error reflecting tables - this means there is a DB schema mismatch \n"
"[-] This is probably because a newer version of CME is being ran on an old DB schema\n"
"[-] If you wish to save the old DB data, copy it to a new location (`cp -r ~/.cme/workspaces/ ~/old_cme_workspaces/`)\n"
"[-] Then remove the CME DB folders (`rm -rf ~/.cme/workspaces/`) and rerun CME to initialize the new DB schema"
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
self.conn.execute(table.delete())
def add_host(self, ip, hostname, domain, os, instances):
"""
Check if this host has already been added to the database, if not, add it in.
@ -95,7 +98,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.ip == ip
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"mssql add_host() - hosts returned: {results}")
host_data = {
@ -133,12 +136,10 @@ class database:
index_elements=self.HostsTable.primary_key,
set_=update_columns
)
asyncio.run(
self.conn.execute(
q,
hosts
)
)
def add_credential(self, credtype, domain, username, password, pillaged_from=None):
"""
@ -164,7 +165,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
if not results:
user_data = {
@ -175,13 +176,13 @@ class database:
"pillaged_from_hostid": pillaged_from,
}
q = insert(self.UsersTable).values(user_data) # .returning(self.UsersTable.c.id)
asyncio.run(self.conn.execute(q)) # .first()
self.conn.execute(q) # .first()
else:
for user in results:
# might be able to just remove this if check, but leaving it in for now
if not user[3] and not user[4] and not user[5]:
q = update(self.UsersTable).values(credential_data) # .returning(self.UsersTable.c.id)
results = asyncio.run(self.conn.execute(q)) # .first()
results = self.conn.execute(q) # .first()
# user_rowid = results.id
logging.debug('add_credential(credtype={}, domain={}, username={}, password={}, pillaged_from={})'.format(
@ -203,7 +204,7 @@ class database:
self.UsersTable.c.id == cred_id
)
del_hosts.append(q)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
domain = domain.split('.')[0].upper()
@ -212,7 +213,7 @@ class database:
q = select(self.UsersTable).filter(
self.UsersTable.c.id == user_id
)
users = asyncio.run(self.conn.execute(q)).all()
users = self.conn.execute(q).all()
else:
q = select(self.UsersTable).filter(
self.UsersTable.c.credtype == credtype,
@ -220,31 +221,35 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
self.UsersTable.c.password == password
)
users = asyncio.run(self.conn.execute(q)).all()
users = self.conn.execute(q).all()
logging.debug(f"Users: {users}")
like_term = func.lower(f"%{host}%")
q = q.filter(
self.HostsTable.c.ip.like(like_term)
)
hosts = asyncio.run(self.conn.execute(q)).all()
hosts = self.conn.execute(q).all()
logging.debug(f"Hosts: {hosts}")
if users is not None and hosts is not None:
for user, host in zip(users, hosts):
user_id = user[0]
host_id = host[0]
link = {
"userid": user_id,
"hostid": host_id
}
q = select(self.AdminRelationsTable).filter(
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id
)
links = asyncio.run(self.conn.execute(q)).all()
links = self.conn.execute(q).all()
if not links:
asyncio.run(self.conn.execute(
self.conn.execute(
insert(self.AdminRelationsTable).values(link)
))
)
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@ -258,7 +263,7 @@ class database:
else:
q = select(self.AdminRelationsTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_admin_relation(self, user_ids=None, host_ids=None):
@ -273,7 +278,7 @@ class database:
q = q.filter(
self.AdminRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def is_credential_valid(self, credential_id):
"""
@ -283,7 +288,7 @@ class database:
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@ -309,7 +314,7 @@ class database:
else:
q = select(self.UsersTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def is_host_valid(self, host_id):
@ -319,7 +324,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.id == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None, domain=None):
@ -333,7 +338,7 @@ class database:
q = q.filter(
self.HostsTable.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@ -353,9 +358,6 @@ class database:
func.lower(self.HostsTable.c.hostname).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))

View File

@ -15,8 +15,11 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
@ -39,20 +42,9 @@ class database:
"server_banner" text
)''')
async def shutdown_db(self):
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -64,6 +56,15 @@ class database:
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -30,13 +30,10 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
# we don't use async_sessionmaker or async_scoped_session because when `database` is initialized,
# there is no running async loop
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True,
class_=AsyncSession
expire_on_commit=True
)
Session = scoped_session(session_factory)
@ -134,11 +131,9 @@ class database:
# FOREIGN KEY(hostid) REFERENCES hosts(id)
# )''')
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
self.GroupsTable = Table("groups", self.metadata, autoload_with=self.db_engine)
@ -157,9 +152,9 @@ class database:
)
exit()
async def shutdown_db(self):
def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
@ -168,7 +163,7 @@ class database:
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())
# pull/545
def add_host(self, ip, hostname, domain, os, smbv1, signing, spooler=None, zerologon=None, petitpotam=None,
@ -183,7 +178,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.ip == ip
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
# create new host
if not results:
@ -238,11 +233,10 @@ class database:
index_elements=self.HostsTable.primary_key,
set_=update_columns
)
asyncio.run(
self.conn.execute(
q,
hosts
)
) # .scalar()
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
if updated_ids:
@ -267,7 +261,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
# add new credential
if not results:
@ -314,21 +308,19 @@ class database:
set_=update_columns_users
)
logging.debug(f"Adding credentials: {credentials}")
asyncio.run(
self.conn.execute(
q_users,
credentials
)
) # .scalar()
if groups:
q_groups = Insert(self.GroupRelationsTable)
asyncio.run(
self.conn.execute(
q_groups,
groups
)
)
# return user_ids
def remove_credentials(self, creds_id):
@ -341,7 +333,7 @@ class database:
self.UsersTable.c.id == cred_id
)
del_hosts.append(q)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
domain = domain.split('.')[0]
@ -359,7 +351,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
self.UsersTable.c.password == password
)
users = asyncio.run(self.conn.execute(creds_q))
users = self.conn.execute(creds_q)
hosts = self.get_hosts(host)
if users and hosts:
@ -374,7 +366,7 @@ class database:
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id
)
links = asyncio.run(self.conn.execute(admin_relations_select)).all()
links = self.conn.execute(admin_relations_select).all()
if not links:
add_links.append(link)
@ -382,10 +374,10 @@ class database:
admin_relations_insert = Insert(self.AdminRelationsTable)
if add_links:
asyncio.run(self.conn.execute(
self.conn.execute(
admin_relations_insert,
add_links
))
)
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@ -399,7 +391,7 @@ class database:
else:
q = select(self.AdminRelationsTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_admin_relation(self, user_ids=None, host_ids=None):
@ -414,7 +406,7 @@ class database:
q = q.filter(
self.AdminRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def is_credential_valid(self, credential_id):
"""
@ -424,7 +416,7 @@ class database:
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@ -450,7 +442,7 @@ class database:
else:
q = select(self.UsersTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_credential(self, cred_type, domain, username, password):
@ -462,20 +454,20 @@ class database:
self.UsersTable.c.password == password,
self.UsersTable.c.credtype == cred_type
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
return results.id
def is_credential_local(self, credential_id):
q = select(self.UsersTable.c.domain).filter(
self.UsersTable.c.id == credential_id
)
user_domain = asyncio.run(self.conn.execute(q)).all()
user_domain = self.conn.execute(q).all()
if user_domain:
q = select(self.HostsTable).filter(
func.lower(self.HostsTable.c.id) == func.lower(user_domain)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
@ -486,7 +478,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.id == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None, domain=None):
@ -500,7 +492,7 @@ class database:
q = q.filter(
self.HostsTable.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@ -542,7 +534,7 @@ class database:
self.HostsTable.c.ip.like(like_term) |
func.lower(self.HostsTable.c.hostname).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"smb hosts() - results: {results}")
return results
@ -553,7 +545,7 @@ class database:
q = select(self.GroupsTable).filter(
self.GroupsTable.c.id == group_id
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
valid = True if results else False
logging.debug(f"is_group_valid(groupID={group_id}) => {valid}")
@ -585,12 +577,11 @@ class database:
# insert the group and get the returned id right away, this can be refactored when we can use RETURNING
q = Insert(self.GroupsTable)
asyncio.run(
self.conn.execute(
q,
groups
)
)
new_group_data = self.get_groups(group_name=group_data["name"], group_domain=group_data["domain"])
returned_id = [new_group_data[0].id]
logging.debug(f"Inserted group with ID: {returned_id[0]}")
@ -623,12 +614,11 @@ class database:
index_elements=self.GroupsTable.primary_key,
set_=update_columns
)
asyncio.run(
self.conn.execute(
q,
groups
)
)
# TODO: always return a list and fix code references to not expect a single integer
# inserted_result = res_inserted_result.first()
# gid = inserted_result.id
@ -650,7 +640,7 @@ class database:
q = select(self.GroupsTable).filter(
self.GroupsTable.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif group_name and group_domain:
@ -666,7 +656,7 @@ class database:
else:
q = select(self.GroupsTable).filter()
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(
f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}")
@ -687,7 +677,7 @@ class database:
self.GroupRelationsTable.c.groupid == group_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_group_relations(self, user_id=None, group_id=None):
@ -700,7 +690,7 @@ class database:
q = q.filter(
self.GroupRelationsTable.c.groupid == group_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def is_user_valid(self, user_id):
"""
@ -709,7 +699,7 @@ class database:
q = select(self.UsersTable).filter(
self.UsersTable.c.id == user_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_users(self, filter_term=None):
@ -725,7 +715,7 @@ class database:
q = q.filter(
func.lower(self.UsersTable.c.username).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_user(self, domain, username):
@ -733,7 +723,7 @@ class database:
func.lower(self.UsersTable.c.domain) == func.lower(domain),
func.lower(self.UsersTable.c.username) == func.lower(username)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_domain_controllers(self, domain=None):
@ -746,7 +736,7 @@ class database:
q = select(self.SharesTable).filter(
self.SharesTable.c.id == share_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"is_share_valid(shareID={share_id}) => {len(results) > 0}")
return len(results) > 0
@ -760,10 +750,10 @@ class database:
"read": read,
"write": write,
}
share_id = asyncio.run(self.conn.execute(
share_id = self.conn.execute(
Insert(self.SharesTable).on_conflict_do_nothing(), # .returning(self.SharesTable.c.id),
share_data
)) # .scalar_one()
) # .scalar_one()
# return share_id
def get_shares(self, filter_term=None):
@ -778,7 +768,7 @@ class database:
)
else:
q = select(self.SharesTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_shares_by_access(self, permissions, share_id=None):
@ -790,7 +780,7 @@ class database:
q = q.filter(self.SharesTable.c.read == 1)
if "w" in permissions:
q = q.filter(self.SharesTable.c.write == 1)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_users_with_share_access(self, host_id, share_name, permissions):
@ -803,7 +793,7 @@ class database:
q = q.filter(self.SharesTable.c.read == 1)
if "w" in permissions:
q = q.filter(self.SharesTable.c.write == 1)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
@ -816,7 +806,7 @@ class database:
q = select(self.DpapiBackupkey).filter(
func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
if not len(results):
pvk_encoded = base64.b64encode(pvk)
@ -827,11 +817,10 @@ class database:
try:
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.DpapiBackupkey) # .returning(self.DpapiBackupkey.c.id)
asyncio.run(
self.conn.execute(
q,
[backup_key]
)
) # .scalar()
logging.debug(f"add_domain_backupkey(domain={domain}, pvk={pvk_encoded})")
# return inserted_id
@ -848,7 +837,7 @@ class database:
q = q.filter(
func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"get_domain_backupkey(domain={domain}) => {results}")
@ -864,7 +853,7 @@ class database:
q = select(self.DpapiSecrets).filter(
func.lower(self.DpapiSecrets.c.id) == dpapi_secret_id
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
valid = True if results is not None else False
logging.debug(f"is_dpapi_secret_valid(groupID={dpapi_secret_id}) => {valid}")
return valid
@ -883,11 +872,10 @@ class database:
"url": url
}
q = Insert(self.DpapiSecrets).on_conflict_do_nothing() # .returning(self.DpapiSecrets.c.id)
asyncio.run(
self.conn.execute(
q,
[secret]
)
) # .scalar()
# inserted_result = res_inserted_result.first()
@ -907,14 +895,14 @@ class database:
q = q.filter(
self.DpapiSecrets.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif host:
q = q.filter(
self.DpapiSecrets.c.host == host
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif dpapi_type:
@ -935,7 +923,7 @@ class database:
q = q.filter(
func.lower(self.DpapiSecrets.c.url) == func.lower(url)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(
f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}")
@ -946,7 +934,7 @@ class database:
self.LoggedinRelationsTable.c.userid == user_id,
self.LoggedinRelationsTable.c.hostid == host_id
)
results = asyncio.run(self.conn.execute(relation_query)).all()
results = self.conn.execute(relation_query).all()
# only add one if one doesn't already exist
if not results:
@ -958,11 +946,10 @@ class database:
logging.debug(f"Inserting loggedin_relations: {relation}")
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
asyncio.run(
self.conn.execute(
q,
[relation]
)
) # .scalar()
inserted_id_results = self.get_loggedin_relations(user_id, host_id)
logging.debug(f"Checking if relation was added: {inserted_id_results}")
@ -980,7 +967,7 @@ class database:
q = q.filter(
self.LoggedinRelationsTable.c.hostid == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_loggedin_relations(self, user_id=None, host_id=None):
@ -993,4 +980,4 @@ class database:
q = q.filter(
self.LoggedinRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)

View File

@ -15,8 +15,11 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True, class_=AsyncSession)
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True
)
Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
@ -37,20 +40,9 @@ class database:
"server_banner" text
)''')
async def shutdown_db(self):
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await asyncio.shield(self.conn.close())
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -62,6 +54,15 @@ class database:
)
exit()
def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
logging.debug(f"Error while closing session db object: {e}")
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -22,13 +22,10 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
# we don't use async_sessionmaker or async_scoped_session because when `database` is initialized,
# there is no running async loop
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True,
class_=AsyncSession
expire_on_commit=True
)
Session = scoped_session(session_factory)
@ -52,11 +49,9 @@ class database:
"server_banner" text
)''')
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.CredentialsTable = Table("credentials", self.metadata, autoload_with=self.db_engine)
except NoInspectionAvailable:
@ -68,9 +63,9 @@ class database:
)
exit()
async def shutdown_db(self):
def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
@ -79,4 +74,4 @@ class database:
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())

View File

@ -19,11 +19,10 @@ class database:
self.db_engine = db_engine
self.metadata = MetaData()
asyncio.run(self.reflect_tables())
self.reflect_tables()
session_factory = sessionmaker(
bind=self.db_engine,
expire_on_commit=True,
class_=AsyncSession
expire_on_commit=True
)
Session = scoped_session(session_factory)
@ -64,11 +63,9 @@ class database:
FOREIGN KEY(hostid) REFERENCES hosts(id)
)''')
async def reflect_tables(self):
async with self.db_engine.connect() as conn:
def reflect_tables(self):
with self.db_engine.connect() as conn:
try:
await conn.run_sync(self.metadata.reflect)
self.HostsTable = Table("hosts", self.metadata, autoload_with=self.db_engine)
self.UsersTable = Table("users", self.metadata, autoload_with=self.db_engine)
self.AdminRelationsTable = Table("admin_relations", self.metadata, autoload_with=self.db_engine)
@ -82,9 +79,9 @@ class database:
)
exit()
async def shutdown_db(self):
def shutdown_db(self):
try:
await asyncio.shield(self.conn.close())
self.conn.close()
# due to the async nature of CME, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
@ -93,7 +90,7 @@ class database:
def clear_database(self):
for table in self.metadata.sorted_tables:
asyncio.run(self.conn.execute(table.delete()))
self.conn.execute(table.delete())
def add_host(self, ip, port, hostname, domain, os=None):
"""
@ -106,7 +103,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.ip == ip
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"smb add_host() - hosts returned: {results}")
# create new host
@ -146,12 +143,10 @@ class database:
index_elements=self.HostsTable.primary_key,
set_=update_columns
)
asyncio.run(
self.conn.execute(
q,
hosts
)
)
def add_credential(self, credtype, domain, username, password, pillaged_from=None):
"""
@ -177,7 +172,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
# add new credential
if not results:
@ -216,11 +211,9 @@ class database:
index_elements=self.UsersTable.primary_key,
set_=update_columns_users
)
asyncio.run(
self.conn.execute(
q_users,
credentials
)
) # .scalar()
# return user_ids
@ -234,7 +227,7 @@ class database:
self.UsersTable.c.id == cred_id
)
del_hosts.append(q)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
domain = domain.split('.')[0]
@ -252,7 +245,7 @@ class database:
func.lower(self.UsersTable.c.username) == func.lower(username),
self.UsersTable.c.password == password
)
users = asyncio.run(self.conn.execute(creds_q))
users = self.conn.execute(creds_q)
hosts = self.get_hosts(host)
if users and hosts:
@ -267,17 +260,17 @@ class database:
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id
)
links = asyncio.run(self.conn.execute(admin_relations_select)).all()
links = self.conn.execute(admin_relations_select).all()
if not links:
add_links.append(link)
admin_relations_insert = Insert(self.AdminRelationsTable)
asyncio.run(self.conn.execute(
self.conn.execute(
admin_relations_insert,
add_links
))
)
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@ -291,7 +284,7 @@ class database:
else:
q = select(self.AdminRelationsTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_admin_relation(self, user_ids=None, host_ids=None):
@ -306,7 +299,7 @@ class database:
q = q.filter(
self.AdminRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)
def is_credential_valid(self, credential_id):
"""
@ -316,7 +309,7 @@ class database:
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@ -342,20 +335,20 @@ class database:
else:
q = select(self.UsersTable)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def is_credential_local(self, credential_id):
q = select(self.UsersTable.c.domain).filter(
self.UsersTable.c.id == credential_id
)
user_domain = asyncio.run(self.conn.execute(q)).all()
user_domain = self.conn.execute(q).all()
if user_domain:
q = select(self.HostsTable).filter(
func.lower(self.HostsTable.c.id) == func.lower(user_domain)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
@ -366,7 +359,7 @@ class database:
q = select(self.HostsTable).filter(
self.HostsTable.c.id == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None):
@ -380,7 +373,7 @@ class database:
q = q.filter(
self.HostsTable.c.id == filter_term
)
results = asyncio.run(self.conn.execute(q)).first()
results = self.conn.execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@ -397,7 +390,7 @@ class database:
self.HostsTable.c.ip.like(like_term) |
func.lower(self.HostsTable.c.hostname).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
logging.debug(f"winrm get_hosts() - results: {results}")
return results
@ -408,7 +401,7 @@ class database:
q = select(self.UsersTable).filter(
self.UsersTable.c.id == user_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return len(results) > 0
def get_users(self, filter_term=None):
@ -424,7 +417,7 @@ class database:
q = q.filter(
func.lower(self.UsersTable.c.username).like(like_term)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def get_user(self, domain, username):
@ -432,7 +425,7 @@ class database:
func.lower(self.UsersTable.c.domain) == func.lower(domain),
func.lower(self.UsersTable.c.username) == func.lower(username)
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def add_loggedin_relation(self, user_id, host_id):
@ -440,7 +433,7 @@ class database:
self.LoggedinRelationsTable.c.userid == user_id,
self.LoggedinRelationsTable.c.hostid == host_id
)
results = asyncio.run(self.conn.execute(relation_query)).all()
results = self.conn.execute(relation_query).all()
# only add one if one doesn't already exist
if not results:
@ -451,11 +444,10 @@ class database:
try:
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
asyncio.run(
self.conn.execute(
q,
[relation]
)
) # .scalar()
# return inserted_ids
except Exception as e:
@ -471,7 +463,7 @@ class database:
q = q.filter(
self.LoggedinRelationsTable.c.hostid == host_id
)
results = asyncio.run(self.conn.execute(q)).all()
results = self.conn.execute(q).all()
return results
def remove_loggedin_relations(self, user_id=None, host_id=None):
@ -484,4 +476,4 @@ class database:
q = q.filter(
self.LoggedinRelationsTable.c.hostid == host_id
)
asyncio.run(self.conn.execute(q))
self.conn.execute(q)

View File

@ -47,7 +47,9 @@ def run_e2e_tests():
result = subprocess.Popen("crackmapexec --version", shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
version = result.communicate()[0].decode().strip()
with console.status(f"[bold green] :brain: Running test commands for cme v{version}...") as status:
with console.status(f"[bold green] :brain: Running {len(tasks)} test commands for cme v{version}...") as status:
passed = 0
failed = 0
while tasks:
task = tasks.pop(0)
result = subprocess.Popen(str(task), shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
@ -56,8 +58,11 @@ def run_e2e_tests():
return_code = result.returncode
if return_code == 0:
console.log(f"{task.strip()} :heavy_check_mark:")
passed += 1
else:
console.log(f"[bold red]{task.strip()} :cross_mark:[/]")
failed += 1
console.log(f"Tests [bold green] Passed: {passed} [bold red] Failed: {failed}")
if __name__ == "__main__":

View File

@ -1,11 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import os
from time import sleep
import os
import pytest
import pytest_asyncio
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session
@ -18,11 +15,14 @@ from sqlalchemy.dialects.sqlite import Insert
@pytest.fixture(scope="session")
<<<<<<< HEAD
def event_loop():
return asyncio.get_event_loop()
@pytest_asyncio.fixture(scope="session")
=======
>>>>>>> tests_marshall
def db_engine():
db_path = os.path.join(WS_PATH, "test/smb.db")
db_engine = create_engine(
@ -34,8 +34,8 @@ def db_engine():
db_engine.dispose()
@pytest_asyncio.fixture(scope="session")
async def db(db_engine):
@pytest.fixture(scope="session")
def db(db_engine):
proto = "smb"
setup_logger()
logger = CMEAdapter()
@ -47,19 +47,19 @@ async def db(db_engine):
protocol_db_path = p_loader.get_protocols()[proto]["dbpath"]
protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), "database")
db = protocol_db_object(db_engine)
yield db
db.shutdown_db()
database_obj = protocol_db_object(db_engine)
database_obj.reflect_tables()
yield database_obj
database_obj.shutdown_db()
delete_workspace("test")
@pytest_asyncio.fixture(scope="session")
@pytest.fixture(scope="session")
def sess(db_engine):
session_factory = sessionmaker(
bind=db_engine,
expire_on_commit=True
)
Session = scoped_session(
session_factory
)
@ -68,9 +68,8 @@ def sess(db_engine):
sess.close()
@pytest.mark.asyncio
async def test_add_host(db):
await db.add_host(
def test_add_host(db):
db.add_host(
"127.0.0.1",
"localhost",
"TEST.DEV",
@ -82,6 +81,20 @@ async def test_add_host(db):
False,
False
)
inserted_host = db.get_hosts()
assert len(inserted_host) == 1
host = inserted_host[0]
assert host.id == 1
assert host.ip == "127.0.0.1"
assert host.hostname == "localhost"
assert host.os == "Windows Testing 2023"
assert host.smbv1 is False
assert host.signing is True
assert host.spooler is True
assert host.zerologon is True
assert host.petitpotam is False
assert host.dc is False
db.clear_database()
def test_update_host(db, sess):
@ -90,15 +103,41 @@ def test_update_host(db, sess):
"hostname": "localhost",
"domain": "TEST.DEV",
"os": "Windows Testing 2023",
"dc": False,
"smbv1": True,
"signing": True,
"signing": False,
"spooler": True,
"zerologon": False,
"petitpotam": False
"petitpotam": False,
"dc": False
}
iq = Insert(db.HostsTable)
sess.execute(iq, [host])
db.add_host(
"127.0.0.1",
"localhost",
"TEST.DEV",
"Windows Testing 2023 Updated",
False,
True,
False,
False,
False,
False
)
inserted_host = db.get_hosts()
assert len(inserted_host) == 1
host = inserted_host[0]
assert host.id == 1
assert host.ip == "127.0.0.1"
assert host.hostname == "localhost"
assert host.os == "Windows Testing 2023 Updated"
assert host.smbv1 is False
assert host.signing is True
assert host.spooler is False
assert host.zerologon is False
assert host.petitpotam is False
assert host.dc is False
db.clear_database()
def test_add_credential():
@ -116,93 +155,122 @@ def test_remove_credential():
def test_add_admin_user():
pass
def test_get_admin_relations():
pass
def test_remove_admin_relation():
pass
def test_is_credential_valid():
pass
def test_get_credentials():
pass
def test_get_credential():
pass
def test_is_credential_local():
pass
def test_is_host_valid():
pass
def test_get_hosts():
pass
def test_is_group_valid():
pass
def test_add_group():
pass
def test_get_groups():
pass
def test_get_group_relations():
pass
def test_remove_group_relations():
pass
def test_is_user_valid():
pass
def test_get_users():
pass
def test_get_user():
pass
def test_get_domain_controllers():
pass
def test_is_share_valid():
pass
def test_add_share():
pass
def test_get_shares():
pass
def test_get_shares_by_access():
pass
def test_get_users_with_share_access():
pass
def test_add_domain_backupkey():
pass
def test_get_domain_backupkey():
pass
def test_is_dpapi_secret_valid():
pass
def test_add_dpapi_secrets():
pass
def test_get_dpapi_secrets():
pass
def test_add_loggedin_relation():
pass
def test_get_loggedin_relations():
pass
def test_remove_loggedin_relations():
pass