From b0209b799b03014e9c1d1490855a6eba6515f4ab Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Blin?=
 <sebastien.blin@savoirfairelinux.com>
Date: Wed, 21 Oct 2020 13:23:48 -0400
Subject: [PATCH] connectionmanager: onShutdown must call all callbacks

When a shutdown appears for a multiplexed socket, all current
pending channels are closed, so callbacks related to these
closed channels must be closed.

Change-Id: Ic15d4552bc3c2445c9aa25babd93ad9d6c473e19
---
 src/jamidht/connectionmanager.cpp             | 82 ++++++++++---------
 .../connectionManager/connectionManager.cpp   | 69 +++++++++++++++-
 2 files changed, 110 insertions(+), 41 deletions(-)

diff --git a/src/jamidht/connectionmanager.cpp b/src/jamidht/connectionmanager.cpp
index e7a39ae545..84d64fb46a 100644
--- a/src/jamidht/connectionmanager.cpp
+++ b/src/jamidht/connectionmanager.cpp
@@ -39,6 +39,7 @@ static constexpr std::chrono::seconds ICE_INIT_TIMEOUT {10};
 static constexpr std::chrono::seconds DHT_MSG_TIMEOUT {30};
 static constexpr std::chrono::seconds SOCK_TIMEOUT {10};
 using ValueIdDist = std::uniform_int_distribution<dht::Value::Id>;
+using CallbackId = std::pair<jami::DeviceId, dht::Value::Id>;
 
 namespace jami {
 
@@ -52,13 +53,12 @@ struct ConnectionInfo
     // Used to store currently non ready TLS Socket
     std::unique_ptr<TlsSocketEndpoint> tls_ {nullptr};
     std::shared_ptr<MultiplexedSocket> socket_ {};
+    std::set<CallbackId> cbIds_ {};
 };
 
 class ConnectionManager::Impl : public std::enable_shared_from_this<ConnectionManager::Impl>
 {
 public:
-    using ConnectionKey = std::pair<DeviceId /* device id */, dht::Value::Id /* uid */>;
-
     explicit Impl(JamiAccount& account)
         : account {account}
     {}
@@ -161,15 +161,25 @@ public:
     std::mutex infosMtx_ {};
     // Note: Someone can ask multiple sockets, so to avoid any race condition,
     // each device can have multiple multiplexed sockets.
-    std::map<ConnectionKey, std::shared_ptr<ConnectionInfo>> infos_ {};
+    std::map<CallbackId, std::shared_ptr<ConnectionInfo>> infos_ {};
 
-    std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, const dht::Value::Id& id)
+    std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId,
+                                            const dht::Value::Id& id = dht::Value::INVALID_ID)
     {
         std::lock_guard<std::mutex> lk(infosMtx_);
-        auto it = infos_.find({deviceId, id});
-        if (it == infos_.end())
-            return {};
-        return it->second;
+        decltype(infos_)::iterator it;
+        if (id == dht::Value::INVALID_ID) {
+            it = std::find_if(infos_.begin(), infos_.end(), [&](const auto& item) {
+                auto& [key, value] = item;
+                return key.first == deviceId;
+            });
+        } else {
+            it = infos_.find({deviceId, id});
+        }
+
+        if (it != infos_.end())
+            return it->second;
+        return {};
     }
 
     ChannelRequestCallback channelReqCb_ {};
@@ -177,7 +187,6 @@ public:
     onICERequestCallback iceReqCb_ {};
 
     std::mutex connectCbsMtx_ {};
-    using CallbackId = std::pair<DeviceId, dht::Value::Id>;
     std::map<CallbackId, ConnectCallback> pendingCbs_ {};
 
     ConnectCallback getPendingCallback(const CallbackId& cbId)
@@ -371,7 +380,7 @@ ConnectionManager::Impl::connectDevice(const DeviceId& deviceId,
                     return;
                 }
                 auto vid = ValueIdDist()(sthis->account.rand);
-                ConnectionKey cbId(deviceId, vid);
+                CallbackId cbId(deviceId, vid);
                 {
                     std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_);
                     auto cbIt = sthis->pendingCbs_.find(cbId);
@@ -383,25 +392,16 @@ ConnectionManager::Impl::connectDevice(const DeviceId& deviceId,
                     }
                 }
 
-                std::shared_ptr<MultiplexedSocket> sock;
-                {
-                    // Test if a socket already exists for this device
-                    std::lock_guard<std::mutex> lk(sthis->infosMtx_);
-                    auto it = std::find_if(sthis->infos_.begin(),
-                                           sthis->infos_.end(),
-                                           [deviceId](const auto& item) {
-                                               auto& [key, value] = item;
-                                               return key.first == deviceId;
-                                           });
-                    if (it != sthis->infos_.end() && it->second) {
-                        sock = it->second->socket_;
+                if (auto info = sthis->getInfo(deviceId)) {
+                    std::lock_guard<std::mutex> lk(info->mutex_);
+                    if (info->socket_) {
+                        JAMI_DBG("Peer already connected. Add a new channel");
+                        info->cbIds_.emplace(cbId);
+                        sthis->sendChannelRequest(info->socket_, name, deviceId, vid);
+                        return;
                     }
                 }
-                if (sock) {
-                    JAMI_DBG("Peer already connected. Add a new channel");
-                    sthis->sendChannelRequest(sock, name, deviceId, vid);
-                    return;
-                }
+
                 // If no socket exists, we need to initiate an ICE connection.
                 auto& iceTransportFactory = Manager::instance().getIceTransportFactory();
                 auto ice_config = sthis->account.getIceOptions();
@@ -794,25 +794,27 @@ ConnectionManager::Impl::addNewMultiplexedSocket(const DeviceId& deviceId, const
             return false;
         });
     info->socket_->onShutdown([w = weak(), deviceId, vid]() {
-        auto sthis = w.lock();
-        if (!sthis)
-            return;
         // Cancel current outgoing connections
-        if (auto cb = sthis->getPendingCallback({deviceId, vid}))
-            cb(nullptr);
         dht::ThreadPool::io().run([w, deviceId = dht::InfoHash(deviceId), vid] {
             auto sthis = w.lock();
             if (!sthis)
                 return;
-            auto info = sthis->getInfo(deviceId, vid);
-            if (!info)
-                return;
-
-            if (info->socket_)
-                info->socket_->shutdown();
 
-            if (info && info->ice_)
-                info->ice_->cancelOperations();
+            std::set<CallbackId> ids;
+            if (auto info = sthis->getInfo(deviceId, vid)) {
+                std::lock_guard<std::mutex> lk(info->mutex_);
+                if (info->socket_) {
+                    ids = std::move(info->cbIds_);
+                    info->socket_->shutdown();
+                }
+                if (info->ice_)
+                    info->ice_->cancelOperations();
+            }
+            for (const auto& cbId : ids) {
+                if (auto cb = sthis->getPendingCallback(cbId)) {
+                    cb(nullptr);
+                }
+            }
 
             std::lock_guard<std::mutex> lk(sthis->infosMtx_);
             sthis->infos_.erase({deviceId, vid});
diff --git a/test/unitTest/connectionManager/connectionManager.cpp b/test/unitTest/connectionManager/connectionManager.cpp
index 3af21fe308..10b78b10ee 100644
--- a/test/unitTest/connectionManager/connectionManager.cpp
+++ b/test/unitTest/connectionManager/connectionManager.cpp
@@ -65,6 +65,7 @@ private:
     void testChannelRcvShutdown();
     void testChannelSenderShutdown();
     void testCloseConnectionWithDevice();
+    void testShutdownCallbacks();
 
     CPPUNIT_TEST_SUITE(ConnectionManagerTest);
     CPPUNIT_TEST(testConnectDevice);
@@ -78,6 +79,7 @@ private:
     CPPUNIT_TEST(testChannelRcvShutdown);
     CPPUNIT_TEST(testChannelSenderShutdown);
     CPPUNIT_TEST(testCloseConnectionWithDevice);
+    CPPUNIT_TEST(testShutdownCallbacks);
     CPPUNIT_TEST_SUITE_END();
 };
 
@@ -724,13 +726,78 @@ ConnectionManagerTest::testCloseConnectionWithDevice()
     // This should trigger onShutdown
     aliceAccount->connectionManager().closeConnectionsWith(bobDeviceId);
     auto expiration = std::chrono::system_clock::now() + std::chrono::seconds(10);
-    scv.wait_until(lk, expiration, [&events]() { return events == 4; });
+    scv.wait_until(lk, expiration, [&events]() { return events == 2; });
     CPPUNIT_ASSERT(events == 2);
     CPPUNIT_ASSERT(successfullyReceive);
     CPPUNIT_ASSERT(successfullyConnected);
     CPPUNIT_ASSERT(receiverConnected);
 }
 
+void
+ConnectionManagerTest::testShutdownCallbacks()
+{
+    auto aliceAccount = Manager::instance().getAccount<JamiAccount>(aliceId);
+    auto bobAccount = Manager::instance().getAccount<JamiAccount>(bobId);
+    auto bobDeviceId = DeviceId(bobAccount->getAccountDetails()[ConfProperties::RING_DEVICE_ID]);
+    auto aliceDeviceId = DeviceId(aliceAccount->getAccountDetails()[ConfProperties::RING_DEVICE_ID]);
+
+    bobAccount->connectionManager().onICERequest([](const DeviceId&) { return true; });
+    aliceAccount->connectionManager().onICERequest([](const DeviceId&) { return true; });
+
+    std::mutex mtx;
+    std::unique_lock<std::mutex> lk {mtx};
+    std::condition_variable rcv, chan2cv;
+    bool successfullyConnected = false;
+    bool successfullyReceive = false;
+    bool receiverConnected = false;
+
+    bobAccount->connectionManager().onChannelRequest(
+        [&successfullyReceive, &chan2cv](const DeviceId&, const std::string& name) {
+            if (name == "1") {
+                successfullyReceive = true;
+            } else {
+                chan2cv.notify_one();
+                // Do not return directly. Let the connection be closed
+                std::this_thread::sleep_for(std::chrono::seconds(10));
+            }
+            return true;
+        });
+
+    bobAccount->connectionManager().onConnectionReady(
+        [&](const DeviceId&, const std::string& name, std::shared_ptr<ChannelSocket> socket) {
+            receiverConnected = socket && (name == "1");
+        });
+
+    aliceAccount->connectionManager().connectDevice(bobDeviceId,
+                                                    "1",
+                                                    [&](std::shared_ptr<ChannelSocket> socket) {
+                                                        if (socket) {
+                                                            successfullyConnected = true;
+                                                            rcv.notify_one();
+                                                        }
+                                                    });
+    // Connect first channel. This will initiate a mx sock
+    rcv.wait_for(lk, std::chrono::seconds(30));
+    CPPUNIT_ASSERT(successfullyReceive);
+    CPPUNIT_ASSERT(successfullyConnected);
+    CPPUNIT_ASSERT(receiverConnected);
+
+    // Connect another channel, but close the connection
+    bool channel2NotConnected = false;
+    aliceAccount->connectionManager().connectDevice(bobDeviceId,
+                                                    "2",
+                                                    [&](std::shared_ptr<ChannelSocket> socket) {
+                                                        channel2NotConnected = !socket;
+                                                        rcv.notify_one();
+                                                    });
+    chan2cv.wait_for(lk, std::chrono::seconds(30));
+
+    // This should trigger onShutdown for second callback
+    bobAccount->connectionManager().closeConnectionsWith(aliceDeviceId);
+    rcv.wait_for(lk, std::chrono::seconds(30));
+    CPPUNIT_ASSERT(channel2NotConnected);
+}
+
 } // namespace test
 } // namespace jami
 
-- 
GitLab