From b0209b799b03014e9c1d1490855a6eba6515f4ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Blin?= <sebastien.blin@savoirfairelinux.com> Date: Wed, 21 Oct 2020 13:23:48 -0400 Subject: [PATCH] connectionmanager: onShutdown must call all callbacks When a shutdown appears for a multiplexed socket, all current pending channels are closed, so callbacks related to these closed channels must be closed. Change-Id: Ic15d4552bc3c2445c9aa25babd93ad9d6c473e19 --- src/jamidht/connectionmanager.cpp | 82 ++++++++++--------- .../connectionManager/connectionManager.cpp | 69 +++++++++++++++- 2 files changed, 110 insertions(+), 41 deletions(-) diff --git a/src/jamidht/connectionmanager.cpp b/src/jamidht/connectionmanager.cpp index e7a39ae545..84d64fb46a 100644 --- a/src/jamidht/connectionmanager.cpp +++ b/src/jamidht/connectionmanager.cpp @@ -39,6 +39,7 @@ static constexpr std::chrono::seconds ICE_INIT_TIMEOUT {10}; static constexpr std::chrono::seconds DHT_MSG_TIMEOUT {30}; static constexpr std::chrono::seconds SOCK_TIMEOUT {10}; using ValueIdDist = std::uniform_int_distribution<dht::Value::Id>; +using CallbackId = std::pair<jami::DeviceId, dht::Value::Id>; namespace jami { @@ -52,13 +53,12 @@ struct ConnectionInfo // Used to store currently non ready TLS Socket std::unique_ptr<TlsSocketEndpoint> tls_ {nullptr}; std::shared_ptr<MultiplexedSocket> socket_ {}; + std::set<CallbackId> cbIds_ {}; }; class ConnectionManager::Impl : public std::enable_shared_from_this<ConnectionManager::Impl> { public: - using ConnectionKey = std::pair<DeviceId /* device id */, dht::Value::Id /* uid */>; - explicit Impl(JamiAccount& account) : account {account} {} @@ -161,15 +161,25 @@ public: std::mutex infosMtx_ {}; // Note: Someone can ask multiple sockets, so to avoid any race condition, // each device can have multiple multiplexed sockets. - std::map<ConnectionKey, std::shared_ptr<ConnectionInfo>> infos_ {}; + std::map<CallbackId, std::shared_ptr<ConnectionInfo>> infos_ {}; - std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, const dht::Value::Id& id) + std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, + const dht::Value::Id& id = dht::Value::INVALID_ID) { std::lock_guard<std::mutex> lk(infosMtx_); - auto it = infos_.find({deviceId, id}); - if (it == infos_.end()) - return {}; - return it->second; + decltype(infos_)::iterator it; + if (id == dht::Value::INVALID_ID) { + it = std::find_if(infos_.begin(), infos_.end(), [&](const auto& item) { + auto& [key, value] = item; + return key.first == deviceId; + }); + } else { + it = infos_.find({deviceId, id}); + } + + if (it != infos_.end()) + return it->second; + return {}; } ChannelRequestCallback channelReqCb_ {}; @@ -177,7 +187,6 @@ public: onICERequestCallback iceReqCb_ {}; std::mutex connectCbsMtx_ {}; - using CallbackId = std::pair<DeviceId, dht::Value::Id>; std::map<CallbackId, ConnectCallback> pendingCbs_ {}; ConnectCallback getPendingCallback(const CallbackId& cbId) @@ -371,7 +380,7 @@ ConnectionManager::Impl::connectDevice(const DeviceId& deviceId, return; } auto vid = ValueIdDist()(sthis->account.rand); - ConnectionKey cbId(deviceId, vid); + CallbackId cbId(deviceId, vid); { std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_); auto cbIt = sthis->pendingCbs_.find(cbId); @@ -383,25 +392,16 @@ ConnectionManager::Impl::connectDevice(const DeviceId& deviceId, } } - std::shared_ptr<MultiplexedSocket> sock; - { - // Test if a socket already exists for this device - std::lock_guard<std::mutex> lk(sthis->infosMtx_); - auto it = std::find_if(sthis->infos_.begin(), - sthis->infos_.end(), - [deviceId](const auto& item) { - auto& [key, value] = item; - return key.first == deviceId; - }); - if (it != sthis->infos_.end() && it->second) { - sock = it->second->socket_; + if (auto info = sthis->getInfo(deviceId)) { + std::lock_guard<std::mutex> lk(info->mutex_); + if (info->socket_) { + JAMI_DBG("Peer already connected. Add a new channel"); + info->cbIds_.emplace(cbId); + sthis->sendChannelRequest(info->socket_, name, deviceId, vid); + return; } } - if (sock) { - JAMI_DBG("Peer already connected. Add a new channel"); - sthis->sendChannelRequest(sock, name, deviceId, vid); - return; - } + // If no socket exists, we need to initiate an ICE connection. auto& iceTransportFactory = Manager::instance().getIceTransportFactory(); auto ice_config = sthis->account.getIceOptions(); @@ -794,25 +794,27 @@ ConnectionManager::Impl::addNewMultiplexedSocket(const DeviceId& deviceId, const return false; }); info->socket_->onShutdown([w = weak(), deviceId, vid]() { - auto sthis = w.lock(); - if (!sthis) - return; // Cancel current outgoing connections - if (auto cb = sthis->getPendingCallback({deviceId, vid})) - cb(nullptr); dht::ThreadPool::io().run([w, deviceId = dht::InfoHash(deviceId), vid] { auto sthis = w.lock(); if (!sthis) return; - auto info = sthis->getInfo(deviceId, vid); - if (!info) - return; - - if (info->socket_) - info->socket_->shutdown(); - if (info && info->ice_) - info->ice_->cancelOperations(); + std::set<CallbackId> ids; + if (auto info = sthis->getInfo(deviceId, vid)) { + std::lock_guard<std::mutex> lk(info->mutex_); + if (info->socket_) { + ids = std::move(info->cbIds_); + info->socket_->shutdown(); + } + if (info->ice_) + info->ice_->cancelOperations(); + } + for (const auto& cbId : ids) { + if (auto cb = sthis->getPendingCallback(cbId)) { + cb(nullptr); + } + } std::lock_guard<std::mutex> lk(sthis->infosMtx_); sthis->infos_.erase({deviceId, vid}); diff --git a/test/unitTest/connectionManager/connectionManager.cpp b/test/unitTest/connectionManager/connectionManager.cpp index 3af21fe308..10b78b10ee 100644 --- a/test/unitTest/connectionManager/connectionManager.cpp +++ b/test/unitTest/connectionManager/connectionManager.cpp @@ -65,6 +65,7 @@ private: void testChannelRcvShutdown(); void testChannelSenderShutdown(); void testCloseConnectionWithDevice(); + void testShutdownCallbacks(); CPPUNIT_TEST_SUITE(ConnectionManagerTest); CPPUNIT_TEST(testConnectDevice); @@ -78,6 +79,7 @@ private: CPPUNIT_TEST(testChannelRcvShutdown); CPPUNIT_TEST(testChannelSenderShutdown); CPPUNIT_TEST(testCloseConnectionWithDevice); + CPPUNIT_TEST(testShutdownCallbacks); CPPUNIT_TEST_SUITE_END(); }; @@ -724,13 +726,78 @@ ConnectionManagerTest::testCloseConnectionWithDevice() // This should trigger onShutdown aliceAccount->connectionManager().closeConnectionsWith(bobDeviceId); auto expiration = std::chrono::system_clock::now() + std::chrono::seconds(10); - scv.wait_until(lk, expiration, [&events]() { return events == 4; }); + scv.wait_until(lk, expiration, [&events]() { return events == 2; }); CPPUNIT_ASSERT(events == 2); CPPUNIT_ASSERT(successfullyReceive); CPPUNIT_ASSERT(successfullyConnected); CPPUNIT_ASSERT(receiverConnected); } +void +ConnectionManagerTest::testShutdownCallbacks() +{ + auto aliceAccount = Manager::instance().getAccount<JamiAccount>(aliceId); + auto bobAccount = Manager::instance().getAccount<JamiAccount>(bobId); + auto bobDeviceId = DeviceId(bobAccount->getAccountDetails()[ConfProperties::RING_DEVICE_ID]); + auto aliceDeviceId = DeviceId(aliceAccount->getAccountDetails()[ConfProperties::RING_DEVICE_ID]); + + bobAccount->connectionManager().onICERequest([](const DeviceId&) { return true; }); + aliceAccount->connectionManager().onICERequest([](const DeviceId&) { return true; }); + + std::mutex mtx; + std::unique_lock<std::mutex> lk {mtx}; + std::condition_variable rcv, chan2cv; + bool successfullyConnected = false; + bool successfullyReceive = false; + bool receiverConnected = false; + + bobAccount->connectionManager().onChannelRequest( + [&successfullyReceive, &chan2cv](const DeviceId&, const std::string& name) { + if (name == "1") { + successfullyReceive = true; + } else { + chan2cv.notify_one(); + // Do not return directly. Let the connection be closed + std::this_thread::sleep_for(std::chrono::seconds(10)); + } + return true; + }); + + bobAccount->connectionManager().onConnectionReady( + [&](const DeviceId&, const std::string& name, std::shared_ptr<ChannelSocket> socket) { + receiverConnected = socket && (name == "1"); + }); + + aliceAccount->connectionManager().connectDevice(bobDeviceId, + "1", + [&](std::shared_ptr<ChannelSocket> socket) { + if (socket) { + successfullyConnected = true; + rcv.notify_one(); + } + }); + // Connect first channel. This will initiate a mx sock + rcv.wait_for(lk, std::chrono::seconds(30)); + CPPUNIT_ASSERT(successfullyReceive); + CPPUNIT_ASSERT(successfullyConnected); + CPPUNIT_ASSERT(receiverConnected); + + // Connect another channel, but close the connection + bool channel2NotConnected = false; + aliceAccount->connectionManager().connectDevice(bobDeviceId, + "2", + [&](std::shared_ptr<ChannelSocket> socket) { + channel2NotConnected = !socket; + rcv.notify_one(); + }); + chan2cv.wait_for(lk, std::chrono::seconds(30)); + + // This should trigger onShutdown for second callback + bobAccount->connectionManager().closeConnectionsWith(aliceDeviceId); + rcv.wait_for(lk, std::chrono::seconds(30)); + CPPUNIT_ASSERT(channel2NotConnected); +} + } // namespace test } // namespace jami -- GitLab