diff --git a/src/jamidht/multiplexed_socket.cpp b/src/jamidht/multiplexed_socket.cpp index 8981f66028ea4129e485c38d237144cb00d3751d..504e3db519f0a6da23fe54aa09282c2253759d6e 100644 --- a/src/jamidht/multiplexed_socket.cpp +++ b/src/jamidht/multiplexed_socket.cpp @@ -55,6 +55,7 @@ using time_point = clock::time_point; struct ChannelInfo { + std::atomic_bool isDestroying {false}; std::deque<uint8_t> buf {}; std::mutex mutex {}; std::condition_variable cv {}; @@ -99,8 +100,10 @@ public: std::lock_guard<std::mutex> lkSockets(socketsMutex); socks = std::move(sockets); for (auto& [key, channelData] : channelDatas_) { - if (channelData) + if (channelData) { + channelData->isDestroying = true; channelData->cv.notify_all(); + } } } for (auto& socket : socks) { @@ -180,7 +183,7 @@ public: std::thread eventLoopThread_ {}; // Multiplexed available datas - std::map<uint16_t, std::unique_ptr<ChannelInfo>> channelDatas_ {}; + std::map<uint16_t, std::shared_ptr<ChannelInfo>> channelDatas_ {}; std::mutex channelCbsMtx_ {}; std::map<uint16_t, GenericSocket<uint8_t>::RecvCb> channelCbs_ {}; std::atomic_bool isShutdown_ {false}; @@ -255,7 +258,7 @@ MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel) std::lock_guard<std::mutex> lkSockets(socketsMutex); auto& channelData = channelDatas_[channel]; if (not channelData) - channelData = std::make_unique<ChannelInfo>(); + channelData = std::make_shared<ChannelInfo>(); auto& channelSocket = sockets[channel]; if (not channelSocket) channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel); @@ -371,7 +374,7 @@ MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel) "the request will close the previous one"); sockets.erase(sockIt); } - channelDatas_.emplace(channel, std::make_unique<ChannelInfo>()); + channelDatas_.emplace(channel, std::make_shared<ChannelInfo>()); sockets.emplace(channel, channelSocket); } } @@ -419,17 +422,24 @@ MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt) msgpack::unpacked result; msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off); auto object = result.get(); - if (shared->pimpl_->handleProtocolMsg(object)) + auto& pimpl = *shared->pimpl_; + if (pimpl.handleProtocolMsg(object)) continue; auto req = object.as<ChannelRequest>(); if (req.state == ChannelRequestState::ACCEPT) { - shared->pimpl_->onAccept(req.name, req.channel); + pimpl.onAccept(req.name, req.channel); } else if (req.state == ChannelRequestState::DECLINE) { - std::lock_guard<std::mutex> lkSockets(shared->pimpl_->socketsMutex); - shared->pimpl_->channelDatas_.erase(req.channel); - shared->pimpl_->sockets.erase(req.channel); - } else if (shared->pimpl_->onRequest_) { - shared->pimpl_->onRequest(req.name, req.channel); + std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex); + auto& channelDatas = pimpl.channelDatas_; + auto dataIt = channelDatas.find(req.channel); + if (dataIt != channelDatas.end() && dataIt->second) { + dataIt->second->isDestroying = true; + dataIt->second->cv.notify_all(); + channelDatas.erase(dataIt); + } + pimpl.sockets.erase(req.channel); + } else if (pimpl.onRequest_) { + pimpl.onRequest(req.name, req.channel); } } } catch (const std::exception& e) { @@ -447,6 +457,7 @@ MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8 if (channel > 0 && sockIt->second && dataIt->second) { if (pkt.size() == 0) { sockIt->second->shutdown(); + dataIt->second->isDestroying = true; dataIt->second->cv.notify_all(); channelDatas_.erase(dataIt); sockets.erase(sockIt); @@ -555,7 +566,7 @@ MultiplexedSocket::addChannel(const std::string& name) if (!socket) { auto& channel = pimpl_->channelDatas_[c]; if (!channel) - channel = std::make_unique<ChannelInfo>(); + channel = std::make_shared<ChannelInfo>(); socket = std::make_shared<ChannelSocket>(weak(), name, c); return socket; } @@ -691,18 +702,17 @@ MultiplexedSocket::waitForData(const uint16_t& channel, ec = std::make_error_code(std::errc::broken_pipe); return -1; } + std::unique_lock lkSockets {pimpl_->socketsMutex}; auto dataIt = pimpl_->channelDatas_.find(channel); - if (dataIt == pimpl_->channelDatas_.end()) { + if (dataIt == pimpl_->channelDatas_.end() or not dataIt->second) { ec = std::make_error_code(std::errc::broken_pipe); return -1; } - auto& channelData = dataIt->second; - if (!channelData) { - return -1; - } + auto channelData = dataIt->second; + lkSockets.unlock(); std::unique_lock<std::mutex> lk {channelData->mutex}; channelData->cv.wait_for(lk, timeout, [&] { - return !channelData->buf.empty() or pimpl_->isShutdown_; + return channelData->isDestroying or !channelData->buf.empty() or pimpl_->isShutdown_; }); return channelData->buf.size(); } @@ -712,6 +722,7 @@ MultiplexedSocket::setOnRecv(const uint16_t& channel, GenericSocket<uint8_t>::Re { std::deque<uint8_t> recv; { + // NOTE: here socketsMtx is already locked via onAccept std::lock_guard<std::mutex> lk(pimpl_->channelCbsMtx_); pimpl_->channelCbs_[channel] = cb;