/* * Copyright (C) 2004-2023 Savoir-faire Linux Inc. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <https://www.gnu.org/licenses/>. */ #include "multiplexed_socket.h" #include "peer_connection.h" #include "ice_transport.h" #include "certstore.h" #include <opendht/logger.h> #include <opendht/thread_pool.h> #include <asio/io_context.hpp> #include <asio/steady_timer.hpp> #include <deque> static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations static constexpr int MULTIPLEXED_SOCKET_VERSION {1}; struct ChanneledMessage { uint16_t channel; std::vector<uint8_t> data; MSGPACK_DEFINE(channel, data) }; struct BeaconMsg { bool p; MSGPACK_DEFINE_MAP(p) }; struct VersionMsg { int v; MSGPACK_DEFINE_MAP(v) }; namespace dhtnet { using clock = std::chrono::steady_clock; using time_point = clock::time_point; class MultiplexedSocket::Impl { public: Impl(MultiplexedSocket& parent, std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, std::unique_ptr<TlsSocketEndpoint> endpoint) : parent_(parent) , deviceId(deviceId) , ctx_(std::move(ctx)) , beaconTimer_(*ctx_) , endpoint(std::move(endpoint)) , eventLoopThread_ {[this] { try { eventLoop(); } catch (const std::exception& e) { if (logger_) logger_->error("[CNX] peer connection event loop failure: {}", e.what()); shutdown(); } }} {} ~Impl() {} void join() { if (!isShutdown_) { if (endpoint) endpoint->setOnStateChange({}); shutdown(); } else { clearSockets(); } if (eventLoopThread_.joinable()) eventLoopThread_.join(); } void clearSockets() { decltype(sockets) socks; { std::lock_guard<std::mutex> lkSockets(socketsMutex); socks = std::move(sockets); } for (auto& socket : socks) { // Just trigger onShutdown() to make client know // No need to write the EOF for the channel, the write will fail because endpoint is // already shutdown if (socket.second) socket.second->stop(); } } void shutdown() { if (isShutdown_) return; stop.store(true); isShutdown_ = true; beaconTimer_.cancel(); if (onShutdown_) onShutdown_(); if (endpoint) { std::unique_lock<std::mutex> lk(writeMtx); endpoint->shutdown(); } clearSockets(); } std::shared_ptr<ChannelSocket> makeSocket(const std::string& name, uint16_t channel, bool isInitiator = false) { auto& channelSocket = sockets[channel]; if (not channelSocket) channelSocket = std::make_shared<ChannelSocket>( parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() { // Remove socket in another thread to avoid any lock dht::ThreadPool::io().run([w, channel]() { if (auto shared = w.lock()) { shared->eraseChannel(channel); } }); }); else { if (logger_) logger_->warn("A channel is already present on that socket, accepting " "the request will close the previous one {}", name); } return channelSocket; } /** * Handle packets on the TLS endpoint and parse RTP */ void eventLoop(); /** * Triggered when a new control packet is received */ void handleControlPacket(std::vector<uint8_t>&& pkt); void handleProtocolPacket(std::vector<uint8_t>&& pkt); bool handleProtocolMsg(const msgpack::object& o); /** * Triggered when a new packet on a channel is received */ void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt); void onRequest(const std::string& name, uint16_t channel); void onAccept(const std::string& name, uint16_t channel); void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); } void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); } // Beacon void sendBeacon(const std::chrono::milliseconds& timeout); void handleBeaconRequest(); void handleBeaconResponse(); std::atomic_int beaconCounter_ {0}; bool writeProtocolMessage(const msgpack::sbuffer& buffer); msgpack::unpacker pac_ {}; MultiplexedSocket& parent_; std::shared_ptr<Logger> logger_; std::shared_ptr<asio::io_context> ctx_; OnConnectionReadyCb onChannelReady_ {}; OnConnectionRequestCb onRequest_ {}; OnShutdownCb onShutdown_ {}; DeviceId deviceId {}; // Main socket std::unique_ptr<TlsSocketEndpoint> endpoint {}; std::mutex socketsMutex {}; std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {}; // Main loop to parse incoming packets std::atomic_bool stop {false}; std::thread eventLoopThread_ {}; std::atomic_bool isShutdown_ {false}; std::mutex writeMtx {}; time_point start_ {clock::now()}; //std::shared_ptr<Task> beaconTask_ {}; asio::steady_timer beaconTimer_; // version related stuff void sendVersion(); void onVersion(int version); std::atomic_bool canSendBeacon_ {false}; std::atomic_bool answerBeacon_ {true}; int version_ {MULTIPLEXED_SOCKET_VERSION}; std::function<void(bool)> onBeaconCb_ {}; std::function<void(int)> onVersionCb_ {}; }; void MultiplexedSocket::Impl::eventLoop() { endpoint->setOnStateChange([this](tls::TlsSessionState state) { if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) { if (logger_) logger_->debug("Tls endpoint is down, shutdown multiplexed socket"); shutdown(); return false; } return true; }); sendVersion(); std::error_code ec; while (!stop) { if (!endpoint) { shutdown(); return; } pac_.reserve_buffer(IO_BUFFER_SIZE); int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec); if (size < 0) { if (ec && logger_) logger_->error("Read error detected: {}", ec.message()); break; } if (size == 0) { // We can close the socket shutdown(); break; } pac_.buffer_consumed(size); msgpack::object_handle oh; while (pac_.next(oh) && !stop) { try { auto msg = oh.get().as<ChanneledMessage>(); if (msg.channel == CONTROL_CHANNEL) handleControlPacket(std::move(msg.data)); else if (msg.channel == PROTOCOL_CHANNEL) handleProtocolPacket(std::move(msg.data)); else handleChannelPacket(msg.channel, std::move(msg.data)); } catch (const std::exception& e) { if (logger_) logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what()); } catch (...) { if (logger_) logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size); } } } } void MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel) { std::lock_guard<std::mutex> lkSockets(socketsMutex); auto& socket = sockets[channel]; if (!socket) { if (logger_) logger_->error("Receiving an answer for a non existing channel. This is a bug."); return; } onChannelReady_(deviceId, socket); socket->ready(true); // Due to the callbacks that can take some time, onAccept can arrive after // receiving all the data. In this case, the socket should be removed here // as handle by onChannelReady_ if (socket->isRemovable()) sockets.erase(channel); else socket->answered(); } void MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout) { if (!canSendBeacon_) return; beaconCounter_++; if (logger_) logger_->debug("Send beacon to peer {}", deviceId); msgpack::sbuffer buffer(8); msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack(BeaconMsg {true}); if (!writeProtocolMessage(buffer)) return; beaconTimer_.expires_after(timeout); beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) { if (ec == asio::error::operation_aborted) return; if (auto shared = w.lock()) { if (shared->pimpl_->beaconCounter_ != 0) { if (shared->pimpl_->logger_) shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket"); shared->shutdown(); } } }); } void MultiplexedSocket::Impl::handleBeaconRequest() { if (!answerBeacon_) return; // Run this on dedicated thread because some callbacks can take time dht::ThreadPool::io().run([w = parent_.weak()]() { if (auto shared = w.lock()) { msgpack::sbuffer buffer(8); msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack(BeaconMsg {false}); if (shared->pimpl_->logger_) shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId()); shared->pimpl_->writeProtocolMessage(buffer); } }); } void MultiplexedSocket::Impl::handleBeaconResponse() { if (logger_) logger_->debug("Get beacon response from peer {}", deviceId); beaconCounter_--; } bool MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer) { std::error_code ec; int wr = parent_.write(PROTOCOL_CHANNEL, (const unsigned char*) buffer.data(), buffer.size(), ec); return wr > 0; } void MultiplexedSocket::Impl::sendVersion() { dht::ThreadPool::io().run([w = parent_.weak()]() { if (auto shared = w.lock()) { auto version = shared->pimpl_->version_; msgpack::sbuffer buffer(8); msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack(VersionMsg {version}); shared->pimpl_->writeProtocolMessage(buffer); } }); } void MultiplexedSocket::Impl::onVersion(int version) { // Check if version > 1 if (version >= 1) { if (logger_) logger_->debug("Peer {} supports beacon", deviceId); canSendBeacon_ = true; } else { if (logger_) logger_->warn("Peer {} uses version {:d} which doesn't support beacon", deviceId, version); canSendBeacon_ = false; } } void MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel) { auto accept = onRequest_(endpoint->peerCertificate(), channel, name); std::shared_ptr<ChannelSocket> channelSocket; if (accept) { std::lock_guard<std::mutex> lkSockets(socketsMutex); channelSocket = makeSocket(name, channel); } // Answer to ChannelRequest if accepted ChannelRequest val; val.channel = channel; val.name = name; val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE; msgpack::sbuffer buffer(512); msgpack::pack(buffer, val); std::error_code ec; int wr = parent_.write(CONTROL_CHANNEL, reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size(), ec); if (wr < 0) { if (ec && logger_) logger_->error("The write operation failed with error: {:s}", ec.message()); stop.store(true); return; } if (accept) { onChannelReady_(deviceId, channelSocket); channelSocket->ready(true); } } void MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt) { // Run this on dedicated thread because some callbacks can take time dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() { auto shared = w.lock(); if (!shared) return; auto& pimpl = *shared->pimpl_; try { size_t off = 0; while (off != pkt.size()) { msgpack::unpacked result; msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off); auto object = result.get(); if (pimpl.handleProtocolMsg(object)) continue; auto req = object.as<ChannelRequest>(); if (req.state == ChannelRequestState::ACCEPT) { pimpl.onAccept(req.name, req.channel); } else if (req.state == ChannelRequestState::DECLINE) { std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex); auto channel = pimpl.sockets.find(req.channel); if (channel != pimpl.sockets.end()) { channel->second->ready(false); channel->second->stop(); pimpl.sockets.erase(channel); } } else if (pimpl.onRequest_) { pimpl.onRequest(req.name, req.channel); } } } catch (const std::exception& e) { if (pimpl.logger_) pimpl.logger_->error("Error on the control channel: {}", e.what()); } }); } void MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt) { std::lock_guard<std::mutex> lkSockets(socketsMutex); auto sockIt = sockets.find(channel); if (channel > 0 && sockIt != sockets.end() && sockIt->second) { if (pkt.size() == 0) { sockIt->second->stop(); if (sockIt->second->isAnswered()) sockets.erase(sockIt); else sockIt->second->removable(); // This means that onAccept didn't happen yet, will be // removed later. } else { sockIt->second->onRecv(std::move(pkt)); } } else if (pkt.size() != 0) { if (logger_) logger_->warn("Non existing channel: {}", channel); } } bool MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o) { try { if (o.type == msgpack::type::MAP && o.via.map.size > 0) { auto key = o.via.map.ptr[0].key.as<std::string_view>(); if (key == "p") { auto msg = o.as<BeaconMsg>(); if (msg.p) handleBeaconRequest(); else handleBeaconResponse(); if (onBeaconCb_) onBeaconCb_(msg.p); return true; } else if (key == "v") { auto msg = o.as<VersionMsg>(); onVersion(msg.v); if (onVersionCb_) onVersionCb_(msg.v); return true; } else { if (logger_) logger_->warn("Unknown message type"); } } } catch (const std::exception& e) { if (logger_) logger_->error("Error on the protocol channel: {}", e.what()); } return false; } void MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt) { // Run this on dedicated thread because some callbacks can take time dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() { auto shared = w.lock(); if (!shared) return; try { size_t off = 0; while (off != pkt.size()) { msgpack::unpacked result; msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off); auto object = result.get(); if (shared->pimpl_->handleProtocolMsg(object)) return; } } catch (const std::exception& e) { if (shared->pimpl_->logger_) shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what()); } }); } MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, std::unique_ptr<TlsSocketEndpoint> endpoint) : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint))) {} MultiplexedSocket::~MultiplexedSocket() {} std::shared_ptr<ChannelSocket> MultiplexedSocket::addChannel(const std::string& name) { // Note: because both sides can request the same channel number at the same time // it's better to use a random channel number instead of just incrementing the request. thread_local dht::crypto::random_device rd; std::uniform_int_distribution<uint16_t> dist; auto offset = dist(rd); std::lock_guard<std::mutex> lk(pimpl_->socketsMutex); for (int i = 1; i < UINT16_MAX; ++i) { auto c = (offset + i) % UINT16_MAX; if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL || pimpl_->sockets.find(c) != pimpl_->sockets.end()) continue; auto channel = pimpl_->makeSocket(name, c, true); return channel; } return {}; } DeviceId MultiplexedSocket::deviceId() const { return pimpl_->deviceId; } void MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb) { pimpl_->onChannelReady_ = std::move(cb); } void MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb) { pimpl_->onRequest_ = std::move(cb); } bool MultiplexedSocket::isReliable() const { return true; } bool MultiplexedSocket::isInitiator() const { if (!pimpl_->endpoint) { if (pimpl_->logger_) pimpl_->logger_->warn("No endpoint found for socket"); return false; } return pimpl_->endpoint->isInitiator(); } int MultiplexedSocket::maxPayload() const { if (!pimpl_->endpoint) { if (pimpl_->logger_) pimpl_->logger_->warn("No endpoint found for socket"); return 0; } return pimpl_->endpoint->maxPayload(); } std::size_t MultiplexedSocket::write(const uint16_t& channel, const uint8_t* buf, std::size_t len, std::error_code& ec) { assert(nullptr != buf); if (pimpl_->isShutdown_) { ec = std::make_error_code(std::errc::broken_pipe); return -1; } if (len > UINT16_MAX) { ec = std::make_error_code(std::errc::message_size); return -1; } bool oneShot = len < 8192; msgpack::sbuffer buffer(oneShot ? 16 + len : 16); msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_array(2); pk.pack(channel); pk.pack_bin(len); if (oneShot) pk.pack_bin_body((const char*) buf, len); std::unique_lock<std::mutex> lk(pimpl_->writeMtx); if (!pimpl_->endpoint) { if (pimpl_->logger_) pimpl_->logger_->warn("No endpoint found for socket"); ec = std::make_error_code(std::errc::broken_pipe); return -1; } int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec); if (not oneShot and res >= 0) res = pimpl_->endpoint->write(buf, len, ec); lk.unlock(); if (res < 0) { if (ec && pimpl_->logger_) pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message()); shutdown(); } return res; } void MultiplexedSocket::shutdown() { pimpl_->shutdown(); } void MultiplexedSocket::join() { pimpl_->join(); } void MultiplexedSocket::onShutdown(OnShutdownCb&& cb) { pimpl_->onShutdown_ = std::move(cb); if (pimpl_->isShutdown_) pimpl_->onShutdown_(); } const std::shared_ptr<Logger>& MultiplexedSocket::logger() { return pimpl_->logger_; } void MultiplexedSocket::monitor() const { auto cert = peerCertificate(); if (!cert || !cert->issuer) return; auto now = clock::now(); if (!pimpl_->logger_) return; pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId()); pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_)); pimpl_->endpoint->monitor(); std::lock_guard<std::mutex> lk(pimpl_->socketsMutex); for (const auto& [_, channel] : pimpl_->sockets) { if (channel) pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}", fmt::ptr(channel.get()), channel.use_count(), channel->name(), channel->isInitiator()); } } void MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout) { pimpl_->sendBeacon(timeout); } std::shared_ptr<dht::crypto::Certificate> MultiplexedSocket::peerCertificate() const { return pimpl_->endpoint->peerCertificate(); } #ifdef LIBJAMI_TESTABLE bool MultiplexedSocket::canSendBeacon() const { return pimpl_->canSendBeacon_; } void MultiplexedSocket::answerToBeacon(bool value) { pimpl_->answerBeacon_ = value; } void MultiplexedSocket::setVersion(int version) { pimpl_->version_ = version; } void MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb) { pimpl_->onBeaconCb_ = cb; } void MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb) { pimpl_->onVersionCb_ = cb; } void MultiplexedSocket::sendVersion() { pimpl_->sendVersion(); } #endif IpAddr MultiplexedSocket::getLocalAddress() const { return pimpl_->endpoint->getLocalAddress(); } IpAddr MultiplexedSocket::getRemoteAddress() const { return pimpl_->endpoint->getRemoteAddress(); } void MultiplexedSocket::eraseChannel(uint16_t channel) { std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex); auto itSocket = pimpl_->sockets.find(channel); if (pimpl_->sockets.find(channel) != pimpl_->sockets.end()) pimpl_->sockets.erase(itSocket); } //////////////////////////////////////////////////////////////// class ChannelSocket::Impl { public: Impl(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel, bool isInitiator, std::function<void()> rmFromMxSockCb) : name(name) , channel(channel) , endpoint(std::move(endpoint)) , isInitiator_(isInitiator) , rmFromMxSockCb_(std::move(rmFromMxSockCb)) {} ~Impl() {} ChannelReadyCb readyCb_ {}; OnShutdownCb shutdownCb_ {}; std::atomic_bool isShutdown_ {false}; std::string name {}; uint16_t channel {}; std::weak_ptr<MultiplexedSocket> endpoint {}; bool isInitiator_ {false}; std::function<void()> rmFromMxSockCb_; bool isAnswered_ {false}; bool isRemovable_ {false}; std::vector<uint8_t> buf {}; std::mutex mutex {}; std::condition_variable cv {}; GenericSocket<uint8_t>::RecvCb cb {}; }; ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, const std::string& name, const uint16_t& channel) : pimpl_deviceId(deviceId) , pimpl_name(name) , pimpl_channel(channel) , ioCtx_(*ctx) {} ChannelSocketTest::~ChannelSocketTest() {} void ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1, const std::shared_ptr<ChannelSocketTest>& socket2) { socket1->remote = socket2; socket2->remote = socket1; } DeviceId ChannelSocketTest::deviceId() const { return pimpl_deviceId; } std::string ChannelSocketTest::name() const { return pimpl_name; } uint16_t ChannelSocketTest::channel() const { return pimpl_channel; } void ChannelSocketTest::shutdown() { { std::unique_lock<std::mutex> lk {mutex}; if (!isShutdown_.exchange(true)) { lk.unlock(); shutdownCb_(); } cv.notify_all(); } if (auto peer = remote.lock()) { if (!peer->isShutdown_.exchange(true)) { peer->shutdownCb_(); } peer->cv.notify_all(); } } std::size_t ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec) { std::size_t size = std::min(len, this->rx_buf.size()); for (std::size_t i = 0; i < size; ++i) buf[i] = this->rx_buf[i]; if (size == this->rx_buf.size()) { this->rx_buf.clear(); } else this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size); return size; } std::size_t ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec) { if (isShutdown_) { ec = std::make_error_code(std::errc::broken_pipe); return -1; } ec = {}; dht::ThreadPool::computation().run( [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable { if (auto peer = r.lock()) peer->onRecv(std::move(data)); }); return len; } int ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const { std::unique_lock<std::mutex> lk {mutex}; cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; }); return rx_buf.size(); } void ChannelSocketTest::setOnRecv(RecvCb&& cb) { std::lock_guard<std::mutex> lkSockets(mutex); this->cb = std::move(cb); if (!rx_buf.empty() && this->cb) { this->cb(rx_buf.data(), rx_buf.size()); rx_buf.clear(); } } void ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt) { std::lock_guard<std::mutex> lkSockets(mutex); if (cb) { cb(pkt.data(), pkt.size()); return; } rx_buf.insert(rx_buf.end(), std::make_move_iterator(pkt.begin()), std::make_move_iterator(pkt.end())); cv.notify_all(); } void ChannelSocketTest::onReady(ChannelReadyCb&& cb) {} void ChannelSocketTest::onShutdown(OnShutdownCb&& cb) { std::unique_lock<std::mutex> lk {mutex}; shutdownCb_ = std::move(cb); if (isShutdown_) { lk.unlock(); shutdownCb_(); } } ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel, bool isInitiator, std::function<void()> rmFromMxSockCb) : pimpl_ { std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))} {} ChannelSocket::~ChannelSocket() {} DeviceId ChannelSocket::deviceId() const { if (auto ep = pimpl_->endpoint.lock()) { return ep->deviceId(); } return {}; } std::string ChannelSocket::name() const { return pimpl_->name; } uint16_t ChannelSocket::channel() const { return pimpl_->channel; } bool ChannelSocket::isReliable() const { if (auto ep = pimpl_->endpoint.lock()) { return ep->isReliable(); } return false; } bool ChannelSocket::isInitiator() const { // Note. Is initiator here as not the same meaning of MultiplexedSocket. // because a multiplexed socket can have sockets from accepted requests // or made via connectDevice(). Here, isInitiator_ return if the socket // is from connectDevice. return pimpl_->isInitiator_; } int ChannelSocket::maxPayload() const { if (auto ep = pimpl_->endpoint.lock()) { return ep->maxPayload(); } return -1; } void ChannelSocket::setOnRecv(RecvCb&& cb) { std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); pimpl_->cb = std::move(cb); if (!pimpl_->buf.empty() && pimpl_->cb) { pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size()); pimpl_->buf.clear(); } } void ChannelSocket::onRecv(std::vector<uint8_t>&& pkt) { std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); if (pimpl_->cb) { pimpl_->cb(&pkt[0], pkt.size()); return; } pimpl_->buf.insert(pimpl_->buf.end(), std::make_move_iterator(pkt.begin()), std::make_move_iterator(pkt.end())); pimpl_->cv.notify_all(); } #ifdef LIBJAMI_TESTABLE std::shared_ptr<MultiplexedSocket> ChannelSocket::underlyingSocket() const { if (auto mtx = pimpl_->endpoint.lock()) return mtx; return {}; } #endif void ChannelSocket::answered() { pimpl_->isAnswered_ = true; } void ChannelSocket::removable() { pimpl_->isRemovable_ = true; } bool ChannelSocket::isRemovable() const { return pimpl_->isRemovable_; } bool ChannelSocket::isAnswered() const { return pimpl_->isAnswered_; } void ChannelSocket::ready(bool accepted) { if (pimpl_->readyCb_) pimpl_->readyCb_(accepted); } void ChannelSocket::stop() { if (pimpl_->isShutdown_) return; pimpl_->isShutdown_ = true; if (pimpl_->shutdownCb_) pimpl_->shutdownCb_(); pimpl_->cv.notify_all(); // stop() can be called by ChannelSocket::shutdown() // In this case, the eventLoop is not used, but MxSock // must remove the channel from its list (so that the // channel can be destroyed and its shared_ptr invalidated). if (pimpl_->rmFromMxSockCb_) pimpl_->rmFromMxSockCb_(); } void ChannelSocket::shutdown() { if (pimpl_->isShutdown_) return; stop(); if (auto ep = pimpl_->endpoint.lock()) { std::error_code ec; const uint8_t dummy = '\0'; ep->write(pimpl_->channel, &dummy, 0, ec); } } std::size_t ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec) { std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); std::size_t size = std::min(len, pimpl_->buf.size()); for (std::size_t i = 0; i < size; ++i) outBuf[i] = pimpl_->buf[i]; pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size); return size; } std::size_t ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec) { if (pimpl_->isShutdown_) { ec = std::make_error_code(std::errc::broken_pipe); return -1; } if (auto ep = pimpl_->endpoint.lock()) { std::size_t sent = 0; do { std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent); auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec); if (ec) { if (ep->logger()) ep->logger()->error("Error when writing on channel: {}", ec.message()); return res; } sent += toSend; } while (sent < len); return sent; } ec = std::make_error_code(std::errc::broken_pipe); return -1; } int ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const { std::unique_lock<std::mutex> lk {pimpl_->mutex}; pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; }); return pimpl_->buf.size(); } void ChannelSocket::onShutdown(OnShutdownCb&& cb) { pimpl_->shutdownCb_ = std::move(cb); if (pimpl_->isShutdown_) { pimpl_->shutdownCb_(); } } void ChannelSocket::onReady(ChannelReadyCb&& cb) { pimpl_->readyCb_ = std::move(cb); } void ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout) { if (auto ep = pimpl_->endpoint.lock()) { ep->sendBeacon(timeout); } else { shutdown(); } } std::shared_ptr<dht::crypto::Certificate> ChannelSocket::peerCertificate() const { if (auto ep = pimpl_->endpoint.lock()) return ep->peerCertificate(); return {}; } IpAddr ChannelSocket::getLocalAddress() const { if (auto ep = pimpl_->endpoint.lock()) return ep->getLocalAddress(); return {}; } IpAddr ChannelSocket::getRemoteAddress() const { if (auto ep = pimpl_->endpoint.lock()) return ep->getRemoteAddress(); return {}; } } // namespace jami