From 37d1d9a5e36f245e0b577a2279c22df6a4869e4a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Blin?=
 <sebastien.blin@savoirfairelinux.com>
Date: Wed, 15 Jun 2022 10:54:03 -0400
Subject: [PATCH] connectionmanager: erase info on connection failure

In some cases, the infos were not correctly refreshed, causing some
pending callbacks to never be called.
Also, split getInfo() in two methods to improve readability.

Change-Id: I1b60f2cf2ac5bf97c9a44a53794b56906d314e6a
GitLab: #TODO
---
 src/jamidht/connectionmanager.cpp             | 131 ++++++++----------
 .../connectionManager/connectionManager.cpp   |  60 ++++++++
 2 files changed, 121 insertions(+), 70 deletions(-)

diff --git a/src/jamidht/connectionmanager.cpp b/src/jamidht/connectionmanager.cpp
index 560a35e50c..c2e39402b8 100644
--- a/src/jamidht/connectionmanager.cpp
+++ b/src/jamidht/connectionmanager.cpp
@@ -115,9 +115,9 @@ public:
         dht::Value::Id vid;
     };
 
-    void connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk,
+    bool connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk,
                                const dht::Value::Id& vid);
-    void connectDeviceOnNegoDone(const DeviceId& deviceId,
+    bool connectDeviceOnNegoDone(const DeviceId& deviceId,
                                  const std::string& name,
                                  const dht::Value::Id& vid,
                                  const std::shared_ptr<dht::crypto::Certificate>& cert);
@@ -146,8 +146,8 @@ public:
     void answerTo(IceTransport& ice,
                   const dht::Value::Id& id,
                   const std::shared_ptr<dht::crypto::PublicKey>& fromPk);
-    void onRequestStartIce(const PeerConnectionRequest& req);
-    void onRequestOnNegoDone(const PeerConnectionRequest& req);
+    bool onRequestStartIce(const PeerConnectionRequest& req);
+    bool onRequestOnNegoDone(const PeerConnectionRequest& req);
     void onDhtPeerRequest(const PeerConnectionRequest& req,
                           const std::shared_ptr<dht::crypto::Certificate>& cert);
 
@@ -175,19 +175,22 @@ public:
     std::map<CallbackId, std::shared_ptr<ConnectionInfo>> infos_ {};
 
     std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId,
-                                            const dht::Value::Id& id = dht::Value::INVALID_ID)
+                                            const dht::Value::Id& id)
     {
         std::lock_guard<std::mutex> lk(infosMtx_);
-        decltype(infos_)::iterator it;
-        if (id == dht::Value::INVALID_ID) {
-            it = std::find_if(infos_.begin(), infos_.end(), [&](const auto& item) {
-                auto& [key, value] = item;
-                return key.first == deviceId;
-            });
-        } else {
-            it = infos_.find({deviceId, id});
-        }
+        auto it = infos_.find({deviceId, id});
+        if (it != infos_.end())
+            return it->second;
+        return {};
+    }
 
+    std::shared_ptr<ConnectionInfo> getConnectedInfo(const DeviceId& deviceId)
+    {
+        std::lock_guard<std::mutex> lk(infosMtx_);
+        auto it = std::find_if(infos_.begin(), infos_.end(), [&](const auto& item) {
+            auto& [key, value] = item;
+            return key.first == deviceId && value && value->socket_;
+        });
         if (it != infos_.end())
             return it->second;
         return {};
@@ -270,30 +273,21 @@ public:
     std::atomic_bool isDestroying_ {false};
 };
 
-void
+bool
 ConnectionManager::Impl::connectDeviceStartIce(
     const std::shared_ptr<dht::crypto::PublicKey>& devicePk, const dht::Value::Id& vid)
 {
     auto deviceId = devicePk->getLongId();
     auto info = getInfo(deviceId, vid);
-    if (!info) {
-        return;
-    }
+    if (!info)
+        return false;
 
     std::unique_lock<std::mutex> lk(info->mutex_);
     auto& ice = info->ice_;
 
-    auto onError = [&]() {
-        ice.reset();
-        // Erase all pending connect
-        for (const auto& pending : extractPendingCallbacks(deviceId))
-            pending.cb(nullptr, deviceId);
-    };
-
     if (!ice) {
         JAMI_ERR("No ICE detected");
-        onError();
-        return;
+        return false;
     }
 
     auto iceAttributes = ice->getLocalAttributes();
@@ -326,28 +320,28 @@ ConnectionManager::Impl::connectDeviceStartIce(
                                 });
     // Wait for call to onResponse() operated by DHT
     if (isDestroying_)
-        return; // This avoid to wait new negotiation when destroying
+        return true; // This avoid to wait new negotiation when destroying
     info->responseCv_.wait_for(lk, DHT_MSG_TIMEOUT);
     if (isDestroying_)
-        return; // The destructor can wake a pending wait here.
+        return true; // The destructor can wake a pending wait here.
     if (!info->responseReceived_) {
         JAMI_ERR("no response from DHT to E2E request.");
-        onError();
-        return;
+        return false;
     }
 
     if (!ice)
-        return;
+        return false;
 
     auto sdp = ice->parseIceCandidates(info->response_.ice_msg);
 
     if (not ice->startIce({sdp.rem_ufrag, sdp.rem_pwd}, std::move(sdp.rem_candidates))) {
         JAMI_WARN("[Account:%s] start ICE failed", account.getAccountID().c_str());
-        onError();
+        return false;
     }
+    return true;
 }
 
-void
+bool
 ConnectionManager::Impl::connectDeviceOnNegoDone(
     const DeviceId& deviceId,
     const std::string& name,
@@ -356,15 +350,13 @@ ConnectionManager::Impl::connectDeviceOnNegoDone(
 {
     auto info = getInfo(deviceId, vid);
     if (!info)
-        return;
+        return false;
 
     std::unique_lock<std::mutex> lk {info->mutex_};
     auto& ice = info->ice_;
     if (!ice || !ice->isRunning()) {
         JAMI_ERR("No ICE detected or not running");
-        for (const auto& pending : extractPendingCallbacks(deviceId))
-            pending.cb(nullptr, deviceId);
-        return;
+        return false;
     }
 
     // Build socket
@@ -387,6 +379,7 @@ ConnectionManager::Impl::connectDeviceOnNegoDone(
             if (auto shared = w.lock())
                 shared->onTlsNegotiationDone(ok, deviceId, vid, name);
         });
+    return true;
 }
 
 void
@@ -467,7 +460,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
 
         // Check if already negotiated
         CallbackId cbId(deviceId, vid);
-        if (auto info = sthis->getInfo(deviceId)) {
+        if (auto info = sthis->getConnectedInfo(deviceId)) {
             std::lock_guard<std::mutex> lk(info->mutex_);
             if (info->socket_) {
                 JAMI_DBG("Peer already connected to %s. Add a new channel", deviceId.to_c_str());
@@ -490,8 +483,11 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
 
         // Note: used when the ice negotiation fails to erase
         // all stored structures.
-        auto eraseInfo = [w, cbId] {
+        auto eraseInfo = [w, cbId, deviceId] {
             if (auto shared = w.lock()) {
+                // If no new socket is specified, we don't try to generate a new socket
+                for (const auto& pending : shared->extractPendingCallbacks(deviceId))
+                    pending.cb(nullptr, deviceId);
                 std::lock_guard<std::mutex> lk(shared->infosMtx_);
                 shared->infos_.erase(cbId);
             }
@@ -521,16 +517,15 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
                     return;
                 if (!ok) {
                     JAMI_ERR("Cannot initialize ICE session.");
-                    for (const auto& pending : sthis->extractPendingCallbacks(deviceId))
-                        pending.cb(nullptr, deviceId);
                     runOnMainThread([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
                     return;
                 }
 
                 dht::ThreadPool::io().run(
-                    [w = std::move(w), devicePk = std::move(devicePk), vid = std::move(vid)] {
+                    [w = std::move(w), devicePk = std::move(devicePk), vid = std::move(vid), eraseInfo] {
                         if (auto sthis = w.lock())
-                            sthis->connectDeviceStartIce(devicePk, vid);
+                            if (!sthis->connectDeviceStartIce(devicePk, vid))
+                                runOnMainThread([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
                     });
             };
             ice_config.onNegoDone = [w,
@@ -544,8 +539,6 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
                     return;
                 if (!ok) {
                     JAMI_ERR("ICE negotiation failed.");
-                    for (const auto& pending : sthis->extractPendingCallbacks(deviceId))
-                        pending.cb(nullptr, deviceId);
                     runOnMainThread([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
                     return;
                 }
@@ -554,11 +547,13 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
                                            deviceId = std::move(deviceId),
                                            name = std::move(name),
                                            cert = std::move(cert),
-                                           vid = std::move(vid)] {
+                                           vid = std::move(vid),
+                                           eraseInfo = std::move(eraseInfo)] {
                     auto sthis = w.lock();
                     if (!sthis)
                         return;
-                    sthis->connectDeviceOnNegoDone(deviceId, name, vid, cert);
+                    if (!sthis->connectDeviceOnNegoDone(deviceId, name, vid, cert))
+                        runOnMainThread([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
                 });
             };
 
@@ -577,8 +572,6 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
 
             if (!info->ice_) {
                 JAMI_ERR("Cannot initialize ICE session.");
-                for (const auto& pending : sthis->extractPendingCallbacks(deviceId))
-                    pending.cb(nullptr, deviceId);
                 eraseInfo();
             }
         });
@@ -596,7 +589,7 @@ ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>&
         [wSock = std::weak_ptr<ChannelSocket>(channelSock), name, deviceId, vid, w = weak()]() {
             auto shared = w.lock();
             auto channelSock = wSock.lock();
-            if (shared and channelSock)
+            if (shared && channelSock)
                 for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid))
                     pending.cb(channelSock, deviceId);
         });
@@ -776,13 +769,13 @@ ConnectionManager::Impl::answerTo(IceTransport& ice,
         });
 }
 
-void
+bool
 ConnectionManager::Impl::onRequestStartIce(const PeerConnectionRequest& req)
 {
     auto deviceId = req.owner->getLongId();
     auto info = getInfo(deviceId, req.id);
     if (!info)
-        return;
+        return false;
 
     std::unique_lock<std::mutex> lk {info->mutex_};
     auto& ice = info->ice_;
@@ -790,7 +783,7 @@ ConnectionManager::Impl::onRequestStartIce(const PeerConnectionRequest& req)
         JAMI_ERR("No ICE detected");
         if (connReadyCb_)
             connReadyCb_(deviceId, "", nullptr);
-        return;
+        return false;
     }
 
     auto sdp = ice->parseIceCandidates(req.ice_msg);
@@ -800,25 +793,24 @@ ConnectionManager::Impl::onRequestStartIce(const PeerConnectionRequest& req)
         ice = nullptr;
         if (connReadyCb_)
             connReadyCb_(deviceId, "", nullptr);
-        return;
+        return false;
     }
+    return true;
 }
 
-void
+bool
 ConnectionManager::Impl::onRequestOnNegoDone(const PeerConnectionRequest& req)
 {
     auto deviceId = req.owner->getLongId();
     auto info = getInfo(deviceId, req.id);
     if (!info)
-        return;
+        return false;
 
     std::unique_lock<std::mutex> lk {info->mutex_};
     auto& ice = info->ice_;
     if (!ice) {
         JAMI_ERR("No ICE detected");
-        if (connReadyCb_)
-            connReadyCb_(deviceId, "", nullptr);
-        return;
+        return false;
     }
 
     // Build socket
@@ -849,6 +841,7 @@ ConnectionManager::Impl::onRequestOnNegoDone(const PeerConnectionRequest& req)
             if (auto shared = w.lock())
                 shared->onTlsNegotiationDone(ok, deviceId, vid);
         });
+    return true;
 }
 
 void
@@ -873,6 +866,8 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req,
         // all stored structures.
         auto eraseInfo = [w, id = req.id, deviceId] {
             if (auto shared = w.lock()) {
+                if (shared->connReadyCb_)
+                    shared->connReadyCb_(deviceId, "", nullptr);
                 std::lock_guard<std::mutex> lk(shared->infosMtx_);
                 shared->infos_.erase({deviceId, id});
             }
@@ -885,17 +880,16 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req,
                 return;
             if (!ok) {
                 JAMI_ERR("Cannot initialize ICE session.");
-                if (shared->connReadyCb_)
-                    shared->connReadyCb_(deviceId, "", nullptr);
                 runOnMainThread([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
                 return;
             }
 
-            dht::ThreadPool::io().run([w = std::move(w), req = std::move(req)] {
+            dht::ThreadPool::io().run([w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] {
                 auto shared = w.lock();
                 if (!shared)
                     return;
-                shared->onRequestStartIce(req);
+                if (!shared->onRequestStartIce(req))
+                    runOnMainThread([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
             });
         };
 
@@ -905,15 +899,14 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req,
                 return;
             if (!ok) {
                 JAMI_ERR("ICE negotiation failed");
-                if (shared->connReadyCb_)
-                    shared->connReadyCb_(deviceId, "", nullptr);
                 runOnMainThread([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
                 return;
             }
 
-            dht::ThreadPool::io().run([w = std::move(w), req = std::move(req)] {
+            dht::ThreadPool::io().run([w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] {
                 if (auto shared = w.lock())
-                    shared->onRequestOnNegoDone(req);
+                    if (!shared->onRequestOnNegoDone(req))
+                        runOnMainThread([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
             });
         };
 
@@ -936,8 +929,6 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req,
 
         if (not info->ice_) {
             JAMI_ERR("Cannot initialize ICE session.");
-            if (shared->connReadyCb_)
-                shared->connReadyCb_(deviceId, "", nullptr);
             eraseInfo();
         }
     });
diff --git a/test/unitTest/connectionManager/connectionManager.cpp b/test/unitTest/connectionManager/connectionManager.cpp
index 056c045039..cd3b99028e 100644
--- a/test/unitTest/connectionManager/connectionManager.cpp
+++ b/test/unitTest/connectionManager/connectionManager.cpp
@@ -58,6 +58,7 @@ private:
     void testConnectDevice();
     void testAcceptConnection();
     void testMultipleChannels();
+    void testMultipleChannelsOneDeclined();
     void testMultipleChannelsSameName();
     void testDeclineConnection();
     void testSendReceiveData();
@@ -79,6 +80,7 @@ private:
     CPPUNIT_TEST(testConnectDevice);
     CPPUNIT_TEST(testAcceptConnection);
     CPPUNIT_TEST(testMultipleChannels);
+    CPPUNIT_TEST(testMultipleChannelsOneDeclined);
     CPPUNIT_TEST(testMultipleChannelsSameName);
     CPPUNIT_TEST(testDeclineConnection);
     CPPUNIT_TEST(testSendReceiveData);
@@ -252,6 +254,64 @@ ConnectionManagerTest::testMultipleChannels()
     CPPUNIT_ASSERT(aliceAccount->connectionManager().activeSockets() == 1);
 }
 
+void
+ConnectionManagerTest::testMultipleChannelsOneDeclined()
+{
+    auto aliceAccount = Manager::instance().getAccount<JamiAccount>(aliceId);
+    auto bobAccount = Manager::instance().getAccount<JamiAccount>(bobId);
+    auto bobDeviceId = DeviceId(std::string(bobAccount->currentDeviceId()));
+
+    bobAccount->connectionManager().onICERequest([](const DeviceId&) { return true; });
+    aliceAccount->connectionManager().onICERequest([](const DeviceId&) { return true; });
+
+    std::mutex mtx;
+    std::unique_lock<std::mutex> lk {mtx};
+    std::condition_variable cv;
+    bool successfullyNotConnected = false;
+    bool successfullyConnected2 = false;
+    int receiverConnected = 0;
+
+    bobAccount->connectionManager().onChannelRequest(
+        [](const std::shared_ptr<dht::crypto::Certificate>&, const std::string& name) {
+            if (name == "git://*")
+                return false;
+            return true;
+        });
+
+    bobAccount->connectionManager().onConnectionReady(
+        [&receiverConnected](const DeviceId&,
+                             const std::string&,
+                             std::shared_ptr<ChannelSocket> socket) {
+            if (socket)
+                receiverConnected += 1;
+        });
+
+    aliceAccount->connectionManager().connectDevice(bobDeviceId,
+                                                    "git://*",
+                                                    [&](std::shared_ptr<ChannelSocket> socket,
+                                                        const DeviceId&) {
+                                                        if (!socket) {
+                                                            successfullyNotConnected = true;
+                                                        }
+                                                        cv.notify_one();
+                                                    });
+
+    aliceAccount->connectionManager().connectDevice(bobDeviceId,
+                                                    "sip://*",
+                                                    [&](std::shared_ptr<ChannelSocket> socket,
+                                                        const DeviceId&) {
+                                                        if (socket) {
+                                                            successfullyConnected2 = true;
+                                                        }
+                                                        cv.notify_one();
+                                                    });
+
+    CPPUNIT_ASSERT(cv.wait_for(lk, std::chrono::seconds(60), [&] {
+        return successfullyNotConnected && successfullyConnected2 && receiverConnected == 1;
+    }));
+    CPPUNIT_ASSERT(aliceAccount->connectionManager().activeSockets() == 1);
+}
+
 void
 ConnectionManagerTest::testMultipleChannelsSameName()
 {
-- 
GitLab