diff --git a/cme/protocols/ftp/database.py b/cme/protocols/ftp/database.py index cefe81ce..e275821d 100644 --- a/cme/protocols/ftp/database.py +++ b/cme/protocols/ftp/database.py @@ -2,9 +2,12 @@ # -*- coding: utf-8 -*- class database: - - def __init__(self, conn): + def __init__(self, conn, metadata=None): + # this is still named "conn" when it is the Session object, TODO: rename self.conn = conn + self.metadata = metadata + self.credentials_table = metadata.tables["credentials"] + self.hosts_table = metadata.tables["hosts"] @staticmethod def db_schema(db_conn): @@ -20,3 +23,8 @@ class database: "port" integer, "server_banner" text )''') + + def clear_database(self): + for table in self.metadata.tables: + self.conn.query(self.metadata.tables[table]).delete() + self.conn.commit() diff --git a/cme/protocols/ftp/db_navigator.py b/cme/protocols/ftp/db_navigator.py index f1e59e66..5359851a 100644 --- a/cme/protocols/ftp/db_navigator.py +++ b/cme/protocols/ftp/db_navigator.py @@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator class navigator(DatabaseNavigator): - pass + def do_clear_database(self, line): + self.db.clear_database() diff --git a/cme/protocols/ldap/database.py b/cme/protocols/ldap/database.py index dbede2c3..f2116041 100644 --- a/cme/protocols/ldap/database.py +++ b/cme/protocols/ldap/database.py @@ -2,9 +2,12 @@ # -*- coding: utf-8 -*- class database: - - def __init__(self, conn): + def __init__(self, conn, metadata=None): + # this is still named "conn" when it is the Session object, TODO: rename self.conn = conn + self.metadata = metadata + self.credentials_table = metadata.tables["credentials"] + self.hosts_table = metadata.tables["hosts"] @staticmethod def db_schema(db_conn): @@ -20,3 +23,8 @@ class database: "hostname" text, "port" integer )''') + + def clear_database(self): + for table in self.metadata.tables: + self.conn.query(self.metadata.tables[table]).delete() + self.conn.commit() diff --git a/cme/protocols/ldap/db_navigator.py b/cme/protocols/ldap/db_navigator.py index f1e59e66..5359851a 100644 --- a/cme/protocols/ldap/db_navigator.py +++ b/cme/protocols/ldap/db_navigator.py @@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator class navigator(DatabaseNavigator): - pass + def do_clear_database(self, line): + self.db.clear_database() diff --git a/cme/protocols/mssql/database.py b/cme/protocols/mssql/database.py index 080dc6d8..7c709e4a 100755 --- a/cme/protocols/mssql/database.py +++ b/cme/protocols/mssql/database.py @@ -2,9 +2,13 @@ # -*- coding: utf-8 -*- class database: - - def __init__(self, conn): + def __init__(self, conn, metadata=None): + # this is still named "conn" when it is the Session object, TODO: rename self.conn = conn + self.metadata = metadata + self.credentials_table = metadata.tables["credentials"] + self.admin_relations_table = metadata.tables["admin_relations"] + self.users_table = metadata.tables["users"] @staticmethod def db_schema(db_conn): @@ -220,3 +224,8 @@ class database: results = cur.fetchall() cur.close() return results + + def clear_database(self): + for table in self.metadata.tables: + self.conn.query(self.metadata.tables[table]).delete() + self.conn.commit() diff --git a/cme/protocols/mssql/db_navigator.py b/cme/protocols/mssql/db_navigator.py index a08b7436..9fcc4ec8 100644 --- a/cme/protocols/mssql/db_navigator.py +++ b/cme/protocols/mssql/db_navigator.py @@ -172,6 +172,9 @@ class navigator(DatabaseNavigator): print_table(data, title='Admin Access to Host(s)') + def do_clear_database(self, line): + self.db.clear_database() + def complete_hosts(self, text, line, begidx, endidx): "Tab-complete 'creds' commands." diff --git a/cme/protocols/rdp/database.py b/cme/protocols/rdp/database.py index fa2d5099..9e2391aa 100644 --- a/cme/protocols/rdp/database.py +++ b/cme/protocols/rdp/database.py @@ -2,9 +2,12 @@ # -*- coding: utf-8 -*- class database: - - def __init__(self, conn): + def __init__(self, conn, metadata=None): + # this is still named "conn" when it is the Session object, TODO: rename self.conn = conn + self.metadata = metadata + self.credentials_table = metadata.tables["credentials"] + self.hosts_table = metadata.tables["hosts"] @staticmethod def db_schema(db_conn): @@ -21,4 +24,9 @@ class database: "hostname" text, "port" integer, "server_banner" text - )''') \ No newline at end of file + )''') + + def clear_database(self): + for table in self.metadata.tables: + self.conn.query(self.metadata.tables[table]).delete() + self.conn.commit() diff --git a/cme/protocols/rdp/db_navigator.py b/cme/protocols/rdp/db_navigator.py index f1e59e66..5359851a 100644 --- a/cme/protocols/rdp/db_navigator.py +++ b/cme/protocols/rdp/db_navigator.py @@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator class navigator(DatabaseNavigator): - pass + def do_clear_database(self, line): + self.db.clear_database() diff --git a/cme/protocols/smb/database.py b/cme/protocols/smb/database.py index 9af4def5..07cdb202 100755 --- a/cme/protocols/smb/database.py +++ b/cme/protocols/smb/database.py @@ -2,11 +2,10 @@ # -*- coding: utf-8 -*- import logging -from sqlalchemy import func +from sqlalchemy import func, text class database: - def __init__(self, conn, metadata=None): # this is still named "conn" when it is the Session object, TODO: rename self.conn = conn @@ -117,69 +116,109 @@ class database: # )''') def add_share(self, computerid, userid, name, remark, read, write): - self.conn.execute("INSERT OR IGNORE INTO shares (computerid, userid, name, remark, read, write) VALUES (?,?,?,?,?,?)", [computerid, userid, name, remark, read, write]) + data = { + "computerid": computerid, + "userid": userid, + "name": name, + "remark": remark, + "read": read, + "write": write, + } + self.conn.execute( + self.shares_table.insert(), + [data] + ) self.conn.commit() self.conn.close() - def is_share_valid(self, shareID): + def is_share_valid(self, share_id): """ Check if this share ID is valid. """ - self.conn.execute('SELECT * FROM shares WHERE id=? LIMIT 1', [shareID]) - results = self.conn.fetchall() + results = self.conn.query(self.shares_table).filter( + self.shares_table.c.id == share_id + ).all() self.conn.commit() self.conn.close() - logging.debug(f"is_share_valid(shareID={shareID}) => {len(results) > 0}") + logging.debug(f"is_share_valid(shareID={share_id}) => {len(results) > 0}") return len(results) > 0 - def get_shares(self, filterTerm = None): - if self.is_share_valid(filterTerm): - self.conn.execute("SELECT * FROM shares WHERE id=?", [filterTerm]) - elif filterTerm: - self.conn.execute("SELECT * FROM shares WHERE LOWER(name) LIKE LOWER(?)", [f"%{filterTerm}%"]) + def get_shares(self, filter_term=None): + if self.is_share_valid(filter_term): + results = self.conn.query(self.shares_table).filter( + self.shares_table.c.id == filter_term + ).all() + elif filter_term: + results = self.conn.query(self.shares_table).filter( + func.lower(self.shares_table.c.name).like(func.lower(f"%{filter_term}%")) + ).all() else: - self.conn.execute("SELECT * FROM shares") - - results = self.conn.fetchall() + results = self.conn.query(self.shares_table).all() return results - def get_shares_by_access(self, permissions, shareID=None): + def get_shares_by_access(self, permissions, share_id=None): permissions = permissions.lower() - if shareID: + if share_id: if permissions == "r": - self.conn.execute("SELECT * FROM shares WHERE id=? AND read=1",[shareID]) + results = self.conn.query(self.shares_table).filter( + self.shares_table.c.id == share_id, + self.shares_table.c.read == 1 + ).all() elif permissions == "w": - self.conn.execute("SELECT * FROM shares WHERE id=? write=1", [shareID]) + results = self.conn.query(self.shares_table).filter( + self.shares_table.c.id == share_id, + self.shares_table.c.write == 1 + ).all() elif permissions == "rw": - self.conn.execute("SELECT * FROM shares WHERE id=? AND read=1 AND write=1", [shareID]) + results = self.conn.query(self.shares_table).filter( + self.shares_table.c.id == share_id, + self.shares_table.c.read == 1, + self.shares_table.c.write == 1 + ).all() else: if permissions == "r": - self.conn.execute("SELECT * FROM shares WHERE read=1") + results = self.conn.query(self.shares_table).filter( + self.shares_table.c.read == 1 + ).all() elif permissions == "w": - self.conn.execute("SELECT * FROM shares WHERE write=1") + results = self.conn.query(self.shares_table).filter( + self.shares_table.c.write == 1 + ).all() elif permissions == "rw": - self.conn.execute("SELECT * FROM shares WHERE read= AND write=1") - - results = self.conn.fetchall() + results = self.conn.query(self.shares_table).filter( + self.shares_table.c.read == 1, + self.shares_table.c.write == 1 + ).all() return results - def get_users_with_share_access(self, computerID, share_name, permissions): + def get_users_with_share_access(self, computer_id, share_name, permissions): permissions = permissions.lower() if permissions == "r": - self.conn.execute("SELECT userid FROM shares WHERE computerid=(?) AND name=(?) AND read=1", [computerID, share_name]) + results = self.conn.query(self.shares_table.c.userid).filter( + self.shares_table.c.computerid == computer_id, + self.shares_table.c.name == share_name, + self.shares_table.c.read == 1 + ).all() elif permissions == "w": - self.conn.execute("SELECT userid FROM shares WHERE computerid=(?) AND name=(?) AND write=1", [computerID, share_name]) + results = self.conn.query(self.shares_table.c.userid).filter( + self.shares_table.c.computerid == computer_id, + self.shares_table.c.name == share_name, + self.shares_table.c.write == 1 + ).all() elif permissions == "rw": - self.conn.execute("SELECT userid FROM shares WHERE computerid=(?) AND name=(?) AND read=1 AND write=1", [computerID, share_name]) - - results = self.conn.fetchall() + results = self.conn.query(self.shares_table.c.userid).filter( + self.shares_table.c.computerid == computer_id, + self.shares_table.c.name == share_name, + self.shares_table.c.read == 1, + self.shares_table.c.write == 1 + ).all() return results - #pull/545 - def add_computer(self, ip, hostname, domain, os, smbv1, signing=None, spooler=0, zerologon=0, petitpotam=0, dc=None): + # pull/545 + def add_computer(self, ip, hostname, domain, os, smbv1, signing=None, spooler=None, zerologon=None, petitpotam=None, dc=None): """ Check if this host has already been added to the database, if not add it in. """ @@ -188,51 +227,88 @@ class database: results = self.conn.query(self.computers_table).filter( self.computers_table.c.ip == ip ).all() - host = { - "ip": ip, - "hostname": hostname, - "domain": domain, - "os": os, - "dc": dc, - "smbv1": smbv1, - "signing": signing, - "spooler": spooler, - "zerologon": zerologon, - "petitpotam": petitpotam - } + data = {} + if ip is not None: + data["ip"] = ip + if hostname is not None: + data["hostname"] = hostname + if domain is not None: + data["domain"] = domain + if os is not None: + data["os"] = os + if smbv1 is not None: + data["smbv1"] = smbv1 + if signing is not None: + data["signing"] = signing + if spooler is not None: + data["spooler"] = spooler + if zerologon is not None: + data["zerologon"] = zerologon + if petitpotam is not None: + data["petitpotam"] = petitpotam + if dc is not None: + data["dc"] = dc + print(f"DATA: {data}") + print(f"RESULTS: {results}") - print(f"IP: {ip}") - print(f"Hostname: {hostname}") - print(f"Domain: {domain}") - print(f"OS: {os}") - print(f"SMB: {smbv1}") - print(f"Signing: {signing}") - print(f"DC: {dc}") if not results: - # host doesn't exist in the DB - pass + print(f"IP: {ip}") + print(f"Hostname: {hostname}") + print(f"Domain: {domain}") + print(f"OS: {os}") + print(f"SMB: {smbv1}") + print(f"Signing: {signing}") + print(f"DC: {dc}") - if not len(results): + new_host = { + "ip": ip, + "hostname": hostname, + "domain": domain, + "os": os, + "dc": dc, + "smbv1": smbv1, + "signing": signing, + "spooler": spooler, + "zerologon": zerologon, + "petitpotam": petitpotam + } try: - self.conn.execute("INSERT INTO computers (ip, hostname, domain, os, dc, smbv1, signing) VALUES (?,?,?,?,?,?,?,?,?,?)", [ip, hostname, domain, os, dc, smbv1, signing, spooler, zerologon, petitpotam]) + cid = self.conn.execute( + self.computers_table.insert(), + [new_host] + ) except Exception as e: print(f"Exception: {e}") - self.conn.execute("INSERT INTO computers (ip, hostname, domain, os, dc) VALUES (?,?,?,?,?)", [ip, hostname, domain, os, dc]) + #self.conn.execute("INSERT INTO computers (ip, hostname, domain, os, dc) VALUES (?,?,?,?,?)", [ip, hostname, domain, os, dc]) else: for host in results: + print(host.id) + print(f"Host: {host}") + print(f"Host Type: {type(host)}") try: - if (hostname != host[2]) or (domain != host[3]) or (os != host[4]) or (smbv1 != host[6]) or (signing != host[7]): - self.conn.execute("UPDATE computers SET hostname=?, domain=?, os=?, smbv1=?, signing=?, spooler=?, zerologon=?, petitpotam=? WHERE id=?", [hostname, domain, os, smbv1, signing, spooler, zerologon, petitpotam, host[0]]) - except: - if (hostname != host[2]) or (domain != host[3]) or (os != host[4]): - self.conn.execute("UPDATE computers SET hostname=?, domain=?, os=? WHERE id=?", [hostname, domain, os, host[0]]) - if dc != None and (dc != host[5]): - self.conn.execute("UPDATE computers SET dc=? WHERE id=?", [dc, host[0]]) + cid = self.conn.execute( + self.computers_table.update().values( + data + ).where( + self.computers_table.c.id == host.id + ) + ) + self.conn.commit() + except Exception as e: + print(f"Exception: {e}") + # try: + # if (hostname != host[2]) or (domain != host[3]) or (os != host[4]) or (smbv1 != host[6]) or (signing != host[7]): + # self.conn.execute("UPDATE computers SET hostname=?, domain=?, os=?, smbv1=?, signing=?, spooler=?, zerologon=?, petitpotam=? WHERE id=?", [hostname, domain, os, smbv1, signing, spooler, zerologon, petitpotam, host[0]]) + # except: + # if (hostname != host[2]) or (domain != host[3]) or (os != host[4]): + # self.conn.execute("UPDATE computers SET hostname=?, domain=?, os=? WHERE id=?", [hostname, domain, os, host[0]]) + # if dc != None and (dc != host[5]): + # self.conn.execute("UPDATE computers SET dc=? WHERE id=?", [dc, host[0]]) self.conn.commit() self.conn.close() - return self.conn.lastrowid + return cid def update_computer(self, host_id, hostname=None, domain=None, os=None, smbv1=None, signing=None, spooler=None, zerologon=None, petitpotam=None, dc=None): data = { @@ -250,23 +326,46 @@ class database: user_rowid = None if groupid and not self.is_group_valid(groupid): - self.conn.commit() self.conn.close() return if pillaged_from and not self.is_computer_valid(pillaged_from): - self.conn.commit() self.conn.close() return - self.conn.execute("SELECT * FROM users WHERE LOWER(domain)=LOWER(?) AND LOWER(username)=LOWER(?) AND LOWER(credtype)=LOWER(?)", [domain, username, credtype]) - results = self.conn.fetchall() + results = self.conn.query(self.users_table).filter( + func.lower(self.users_table.c.domain) == func.lower(domain), + func.lower(self.users_table.c.username) == func.lower(username), + func.lower(self.users_table.c.credtype) == func.lower(credtype) + ).all() + logging.debug(f"Credential results: {results}") - if not len(results): - self.conn.execute("INSERT INTO users (domain, username, password, credtype, pillaged_from_computerid) VALUES (?,?,?,?,?)", [domain, username, password, credtype, pillaged_from]) - user_rowid = self.conn.lastrowid + if not results: + data = { + "domain": domain, + "username": username, + "password": password, + "credtype": credtype, + "pillaged_from_computerid": pillaged_from, + } + #self.conn.execute("INSERT INTO users (domain, username, password, credtype, pillaged_from_computerid) VALUES (?,?,?,?,?)", [domain, username, password, credtype, pillaged_from]) + user_rowid = self.conn.execute( + self.users_table.insert(), + [data] + ) + logging.debug(f"User RowID: {user_rowid}") + #user_rowid = self.conn.lastrowid if groupid: - self.conn.execute("INSERT INTO group_relations (userid, groupid) VALUES (?,?)", [user_rowid, groupid]) + gr_data = { + "userid": user_rowid, + "groupid": groupid, + } + #self.conn.execute("INSERT INTO group_relations (userid, groupid) VALUES (?,?)", [user_rowid, groupid]) + self.conn.execute( + self.group_relations_table.insert(), + [gr_data] + ) + self.conn.commit() else: for user in results: if not user[3] and not user[4] and not user[5]: @@ -330,40 +429,53 @@ class database: return self.conn.lastrowid - def remove_credentials(self, credIDs): + def remove_credentials(self, creds_id): """ Removes a credential ID from the database """ - for credID in credIDs: - - self.conn.execute("DELETE FROM users WHERE id=?", [credID]) + for cred_id in creds_id: + self.conn.execute("DELETE FROM users WHERE id=?", [cred_id]) self.conn.commit() - self.conn.close() + self.conn.close() - def add_admin_user(self, credtype, domain, username, password, host, userid=None): + def add_admin_user(self, credtype, domain, username, password, host, user_id=None): domain = domain.split('.')[0].upper() - if userid: - self.conn.execute("SELECT * FROM users WHERE id=?", [userid]) - users = self.conn.fetchall() + if user_id: + users = self.conn.query(self.users_table).filter( + self.users_table.c.id == user_id + ).all() else: - self.conn.execute("SELECT * FROM users WHERE credtype=? AND LOWER(domain)=LOWER(?) AND LOWER(username)=LOWER(?) AND password=?", [credtype, domain, username, password]) - users = self.conn.fetchall() + users = self.conn.query(self.users_table).filter( + self.users_table.c.credtype == credtype, + func.lower(self.users_table.c.domain) == func.lower(domain), + func.lower(self.users_table.c.username) == func.lower(username), + self.users_table.c.password == password + ).all() + logging.debug(f"Users: {users}") - self.conn.execute('SELECT * FROM computers WHERE ip LIKE ?', [host]) - hosts = self.conn.fetchall() + hosts = self.conn.query(self.computers_table).filter( + self.computers_table.c.ip.like(func.lower(f"%{host}%")) + ) + logging.debug(f"Hosts: {hosts}") - if len(users) and len(hosts): + if users is not None and hosts is not None: for user, host in zip(users, hosts): - userid = user[0] - hostid = host[0] + user_id = user[0] + host_id = host[0] - #Check to see if we already added this link - self.conn.execute("SELECT * FROM admin_relations WHERE userid=? AND computerid=?", [userid, hostid]) - links = self.conn.fetchall() + # Check to see if we already added this link + links = self.conn.query(self.admin_relations_table).filter( + self.admin_relations_table.c.userid == user_id, + self.admin_relations_table.c.computerid == host_id + ).all() - if not len(links): - self.conn.execute("INSERT INTO admin_relations (userid, computerid) VALUES (?,?)", [userid, hostid]) + if not links: + self.conn.execute( + self.admin_relations_table.insert(), + [{"userid": user_id, "computerid": host_id}] + ) + self.conn.commit() self.conn.commit() self.conn.close() @@ -501,8 +613,10 @@ class database: return results def get_user(self, domain, username): - self.conn.execute("SELECT * FROM users WHERE LOWER(domain)=LOWER(?) AND LOWER(username)=LOWER(?)", [domain, username]) - results = self.conn.fetchall() + results = self.conn.query(self.users_table).filter( + func.lower(self.users_table.c.domain) == func.lower(domain), + func.lower(self.users_table.c.username) == func.lower(username) + ).all() self.conn.commit() self.conn.close() return results @@ -587,7 +701,7 @@ class database: self.conn.close() logging.debug(f"get_groups(filterTerm={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}") return results - + def add_domain_backupkey(self, domain:str, pvk:bytes): """ Add domain backupkey @@ -626,7 +740,7 @@ class database: import base64 results = [(idkey, domain, base64.b64decode(pvk)) for idkey, domain, pvk in results] return results - + def is_dpapi_secret_valid(self, dpapiSecretID): """ Check if this group ID is valid. @@ -672,4 +786,9 @@ class database: results = cur.fetchall() cur.close() logging.debug('get_dpapi_secrets(filterTerm={}, computer={}, dpapi_type={}, windows_user={}, username={}, url={}) => {}'.format(filterTerm, computer, dpapi_type, windows_user, username, url, results)) - return results \ No newline at end of file + return results + + def clear_database(self): + for table in self.metadata.tables: + self.conn.query(self.metadata.tables[table]).delete() + self.conn.commit() diff --git a/cme/protocols/smb/db_navigator.py b/cme/protocols/smb/db_navigator.py index 3d0faab3..452b8f90 100644 --- a/cme/protocols/smb/db_navigator.py +++ b/cme/protocols/smb/db_navigator.py @@ -3,6 +3,7 @@ from cme.helpers.misc import validate_ntlm from cme.cmedb import DatabaseNavigator, print_table +from sqlalchemy.sql import text class navigator(DatabaseNavigator): @@ -86,13 +87,13 @@ class navigator(DatabaseNavigator): remark = share[4] users_r_access = self.db.get_users_with_share_access( - computerID=computerid, + computer_id=computerid, share_name=name, permissions='r' ) users_w_access = self.db.get_users_with_share_access( - computerID=computerid, + computer_id=computerid, share_name=name, permissions='w' ) @@ -108,7 +109,7 @@ class navigator(DatabaseNavigator): shares = self.db.get_shares() self.display_shares(shares) else: - shares = self.db.get_shares(filterTerm=filterTerm) + shares = self.db.get_shares(filter_term=filterTerm) if len(shares) > 1: self.display_shares(shares) @@ -120,13 +121,13 @@ class navigator(DatabaseNavigator): remark = share[4] users_r_access = self.db.get_users_with_share_access( - computerID=computerID, + computer_id=computerID, share_name=name, permissions='r' ) users_w_access = self.db.get_users_with_share_access( - computerID=computerID, + computer_id=computerID, share_name=name, permissions='w' ) @@ -431,6 +432,9 @@ class navigator(DatabaseNavigator): print_table(data, title='Admin Access to Host(s)') + def do_clear_database(self, line): + self.db.clear_database() + def complete_hosts(self, text, line, begidx, endidx): "Tab-complete 'creds' commands." diff --git a/cme/protocols/ssh/database.py b/cme/protocols/ssh/database.py index fa2d5099..9e2391aa 100644 --- a/cme/protocols/ssh/database.py +++ b/cme/protocols/ssh/database.py @@ -2,9 +2,12 @@ # -*- coding: utf-8 -*- class database: - - def __init__(self, conn): + def __init__(self, conn, metadata=None): + # this is still named "conn" when it is the Session object, TODO: rename self.conn = conn + self.metadata = metadata + self.credentials_table = metadata.tables["credentials"] + self.hosts_table = metadata.tables["hosts"] @staticmethod def db_schema(db_conn): @@ -21,4 +24,9 @@ class database: "hostname" text, "port" integer, "server_banner" text - )''') \ No newline at end of file + )''') + + def clear_database(self): + for table in self.metadata.tables: + self.conn.query(self.metadata.tables[table]).delete() + self.conn.commit() diff --git a/cme/protocols/ssh/db_navigator.py b/cme/protocols/ssh/db_navigator.py index f1e59e66..5359851a 100644 --- a/cme/protocols/ssh/db_navigator.py +++ b/cme/protocols/ssh/db_navigator.py @@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator class navigator(DatabaseNavigator): - pass + def do_clear_database(self, line): + self.db.clear_database() diff --git a/cme/protocols/winrm/database.py b/cme/protocols/winrm/database.py index dbede2c3..f2116041 100644 --- a/cme/protocols/winrm/database.py +++ b/cme/protocols/winrm/database.py @@ -2,9 +2,12 @@ # -*- coding: utf-8 -*- class database: - - def __init__(self, conn): + def __init__(self, conn, metadata=None): + # this is still named "conn" when it is the Session object, TODO: rename self.conn = conn + self.metadata = metadata + self.credentials_table = metadata.tables["credentials"] + self.hosts_table = metadata.tables["hosts"] @staticmethod def db_schema(db_conn): @@ -20,3 +23,8 @@ class database: "hostname" text, "port" integer )''') + + def clear_database(self): + for table in self.metadata.tables: + self.conn.query(self.metadata.tables[table]).delete() + self.conn.commit() diff --git a/cme/protocols/winrm/db_navigator.py b/cme/protocols/winrm/db_navigator.py index f1e59e66..5359851a 100644 --- a/cme/protocols/winrm/db_navigator.py +++ b/cme/protocols/winrm/db_navigator.py @@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator class navigator(DatabaseNavigator): - pass + def do_clear_database(self, line): + self.db.clear_database()