refactor(smbdb): change all add_user references to add_credential and refactor some if statements

main
Marshall Hallenbeck 2023-03-08 00:27:14 -05:00
parent b25b74d473
commit 1d33c58059
2 changed files with 9 additions and 70 deletions

View File

@ -891,7 +891,7 @@ class smb(connection):
# So I put a domain group as a member of a local group which is also a member of another local group.
# (╯°□°)╯︵ ┻━┻
if not group.isgroup:
self.db.add_user(domain, name, group_id)
self.db.add_credential("", domain, name, "", group_id, "")
elif group.isgroup:
self.db.add_group(domain, name, member_count_ad=group.membercount)
break
@ -957,11 +957,7 @@ class smb(connection):
member_count_ad=member_count
)
if not group.isgroup:
self.db.add_user(
group.memberdomain,
group.membername,
group_id
)
self.db.add_credential("", group.memberdomain, group.membername, "", group_id, "")
elif group.isgroup:
self.db.add_group(
group.groupdomain,

View File

@ -240,7 +240,7 @@ class database:
self.ComputersTable.c.ip == ip
)
results = asyncio.run(self.conn.execute(q)).all()
logging.debug(f"Results in add_computer: {results}")
logging.debug(f"add_computer() - computers returned: {results}")
host_data = {
"ip": ip,
@ -306,13 +306,14 @@ class database:
domain = domain.split('.')[0].upper()
user_rowid = None
if group_id and not self.is_group_valid(group_id):
if (group_id and not self.is_group_valid(group_id)) or \
(pillaged_from and not self.is_computer_valid(pillaged_from)):
self.conn.close()
return
if pillaged_from and not self.is_computer_valid(pillaged_from):
self.conn.close()
return
# if pillaged_from and not self.is_computer_valid(pillaged_from):
# self.conn.close()
# return
credential_data = {}
if credtype is not None:
@ -385,64 +386,6 @@ class database:
))
return user_rowid
def add_user(self, domain, username, group_id=None):
if group_id and not self.is_group_valid(group_id):
return
domain = domain.split('.')[0].upper()
user_rowid = None
user_data = {
"password": "",
"credtype": "",
"pillaged_from_computerid": ""
}
if domain is not None:
user_data["domain"] = domain
if username is not None:
user_data["username"] = username
if group_id is not None:
user_data["group_id"] = group_id
results = self.conn.query(self.UsersTable).filter(
func.lower(self.UsersTable.c.domain) == func.lower(domain),
func.lower(self.UsersTable.c.username) == func.lower(username)
).all()
if not len(results):
user_rowid = self.conn.execute(
self.UsersTable.insert(),
[user_data]
)
if group_id:
gr_data = {
"userid": user_rowid,
"groupid": group_id,
}
self.conn.execute(
self.GroupRelationsTable.insert(),
[gr_data]
)
else:
for user in results:
if (domain != user[1]) and (username != user[2]):
user_rowid = self.conn.execute(self.UsersTable.update().values(
user_data
).where(
self.UsersTable.c.id == user[0]
))
if not user_rowid: user_rowid = user[0]
if group_id and not len(self.get_group_relations(user_rowid, group_id)):
self.conn.execute(self.GroupRelationsTable.update().values(
{"userid": user_rowid, "groupid": group_id}
))
self.conn.commit()
self.conn.close()
logging.debug('add_user(domain={}, username={}, groupid={}) => {}'.format(domain, username, group_id, user_rowid))
return user_rowid
def add_group(self, domain, name, member_count_ad=None):
domain = domain.split('.')[0].upper()
groups = []
@ -452,7 +395,7 @@ class database:
func.lower(self.GroupsTable.c.name) == func.lower(name)
)
results = asyncio.run(self.conn.execute(q)).all()
logging.debug(f"add_group(): groups returned: {results}")
logging.debug(f"add_group() - groups returned: {results}")
group_data = {
"domain": domain,