Skip to content
Snippets Groups Projects
Commit 665294f4 authored by Adrien Béraud's avatar Adrien Béraud
Browse files

connectionmanager: close all pending callback if all connections fails

backport of af3c2229c051975aa9d518f77ba597604546a6c6 in jami-daemon

Change-Id: I7ace35c63c97bb957359ea8c7dafae6f6f93f5b8
parent ec224b9a
No related branches found
No related tags found
No related merge requests found
...@@ -118,24 +118,21 @@ public: ...@@ -118,24 +118,21 @@ public:
{ {
if (isDestroying_.exchange(true)) if (isDestroying_.exchange(true))
return; return;
decltype(pendingOperations_) po;
{ {
std::lock_guard<std::mutex> lk(connectCbsMtx_); std::lock_guard<std::mutex> lk(connectCbsMtx_);
// Call all pending callbacks that channel is not ready po = std::move(pendingOperations_);
for (auto& [deviceId, pcbs] : pendingCbs_) }
for (auto& pending : pcbs) for (auto& [deviceId, pcbs] : po) {
for (auto& [id, pending] : pcbs.connecting)
pending.cb(nullptr, deviceId);
for (auto& [id, pending] : pcbs.waiting)
pending.cb(nullptr, deviceId); pending.cb(nullptr, deviceId);
pendingCbs_.clear();
} }
removeUnusedConnections(); removeUnusedConnections();
} }
struct PendingCb
{
std::string name;
ConnectCallback cb;
dht::Value::Id vid;
};
void connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk, void connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk,
const dht::Value::Id& vid, const dht::Value::Id& vid,
const std::string& connType, const std::string& connType,
...@@ -319,49 +316,72 @@ public: ...@@ -319,49 +316,72 @@ public:
* be done in parallel and we only want one socket * be done in parallel and we only want one socket
*/ */
std::mutex connectCbsMtx_ {}; std::mutex connectCbsMtx_ {};
std::map<DeviceId, std::vector<PendingCb>> pendingCbs_ {};
std::vector<PendingCb> extractPendingCallbacks(const DeviceId& deviceId, struct PendingCb
const dht::Value::Id vid = 0)
{ {
std::vector<PendingCb> ret; std::string name;
std::lock_guard<std::mutex> lk(connectCbsMtx_); ConnectCallback cb;
auto pendingIt = pendingCbs_.find(deviceId); };
if (pendingIt == pendingCbs_.end()) struct PendingOperations {
return ret; std::map<dht::Value::Id, PendingCb> connecting;
auto& pendings = pendingIt->second; std::map<dht::Value::Id, PendingCb> waiting;
if (vid == 0) { };
ret = std::move(pendings);
} else { std::map<DeviceId, PendingOperations> pendingOperations_ {};
for (auto it = pendings.begin(); it != pendings.end(); ++it) {
if (it->vid == vid) {
ret.emplace_back(std::move(*it));
pendings.erase(it);
break;
}
}
}
if (pendings.empty())
pendingCbs_.erase(pendingIt);
return ret;
}
std::vector<PendingCb> getPendingCallbacks(const DeviceId& deviceId, void executePendingOperations(const DeviceId& deviceId, const dht::Value::Id& vid, const std::shared_ptr<ChannelSocket>& sock)
const dht::Value::Id vid = 0)
{ {
std::vector<PendingCb> ret; std::vector<PendingCb> ret;
std::unique_lock<std::mutex> lk(connectCbsMtx_);
auto it = pendingOperations_.find(deviceId);
if (it == pendingOperations_.end())
return;
auto& pendingOperations = it->second;
if (vid == 0) {
// Extract all pending callbacks
for (auto& [vid, cb] : pendingOperations.connecting)
ret.emplace_back(std::move(cb));
pendingOperations.connecting.clear();
for (auto& [vid, cb] : pendingOperations.waiting)
ret.emplace_back(std::move(cb));
pendingOperations.waiting.clear();
} else if (auto n = pendingOperations.waiting.extract(vid)) {
// If it's a waiting operation, just move it
ret.emplace_back(std::move(n.mapped()));
} else if (auto n = pendingOperations.connecting.extract(vid)) {
ret.emplace_back(std::move(n.mapped()));
// If sock is nullptr, execute if it's the last connecting operation
if (!sock && pendingOperations.connecting.empty()) {
for (auto& [vid, cb] : pendingOperations.waiting)
ret.emplace_back(std::move(cb));
pendingOperations.waiting.clear();
for (auto& [vid, cb] : pendingOperations.connecting)
ret.emplace_back(std::move(cb));
pendingOperations.connecting.clear();
}
}
if (pendingOperations.waiting.empty() && pendingOperations.connecting.empty())
pendingOperations_.erase(it);
lk.unlock();
for (auto& cb : ret)
cb.cb(sock, deviceId);
}
std::map<dht::Value::Id, std::string> getPendingIds(const DeviceId& deviceId, const dht::Value::Id vid = 0)
{
std::map<dht::Value::Id, std::string> ret;
std::lock_guard<std::mutex> lk(connectCbsMtx_); std::lock_guard<std::mutex> lk(connectCbsMtx_);
auto pendingIt = pendingCbs_.find(deviceId); auto it = pendingOperations_.find(deviceId);
if (pendingIt == pendingCbs_.end()) if (it == pendingOperations_.end())
return ret; return ret;
auto& pendings = pendingIt->second; auto& pendingOp = it->second;
if (vid == 0) { for (const auto& [id, pc]: pendingOp.connecting) {
ret = pendings; if (vid == 0 || id == vid)
} else { ret[id] = pc.name;
std::copy_if(pendings.begin(), }
pendings.end(), for (const auto& [id, pc]: pendingOp.waiting) {
std::back_inserter(ret), if (vid == 0 || id == vid)
[&](auto pending) { return pending.vid == vid; }); ret[id] = pc.name;
} }
return ret; return ret;
} }
...@@ -616,23 +636,24 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif ...@@ -616,23 +636,24 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
auto isConnectingToDevice = false; auto isConnectingToDevice = false;
{ {
std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_); std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_);
auto pendingsIt = sthis->pendingCbs_.find(deviceId); auto pendingsIt = sthis->pendingOperations_.find(deviceId);
if (pendingsIt != sthis->pendingCbs_.end()) { if (pendingsIt != sthis->pendingOperations_.end()) {
const auto& pendings = pendingsIt->second; const auto& pendings = pendingsIt->second;
while (std::find_if(pendings.begin(), pendings.end(), [&](const auto& it){ return it.vid == vid; }) != pendings.end()) { while (pendings.connecting.find(vid) != pendings.connecting.end()
vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand); && pendings.waiting.find(vid) != pendings.waiting.end()) {
vid = ValueIdDist(1, JAMI_ID_MAX_VAL)(sthis->account.rand);
} }
} }
// Check if already connecting // Check if already connecting
isConnectingToDevice = pendingsIt != sthis->pendingCbs_.end(); isConnectingToDevice = pendingsIt != sthis->pendingOperations_.end();
// Save current request for sendChannelRequest. // Save current request for sendChannelRequest.
// Note: do not return here, cause we can be in a state where first // Note: do not return here, cause we can be in a state where first
// socket is negotiated and first channel is pending // socket is negotiated and first channel is pending
// so return only after we checked the info // so return only after we checked the info
if (isConnectingToDevice) if (isConnectingToDevice && !forceNewSocket)
pendingsIt->second.emplace_back(PendingCb {name, std::move(cb), vid}); pendingsIt->second.waiting[vid] = PendingCb {name, std::move(cb)};
else else
sthis->pendingCbs_[deviceId] = {{name, std::move(cb), vid}}; sthis->pendingOperations_[deviceId].connecting[vid] = PendingCb {name, std::move(cb)};
} }
// Check if already negotiated // Check if already negotiated
...@@ -655,8 +676,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif ...@@ -655,8 +676,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
} }
if (noNewSocket) { if (noNewSocket) {
// If no new socket is specified, we don't try to generate a new socket // If no new socket is specified, we don't try to generate a new socket
for (const auto& pending : sthis->extractPendingCallbacks(deviceId, vid)) sthis->executePendingOperations(deviceId, vid, nullptr);
pending.cb(nullptr, deviceId);
return; return;
} }
...@@ -665,8 +685,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif ...@@ -665,8 +685,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif
auto eraseInfo = [w, cbId] { auto eraseInfo = [w, cbId] {
if (auto shared = w.lock()) { if (auto shared = w.lock()) {
// If no new socket is specified, we don't try to generate a new socket // If no new socket is specified, we don't try to generate a new socket
for (const auto& pending : shared->extractPendingCallbacks(cbId.first, cbId.second)) shared->executePendingOperations(cbId.first, cbId.second, nullptr);
pending.cb(nullptr, cbId.first);
std::lock_guard<std::mutex> lk(shared->infosMtx_); std::lock_guard<std::mutex> lk(shared->infosMtx_);
shared->infos_.erase(cbId); shared->infos_.erase(cbId);
} }
...@@ -776,17 +795,15 @@ ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>& ...@@ -776,17 +795,15 @@ ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>&
auto channelSock = sock->addChannel(name); auto channelSock = sock->addChannel(name);
channelSock->onShutdown([name, deviceId, vid, w = weak()] { channelSock->onShutdown([name, deviceId, vid, w = weak()] {
auto shared = w.lock(); auto shared = w.lock();
if (shared) if (auto shared = w.lock())
for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid)) shared->executePendingOperations(deviceId, vid, nullptr);
pending.cb(nullptr, deviceId);
}); });
channelSock->onReady( channelSock->onReady(
[wSock = std::weak_ptr<ChannelSocket>(channelSock), name, deviceId, vid, w = weak()]() { [wSock = std::weak_ptr<ChannelSocket>(channelSock), name, deviceId, vid, w = weak()]() {
auto shared = w.lock(); auto shared = w.lock();
auto channelSock = wSock.lock(); auto channelSock = wSock.lock();
if (shared) if (shared)
for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid)) shared->executePendingOperations(deviceId, vid, channelSock);
pending.cb(channelSock, deviceId);
}); });
ChannelRequest val; ChannelRequest val;
...@@ -911,8 +928,7 @@ ConnectionManager::Impl::onTlsNegotiationDone(bool ok, ...@@ -911,8 +928,7 @@ ConnectionManager::Impl::onTlsNegotiationDone(bool ok,
deviceId, deviceId,
name, name,
vid); vid);
for (const auto& pending : extractPendingCallbacks(deviceId)) executePendingOperations(deviceId, vid, nullptr);
pending.cb(nullptr, deviceId);
} }
} else { } else {
// The socket is ready, store it // The socket is ready, store it
...@@ -934,12 +950,12 @@ ConnectionManager::Impl::onTlsNegotiationDone(bool ok, ...@@ -934,12 +950,12 @@ ConnectionManager::Impl::onTlsNegotiationDone(bool ok,
// Finally, open the channel and launch pending callbacks // Finally, open the channel and launch pending callbacks
if (info->socket_) { if (info->socket_) {
// Note: do not remove pending there it's done in sendChannelRequest // Note: do not remove pending there it's done in sendChannelRequest
for (const auto& pending : getPendingCallbacks(deviceId)) { for (const auto& [id, name] : getPendingIds(deviceId)) {
if (config_->logger) if (config_->logger)
config_->logger->debug("Send request on TLS socket for channel {} to {}", config_->logger->debug("Send request on TLS socket for channel {} to {}",
pending.name, name,
deviceId); deviceId.toString());
sendChannelRequest(info->socket_, pending.name, deviceId, pending.vid); sendChannelRequest(info->socket_, name, deviceId, id);
} }
} }
} }
...@@ -1085,8 +1101,7 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req, ...@@ -1085,8 +1101,7 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req,
auto eraseInfo = [w, id = req.id, deviceId] { auto eraseInfo = [w, id = req.id, deviceId] {
if (auto shared = w.lock()) { if (auto shared = w.lock()) {
// If no new socket is specified, we don't try to generate a new socket // If no new socket is specified, we don't try to generate a new socket
for (const auto& pending : shared->extractPendingCallbacks(deviceId, id)) shared->executePendingOperations(deviceId, id, nullptr);
pending.cb(nullptr, deviceId);
if (shared->connReadyCb_) if (shared->connReadyCb_)
shared->connReadyCb_(deviceId, "", nullptr); shared->connReadyCb_(deviceId, "", nullptr);
std::lock_guard<std::mutex> lk(shared->infosMtx_); std::lock_guard<std::mutex> lk(shared->infosMtx_);
...@@ -1202,8 +1217,7 @@ ConnectionManager::Impl::addNewMultiplexedSocket(const CallbackId& id, const std ...@@ -1202,8 +1217,7 @@ ConnectionManager::Impl::addNewMultiplexedSocket(const CallbackId& id, const std
} }
} }
for (const auto& cbId : ids) for (const auto& cbId : ids)
for (const auto& pending : sthis->extractPendingCallbacks(cbId.first, cbId.second)) sthis->executePendingOperations(cbId.first, cbId.second, nullptr);
pending.cb(nullptr, deviceId);
std::lock_guard<std::mutex> lk(sthis->infosMtx_); std::lock_guard<std::mutex> lk(sthis->infosMtx_);
sthis->infos_.erase({deviceId, vid}); sthis->infos_.erase({deviceId, vid});
...@@ -1522,8 +1536,8 @@ ConnectionManager::connectDevice(const std::shared_ptr<dht::crypto::Certificate> ...@@ -1522,8 +1536,8 @@ ConnectionManager::connectDevice(const std::shared_ptr<dht::crypto::Certificate>
bool bool
ConnectionManager::isConnecting(const DeviceId& deviceId, const std::string& name) const ConnectionManager::isConnecting(const DeviceId& deviceId, const std::string& name) const
{ {
auto pending = pimpl_->getPendingCallbacks(deviceId); auto pending = pimpl_->getPendingIds(deviceId);
return std::find_if(pending.begin(), pending.end(), [&](auto p) { return p.name == name; }) return std::find_if(pending.begin(), pending.end(), [&](auto p) { return p.second == name; })
!= pending.end(); != pending.end();
} }
...@@ -1549,8 +1563,7 @@ ConnectionManager::closeConnectionsWith(const std::string& peerUri) ...@@ -1549,8 +1563,7 @@ ConnectionManager::closeConnectionsWith(const std::string& peerUri)
} }
// Stop connections to all peers devices // Stop connections to all peers devices
for (const auto& deviceId : peersDevices) { for (const auto& deviceId : peersDevices) {
for (const auto& pending : pimpl_->extractPendingCallbacks(deviceId)) pimpl_->executePendingOperations(deviceId, 0, nullptr);
pending.cb(nullptr, deviceId);
// This will close the TLS Session // This will close the TLS Session
pimpl_->removeUnusedConnections(deviceId); pimpl_->removeUnusedConnections(deviceId);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment