From e781b9c73aec2c8ea4f6b9f622eeb579580634a8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Blin?=
 <sebastien.blin@savoirfairelinux.com>
Date: Tue, 16 Jun 2020 12:15:09 -0400
Subject: [PATCH] multiplexedsocket: various fixes

Check for endpoint existance to avoid nullptr exception.
Avoid to write multiple packets at the same time on the socket. This avoid
to mix packet.
Avoid to be able to shut the channel multiple times

Change-Id: If5158b51f55f368091616062ced4d641130c8468
---
 src/jamidht/multiplexed_socket.cpp | 39 ++++++++++++++++++------------
 src/jamidht/multiplexed_socket.h   |  5 ++++
 2 files changed, 28 insertions(+), 16 deletions(-)

diff --git a/src/jamidht/multiplexed_socket.cpp b/src/jamidht/multiplexed_socket.cpp
index 4af6ba90ca..81daa39937 100644
--- a/src/jamidht/multiplexed_socket.cpp
+++ b/src/jamidht/multiplexed_socket.cpp
@@ -67,17 +67,18 @@ public:
         stop.store(true);
         isShutdown_ = true;
         if (onShutdown_) onShutdown_();
-        endpoint->setOnStateChange({});
-        endpoint->shutdown();
-        {
-            std::lock_guard<std::mutex> lkSockets(socketsMutex);
-            for (auto& socket : sockets) {
-                // Just trigger onShutdown() to make client know
-                // No need to write the EOF for the channel, the write will fail because endpoint is already shutdown
-                if (socket.second) socket.second->stop();
-            }
-            sockets.clear();
+        if (endpoint) {
+            endpoint->setOnStateChange({});
+            std::unique_lock<std::mutex> lk(writeMtx);
+            endpoint->shutdown();
         }
+        std::lock_guard<std::mutex> lkSockets(socketsMutex);
+        for (auto& socket : sockets) {
+            // Just trigger onShutdown() to make client know
+            // No need to write the EOF for the channel, the write will fail because endpoint is already shutdown
+            if (socket.second) socket.second->stop();
+        }
+        sockets.clear();
     }
 
     /**
@@ -123,6 +124,8 @@ public:
     std::mutex channelCbsMtx_ {};
     std::map<uint16_t, GenericSocket<uint8_t>::RecvCb> channelCbs_ {};
     std::atomic_bool isShutdown_ {false};
+
+    std::mutex writeMtx;
 };
 
 void
@@ -160,11 +163,10 @@ MultiplexedSocket::Impl::eventLoop()
             try {
                 auto msg = oh.get().as<ChanneledMessage>();
 
-                if (msg.channel == 0) {
+                if (msg.channel == 0)
                     handleControlPacket(std::move(msg.data));
-                } else {
+                else
                     handleChannelPacket(msg.channel, std::move(msg.data));
-                }
             } catch (const msgpack::unpack_error &e) {
                 JAMI_WARN("Error when decoding msgpack message: %s", e.what());
             }
@@ -273,7 +275,7 @@ MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, const std::vector
             auto cb = channelCbs_.find(channel);
             if (cb != channelCbs_.end()) {
                 lk.unlock();
-                cb->second(&pkt[0], pkt.size());
+                if (cb->second) cb->second(&pkt[0], pkt.size());
                 return;
             }
             lk.unlock();
@@ -413,7 +415,9 @@ MultiplexedSocket::write(const uint16_t& channel, const uint8_t* buf, std::size_
     pk.pack_bin(len);
     pk.pack_bin_body((const char*)buf, len);
 
+    std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
     int res = pimpl_->endpoint->write((const unsigned char*)buffer.data(), buffer.size(), ec);
+    lk.unlock();
     if (res < 0) {
         if (ec)
             JAMI_ERR("Error when writing on socket: %s", ec.message().c_str());
@@ -481,7 +485,7 @@ public:
     Impl(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel)
     : name(name), channel(channel), endpoint(std::move(endpoint)) {}
 
-    ~Impl() {}
+    ~Impl() { }
 
     OnShutdownCb shutdownCb_;
     std::atomic_bool isShutdown_ {false};
@@ -564,13 +568,16 @@ ChannelSocket::underlyingICE() const
 void
 ChannelSocket::stop()
 {
+    if (pimpl_->isShutdown_) return;
     pimpl_->isShutdown_ = true;
-    if (pimpl_->shutdownCb_) pimpl_->shutdownCb_();
+    if (pimpl_->shutdownCb_)
+        pimpl_->shutdownCb_();
 }
 
 void
 ChannelSocket::shutdown()
 {
+    if (pimpl_->isShutdown_) return;
     stop();
     if (auto ep = pimpl_->endpoint.lock()) {
         std::error_code ec;
diff --git a/src/jamidht/multiplexed_socket.h b/src/jamidht/multiplexed_socket.h
index 4f4767ef6f..101f50bb51 100644
--- a/src/jamidht/multiplexed_socket.h
+++ b/src/jamidht/multiplexed_socket.h
@@ -158,6 +158,11 @@ public:
     std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
     int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override;
 
+    /**
+     * set a callback when receiving data
+     * @note: this callback should take a little time and not block
+     * but you can move it in a thread
+     */
     void setOnRecv(RecvCb&&) override;
 
     std::shared_ptr<IceTransport> underlyingICE() const;
-- 
GitLab