diff --git a/Src/ChannelLinter/ChannelLinter.cpp b/Src/ChannelLinter/ChannelLinter.cpp index e4d36ee..6f4e48e 100644 --- a/Src/ChannelLinter/ChannelLinter.cpp +++ b/Src/ChannelLinter/ChannelLinter.cpp @@ -150,54 +150,30 @@ namespace FSecure::C3::Linter std::cout << "Testing channel with " << packetLen << " bytes of data ... " << std::flush; auto data = ByteVector(FSecure::Utils::GenerateRandomData(packetLen)); - // call send and receive interleaved - size_t sentTotal = 0; - ByteVector received; - ByteView sendView{ data }; - while (sentTotal != packetLen && received.size() != packetLen) - { - if (sentTotal != packetLen) - { - auto sent = channel->GetDevice()->OnSendToChannelInternal(sendView); - sendView.remove_prefix(sent); - sentTotal += sent; - } - - std::this_thread::sleep_for(channel->GetDevice()->GetUpdateDelay()); - - if (received.size() != packetLen) - { - auto receivedPackets = std::static_pointer_cast(complementary->GetDevice())->OnReceiveFromChannelInternal(); - for (auto&& packet : receivedPackets) - received.Concat(packet); - } - } - - if (data != received) + channel->Send(data); + if (data != complementary->Receive()[0]) throw std::exception("Data sent and received mismatch"); + std::cout << "OK" << std::endl; } auto numberOfTests = 10; auto packetSize = 64; std::cout << "Testing channel order with " << numberOfTests << " packets of " << packetSize << " bytes of data ... " << std::flush; - std::vector sent, received; + std::vector sent; for (auto i = 0; i < numberOfTests; ++i) { sent.push_back(FSecure::Utils::GenerateRandomData(packetSize)); - channel->GetDevice()->OnSendToChannelInternal(sent[i]); + channel->Send(sent[i]); } - for (auto i = 0; i < numberOfTests && received.size() < sent.size(); ++i) - { - auto receivedPackets = std::static_pointer_cast(complementary->GetDevice())->OnReceiveFromChannelInternal(); - received.insert(received.end(), receivedPackets.begin(), receivedPackets.end()); - std::this_thread::sleep_for(channel->GetDevice()->GetUpdateDelay()); - } + auto received = complementary->Receive(sent.size()); + received.resize(sent.size()); if (sent != received) throw std::exception("Data sent and received mismatch"); + std::cout << "OK" << std::endl; } diff --git a/Src/ChannelLinter/MockDeviceBridge.cpp b/Src/ChannelLinter/MockDeviceBridge.cpp index 7350f19..2f37535 100644 --- a/Src/ChannelLinter/MockDeviceBridge.cpp +++ b/Src/ChannelLinter/MockDeviceBridge.cpp @@ -90,4 +90,50 @@ namespace FSecure::C3::Linter return m_Device; } + void MockDeviceBridge::Send(ByteView blob) + { + auto oryginalSize = static_cast(blob.size()); + auto messageId = m_QoS.GetOutgouingPacketId(); + auto chunkId = uint32_t{ 0 }; + for (auto noProgressCounter = 0; noProgressCounter < 10; ++noProgressCounter) + { + auto data = ByteVector{}.Write(messageId, chunkId, oryginalSize).Concat(blob); + auto sent = GetDevice()->OnSendToChannelInternal(data); + + if (sent >= QualityOfService::s_MinFrameSize || sent == data.size()) + { + chunkId++; + noProgressCounter = 0; + blob.remove_prefix(sent - QualityOfService::s_HeaderSize); + } + + if (blob.empty()) + return; + } + + throw std::runtime_error("Cannot send data"); + } + + std::vector MockDeviceBridge::Receive(size_t minExpectedSize) + { + auto packets = std::vector{}; + for (auto noProgressCounter = 0; noProgressCounter < 10; ++noProgressCounter) + { + std::this_thread::sleep_for(GetDevice()->GetUpdateDelay()); + for (auto&& chunk : std::static_pointer_cast(GetDevice())->OnReceiveFromChannelInternal()) + { + m_QoS.PushReceivedChunk(chunk); + noProgressCounter = 0; + } + + auto packet = ByteVector{}; + while (packet = m_QoS.GetNextPacket(), !packet.empty()) + packets.emplace_back(std::move(packet)); + + if (packets.size() >= minExpectedSize) + return packets; + } + + throw std::runtime_error("Cannot receive data"); + } } diff --git a/Src/ChannelLinter/MockDeviceBridge.h b/Src/ChannelLinter/MockDeviceBridge.h index 12a7f88..4ecbdd3 100644 --- a/Src/ChannelLinter/MockDeviceBridge.h +++ b/Src/ChannelLinter/MockDeviceBridge.h @@ -1,5 +1,7 @@ #pragma once +#include "Core/QualityOfService.h" + namespace FSecure::C3::Linter { /// Device bridge used to mock its functionality @@ -66,9 +68,22 @@ namespace FSecure::C3::Linter /// @returns Bridged device std::shared_ptr GetDevice() const; + /// Send data to through channel. + /// @param blob - data to send. + /// @throws std::runtime_error if unable to send data. + void Send(ByteView blob); + + /// Receive data from channel. + /// @param minExpectedSize - min number of packets function should return. + /// @throws std::runtime_error if unable to return vector of required number of full packets. + std::vector Receive(size_t minExpectedSize = 1); + private: /// Bridged device std::shared_ptr m_Device; + + /// Handle chunking. + QualityOfService m_QoS; }; }