diff --git a/src/connectionmanager.cpp b/src/connectionmanager.cpp index e4453f0db26aec2b18ed90faac973a52c3335a14..fc104ed1b24ad0de4c0cd75c48b462cfddf185e0 100644 --- a/src/connectionmanager.cpp +++ b/src/connectionmanager.cpp @@ -118,24 +118,21 @@ public: { if (isDestroying_.exchange(true)) return; + decltype(pendingOperations_) po; { std::lock_guard<std::mutex> lk(connectCbsMtx_); - // Call all pending callbacks that channel is not ready - for (auto& [deviceId, pcbs] : pendingCbs_) - for (auto& pending : pcbs) - pending.cb(nullptr, deviceId); - pendingCbs_.clear(); + po = std::move(pendingOperations_); } + for (auto& [deviceId, pcbs] : po) { + for (auto& [id, pending] : pcbs.connecting) + pending.cb(nullptr, deviceId); + for (auto& [id, pending] : pcbs.waiting) + pending.cb(nullptr, deviceId); + } + removeUnusedConnections(); } - struct PendingCb - { - std::string name; - ConnectCallback cb; - dht::Value::Id vid; - }; - void connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk, const dht::Value::Id& vid, const std::string& connType, @@ -319,49 +316,72 @@ public: * be done in parallel and we only want one socket */ std::mutex connectCbsMtx_ {}; - std::map<DeviceId, std::vector<PendingCb>> pendingCbs_ {}; - std::vector<PendingCb> extractPendingCallbacks(const DeviceId& deviceId, - const dht::Value::Id vid = 0) + struct PendingCb + { + std::string name; + ConnectCallback cb; + }; + struct PendingOperations { + std::map<dht::Value::Id, PendingCb> connecting; + std::map<dht::Value::Id, PendingCb> waiting; + }; + + std::map<DeviceId, PendingOperations> pendingOperations_ {}; + + void executePendingOperations(const DeviceId& deviceId, const dht::Value::Id& vid, const std::shared_ptr<ChannelSocket>& sock) { std::vector<PendingCb> ret; - std::lock_guard<std::mutex> lk(connectCbsMtx_); - auto pendingIt = pendingCbs_.find(deviceId); - if (pendingIt == pendingCbs_.end()) - return ret; - auto& pendings = pendingIt->second; + std::unique_lock<std::mutex> lk(connectCbsMtx_); + auto it = pendingOperations_.find(deviceId); + if (it == pendingOperations_.end()) + return; + auto& pendingOperations = it->second; if (vid == 0) { - ret = std::move(pendings); - } else { - for (auto it = pendings.begin(); it != pendings.end(); ++it) { - if (it->vid == vid) { - ret.emplace_back(std::move(*it)); - pendings.erase(it); - break; - } + // Extract all pending callbacks + for (auto& [vid, cb] : pendingOperations.connecting) + ret.emplace_back(std::move(cb)); + pendingOperations.connecting.clear(); + for (auto& [vid, cb] : pendingOperations.waiting) + ret.emplace_back(std::move(cb)); + pendingOperations.waiting.clear(); + } else if (auto n = pendingOperations.waiting.extract(vid)) { + // If it's a waiting operation, just move it + ret.emplace_back(std::move(n.mapped())); + } else if (auto n = pendingOperations.connecting.extract(vid)) { + ret.emplace_back(std::move(n.mapped())); + // If sock is nullptr, execute if it's the last connecting operation + if (!sock && pendingOperations.connecting.empty()) { + for (auto& [vid, cb] : pendingOperations.waiting) + ret.emplace_back(std::move(cb)); + pendingOperations.waiting.clear(); + for (auto& [vid, cb] : pendingOperations.connecting) + ret.emplace_back(std::move(cb)); + pendingOperations.connecting.clear(); } } - if (pendings.empty()) - pendingCbs_.erase(pendingIt); - return ret; + if (pendingOperations.waiting.empty() && pendingOperations.connecting.empty()) + pendingOperations_.erase(it); + lk.unlock(); + for (auto& cb : ret) + cb.cb(sock, deviceId); } - std::vector<PendingCb> getPendingCallbacks(const DeviceId& deviceId, - const dht::Value::Id vid = 0) + std::map<dht::Value::Id, std::string> getPendingIds(const DeviceId& deviceId, const dht::Value::Id vid = 0) { - std::vector<PendingCb> ret; + std::map<dht::Value::Id, std::string> ret; std::lock_guard<std::mutex> lk(connectCbsMtx_); - auto pendingIt = pendingCbs_.find(deviceId); - if (pendingIt == pendingCbs_.end()) + auto it = pendingOperations_.find(deviceId); + if (it == pendingOperations_.end()) return ret; - auto& pendings = pendingIt->second; - if (vid == 0) { - ret = pendings; - } else { - std::copy_if(pendings.begin(), - pendings.end(), - std::back_inserter(ret), - [&](auto pending) { return pending.vid == vid; }); + auto& pendingOp = it->second; + for (const auto& [id, pc]: pendingOp.connecting) { + if (vid == 0 || id == vid) + ret[id] = pc.name; + } + for (const auto& [id, pc]: pendingOp.waiting) { + if (vid == 0 || id == vid) + ret[id] = pc.name; } return ret; } @@ -616,23 +636,24 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif auto isConnectingToDevice = false; { std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_); - auto pendingsIt = sthis->pendingCbs_.find(deviceId); - if (pendingsIt != sthis->pendingCbs_.end()) { + auto pendingsIt = sthis->pendingOperations_.find(deviceId); + if (pendingsIt != sthis->pendingOperations_.end()) { const auto& pendings = pendingsIt->second; - while (std::find_if(pendings.begin(), pendings.end(), [&](const auto& it){ return it.vid == vid; }) != pendings.end()) { - vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand); + while (pendings.connecting.find(vid) != pendings.connecting.end() + && pendings.waiting.find(vid) != pendings.waiting.end()) { + vid = ValueIdDist(1, JAMI_ID_MAX_VAL)(sthis->account.rand); } } // Check if already connecting - isConnectingToDevice = pendingsIt != sthis->pendingCbs_.end(); + isConnectingToDevice = pendingsIt != sthis->pendingOperations_.end(); // Save current request for sendChannelRequest. // Note: do not return here, cause we can be in a state where first // socket is negotiated and first channel is pending // so return only after we checked the info - if (isConnectingToDevice) - pendingsIt->second.emplace_back(PendingCb {name, std::move(cb), vid}); + if (isConnectingToDevice && !forceNewSocket) + pendingsIt->second.waiting[vid] = PendingCb {name, std::move(cb)}; else - sthis->pendingCbs_[deviceId] = {{name, std::move(cb), vid}}; + sthis->pendingOperations_[deviceId].connecting[vid] = PendingCb {name, std::move(cb)}; } // Check if already negotiated @@ -655,8 +676,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif } if (noNewSocket) { // If no new socket is specified, we don't try to generate a new socket - for (const auto& pending : sthis->extractPendingCallbacks(deviceId, vid)) - pending.cb(nullptr, deviceId); + sthis->executePendingOperations(deviceId, vid, nullptr); return; } @@ -665,8 +685,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif auto eraseInfo = [w, cbId] { if (auto shared = w.lock()) { // If no new socket is specified, we don't try to generate a new socket - for (const auto& pending : shared->extractPendingCallbacks(cbId.first, cbId.second)) - pending.cb(nullptr, cbId.first); + shared->executePendingOperations(cbId.first, cbId.second, nullptr); std::lock_guard<std::mutex> lk(shared->infosMtx_); shared->infos_.erase(cbId); } @@ -776,17 +795,15 @@ ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>& auto channelSock = sock->addChannel(name); channelSock->onShutdown([name, deviceId, vid, w = weak()] { auto shared = w.lock(); - if (shared) - for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid)) - pending.cb(nullptr, deviceId); + if (auto shared = w.lock()) + shared->executePendingOperations(deviceId, vid, nullptr); }); channelSock->onReady( [wSock = std::weak_ptr<ChannelSocket>(channelSock), name, deviceId, vid, w = weak()]() { auto shared = w.lock(); auto channelSock = wSock.lock(); if (shared) - for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid)) - pending.cb(channelSock, deviceId); + shared->executePendingOperations(deviceId, vid, channelSock); }); ChannelRequest val; @@ -911,8 +928,7 @@ ConnectionManager::Impl::onTlsNegotiationDone(bool ok, deviceId, name, vid); - for (const auto& pending : extractPendingCallbacks(deviceId)) - pending.cb(nullptr, deviceId); + executePendingOperations(deviceId, vid, nullptr); } } else { // The socket is ready, store it @@ -934,12 +950,12 @@ ConnectionManager::Impl::onTlsNegotiationDone(bool ok, // Finally, open the channel and launch pending callbacks if (info->socket_) { // Note: do not remove pending there it's done in sendChannelRequest - for (const auto& pending : getPendingCallbacks(deviceId)) { + for (const auto& [id, name] : getPendingIds(deviceId)) { if (config_->logger) config_->logger->debug("Send request on TLS socket for channel {} to {}", - pending.name, - deviceId); - sendChannelRequest(info->socket_, pending.name, deviceId, pending.vid); + name, + deviceId.toString()); + sendChannelRequest(info->socket_, name, deviceId, id); } } } @@ -1085,8 +1101,7 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req, auto eraseInfo = [w, id = req.id, deviceId] { if (auto shared = w.lock()) { // If no new socket is specified, we don't try to generate a new socket - for (const auto& pending : shared->extractPendingCallbacks(deviceId, id)) - pending.cb(nullptr, deviceId); + shared->executePendingOperations(deviceId, id, nullptr); if (shared->connReadyCb_) shared->connReadyCb_(deviceId, "", nullptr); std::lock_guard<std::mutex> lk(shared->infosMtx_); @@ -1202,8 +1217,7 @@ ConnectionManager::Impl::addNewMultiplexedSocket(const CallbackId& id, const std } } for (const auto& cbId : ids) - for (const auto& pending : sthis->extractPendingCallbacks(cbId.first, cbId.second)) - pending.cb(nullptr, deviceId); + sthis->executePendingOperations(cbId.first, cbId.second, nullptr); std::lock_guard<std::mutex> lk(sthis->infosMtx_); sthis->infos_.erase({deviceId, vid}); @@ -1522,8 +1536,8 @@ ConnectionManager::connectDevice(const std::shared_ptr<dht::crypto::Certificate> bool ConnectionManager::isConnecting(const DeviceId& deviceId, const std::string& name) const { - auto pending = pimpl_->getPendingCallbacks(deviceId); - return std::find_if(pending.begin(), pending.end(), [&](auto p) { return p.name == name; }) + auto pending = pimpl_->getPendingIds(deviceId); + return std::find_if(pending.begin(), pending.end(), [&](auto p) { return p.second == name; }) != pending.end(); } @@ -1549,8 +1563,7 @@ ConnectionManager::closeConnectionsWith(const std::string& peerUri) } // Stop connections to all peers devices for (const auto& deviceId : peersDevices) { - for (const auto& pending : pimpl_->extractPendingCallbacks(deviceId)) - pending.cb(nullptr, deviceId); + pimpl_->executePendingOperations(deviceId, 0, nullptr); // This will close the TLS Session pimpl_->removeUnusedConnections(deviceId); }