From 5f5e10dbf43fd7e74a5eebb642f7def9650b175c Mon Sep 17 00:00:00 2001 From: Thomas Ballasi <thomas.ballasi@savoirfairelinux.com> Date: Fri, 21 Oct 2022 12:53:29 -0400 Subject: [PATCH] multiplexed_socket: add ChannelSocketTest The usage of ChannelSocketTest helps avoiding unnecessary overhead when running unit tests (especially on creating a huge quantity of nodes and sockets). ChannelSocketTest implements a simplier form of ChannelSocket that directly communicates to a peer rather than through a socket. Change-Id: Id1c68aaa92f8f8cf8002c417f670254b0b851cfb --- src/jamidht/multiplexed_socket.cpp | 178 ++++++++++++++++++++++++++++- src/jamidht/multiplexed_socket.h | 94 +++++++++++++-- 2 files changed, 263 insertions(+), 9 deletions(-) diff --git a/src/jamidht/multiplexed_socket.cpp b/src/jamidht/multiplexed_socket.cpp index 0bf7ac5428..34bd0238eb 100644 --- a/src/jamidht/multiplexed_socket.cpp +++ b/src/jamidht/multiplexed_socket.cpp @@ -712,7 +712,6 @@ MultiplexedSocket::sendVersion() pimpl_->sendVersion(); } - IpAddr MultiplexedSocket::getLocalAddress() const { @@ -761,6 +760,183 @@ public: GenericSocket<uint8_t>::RecvCb cb {}; }; +ChannelSocketTest::ChannelSocketTest(const DeviceId& deviceId, + const std::string& name, + const uint16_t& channel) + : pimpl_deviceId(deviceId) + , pimpl_name(name) + , pimpl_channel(channel) + , eventLoopThread_ {[this] { + try { + eventLoop(); + } catch (const std::exception& e) { + JAMI_ERR() << "[CNX] peer connection event loop failure: " << e.what(); + shutdown(); + } + }} +{} + +ChannelSocketTest::~ChannelSocketTest() +{ + eventLoopThread_.join(); +} + +void +ChannelSocketTest::link(const std::weak_ptr<ChannelSocketTest>& socket1, + const std::weak_ptr<ChannelSocketTest>& socket2) +{ + if (auto peer = socket1.lock()) { + peer->remote = socket2; + } + if (auto peer = socket2.lock()) { + peer->remote = socket1; + } +} + +DeviceId +ChannelSocketTest::deviceId() const +{ + return pimpl_deviceId; +} + +std::string +ChannelSocketTest::name() const +{ + return pimpl_name; +} + +uint16_t +ChannelSocketTest::channel() const +{ + return pimpl_channel; +} + +void +ChannelSocketTest::shutdown() +{ + if (!isShutdown_) { + isShutdown_ = true; + shutdownCb_(); + } + cv.notify_all(); + if (auto peer = remote.lock()) { + if (!peer->isShutdown_) { + peer->isShutdown_ = true; + peer->shutdownCb_(); + } + peer->cv.notify_all(); + } +} + +std::size_t +ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec) +{ + std::lock_guard<std::mutex> lkSockets(mutex); + std::size_t size = std::min(len, this->buf.size()); + + for (std::size_t i = 0; i < size; ++i) + buf[i] = this->buf[i]; + + this->buf.erase(this->buf.begin(), this->buf.begin() + size); + return size; +}; + +std::size_t +ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec) +{ + if (isShutdown_) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + if (auto peer = remote.lock()) { + std::vector<uint8_t> bufToSend(buf, buf + len); + std::size_t sent = 0; + do { + std::size_t lenToSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent); + peer->buf.insert(peer->buf.end(), + bufToSend.begin() + sent, + bufToSend.begin() + sent + lenToSend); + sent += lenToSend; + peer->cv.notify_all(); + } while (sent < len); + return sent; + } + ec = std::make_error_code(std::errc::broken_pipe); + return -1; +} + +int +ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const +{ + std::unique_lock<std::mutex> lk {mutex}; + cv.wait_for(lk, timeout, [&] { return !buf.empty() or isShutdown_; }); + return buf.size(); +} + +void +ChannelSocketTest::setOnRecv(RecvCb&& cb) +{ + std::lock_guard<std::mutex> lkSockets(mutex); + this->cb = std::move(cb); + if (!buf.empty() && this->cb) { + this->cb(buf.data(), buf.size()); + buf.clear(); + } +} + +void +ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt) +{ + std::lock_guard<std::mutex> lkSockets(mutex); + if (cb) { + cb(&pkt[0], pkt.size()); + return; + } + buf.insert(buf.end(), std::make_move_iterator(pkt.begin()), std::make_move_iterator(pkt.end())); +} + +void +ChannelSocketTest::onReady(ChannelReadyCb&& cb) +{} + +void +ChannelSocketTest::onShutdown(OnShutdownCb&& cb) +{ + shutdownCb_ = std::move(cb); + if (isShutdown_) { + shutdownCb_(); + } +} + +void +ChannelSocketTest::eventLoop() +{ + std::error_code ec; + std::vector<uint8_t> buf(IO_BUFFER_SIZE); + + while (!isShutdown_) { + // wait for new data before reading + std::unique_lock<std::mutex> lk {mutex}; + cv.wait(lk, [&] { return !this->buf.empty() or isShutdown_; }); + lk.unlock(); + + int size = read(reinterpret_cast<uint8_t*>(buf.data()), IO_BUFFER_SIZE, ec); + if (size < 0) { + if (ec) + JAMI_ERR("Read error detected: %s", ec.message().c_str()); + break; + } + + if (size == 0) { + shutdown(); + } + + if (size != 0) { + onRecv(std::move(buf)); + } + } +} + ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel, diff --git a/src/jamidht/multiplexed_socket.h b/src/jamidht/multiplexed_socket.h index 99bc200bd4..f02e02cfbe 100644 --- a/src/jamidht/multiplexed_socket.h +++ b/src/jamidht/multiplexed_socket.h @@ -20,6 +20,7 @@ #include <opendht/default_types.h> #include "generic_io.h" +#include <condition_variable> namespace jami { @@ -175,22 +176,99 @@ private: std::unique_ptr<Impl> pimpl_; }; +class ChannelSocketInterface : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + + virtual DeviceId deviceId() const = 0; + virtual std::string name() const = 0; + virtual uint16_t channel() const = 0; + /** + * Triggered when a specific channel is ready + * Used by ConnectionManager::connectDevice() + */ + virtual void onReady(ChannelReadyCb&& cb) = 0; + /** + * Will trigger that callback when shutdown() is called + */ + virtual void onShutdown(OnShutdownCb&& cb) = 0; + + virtual void onRecv(std::vector<uint8_t>&& pkt) = 0; +}; + +class ChannelSocketTest : public ChannelSocketInterface +{ +public: + ChannelSocketTest(const DeviceId& deviceId, const std::string& name, const uint16_t& channel); + ~ChannelSocketTest(); + + static void link(const std::weak_ptr<ChannelSocketTest>& socket1, + const std::weak_ptr<ChannelSocketTest>& socket2); + + DeviceId deviceId() const override; + std::string name() const override; + uint16_t channel() const override; + + bool isReliable() const override { return true; }; + bool isInitiator() const override { return true; }; + int maxPayload() const override { return 0; }; + + void shutdown() override; + + std::size_t read(ValueType* buf, + std::size_t len, + std::error_code& ec) override; + std::size_t write(const ValueType* buf, + std::size_t len, + std::error_code& ec) override; + int waitForData(std::chrono::milliseconds timeout, + std::error_code&) const override; + void setOnRecv(RecvCb&&) override; + void onRecv(std::vector<uint8_t>&& pkt) override; + + /** + * Triggered when a specific channel is ready + * Used by ConnectionManager::connectDevice() + */ + void onReady(ChannelReadyCb&& cb) override; + /** + * Will trigger that callback when shutdown() is called + */ + void onShutdown(OnShutdownCb&& cb) override; + + std::vector<uint8_t> buf {}; + mutable std::mutex mutex {}; + mutable std::condition_variable cv {}; + GenericSocket<uint8_t>::RecvCb cb {}; + +private: + const DeviceId pimpl_deviceId; + const std::string pimpl_name; + const uint16_t pimpl_channel; + std::weak_ptr<ChannelSocketTest> remote; + OnShutdownCb shutdownCb_ { [&] {} }; + std::atomic_bool isShutdown_ {false}; + + void eventLoop(); + std::thread eventLoopThread_ {}; +}; + /** * Represents a channel of the multiplexed socket (channel, name) */ -class ChannelSocket : public GenericSocket<uint8_t> +class ChannelSocket : ChannelSocketInterface { public: - using SocketType = GenericSocket<uint8_t>; ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel, bool isInitiator = false); ~ChannelSocket(); - DeviceId deviceId() const; - std::string name() const; - uint16_t channel() const; + DeviceId deviceId() const override; + std::string name() const override; + uint16_t channel() const override; bool isReliable() const override; bool isInitiator() const override; int maxPayload() const override; @@ -211,11 +289,11 @@ public: * Triggered when a specific channel is ready * Used by ConnectionManager::connectDevice() */ - void onReady(ChannelReadyCb&& cb); + void onReady(ChannelReadyCb&& cb) override; /** * Will trigger that callback when shutdown() is called */ - void onShutdown(OnShutdownCb&& cb); + void onShutdown(OnShutdownCb&& cb) override; std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override; /** @@ -231,7 +309,7 @@ public: */ void setOnRecv(RecvCb&&) override; - void onRecv(std::vector<uint8_t>&& pkt); + void onRecv(std::vector<uint8_t>&& pkt) override; /** * Send a beacon on the socket and close if no response come -- GitLab