diff --git a/src/connectivity/multiplexed_socket.cpp b/src/connectivity/multiplexed_socket.cpp index 54ec39cf005915b074b7dbe1e051657ba79d7a46..3ab3eb35db352464de8c8cc689dd6eadaf462a6c 100644 --- a/src/connectivity/multiplexed_socket.cpp +++ b/src/connectivity/multiplexed_socket.cpp @@ -129,7 +129,16 @@ public: channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel, - isInitiator); + isInitiator, + [w = parent_.weak(), channel]() { + // Remove socket in another thread to avoid any lock + dht::ThreadPool::io().run([w, channel]() { + if (auto shared = w.lock()) { + shared->eraseChannel(channel); + } + }); + + }); else { JAMI_WARN("A channel is already present on that socket, accepting " "the request will close the previous one %s", @@ -440,7 +449,7 @@ MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8 { std::lock_guard<std::mutex> lkSockets(socketsMutex); auto sockIt = sockets.find(channel); - if (channel > 0 && sockIt->second) { + if (channel > 0 && sockIt != sockets.end() && sockIt->second) { if (pkt.size() == 0) { sockIt->second->stop(); if (sockIt->second->isAnswered()) @@ -659,7 +668,7 @@ MultiplexedSocket::monitor() const std::lock_guard<std::mutex> lk(pimpl_->socketsMutex); for (const auto& [_, channel] : pimpl_->sockets) { if (channel) - JAMI_DEBUG("\t\t- Channel with name {:s}", channel->name()); + JAMI_DEBUG("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}", fmt::ptr(channel.get()), channel.use_count(), channel->name(), channel->isInitiator()); } } @@ -726,6 +735,15 @@ MultiplexedSocket::getRemoteAddress() const #endif +void +MultiplexedSocket::eraseChannel(uint16_t channel) +{ + std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex); + auto itSocket = pimpl_->sockets.find(channel); + if (pimpl_->sockets.find(channel) != pimpl_->sockets.end()) + pimpl_->sockets.erase(itSocket); +} + //////////////////////////////////////////////////////////////// class ChannelSocket::Impl @@ -734,11 +752,13 @@ public: Impl(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel, - bool isInitiator) + bool isInitiator, + std::function<void()> rmFromMxSockCb) : name(name) , channel(channel) , endpoint(std::move(endpoint)) , isInitiator_(isInitiator) + , rmFromMxSockCb_(std::move(rmFromMxSockCb)) {} ~Impl() {} @@ -750,6 +770,7 @@ public: uint16_t channel {}; std::weak_ptr<MultiplexedSocket> endpoint {}; bool isInitiator_ {false}; + std::function<void()> rmFromMxSockCb_; bool isAnswered_ {false}; bool isRemovable_ {false}; @@ -940,8 +961,9 @@ ChannelSocketTest::eventLoop() ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel, - bool isInitiator) - : pimpl_ {std::make_unique<Impl>(endpoint, name, channel, isInitiator)} + bool isInitiator, + std::function<void()> rmFromMxSockCb) + : pimpl_ {std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))} {} ChannelSocket::~ChannelSocket() {} @@ -1070,6 +1092,12 @@ ChannelSocket::stop() if (pimpl_->shutdownCb_) pimpl_->shutdownCb_(); pimpl_->cv.notify_all(); + // stop() can be called by ChannelSocket::shutdown() + // In this case, the eventLoop is not used, but MxSock + // must remove the channel from its list (so that the + // channel can be destroyed and its shared_ptr invalidated). + if (pimpl_->rmFromMxSockCb_) + pimpl_->rmFromMxSockCb_(); } void diff --git a/src/connectivity/multiplexed_socket.h b/src/connectivity/multiplexed_socket.h index 91f0f616eb9556ee09e859286a49ff62791e9535..6f1bad9c89125d81fd034d34eab0513012dd4c3a 100644 --- a/src/connectivity/multiplexed_socket.h +++ b/src/connectivity/multiplexed_socket.h @@ -17,6 +17,7 @@ */ #pragma once +#include <cstdint> #include <opendht/default_types.h> #include "connectivity/generic_io.h" @@ -138,6 +139,8 @@ public: IpAddr getLocalAddress() const; IpAddr getRemoteAddress() const; + void eraseChannel(uint16_t channel); + #ifdef LIBJAMI_TESTABLE /** * Check if we can send beacon on the socket @@ -263,7 +266,8 @@ public: ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel, - bool isInitiator = false); + bool isInitiator = false, + std::function<void()> rmFromMxSockCb = {}); ~ChannelSocket(); DeviceId deviceId() const override;