Merge branch 'CloseConnection' into 'master'

Close connection when peripheral is closed. Resolve access violation after...

See merge request C3/C3!132
dependabot/npm_and_yarn/Src/WebController/UI/websocket-extensions-0.1.4
Pawel Kurowski 2019-09-13 13:30:31 +01:00
commit f536b7a027
8 changed files with 146 additions and 35 deletions

View File

@ -29,7 +29,7 @@ namespace MWR::C3::Interfaces::Connectors
/// Called every time new implant is being created.
/// @param connectionId unused.
/// @param data unused. Prints debug information if not empty.
/// @para isX64 unused.
/// @param isX64 unused.
/// @returns ByteVector copy of data.
ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64) override;
@ -39,8 +39,13 @@ namespace MWR::C3::Interfaces::Connectors
/// @returns ByteVector response for command.
ByteVector TestErrorCommand(ByteView arg);
/// Close desired connection
/// @param connectionId id of connection (RouteId) in string form.
/// @returns ByteVector empty vector.
MWR::ByteVector CloseConnection(ByteView connectionId) override;
/// Represents a single connection with implant.
struct Connection
struct Connection : std::enable_shared_from_this<Connection>
{
/// Constructor.
/// @param owner weak pointer to connector object.
@ -63,7 +68,7 @@ namespace MWR::C3::Interfaces::Connectors
std::mutex m_ConnectionMapAccess;
/// Map of all connections.
std::unordered_map<std::string, std::unique_ptr<Connection>> m_ConnectionMap;
std::unordered_map<std::string, std::shared_ptr<Connection>> m_ConnectionMap;
};
}
@ -97,17 +102,18 @@ MWR::C3::Interfaces::Connectors::MockServer::Connection::Connection(std::weak_pt
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void MWR::C3::Interfaces::Connectors::MockServer::Connection::StartUpdatingInSeparateThread()
{
std::thread([&]()
std::thread([&, id = m_Id]()
{
// Lock pointers.
auto owner = m_Owner.lock();
auto bridge = owner->GetBridge();
while (bridge->IsAlive())
auto self = shared_from_this();
while (bridge->IsAlive() && self.use_count() > 1)
{
// Post something to Binder and wait a little.
try
{
bridge->PostCommandToBinder(m_Id, ByteView(OBF("Beep")));
bridge->PostCommandToBinder(id, ByteView(OBF("Beep")));
}
catch (...)
{
@ -125,11 +131,19 @@ MWR::ByteVector MWR::C3::Interfaces::Connectors::MockServer::OnRunCommand(ByteVi
{
case 0:
return TestErrorCommand(command);
case 1:
return CloseConnection(command);
default:
return AbstractConnector::OnRunCommand(commandCopy);
}
}
MWR::ByteVector MWR::C3::Interfaces::Connectors::MockServer::CloseConnection(ByteView arguments)
{
m_ConnectionMap.erase(arguments);
return {};
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
MWR::ByteVector MWR::C3::Interfaces::Connectors::MockServer::TestErrorCommand(ByteView arg)
{
@ -168,6 +182,19 @@ MWR::ByteView MWR::C3::Interfaces::Connectors::MockServer::GetCapability()
"description": "Error set on connector. Send empty to clean up error"
}
]
},
{
"name": "Close connection",
"description": "Close socket connection with TeamServer if beacon is not available",
"id": 1,
"arguments":
[
{
"name": "Route Id",
"min": 1,
"description": "Id associated to beacon"
}
]
}
]
}

View File

@ -26,7 +26,7 @@ namespace MWR::C3::Interfaces::Connectors
/// Called every time new implant is being created.
/// @param connectionId adders of beacon in C3 network .
/// @param data parameters used to create implant. If payload is empty, new one will be generated.
/// @para isX64 indicates if relay staging beacon is x64.
/// @param isX64 indicates if relay staging beacon is x64.
/// @returns ByteVector correct command that will be used to stage beacon.
ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64) override;
@ -36,7 +36,7 @@ namespace MWR::C3::Interfaces::Connectors
private:
/// Represents a single C3 <-> Team Server connection, as well as each beacon in network.
struct Connection
struct Connection : std::enable_shared_from_this<Connection>
{
/// Constructor.
/// @param listeningPostAddress adders of TeamServer.
@ -91,9 +91,9 @@ namespace MWR::C3::Interfaces::Connectors
MWR::ByteVector GeneratePayload(ByteView binderId, std::string pipename, bool arch64, uint32_t block);
/// Close desired connection
/// @arguments arguments for command. connection Id in string form.
/// @param connectionId id of connection (RouteId) in string form.
/// @returns ByteVector empty vector.
MWR::ByteVector CloseConnection(ByteView arguments);
MWR::ByteVector CloseConnection(ByteView connectionId) override;
/// Initializes Sockets library. Can be called multiple times, but requires corresponding number of calls to DeinitializeSockets() to happen before closing the application.
/// @return value forwarded from WSAStartup call (zero if successful).
@ -116,7 +116,7 @@ namespace MWR::C3::Interfaces::Connectors
std::mutex m_SendMutex;
/// Map of all connections.
std::unordered_map<std::string, std::unique_ptr<Connection>> m_ConnectionMap;
std::unordered_map<std::string, std::shared_ptr<Connection>> m_ConnectionMap;
};
}
@ -188,8 +188,7 @@ MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::GeneratePayload(Byt
MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::CloseConnection(ByteView arguments)
{
auto id = arguments.Read<std::string>();
m_ConnectionMap.erase(id);
m_ConnectionMap.erase(arguments);
return {};
}
@ -352,12 +351,13 @@ MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::Connection::Receive
void MWR::C3::Interfaces::Connectors::TeamServer::Connection::StartUpdatingInSeparateThread()
{
m_SecondThreadStarted = true;
std::thread([&]()
std::thread([this]()
{
// Lock pointers.
auto owner = m_Owner.lock();
auto bridge = owner->GetBridge();
while (bridge->IsAlive())
auto self = shared_from_this();
while (bridge->IsAlive() && self.use_count() > 1)
{
try
{
@ -367,7 +367,7 @@ void MWR::C3::Interfaces::Connectors::TeamServer::Connection::StartUpdatingInSep
if (packet.size() == 1u && packet[0] == 0u)
Send(packet);
else
bridge->PostCommandToBinder(ByteView{ m_Id }, packet);
bridge->PostCommandToBinder(m_Id, packet);
}
}
catch (std::exception& e)

View File

@ -139,8 +139,18 @@ namespace MWR::C3
/// @return Command result.
virtual ByteVector RunCommand(ByteView command) = 0;
/// Called every time new peripheral is being created.
/// @param connectionId adders of peripheral in C3 network .
/// @param data all parameters used to create peripheral. Specific for each connector.
/// @param isX64 indicates if relay staging peripheral is x64.
/// @returns ByteVector correct command that will be used to stage peripheral.
virtual ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64 = false) = 0;
/// Close desired connection
/// @param connectionId id of connection (RouteId) in string form.
/// @returns ByteVector empty vector.
virtual ByteVector CloseConnection(ByteView connectionId) = 0;
/// Logs a message. Used by internal mechanisms to report errors, warnings, informations and debug messages.
/// @param message information to log.
virtual void Log(LogMessage const& message) = 0;

View File

@ -137,6 +137,11 @@ namespace MWR::C3
virtual ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64) { return data; }
/// Close desired connection
/// @param connectionId id of connection (RouteId) in string form.
/// @returns ByteVector empty vector.
virtual ByteVector CloseConnection(ByteView connectionId) = 0;
protected:
/// Close Connector.
virtual void TurnOff();

View File

@ -88,3 +88,8 @@ MWR::ByteVector MWR::C3::Core::ConnectorBridge::PeripheralCreationCommand(ByteVi
{
return m_Connector->PeripheralCreationCommand(connectionId, data, isX64);
}
MWR::ByteVector MWR::C3::Core::ConnectorBridge::CloseConnection(ByteView connectionId)
{
return m_Connector->CloseConnection(connectionId);
}

View File

@ -73,10 +73,15 @@ namespace MWR::C3::Core
/// Called every time new peripheral is being created.
/// @param connectionId adders of peripheral in C3 network .
/// @param data all parameters used to create peripheral. Specific for each connector.
/// @para isX64 indicates if relay staging peripheral is x64.
/// @param isX64 indicates if relay staging peripheral is x64.
/// @returns ByteVector correct command that will be used to stage peripheral.
ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64 = false) override;
/// Close desired connection
/// @param connectionId id of connection (RouteId) in string form.
/// @returns ByteVector empty vector.
ByteVector CloseConnection(ByteView connectionId) override;
protected:
/// Connector object getter.
/// @return Connector this object binds GateRelay with.

View File

@ -109,7 +109,7 @@ void MWR::C3::Core::NodeRelay::OnProtocolG2R(ByteView packet0, std::shared_ptr<D
}
catch (std::exception& exception)
{
throw std::runtime_error{ OBF_STR("Failed to parse G2A packet. ") + exception.what() };
throw std::runtime_error{ OBF_STR("Failed to parse G2R packet. ") + exception.what() };
}
}

View File

@ -509,6 +509,10 @@ void MWR::C3::Core::Profiler::Agent::ParseAndRunCommand(json const& jCommandElem
auto commandWithArgs = base64::decode<ByteVector>(jCommandElement["Command"]["ByteForm"].get<std::string>());
auto gateRelay = profiler->m_Gateway->m_Gateway.lock();
if (!gateRelay)
return; // probably shutting down
if (deviceId)
{
std::function<void()> finalizer = []() {};
@ -521,9 +525,32 @@ void MWR::C3::Core::Profiler::Agent::ParseAndRunCommand(json const& jCommandElem
switch (static_cast<MWR::C3::Core::Relay::Command>(commandId))
{
case MWR::C3::Core::Relay::Command::Close:
finalizer = [this, deviceId, deviceIsChannel]()
finalizer = [&]()
{
deviceIsChannel ? m_Channels.TryRemove(*deviceId) : m_Peripherals.TryRemove(*deviceId);
if (deviceIsChannel)
{
m_Channels.TryRemove(*deviceId);
}
else
{
// Find connector hash
auto element = m_Peripherals.Find(*deviceId);
if (!element)
return;
auto connectorHash = profiler->GetBinderTo(element->m_TypeHash);
// remove peripheral
m_Peripherals.TryRemove(*deviceId);
// Get connector
auto connector = gateRelay->m_Connectors.Find([&](auto const& e) { return e->GetNameHash() == connectorHash; });
if (!connector)
return;
// Remove connection.
connector->CloseConnection(RouteId{ m_Id, *deviceId }.ToByteVector());
}
};
break;
case MWR::C3::Core::Relay::Command::UpdateJitter:
@ -540,10 +567,6 @@ void MWR::C3::Core::Profiler::Agent::ParseAndRunCommand(json const& jCommandElem
}
}
auto gateRelay = profiler->m_Gateway->m_Gateway.lock();
if (!gateRelay)
return; // probably shutting down
auto route = gateRelay->FindRoute(m_Id);
if (!route)
throw std::runtime_error("Failed to find route to agent id = " + m_Id.ToString());
@ -590,12 +613,26 @@ void MWR::C3::Core::Profiler::Agent::RunCommand(ByteView commandWithArguments)
{
case NodeRelay::Command::Close:
{
finalizer = [this]()
finalizer = [&]()
{
auto owner = m_Owner.lock();
if (!owner)
throw std::runtime_error{ "Cannot obtain owner" };
m_Channels.Clear();
for (auto&& element : m_Peripherals.GetUnderlyingContainer())
{
auto connectorHash = profiler->GetBinderTo(element.m_TypeHash);
auto connector = gateRelay->m_Connectors.Find([&](auto const& e) { return e->GetNameHash() == connectorHash; });
if (!connector)
break;
// Remove connection.
connector->CloseConnection(RouteId{ m_Id, element.m_Id }.ToByteVector());
}
m_Peripherals.Clear();
owner->Get().m_Gateway.m_Agents.Remove(m_Id);
};
break;
@ -766,21 +803,43 @@ void MWR::C3::Core::Profiler::Gateway::ParseAndRunCommand(json const& jCommandEl
auto gateway = m_Gateway.lock();
if (auto device = gateway->FindDevice(MWR::Utils::SafeCast<DeviceId::UnderlyingIntegerType>(id)); device)
{
device->RunCommand(commandReadView);
if (auto localView = commandReadView; localView.Read<std::uint16_t>() == static_cast<std::uint16_t>(MWR::C3::Core::Relay::Command::UpdateJitter))
auto localView = commandReadView;
switch (MWR::C3::Core::Relay::Command(localView.Read<std::uint16_t>()))
{
Device* profilerElement = m_Channels.Find(device->GetDid());
if (!profilerElement)
profilerElement = m_Channels.Find(device->GetDid());
case MWR::C3::Core::Relay::Command::UpdateJitter:
{
Device* profilerElement = m_Channels.Find(device->GetDid());
if (!profilerElement)
profilerElement = m_Peripherals.Find(device->GetDid());
if (!profilerElement)
throw std::runtime_error{ "Device not found" };
if (!profilerElement)
throw std::runtime_error{ "Device not found" };
commandReadView.remove_prefix(sizeof(uint16_t)); // command id
profilerElement->m_Jitter.first = MWR::Utils::ToMilliseconds(commandReadView.Read<float>());
profilerElement->m_Jitter.second = MWR::Utils::ToMilliseconds(commandReadView.Read<float>());
profilerElement->m_Jitter.first = MWR::Utils::ToMilliseconds(localView.Read<float>());
profilerElement->m_Jitter.second = MWR::Utils::ToMilliseconds(localView.Read<float>());
break;
}
case MWR::C3::Core::Relay::Command::Close:
{
auto profilerElement = m_Peripherals.Find(device->GetDid());
if (!profilerElement)
break;
auto connectorHash = m_Owner.lock()->GetBinderTo(profilerElement->m_TypeHash);
auto connector = m_Gateway.lock()->m_Connectors.Find([&](auto const& e) { return e->GetNameHash() == connectorHash; });
if (!connector)
break;
// Remove connection.
connector->CloseConnection(RouteId{ m_Id, device->GetDid() }.ToByteVector());
break;
}
default:
break;
}
device->RunCommand(commandReadView);
return true;
}