From 6d140bb1ceff42c671a2bc32d3997cefe8e1ec97 Mon Sep 17 00:00:00 2001 From: Marshall Hallenbeck Date: Fri, 3 Mar 2023 20:38:31 -0500 Subject: [PATCH] feat(cmedb): update some functions for smb.creds --- cme/connection.py | 2 +- cme/protocols/mssql/db_navigator.py | 8 ++++---- cme/protocols/smb/database.py | 26 +++++++++++--------------- cme/protocols/smb/db_navigator.py | 14 +++++++------- 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/cme/connection.py b/cme/connection.py index ce720c2f..e326618c 100755 --- a/cme/connection.py +++ b/cme/connection.py @@ -169,7 +169,7 @@ class connection(object): if cred_id.lower() == 'all': creds = self.db.get_credentials() else: - creds = self.db.get_credentials(filterTerm=int(cred_id)) + creds = self.db.get_credentials(filter_term=int(cred_id)) for cred in creds: logging.debug(cred) try: diff --git a/cme/protocols/mssql/db_navigator.py b/cme/protocols/mssql/db_navigator.py index fa48e727..a08b7436 100644 --- a/cme/protocols/mssql/db_navigator.py +++ b/cme/protocols/mssql/db_navigator.py @@ -80,7 +80,7 @@ class navigator(DatabaseNavigator): for link in links: linkID, credID, hostID = link - creds = self.db.get_credentials(filterTerm=credID) + creds = self.db.get_credentials(filter_term=credID) for cred in creds: credID = cred[0] @@ -127,15 +127,15 @@ class navigator(DatabaseNavigator): self.db.remove_links(credIDs=args) elif filterTerm.split()[0].lower() == "plaintext": - creds = self.db.get_credentials(credtype="plaintext") + creds = self.db.get_credentials(cred_type="plaintext") self.display_creds(creds) elif filterTerm.split()[0].lower() == "hash": - creds = self.db.get_credentials(credtype="hash") + creds = self.db.get_credentials(cred_type="hash") self.display_creds(creds) else: - creds = self.db.get_credentials(filterTerm=filterTerm) + creds = self.db.get_credentials(filter_term=filterTerm) data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']] credIDList = [] diff --git a/cme/protocols/smb/database.py b/cme/protocols/smb/database.py index a709b3ad..b74e847e 100755 --- a/cme/protocols/smb/database.py +++ b/cme/protocols/smb/database.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- import logging +from sqlalchemy import func class database: @@ -419,12 +420,11 @@ class database: return results - def is_credential_valid(self, credentialID): + def is_credential_valid(self, credential_id): """ Check if this credential ID is valid. """ - self.conn.execute('SELECT * FROM users WHERE id=? AND password IS NOT NULL LIMIT 1', [credentialID]) - results = self.conn.fetchall() + results = self.conn.query(self.users_table).filter(self.users_table.c.id == credential_id, self.users_table.c.password is not None).all() self.conn.commit() self.conn.close() return len(results) > 0 @@ -440,26 +440,22 @@ class database: self.conn.close() return len(results) > 0 - def get_credentials(self, filterTerm=None, credtype=None): + def get_credentials(self, filter_term=None, cred_type=None): """ Return credentials from the database. """ # if we're returning a single credential by ID - if self.is_credential_valid(filterTerm): - self.conn.execute("SELECT * FROM users WHERE id=?", [filterTerm]) - - elif credtype: - self.conn.execute("SELECT * FROM users WHERE credtype=?", [credtype]) - + if self.is_credential_valid(filter_term): + results = self.conn.query(self.users_table).filter(self.users_table.c.id == filter_term).all() + elif cred_type: + results = self.conn.query(self.users_table).filter(self.users_table.c.credtype == cred_type).all() # if we're filtering by username - elif filterTerm and filterTerm != '': - self.conn.execute("SELECT * FROM users WHERE LOWER(username) LIKE LOWER(?)", ['%{}%'.format(filterTerm)]) - + elif filter_term and filter_term != '': + results = self.conn.query(self.users_table).filter(func.lower(self.users_table.c.username).like(func.lower(f"%{filter_term}%"))).all() # otherwise return all credentials else: - self.conn.execute("SELECT * FROM users") + results = self.conn.query(self.users_table).all() - results = self.conn.fetchall() self.conn.commit() self.conn.close() return results diff --git a/cme/protocols/smb/db_navigator.py b/cme/protocols/smb/db_navigator.py index 164bc940..cc20f09f 100644 --- a/cme/protocols/smb/db_navigator.py +++ b/cme/protocols/smb/db_navigator.py @@ -157,7 +157,7 @@ class navigator(DatabaseNavigator): data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']] for user in users_r_access: userid = user[0] - creds = self.db.get_credentials(filterTerm=userid) + creds = self.db.get_credentials(filter_term=userid) for cred in creds: credID = cred[0] @@ -174,7 +174,7 @@ class navigator(DatabaseNavigator): data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']] for user in users_w_access: userid = user[0] - creds = self.db.get_credentials(filterTerm=userid) + creds = self.db.get_credentials(filter_term=userid) for cred in creds: credID = cred[0] @@ -221,7 +221,7 @@ class navigator(DatabaseNavigator): for member in members: _,userid,_ = member - creds = self.db.get_credentials(filterTerm=userid) + creds = self.db.get_credentials(filter_term=userid) for cred in creds: credID = cred[0] @@ -271,7 +271,7 @@ class navigator(DatabaseNavigator): for link in links: linkID, credID, hostID = link - creds = self.db.get_credentials(filterTerm=credID) + creds = self.db.get_credentials(filter_term=credID) for cred in creds: credID = cred[0] @@ -366,15 +366,15 @@ class navigator(DatabaseNavigator): self.db.remove_admin_relation(userIDs=args) elif filterTerm.split()[0].lower() == "plaintext": - creds = self.db.get_credentials(credtype="plaintext") + creds = self.db.get_credentials(cred_type="plaintext") self.display_creds(creds) elif filterTerm.split()[0].lower() == "hash": - creds = self.db.get_credentials(credtype="hash") + creds = self.db.get_credentials(cred_type="hash") self.display_creds(creds) else: - creds = self.db.get_credentials(filterTerm=filterTerm) + creds = self.db.get_credentials(filter_term=filterTerm) if len(creds) != 1: self.display_creds(creds) elif len(creds) == 1: