feat(database): update each protocol to use sqlalchemy table reference and add database clear function; closes #189

main
Marshall Hallenbeck 2023-03-04 11:12:29 -05:00
parent 10e7180c20
commit e34fdc2dda
14 changed files with 303 additions and 123 deletions

View File

@ -2,9 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
class database: class database:
def __init__(self, conn, metadata=None):
def __init__(self, conn): # this is still named "conn" when it is the Session object, TODO: rename
self.conn = conn self.conn = conn
self.metadata = metadata
self.credentials_table = metadata.tables["credentials"]
self.hosts_table = metadata.tables["hosts"]
@staticmethod @staticmethod
def db_schema(db_conn): def db_schema(db_conn):
@ -20,3 +23,8 @@ class database:
"port" integer, "port" integer,
"server_banner" text "server_banner" text
)''') )''')
def clear_database(self):
for table in self.metadata.tables:
self.conn.query(self.metadata.tables[table]).delete()
self.conn.commit()

View File

@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator
class navigator(DatabaseNavigator): class navigator(DatabaseNavigator):
pass def do_clear_database(self, line):
self.db.clear_database()

View File

@ -2,9 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
class database: class database:
def __init__(self, conn, metadata=None):
def __init__(self, conn): # this is still named "conn" when it is the Session object, TODO: rename
self.conn = conn self.conn = conn
self.metadata = metadata
self.credentials_table = metadata.tables["credentials"]
self.hosts_table = metadata.tables["hosts"]
@staticmethod @staticmethod
def db_schema(db_conn): def db_schema(db_conn):
@ -20,3 +23,8 @@ class database:
"hostname" text, "hostname" text,
"port" integer "port" integer
)''') )''')
def clear_database(self):
for table in self.metadata.tables:
self.conn.query(self.metadata.tables[table]).delete()
self.conn.commit()

View File

@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator
class navigator(DatabaseNavigator): class navigator(DatabaseNavigator):
pass def do_clear_database(self, line):
self.db.clear_database()

View File

@ -2,9 +2,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
class database: class database:
def __init__(self, conn, metadata=None):
def __init__(self, conn): # this is still named "conn" when it is the Session object, TODO: rename
self.conn = conn self.conn = conn
self.metadata = metadata
self.credentials_table = metadata.tables["credentials"]
self.admin_relations_table = metadata.tables["admin_relations"]
self.users_table = metadata.tables["users"]
@staticmethod @staticmethod
def db_schema(db_conn): def db_schema(db_conn):
@ -220,3 +224,8 @@ class database:
results = cur.fetchall() results = cur.fetchall()
cur.close() cur.close()
return results return results
def clear_database(self):
for table in self.metadata.tables:
self.conn.query(self.metadata.tables[table]).delete()
self.conn.commit()

View File

@ -172,6 +172,9 @@ class navigator(DatabaseNavigator):
print_table(data, title='Admin Access to Host(s)') print_table(data, title='Admin Access to Host(s)')
def do_clear_database(self, line):
self.db.clear_database()
def complete_hosts(self, text, line, begidx, endidx): def complete_hosts(self, text, line, begidx, endidx):
"Tab-complete 'creds' commands." "Tab-complete 'creds' commands."

View File

@ -2,9 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
class database: class database:
def __init__(self, conn, metadata=None):
def __init__(self, conn): # this is still named "conn" when it is the Session object, TODO: rename
self.conn = conn self.conn = conn
self.metadata = metadata
self.credentials_table = metadata.tables["credentials"]
self.hosts_table = metadata.tables["hosts"]
@staticmethod @staticmethod
def db_schema(db_conn): def db_schema(db_conn):
@ -22,3 +25,8 @@ class database:
"port" integer, "port" integer,
"server_banner" text "server_banner" text
)''') )''')
def clear_database(self):
for table in self.metadata.tables:
self.conn.query(self.metadata.tables[table]).delete()
self.conn.commit()

View File

@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator
class navigator(DatabaseNavigator): class navigator(DatabaseNavigator):
pass def do_clear_database(self, line):
self.db.clear_database()

View File

@ -2,11 +2,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
from sqlalchemy import func from sqlalchemy import func, text
class database: class database:
def __init__(self, conn, metadata=None): def __init__(self, conn, metadata=None):
# this is still named "conn" when it is the Session object, TODO: rename # this is still named "conn" when it is the Session object, TODO: rename
self.conn = conn self.conn = conn
@ -117,69 +116,109 @@ class database:
# )''') # )''')
def add_share(self, computerid, userid, name, remark, read, write): def add_share(self, computerid, userid, name, remark, read, write):
self.conn.execute("INSERT OR IGNORE INTO shares (computerid, userid, name, remark, read, write) VALUES (?,?,?,?,?,?)", [computerid, userid, name, remark, read, write]) data = {
"computerid": computerid,
"userid": userid,
"name": name,
"remark": remark,
"read": read,
"write": write,
}
self.conn.execute(
self.shares_table.insert(),
[data]
)
self.conn.commit() self.conn.commit()
self.conn.close() self.conn.close()
def is_share_valid(self, shareID): def is_share_valid(self, share_id):
""" """
Check if this share ID is valid. Check if this share ID is valid.
""" """
self.conn.execute('SELECT * FROM shares WHERE id=? LIMIT 1', [shareID]) results = self.conn.query(self.shares_table).filter(
results = self.conn.fetchall() self.shares_table.c.id == share_id
).all()
self.conn.commit() self.conn.commit()
self.conn.close() self.conn.close()
logging.debug(f"is_share_valid(shareID={shareID}) => {len(results) > 0}") logging.debug(f"is_share_valid(shareID={share_id}) => {len(results) > 0}")
return len(results) > 0 return len(results) > 0
def get_shares(self, filterTerm = None): def get_shares(self, filter_term=None):
if self.is_share_valid(filterTerm): if self.is_share_valid(filter_term):
self.conn.execute("SELECT * FROM shares WHERE id=?", [filterTerm]) results = self.conn.query(self.shares_table).filter(
elif filterTerm: self.shares_table.c.id == filter_term
self.conn.execute("SELECT * FROM shares WHERE LOWER(name) LIKE LOWER(?)", [f"%{filterTerm}%"]) ).all()
elif filter_term:
results = self.conn.query(self.shares_table).filter(
func.lower(self.shares_table.c.name).like(func.lower(f"%{filter_term}%"))
).all()
else: else:
self.conn.execute("SELECT * FROM shares") results = self.conn.query(self.shares_table).all()
results = self.conn.fetchall()
return results return results
def get_shares_by_access(self, permissions, shareID=None): def get_shares_by_access(self, permissions, share_id=None):
permissions = permissions.lower() permissions = permissions.lower()
if shareID: if share_id:
if permissions == "r": if permissions == "r":
self.conn.execute("SELECT * FROM shares WHERE id=? AND read=1",[shareID]) results = self.conn.query(self.shares_table).filter(
self.shares_table.c.id == share_id,
self.shares_table.c.read == 1
).all()
elif permissions == "w": elif permissions == "w":
self.conn.execute("SELECT * FROM shares WHERE id=? write=1", [shareID]) results = self.conn.query(self.shares_table).filter(
self.shares_table.c.id == share_id,
self.shares_table.c.write == 1
).all()
elif permissions == "rw": elif permissions == "rw":
self.conn.execute("SELECT * FROM shares WHERE id=? AND read=1 AND write=1", [shareID]) results = self.conn.query(self.shares_table).filter(
self.shares_table.c.id == share_id,
self.shares_table.c.read == 1,
self.shares_table.c.write == 1
).all()
else: else:
if permissions == "r": if permissions == "r":
self.conn.execute("SELECT * FROM shares WHERE read=1") results = self.conn.query(self.shares_table).filter(
self.shares_table.c.read == 1
).all()
elif permissions == "w": elif permissions == "w":
self.conn.execute("SELECT * FROM shares WHERE write=1") results = self.conn.query(self.shares_table).filter(
self.shares_table.c.write == 1
).all()
elif permissions == "rw": elif permissions == "rw":
self.conn.execute("SELECT * FROM shares WHERE read= AND write=1") results = self.conn.query(self.shares_table).filter(
self.shares_table.c.read == 1,
results = self.conn.fetchall() self.shares_table.c.write == 1
).all()
return results return results
def get_users_with_share_access(self, computerID, share_name, permissions): def get_users_with_share_access(self, computer_id, share_name, permissions):
permissions = permissions.lower() permissions = permissions.lower()
if permissions == "r": if permissions == "r":
self.conn.execute("SELECT userid FROM shares WHERE computerid=(?) AND name=(?) AND read=1", [computerID, share_name]) results = self.conn.query(self.shares_table.c.userid).filter(
self.shares_table.c.computerid == computer_id,
self.shares_table.c.name == share_name,
self.shares_table.c.read == 1
).all()
elif permissions == "w": elif permissions == "w":
self.conn.execute("SELECT userid FROM shares WHERE computerid=(?) AND name=(?) AND write=1", [computerID, share_name]) results = self.conn.query(self.shares_table.c.userid).filter(
self.shares_table.c.computerid == computer_id,
self.shares_table.c.name == share_name,
self.shares_table.c.write == 1
).all()
elif permissions == "rw": elif permissions == "rw":
self.conn.execute("SELECT userid FROM shares WHERE computerid=(?) AND name=(?) AND read=1 AND write=1", [computerID, share_name]) results = self.conn.query(self.shares_table.c.userid).filter(
self.shares_table.c.computerid == computer_id,
results = self.conn.fetchall() self.shares_table.c.name == share_name,
self.shares_table.c.read == 1,
self.shares_table.c.write == 1
).all()
return results return results
# pull/545 # pull/545
def add_computer(self, ip, hostname, domain, os, smbv1, signing=None, spooler=0, zerologon=0, petitpotam=0, dc=None): def add_computer(self, ip, hostname, domain, os, smbv1, signing=None, spooler=None, zerologon=None, petitpotam=None, dc=None):
""" """
Check if this host has already been added to the database, if not add it in. Check if this host has already been added to the database, if not add it in.
""" """
@ -188,7 +227,41 @@ class database:
results = self.conn.query(self.computers_table).filter( results = self.conn.query(self.computers_table).filter(
self.computers_table.c.ip == ip self.computers_table.c.ip == ip
).all() ).all()
host = { data = {}
if ip is not None:
data["ip"] = ip
if hostname is not None:
data["hostname"] = hostname
if domain is not None:
data["domain"] = domain
if os is not None:
data["os"] = os
if smbv1 is not None:
data["smbv1"] = smbv1
if signing is not None:
data["signing"] = signing
if spooler is not None:
data["spooler"] = spooler
if zerologon is not None:
data["zerologon"] = zerologon
if petitpotam is not None:
data["petitpotam"] = petitpotam
if dc is not None:
data["dc"] = dc
print(f"DATA: {data}")
print(f"RESULTS: {results}")
if not results:
print(f"IP: {ip}")
print(f"Hostname: {hostname}")
print(f"Domain: {domain}")
print(f"OS: {os}")
print(f"SMB: {smbv1}")
print(f"Signing: {signing}")
print(f"DC: {dc}")
new_host = {
"ip": ip, "ip": ip,
"hostname": hostname, "hostname": hostname,
"domain": domain, "domain": domain,
@ -200,39 +273,42 @@ class database:
"zerologon": zerologon, "zerologon": zerologon,
"petitpotam": petitpotam "petitpotam": petitpotam
} }
print(f"RESULTS: {results}")
print(f"IP: {ip}")
print(f"Hostname: {hostname}")
print(f"Domain: {domain}")
print(f"OS: {os}")
print(f"SMB: {smbv1}")
print(f"Signing: {signing}")
print(f"DC: {dc}")
if not results:
# host doesn't exist in the DB
pass
if not len(results):
try: try:
self.conn.execute("INSERT INTO computers (ip, hostname, domain, os, dc, smbv1, signing) VALUES (?,?,?,?,?,?,?,?,?,?)", [ip, hostname, domain, os, dc, smbv1, signing, spooler, zerologon, petitpotam]) cid = self.conn.execute(
self.computers_table.insert(),
[new_host]
)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
self.conn.execute("INSERT INTO computers (ip, hostname, domain, os, dc) VALUES (?,?,?,?,?)", [ip, hostname, domain, os, dc]) #self.conn.execute("INSERT INTO computers (ip, hostname, domain, os, dc) VALUES (?,?,?,?,?)", [ip, hostname, domain, os, dc])
else: else:
for host in results: for host in results:
print(host.id)
print(f"Host: {host}")
print(f"Host Type: {type(host)}")
try: try:
if (hostname != host[2]) or (domain != host[3]) or (os != host[4]) or (smbv1 != host[6]) or (signing != host[7]): cid = self.conn.execute(
self.conn.execute("UPDATE computers SET hostname=?, domain=?, os=?, smbv1=?, signing=?, spooler=?, zerologon=?, petitpotam=? WHERE id=?", [hostname, domain, os, smbv1, signing, spooler, zerologon, petitpotam, host[0]]) self.computers_table.update().values(
except: data
if (hostname != host[2]) or (domain != host[3]) or (os != host[4]): ).where(
self.conn.execute("UPDATE computers SET hostname=?, domain=?, os=? WHERE id=?", [hostname, domain, os, host[0]]) self.computers_table.c.id == host.id
if dc != None and (dc != host[5]): )
self.conn.execute("UPDATE computers SET dc=? WHERE id=?", [dc, host[0]]) )
self.conn.commit()
except Exception as e:
print(f"Exception: {e}")
# try:
# if (hostname != host[2]) or (domain != host[3]) or (os != host[4]) or (smbv1 != host[6]) or (signing != host[7]):
# self.conn.execute("UPDATE computers SET hostname=?, domain=?, os=?, smbv1=?, signing=?, spooler=?, zerologon=?, petitpotam=? WHERE id=?", [hostname, domain, os, smbv1, signing, spooler, zerologon, petitpotam, host[0]])
# except:
# if (hostname != host[2]) or (domain != host[3]) or (os != host[4]):
# self.conn.execute("UPDATE computers SET hostname=?, domain=?, os=? WHERE id=?", [hostname, domain, os, host[0]])
# if dc != None and (dc != host[5]):
# self.conn.execute("UPDATE computers SET dc=? WHERE id=?", [dc, host[0]])
self.conn.commit() self.conn.commit()
self.conn.close() self.conn.close()
return self.conn.lastrowid return cid
def update_computer(self, host_id, hostname=None, domain=None, os=None, smbv1=None, signing=None, spooler=None, zerologon=None, petitpotam=None, dc=None): def update_computer(self, host_id, hostname=None, domain=None, os=None, smbv1=None, signing=None, spooler=None, zerologon=None, petitpotam=None, dc=None):
data = { data = {
@ -250,23 +326,46 @@ class database:
user_rowid = None user_rowid = None
if groupid and not self.is_group_valid(groupid): if groupid and not self.is_group_valid(groupid):
self.conn.commit()
self.conn.close() self.conn.close()
return return
if pillaged_from and not self.is_computer_valid(pillaged_from): if pillaged_from and not self.is_computer_valid(pillaged_from):
self.conn.commit()
self.conn.close() self.conn.close()
return return
self.conn.execute("SELECT * FROM users WHERE LOWER(domain)=LOWER(?) AND LOWER(username)=LOWER(?) AND LOWER(credtype)=LOWER(?)", [domain, username, credtype]) results = self.conn.query(self.users_table).filter(
results = self.conn.fetchall() func.lower(self.users_table.c.domain) == func.lower(domain),
func.lower(self.users_table.c.username) == func.lower(username),
func.lower(self.users_table.c.credtype) == func.lower(credtype)
).all()
logging.debug(f"Credential results: {results}")
if not len(results): if not results:
self.conn.execute("INSERT INTO users (domain, username, password, credtype, pillaged_from_computerid) VALUES (?,?,?,?,?)", [domain, username, password, credtype, pillaged_from]) data = {
user_rowid = self.conn.lastrowid "domain": domain,
"username": username,
"password": password,
"credtype": credtype,
"pillaged_from_computerid": pillaged_from,
}
#self.conn.execute("INSERT INTO users (domain, username, password, credtype, pillaged_from_computerid) VALUES (?,?,?,?,?)", [domain, username, password, credtype, pillaged_from])
user_rowid = self.conn.execute(
self.users_table.insert(),
[data]
)
logging.debug(f"User RowID: {user_rowid}")
#user_rowid = self.conn.lastrowid
if groupid: if groupid:
self.conn.execute("INSERT INTO group_relations (userid, groupid) VALUES (?,?)", [user_rowid, groupid]) gr_data = {
"userid": user_rowid,
"groupid": groupid,
}
#self.conn.execute("INSERT INTO group_relations (userid, groupid) VALUES (?,?)", [user_rowid, groupid])
self.conn.execute(
self.group_relations_table.insert(),
[gr_data]
)
self.conn.commit()
else: else:
for user in results: for user in results:
if not user[3] and not user[4] and not user[5]: if not user[3] and not user[4] and not user[5]:
@ -330,40 +429,53 @@ class database:
return self.conn.lastrowid return self.conn.lastrowid
def remove_credentials(self, credIDs): def remove_credentials(self, creds_id):
""" """
Removes a credential ID from the database Removes a credential ID from the database
""" """
for credID in credIDs: for cred_id in creds_id:
self.conn.execute("DELETE FROM users WHERE id=?", [cred_id])
self.conn.execute("DELETE FROM users WHERE id=?", [credID])
self.conn.commit() self.conn.commit()
self.conn.close() self.conn.close()
def add_admin_user(self, credtype, domain, username, password, host, userid=None): def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
domain = domain.split('.')[0].upper() domain = domain.split('.')[0].upper()
if userid: if user_id:
self.conn.execute("SELECT * FROM users WHERE id=?", [userid]) users = self.conn.query(self.users_table).filter(
users = self.conn.fetchall() self.users_table.c.id == user_id
).all()
else: else:
self.conn.execute("SELECT * FROM users WHERE credtype=? AND LOWER(domain)=LOWER(?) AND LOWER(username)=LOWER(?) AND password=?", [credtype, domain, username, password]) users = self.conn.query(self.users_table).filter(
users = self.conn.fetchall() self.users_table.c.credtype == credtype,
func.lower(self.users_table.c.domain) == func.lower(domain),
func.lower(self.users_table.c.username) == func.lower(username),
self.users_table.c.password == password
).all()
logging.debug(f"Users: {users}")
self.conn.execute('SELECT * FROM computers WHERE ip LIKE ?', [host]) hosts = self.conn.query(self.computers_table).filter(
hosts = self.conn.fetchall() self.computers_table.c.ip.like(func.lower(f"%{host}%"))
)
logging.debug(f"Hosts: {hosts}")
if len(users) and len(hosts): if users is not None and hosts is not None:
for user, host in zip(users, hosts): for user, host in zip(users, hosts):
userid = user[0] user_id = user[0]
hostid = host[0] host_id = host[0]
# Check to see if we already added this link # Check to see if we already added this link
self.conn.execute("SELECT * FROM admin_relations WHERE userid=? AND computerid=?", [userid, hostid]) links = self.conn.query(self.admin_relations_table).filter(
links = self.conn.fetchall() self.admin_relations_table.c.userid == user_id,
self.admin_relations_table.c.computerid == host_id
).all()
if not len(links): if not links:
self.conn.execute("INSERT INTO admin_relations (userid, computerid) VALUES (?,?)", [userid, hostid]) self.conn.execute(
self.admin_relations_table.insert(),
[{"userid": user_id, "computerid": host_id}]
)
self.conn.commit()
self.conn.commit() self.conn.commit()
self.conn.close() self.conn.close()
@ -501,8 +613,10 @@ class database:
return results return results
def get_user(self, domain, username): def get_user(self, domain, username):
self.conn.execute("SELECT * FROM users WHERE LOWER(domain)=LOWER(?) AND LOWER(username)=LOWER(?)", [domain, username]) results = self.conn.query(self.users_table).filter(
results = self.conn.fetchall() func.lower(self.users_table.c.domain) == func.lower(domain),
func.lower(self.users_table.c.username) == func.lower(username)
).all()
self.conn.commit() self.conn.commit()
self.conn.close() self.conn.close()
return results return results
@ -673,3 +787,8 @@ class database:
cur.close() cur.close()
logging.debug('get_dpapi_secrets(filterTerm={}, computer={}, dpapi_type={}, windows_user={}, username={}, url={}) => {}'.format(filterTerm, computer, dpapi_type, windows_user, username, url, results)) logging.debug('get_dpapi_secrets(filterTerm={}, computer={}, dpapi_type={}, windows_user={}, username={}, url={}) => {}'.format(filterTerm, computer, dpapi_type, windows_user, username, url, results))
return results return results
def clear_database(self):
for table in self.metadata.tables:
self.conn.query(self.metadata.tables[table]).delete()
self.conn.commit()

View File

@ -3,6 +3,7 @@
from cme.helpers.misc import validate_ntlm from cme.helpers.misc import validate_ntlm
from cme.cmedb import DatabaseNavigator, print_table from cme.cmedb import DatabaseNavigator, print_table
from sqlalchemy.sql import text
class navigator(DatabaseNavigator): class navigator(DatabaseNavigator):
@ -86,13 +87,13 @@ class navigator(DatabaseNavigator):
remark = share[4] remark = share[4]
users_r_access = self.db.get_users_with_share_access( users_r_access = self.db.get_users_with_share_access(
computerID=computerid, computer_id=computerid,
share_name=name, share_name=name,
permissions='r' permissions='r'
) )
users_w_access = self.db.get_users_with_share_access( users_w_access = self.db.get_users_with_share_access(
computerID=computerid, computer_id=computerid,
share_name=name, share_name=name,
permissions='w' permissions='w'
) )
@ -108,7 +109,7 @@ class navigator(DatabaseNavigator):
shares = self.db.get_shares() shares = self.db.get_shares()
self.display_shares(shares) self.display_shares(shares)
else: else:
shares = self.db.get_shares(filterTerm=filterTerm) shares = self.db.get_shares(filter_term=filterTerm)
if len(shares) > 1: if len(shares) > 1:
self.display_shares(shares) self.display_shares(shares)
@ -120,13 +121,13 @@ class navigator(DatabaseNavigator):
remark = share[4] remark = share[4]
users_r_access = self.db.get_users_with_share_access( users_r_access = self.db.get_users_with_share_access(
computerID=computerID, computer_id=computerID,
share_name=name, share_name=name,
permissions='r' permissions='r'
) )
users_w_access = self.db.get_users_with_share_access( users_w_access = self.db.get_users_with_share_access(
computerID=computerID, computer_id=computerID,
share_name=name, share_name=name,
permissions='w' permissions='w'
) )
@ -431,6 +432,9 @@ class navigator(DatabaseNavigator):
print_table(data, title='Admin Access to Host(s)') print_table(data, title='Admin Access to Host(s)')
def do_clear_database(self, line):
self.db.clear_database()
def complete_hosts(self, text, line, begidx, endidx): def complete_hosts(self, text, line, begidx, endidx):
"Tab-complete 'creds' commands." "Tab-complete 'creds' commands."

View File

@ -2,9 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
class database: class database:
def __init__(self, conn, metadata=None):
def __init__(self, conn): # this is still named "conn" when it is the Session object, TODO: rename
self.conn = conn self.conn = conn
self.metadata = metadata
self.credentials_table = metadata.tables["credentials"]
self.hosts_table = metadata.tables["hosts"]
@staticmethod @staticmethod
def db_schema(db_conn): def db_schema(db_conn):
@ -22,3 +25,8 @@ class database:
"port" integer, "port" integer,
"server_banner" text "server_banner" text
)''') )''')
def clear_database(self):
for table in self.metadata.tables:
self.conn.query(self.metadata.tables[table]).delete()
self.conn.commit()

View File

@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator
class navigator(DatabaseNavigator): class navigator(DatabaseNavigator):
pass def do_clear_database(self, line):
self.db.clear_database()

View File

@ -2,9 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
class database: class database:
def __init__(self, conn, metadata=None):
def __init__(self, conn): # this is still named "conn" when it is the Session object, TODO: rename
self.conn = conn self.conn = conn
self.metadata = metadata
self.credentials_table = metadata.tables["credentials"]
self.hosts_table = metadata.tables["hosts"]
@staticmethod @staticmethod
def db_schema(db_conn): def db_schema(db_conn):
@ -20,3 +23,8 @@ class database:
"hostname" text, "hostname" text,
"port" integer "port" integer
)''') )''')
def clear_database(self):
for table in self.metadata.tables:
self.conn.query(self.metadata.tables[table]).delete()
self.conn.commit()

View File

@ -5,4 +5,5 @@ from cme.cmedb import DatabaseNavigator
class navigator(DatabaseNavigator): class navigator(DatabaseNavigator):
pass def do_clear_database(self, line):
self.db.clear_database()