tests: improvement on tests

main
Marshall Hallenbeck 2023-03-25 21:29:18 -04:00
parent c7679c7acf
commit d705290a09
1 changed files with 44 additions and 12 deletions

View File

@ -7,6 +7,7 @@ from time import sleep
import pytest
import pytest_asyncio
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, scoped_session
from cme.cmedb import create_workspace, delete_workspace
@ -16,17 +17,32 @@ from cme.logger import setup_logger, CMEAdapter
from cme.paths import WS_PATH
from sqlalchemy.dialects.sqlite import Insert
pytest_plugins = ('pytest_asyncio',)
@pytest.fixture(scope="session")
def event_loop():
return asyncio.get_event_loop()
# @pytest.fixture(scope="session")
# def event_loop_instance(request):
# """ Add the event_loop as an attribute to the unittest style test class. """
# request.event_loop = asyncio.get_event_loop_policy().new_event_loop()
# yield
# request.event_loop.close()
# @pytest.fixture(scope="session")
# def event_loop():
# return asyncio.get_event_loop()
@pytest.fixture(scope="session", autouse=True)
def event_loop(request):
"""Overrides pytest default function scoped event loop"""
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="session")
def db_engine():
db_path = os.path.join(WS_PATH, "test/smb.db")
db_engine = create_engine(
f"sqlite:///{db_path}",
db_engine = create_async_engine(
f"sqlite+aiosqlite:///{db_path}",
isolation_level="AUTOCOMMIT",
future=True
)
@ -47,9 +63,10 @@ async def db(db_engine):
protocol_db_path = p_loader.get_protocols()[proto]["dbpath"]
protocol_db_object = getattr(p_loader.load_protocol(protocol_db_path), "database")
db = protocol_db_object(db_engine)
yield db
db.shutdown_db()
database_obj = protocol_db_object(db_engine)
await asyncio.shield(database_obj.reflect_tables())
yield database_obj
database_obj.shutdown_db()
delete_workspace("test")
@ -57,7 +74,8 @@ async def db(db_engine):
def sess(db_engine):
session_factory = sessionmaker(
bind=db_engine,
expire_on_commit=True
expire_on_commit=True,
class_=AsyncSession
)
Session = scoped_session(
@ -70,7 +88,7 @@ def sess(db_engine):
@pytest.mark.asyncio
async def test_add_host(db):
await db.add_host(
db.add_host(
"127.0.0.1",
"localhost",
"TEST.DEV",
@ -84,7 +102,8 @@ async def test_add_host(db):
)
def test_update_host(db, sess):
@pytest.mark.asyncio
async def test_update_host(db, sess):
host = {
"ip": "127.0.0.1",
"hostname": "localhost",
@ -98,7 +117,20 @@ def test_update_host(db, sess):
"petitpotam": False
}
iq = Insert(db.HostsTable)
sess.execute(iq, [host])
await sess.execute(iq, [host])
inserted_host = await db.get_hosts()
assert len(inserted_host) == 1
host = inserted_host[0]
assert host.id == 1
assert host.ip == "127.0.0.1"
assert host.hostname == "localhost"
assert host.os == "Windows Testing 2023"
assert host.dc == False
assert host.smbv1 == True
assert host.signing == True
assert host.spooler == True
assert host.zerologon == False
assert host.petitpotam == False
def test_add_credential():