From ef23ce0c3a54791e983535c89a5f36c167cacdb0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Blin?=
 <sebastien.blin@savoirfairelinux.com>
Date: Mon, 17 May 2021 16:13:21 -0400
Subject: [PATCH] multiplexed_socket: use shared_ptr for ChannelInfo

This fix a potential crash into waitForData where the channel is deleted
after the check and before locking. Also notify when deleting. Also
add isDestroying on channel to avoid to wait indefinitely on destroyed channel

GitLab: #549
Change-Id: Ib37c4bab31def4ca6f594fc1338a018138217765
---
 src/jamidht/multiplexed_socket.cpp | 47 ++++++++++++++++++------------
 1 file changed, 29 insertions(+), 18 deletions(-)

diff --git a/src/jamidht/multiplexed_socket.cpp b/src/jamidht/multiplexed_socket.cpp
index 8981f66028..504e3db519 100644
--- a/src/jamidht/multiplexed_socket.cpp
+++ b/src/jamidht/multiplexed_socket.cpp
@@ -55,6 +55,7 @@ using time_point = clock::time_point;
 
 struct ChannelInfo
 {
+    std::atomic_bool isDestroying {false};
     std::deque<uint8_t> buf {};
     std::mutex mutex {};
     std::condition_variable cv {};
@@ -99,8 +100,10 @@ public:
             std::lock_guard<std::mutex> lkSockets(socketsMutex);
             socks = std::move(sockets);
             for (auto& [key, channelData] : channelDatas_) {
-                if (channelData)
+                if (channelData) {
+                    channelData->isDestroying = true;
                     channelData->cv.notify_all();
+                }
             }
         }
         for (auto& socket : socks) {
@@ -180,7 +183,7 @@ public:
     std::thread eventLoopThread_ {};
 
     // Multiplexed available datas
-    std::map<uint16_t, std::unique_ptr<ChannelInfo>> channelDatas_ {};
+    std::map<uint16_t, std::shared_ptr<ChannelInfo>> channelDatas_ {};
     std::mutex channelCbsMtx_ {};
     std::map<uint16_t, GenericSocket<uint8_t>::RecvCb> channelCbs_ {};
     std::atomic_bool isShutdown_ {false};
@@ -255,7 +258,7 @@ MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
     std::lock_guard<std::mutex> lkSockets(socketsMutex);
     auto& channelData = channelDatas_[channel];
     if (not channelData)
-        channelData = std::make_unique<ChannelInfo>();
+        channelData = std::make_shared<ChannelInfo>();
     auto& channelSocket = sockets[channel];
     if (not channelSocket)
         channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel);
@@ -371,7 +374,7 @@ MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
                           "the request will close the previous one");
                 sockets.erase(sockIt);
             }
-            channelDatas_.emplace(channel, std::make_unique<ChannelInfo>());
+            channelDatas_.emplace(channel, std::make_shared<ChannelInfo>());
             sockets.emplace(channel, channelSocket);
         }
     }
@@ -419,17 +422,24 @@ MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
                 msgpack::unpacked result;
                 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
                 auto object = result.get();
-                if (shared->pimpl_->handleProtocolMsg(object))
+                auto& pimpl = *shared->pimpl_;
+                if (pimpl.handleProtocolMsg(object))
                     continue;
                 auto req = object.as<ChannelRequest>();
                 if (req.state == ChannelRequestState::ACCEPT) {
-                    shared->pimpl_->onAccept(req.name, req.channel);
+                    pimpl.onAccept(req.name, req.channel);
                 } else if (req.state == ChannelRequestState::DECLINE) {
-                    std::lock_guard<std::mutex> lkSockets(shared->pimpl_->socketsMutex);
-                    shared->pimpl_->channelDatas_.erase(req.channel);
-                    shared->pimpl_->sockets.erase(req.channel);
-                } else if (shared->pimpl_->onRequest_) {
-                    shared->pimpl_->onRequest(req.name, req.channel);
+                    std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
+                    auto& channelDatas = pimpl.channelDatas_;
+                    auto dataIt = channelDatas.find(req.channel);
+                    if (dataIt != channelDatas.end() && dataIt->second) {
+                        dataIt->second->isDestroying = true;
+                        dataIt->second->cv.notify_all();
+                        channelDatas.erase(dataIt);
+                    }
+                    pimpl.sockets.erase(req.channel);
+                } else if (pimpl.onRequest_) {
+                    pimpl.onRequest(req.name, req.channel);
                 }
             }
         } catch (const std::exception& e) {
@@ -447,6 +457,7 @@ MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8
     if (channel > 0 && sockIt->second && dataIt->second) {
         if (pkt.size() == 0) {
             sockIt->second->shutdown();
+            dataIt->second->isDestroying = true;
             dataIt->second->cv.notify_all();
             channelDatas_.erase(dataIt);
             sockets.erase(sockIt);
@@ -555,7 +566,7 @@ MultiplexedSocket::addChannel(const std::string& name)
         if (!socket) {
             auto& channel = pimpl_->channelDatas_[c];
             if (!channel)
-                channel = std::make_unique<ChannelInfo>();
+                channel = std::make_shared<ChannelInfo>();
             socket = std::make_shared<ChannelSocket>(weak(), name, c);
             return socket;
         }
@@ -691,18 +702,17 @@ MultiplexedSocket::waitForData(const uint16_t& channel,
         ec = std::make_error_code(std::errc::broken_pipe);
         return -1;
     }
+    std::unique_lock lkSockets {pimpl_->socketsMutex};
     auto dataIt = pimpl_->channelDatas_.find(channel);
-    if (dataIt == pimpl_->channelDatas_.end()) {
+    if (dataIt == pimpl_->channelDatas_.end() or not dataIt->second) {
         ec = std::make_error_code(std::errc::broken_pipe);
         return -1;
     }
-    auto& channelData = dataIt->second;
-    if (!channelData) {
-        return -1;
-    }
+    auto channelData = dataIt->second;
+    lkSockets.unlock();
     std::unique_lock<std::mutex> lk {channelData->mutex};
     channelData->cv.wait_for(lk, timeout, [&] {
-        return !channelData->buf.empty() or pimpl_->isShutdown_;
+        return channelData->isDestroying or !channelData->buf.empty() or pimpl_->isShutdown_;
     });
     return channelData->buf.size();
 }
@@ -712,6 +722,7 @@ MultiplexedSocket::setOnRecv(const uint16_t& channel, GenericSocket<uint8_t>::Re
 {
     std::deque<uint8_t> recv;
     {
+        // NOTE: here socketsMtx is already locked via onAccept
         std::lock_guard<std::mutex> lk(pimpl_->channelCbsMtx_);
         pimpl_->channelCbs_[channel] = cb;
 
-- 
GitLab