feat(cmedb): update some functions for smb.creds

main
Marshall Hallenbeck 2023-03-03 20:38:31 -05:00
parent bc2ba6a025
commit 6d140bb1ce
4 changed files with 23 additions and 27 deletions

View File

@ -169,7 +169,7 @@ class connection(object):
if cred_id.lower() == 'all': if cred_id.lower() == 'all':
creds = self.db.get_credentials() creds = self.db.get_credentials()
else: else:
creds = self.db.get_credentials(filterTerm=int(cred_id)) creds = self.db.get_credentials(filter_term=int(cred_id))
for cred in creds: for cred in creds:
logging.debug(cred) logging.debug(cred)
try: try:

View File

@ -80,7 +80,7 @@ class navigator(DatabaseNavigator):
for link in links: for link in links:
linkID, credID, hostID = link linkID, credID, hostID = link
creds = self.db.get_credentials(filterTerm=credID) creds = self.db.get_credentials(filter_term=credID)
for cred in creds: for cred in creds:
credID = cred[0] credID = cred[0]
@ -127,15 +127,15 @@ class navigator(DatabaseNavigator):
self.db.remove_links(credIDs=args) self.db.remove_links(credIDs=args)
elif filterTerm.split()[0].lower() == "plaintext": elif filterTerm.split()[0].lower() == "plaintext":
creds = self.db.get_credentials(credtype="plaintext") creds = self.db.get_credentials(cred_type="plaintext")
self.display_creds(creds) self.display_creds(creds)
elif filterTerm.split()[0].lower() == "hash": elif filterTerm.split()[0].lower() == "hash":
creds = self.db.get_credentials(credtype="hash") creds = self.db.get_credentials(cred_type="hash")
self.display_creds(creds) self.display_creds(creds)
else: else:
creds = self.db.get_credentials(filterTerm=filterTerm) creds = self.db.get_credentials(filter_term=filterTerm)
data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']] data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']]
credIDList = [] credIDList = []

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
from sqlalchemy import func
class database: class database:
@ -419,12 +420,11 @@ class database:
return results return results
def is_credential_valid(self, credentialID): def is_credential_valid(self, credential_id):
""" """
Check if this credential ID is valid. Check if this credential ID is valid.
""" """
self.conn.execute('SELECT * FROM users WHERE id=? AND password IS NOT NULL LIMIT 1', [credentialID]) results = self.conn.query(self.users_table).filter(self.users_table.c.id == credential_id, self.users_table.c.password is not None).all()
results = self.conn.fetchall()
self.conn.commit() self.conn.commit()
self.conn.close() self.conn.close()
return len(results) > 0 return len(results) > 0
@ -440,26 +440,22 @@ class database:
self.conn.close() self.conn.close()
return len(results) > 0 return len(results) > 0
def get_credentials(self, filterTerm=None, credtype=None): def get_credentials(self, filter_term=None, cred_type=None):
""" """
Return credentials from the database. Return credentials from the database.
""" """
# if we're returning a single credential by ID # if we're returning a single credential by ID
if self.is_credential_valid(filterTerm): if self.is_credential_valid(filter_term):
self.conn.execute("SELECT * FROM users WHERE id=?", [filterTerm]) results = self.conn.query(self.users_table).filter(self.users_table.c.id == filter_term).all()
elif cred_type:
elif credtype: results = self.conn.query(self.users_table).filter(self.users_table.c.credtype == cred_type).all()
self.conn.execute("SELECT * FROM users WHERE credtype=?", [credtype])
# if we're filtering by username # if we're filtering by username
elif filterTerm and filterTerm != '': elif filter_term and filter_term != '':
self.conn.execute("SELECT * FROM users WHERE LOWER(username) LIKE LOWER(?)", ['%{}%'.format(filterTerm)]) results = self.conn.query(self.users_table).filter(func.lower(self.users_table.c.username).like(func.lower(f"%{filter_term}%"))).all()
# otherwise return all credentials # otherwise return all credentials
else: else:
self.conn.execute("SELECT * FROM users") results = self.conn.query(self.users_table).all()
results = self.conn.fetchall()
self.conn.commit() self.conn.commit()
self.conn.close() self.conn.close()
return results return results

View File

@ -157,7 +157,7 @@ class navigator(DatabaseNavigator):
data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']] data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']]
for user in users_r_access: for user in users_r_access:
userid = user[0] userid = user[0]
creds = self.db.get_credentials(filterTerm=userid) creds = self.db.get_credentials(filter_term=userid)
for cred in creds: for cred in creds:
credID = cred[0] credID = cred[0]
@ -174,7 +174,7 @@ class navigator(DatabaseNavigator):
data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']] data = [['CredID', 'CredType', 'Domain', 'UserName', 'Password']]
for user in users_w_access: for user in users_w_access:
userid = user[0] userid = user[0]
creds = self.db.get_credentials(filterTerm=userid) creds = self.db.get_credentials(filter_term=userid)
for cred in creds: for cred in creds:
credID = cred[0] credID = cred[0]
@ -221,7 +221,7 @@ class navigator(DatabaseNavigator):
for member in members: for member in members:
_,userid,_ = member _,userid,_ = member
creds = self.db.get_credentials(filterTerm=userid) creds = self.db.get_credentials(filter_term=userid)
for cred in creds: for cred in creds:
credID = cred[0] credID = cred[0]
@ -271,7 +271,7 @@ class navigator(DatabaseNavigator):
for link in links: for link in links:
linkID, credID, hostID = link linkID, credID, hostID = link
creds = self.db.get_credentials(filterTerm=credID) creds = self.db.get_credentials(filter_term=credID)
for cred in creds: for cred in creds:
credID = cred[0] credID = cred[0]
@ -366,15 +366,15 @@ class navigator(DatabaseNavigator):
self.db.remove_admin_relation(userIDs=args) self.db.remove_admin_relation(userIDs=args)
elif filterTerm.split()[0].lower() == "plaintext": elif filterTerm.split()[0].lower() == "plaintext":
creds = self.db.get_credentials(credtype="plaintext") creds = self.db.get_credentials(cred_type="plaintext")
self.display_creds(creds) self.display_creds(creds)
elif filterTerm.split()[0].lower() == "hash": elif filterTerm.split()[0].lower() == "hash":
creds = self.db.get_credentials(credtype="hash") creds = self.db.get_credentials(cred_type="hash")
self.display_creds(creds) self.display_creds(creds)
else: else:
creds = self.db.get_credentials(filterTerm=filterTerm) creds = self.db.get_credentials(filter_term=filterTerm)
if len(creds) != 1: if len(creds) != 1:
self.display_creds(creds) self.display_creds(creds)
elif len(creds) == 1: elif len(creds) == 1: