diff --git a/src/connectionmanager.cpp b/src/connectionmanager.cpp index cf78c542c5d695210fdcb0dab898f09687014686..0304c2196c756c61045e3079d227d2af472e6d7f 100644 --- a/src/connectionmanager.cpp +++ b/src/connectionmanager.cpp @@ -40,14 +40,14 @@ static constexpr std::chrono::seconds DHT_MSG_TIMEOUT {30}; static constexpr uint64_t ID_MAX_VAL = 9007199254740992; using ValueIdDist = std::uniform_int_distribution<dht::Value::Id>; -using CallbackId = std::pair<dhtnet::DeviceId, dht::Value::Id>; + std::string callbackIdToString(const dhtnet::DeviceId& did, const dht::Value::Id& vid) { return fmt::format("{} {}", did.to_view(), vid); } -CallbackId parseCallbackId(std::string_view ci) +std::pair<dhtnet::DeviceId, dht::Value::Id> parseCallbackId(std::string_view ci) { auto sep = ci.find(' '); std::string_view deviceIdString = ci.substr(0, sep); @@ -55,8 +55,7 @@ CallbackId parseCallbackId(std::string_view ci) dhtnet::DeviceId deviceId(deviceIdString); dht::Value::Id vid = std::stoul(std::string(vidString), nullptr, 10); - - return CallbackId(deviceId, vid); + return {deviceId, vid}; } std::shared_ptr<ConnectionManager::Config> @@ -101,12 +100,252 @@ struct ConnectionInfo // Used to store currently non ready TLS Socket std::unique_ptr<TlsSocketEndpoint> tls_ {nullptr}; std::shared_ptr<MultiplexedSocket> socket_ {}; - std::set<CallbackId> cbIds_ {}; + std::set<dht::Value::Id> cbIds_ {}; std::function<void(bool)> onConnected_; std::unique_ptr<asio::steady_timer> waitForAnswer_ {}; + + void shutdown() { + std::lock_guard<std::mutex> lk(mutex_); + if (tls_) + tls_->shutdown(); + if (socket_) + socket_->shutdown(); + if (waitForAnswer_) + waitForAnswer_->cancel(); + if (ice_) { + dht::ThreadPool::io().run( + [ice = std::shared_ptr<IceTransport>(std::move(ice_))] {}); + } + } + + std::map<std::string, std::string> + getInfo(const DeviceId& deviceId, dht::Value::Id valueId, tls::CertificateStore& certStore) const + { + std::map<std::string, std::string> connectionInfo; + connectionInfo["id"] = callbackIdToString(deviceId, valueId); + connectionInfo["device"] = deviceId.toString(); + auto cert = tls_ ? tls_->peerCertificate() : (socket_ ? socket_->peerCertificate() : nullptr); + if (not cert) + cert = certStore.getCertificate(deviceId.toString()); + if (cert) { + connectionInfo["peer"] = cert->issuer->getId().toString(); + } + if (socket_) { + connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Connected)); + connectionInfo["remoteAddress"] = socket_->getRemoteAddress(); + } else if (tls_) { + connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::TLS)); + connectionInfo["remoteAddress"] = tls_->getRemoteAddress(); + } else if(ice_) { + connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::ICE)); + connectionInfo["remoteAddress"] = ice_->getRemoteAddress(ICE_COMP_ID_SIP_TRANSPORT); + } + return connectionInfo; + } +}; + +struct PendingCb { + std::string name; + ConnectCallback cb; }; +struct DeviceInfo { + const DeviceId deviceId; + mutable std::mutex mtx_ {}; + std::map<dht::Value::Id, std::shared_ptr<ConnectionInfo>> info; + std::map<dht::Value::Id, PendingCb> connecting; + std::map<dht::Value::Id, PendingCb> waiting; + DeviceInfo(DeviceId id) : deviceId {id} {} + + inline bool isConnecting() const { + return !connecting.empty() || !waiting.empty(); + } + + inline bool empty() const { + return info.empty() && connecting.empty() && waiting.empty(); + } + + dht::Value::Id newId(std::mt19937_64& rand) const { + ValueIdDist dist(1, ID_MAX_VAL); + dht::Value::Id id; + do { + id = dist(rand); + } while (info.find(id) != info.end() + || connecting.find(id) != connecting.end() + || waiting.find(id) != waiting.end()); + return id; + } + + std::shared_ptr<ConnectionInfo> getConnectedInfo() const { + for (auto& [id, ci] : info) { + if (ci->socket_) + return ci; + } + return {}; + } + + std::vector<PendingCb> extractPendingOperations(dht::Value::Id vid, const std::shared_ptr<ChannelSocket>& sock, bool accepted = true) + { + std::vector<PendingCb> ret; + if (vid == 0) { + // Extract all pending callbacks + ret.reserve(connecting.size() + waiting.size()); + for (auto& [vid, cb] : connecting) + ret.emplace_back(std::move(cb)); + connecting.clear(); + for (auto& [vid, cb] : waiting) + ret.emplace_back(std::move(cb)); + waiting.clear(); + } else if (auto n = waiting.extract(vid)) { + // If it's a waiting operation, just move it + ret.emplace_back(std::move(n.mapped())); + } else if (auto n = connecting.extract(vid)) { + ret.emplace_back(std::move(n.mapped())); + // If sock is nullptr, execute if it's the last connecting operation + // If accepted is false, it means that underlying socket is ok, but channel is declined + if (!sock && connecting.empty() && accepted) { + for (auto& [vid, cb] : waiting) + ret.emplace_back(std::move(cb)); + waiting.clear(); + for (auto& [vid, cb] : connecting) + ret.emplace_back(std::move(cb)); + connecting.clear(); + } + } + return ret; + } + + std::vector<std::shared_ptr<ConnectionInfo>> extractUnusedConnections() { + std::vector<std::shared_ptr<ConnectionInfo>> unused {}; + for (auto& [id, info] : info) + unused.emplace_back(std::move(info)); + info.clear(); + return unused; + } + + void executePendingOperations(std::unique_lock<std::mutex>& lock, dht::Value::Id vid, const std::shared_ptr<ChannelSocket>& sock, bool accepted = true) { + auto ops = extractPendingOperations(vid, sock, accepted); + lock.unlock(); + for (auto& cb : ops) + cb.cb(sock, deviceId); + } + void executePendingOperations(dht::Value::Id vid, const std::shared_ptr<ChannelSocket>& sock, bool accepted = true) { + std::unique_lock<std::mutex> lock(mtx_); + executePendingOperations(lock, vid, sock, accepted); + } + + std::map<dht::Value::Id, std::string> getPendingIds() const { + std::map<dht::Value::Id, std::string> ret; + for (const auto& [id, pc]: connecting) + ret[id] = pc.name; + for (const auto& [id, pc]: waiting) + ret[id] = pc.name; + return ret; + } + + std::vector<std::map<std::string, std::string>> + getConnectionList(tls::CertificateStore& certStore) const { + std::lock_guard<std::mutex> lk(mtx_); + std::vector<std::map<std::string, std::string>> ret; + ret.reserve(info.size()); + for (auto& [id, ci] : info) { + std::lock_guard<std::mutex> lk(ci->mutex_); + ret.emplace_back(ci->getInfo(deviceId, id, certStore)); + } + auto cert = certStore.getCertificate(deviceId.toString()); + for (const auto& [vid, ci] : connecting) { + ret.emplace_back(std::map<std::string, std::string> { + {"id", callbackIdToString(deviceId, vid)}, + {"status", std::to_string(static_cast<int>(ConnectionStatus::Connecting))}, + {"device", deviceId.toString()}, + {"peer", cert ? cert->issuer->getId().toString() : ""} + }); + } + for (const auto& [vid, ci] : waiting) { + ret.emplace_back(std::map<std::string, std::string> { + {"id", callbackIdToString(deviceId, vid)}, + {"status", std::to_string(static_cast<int>(ConnectionStatus::Waiting))}, + {"device", deviceId.toString()}, + {"peer", cert ? cert->issuer->getId().toString() : ""} + }); + } + return ret; + } +}; + +class DeviceInfoSet { +public: + std::shared_ptr<DeviceInfo> getDeviceInfo(const DeviceId& deviceId) { + std::lock_guard<std::mutex> lk(mtx_); + auto it = infos_.find(deviceId); + if (it != infos_.end()) + return it->second; + return {}; + } + + std::vector<std::shared_ptr<DeviceInfo>> getDeviceInfos() { + std::vector<std::shared_ptr<DeviceInfo>> deviceInfos; + std::lock_guard<std::mutex> lk(mtx_); + deviceInfos.reserve(infos_.size()); + for (auto& [deviceId, info] : infos_) + deviceInfos.emplace_back(info); + return deviceInfos; + } + + std::shared_ptr<DeviceInfo> createDeviceInfo(const DeviceId& deviceId) { + std::lock_guard<std::mutex> lk(mtx_); + auto& info = infos_[deviceId]; + if (!info) + info = std::make_shared<DeviceInfo>(deviceId); + return info; + } + + bool removeDeviceInfo(const DeviceId& deviceId) { + std::lock_guard<std::mutex> lk(mtx_); + return infos_.erase(deviceId) != 0; + } + + std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, const dht::Value::Id& id) { + if (auto info = getDeviceInfo(deviceId)) { + std::lock_guard<std::mutex> lk(info->mtx_); + auto it = info->info.find(id); + if (it != info->info.end()) + return it->second; + } + return {}; + } + + std::vector<std::shared_ptr<ConnectionInfo>> getConnectedInfos() { + auto deviceInfos = getDeviceInfos(); + std::vector<std::shared_ptr<ConnectionInfo>> ret; + ret.reserve(deviceInfos.size()); + for (auto& info : deviceInfos) { + std::lock_guard<std::mutex> lk(info->mtx_); + for (auto& [id, ci] : info->info) { + if (ci->socket_) + ret.emplace_back(ci); + } + } + return ret; + } + std::vector<std::shared_ptr<DeviceInfo>> shutdown() { + std::vector<std::shared_ptr<DeviceInfo>> ret; + std::lock_guard<std::mutex> lk(mtx_); + ret.reserve(infos_.size()); + for (auto& [deviceId, info] : infos_) { + ret.emplace_back(std::move(info)); + } + infos_.clear(); + return ret; + } + +private: + std::mutex mtx_ {}; + std::map<DeviceId, std::shared_ptr<DeviceInfo>> infos_ {}; +}; + + /** * returns whether or not UPnP is enabled and active_ * ie: if it is able to make port mappings @@ -124,7 +363,7 @@ class ConnectionManager::Impl : public std::enable_shared_from_this<ConnectionMa public: explicit Impl(std::shared_ptr<ConnectionManager::Config> config_) : config_ {std::move(createConfig(config_))} - , rand {dht::crypto::getSeededRandomEngine<std::mt19937_64>()} + , rand_ {dht::crypto::getSeededRandomEngine<std::mt19937_64>()} { loadTreatedMessages(); if(!config_->ioContext) { @@ -151,61 +390,40 @@ public: std::shared_ptr<dht::DhtRunner> dht() { return config_->dht; } const dht::crypto::Identity& identity() const { return config_->id; } - void removeUnusedConnections(const DeviceId& deviceId = {}) + void shutdown() { - std::vector<std::shared_ptr<ConnectionInfo>> unused {}; - - { - std::lock_guard<std::mutex> lk(infosMtx_); - for (auto it = infos_.begin(); it != infos_.end();) { - auto& [key, info] = *it; - if (info && (!deviceId || key.first == deviceId)) { - unused.emplace_back(std::move(info)); - it = infos_.erase(it); - } else { - ++it; - } - } - } - for (auto& info: unused) { - if (info->tls_) - info->tls_->shutdown(); - if (info->socket_) - info->socket_->shutdown(); - if (info->waitForAnswer_) - info->waitForAnswer_->cancel(); + if (isDestroying_.exchange(true)) + return; + std::vector<std::shared_ptr<ConnectionInfo>> unused; + std::vector<std::pair<DeviceId, std::vector<PendingCb>>> pending; + for (auto& dinfo: infos_.shutdown()) { + std::lock_guard<std::mutex> lk(dinfo->mtx_); + auto p = dinfo->extractPendingOperations(0, nullptr, false); + if (!p.empty()) + pending.emplace_back(dinfo->deviceId, std::move(p)); + auto uc = dinfo->extractUnusedConnections(); + unused.insert(unused.end(), std::make_move_iterator(uc.begin()), std::make_move_iterator(uc.end())); } + for (auto& info: unused) + info->shutdown(); + for (auto& op: pending) + for (auto& cb: op.second) + cb.cb(nullptr, op.first); if (!unused.empty()) dht::ThreadPool::io().run([infos = std::move(unused)]() mutable { infos.clear(); }); } - void shutdown() - { - if (isDestroying_.exchange(true)) - return; - decltype(pendingOperations_) po; - { - std::lock_guard<std::mutex> lk(connectCbsMtx_); - po = std::move(pendingOperations_); - } - for (auto& [deviceId, pcbs] : po) { - for (auto& [id, pending] : pcbs.connecting) - pending.cb(nullptr, deviceId); - for (auto& [id, pending] : pcbs.waiting) - pending.cb(nullptr, deviceId); - } - - removeUnusedConnections(); - } - - void connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk, + void connectDeviceStartIce(const std::shared_ptr<ConnectionInfo>& info, + const std::shared_ptr<dht::crypto::PublicKey>& devicePk, const dht::Value::Id& vid, const std::string& connType, std::function<void(bool)> onConnected); - void onResponse(const asio::error_code& ec, const DeviceId& deviceId, const dht::Value::Id& vid); - bool connectDeviceOnNegoDone(const DeviceId& deviceId, + void onResponse(const asio::error_code& ec, const std::weak_ptr<ConnectionInfo>& info, const DeviceId& deviceId, const dht::Value::Id& vid); + bool connectDeviceOnNegoDone(const std::weak_ptr<DeviceInfo>& dinfo, + const std::shared_ptr<ConnectionInfo>& info, + const DeviceId& deviceId, const std::string& name, const dht::Value::Id& vid, const std::shared_ptr<dht::crypto::Certificate>& cert); @@ -235,9 +453,9 @@ public: * @param vid channel's id * @param deviceId to identify the linked ConnectCallback */ - void sendChannelRequest(std::shared_ptr<MultiplexedSocket>& sock, + void sendChannelRequest(const std::weak_ptr<DeviceInfo>& dinfo, + const std::shared_ptr<MultiplexedSocket>& sock, const std::string& name, - const DeviceId& deviceId, const dht::Value::Id& vid); /** * Triggered when a PeerConnectionRequest comes from the DHT @@ -245,15 +463,29 @@ public: void answerTo(IceTransport& ice, const dht::Value::Id& id, const std::shared_ptr<dht::crypto::PublicKey>& fromPk); - bool onRequestStartIce(const PeerConnectionRequest& req); - bool onRequestOnNegoDone(const PeerConnectionRequest& req); + bool onRequestStartIce(const std::shared_ptr<ConnectionInfo>& info, const PeerConnectionRequest& req); + bool onRequestOnNegoDone(const std::weak_ptr<DeviceInfo>& dinfo, const std::shared_ptr<ConnectionInfo>& info, const PeerConnectionRequest& req); void onDhtPeerRequest(const PeerConnectionRequest& req, const std::shared_ptr<dht::crypto::Certificate>& cert); + /** + * Triggered when a new TLS socket is ready to use + * @param ok If succeed + * @param deviceId Related device + * @param vid vid of the connection request + * @param name non empty if TLS was created by connectDevice() + */ + void onTlsNegotiationDone(const std::shared_ptr<DeviceInfo>& dinfo, + const std::shared_ptr<ConnectionInfo>& info, + bool ok, + const DeviceId& deviceId, + const dht::Value::Id& vid, + const std::string& name = ""); - void addNewMultiplexedSocket(const CallbackId& id, const std::shared_ptr<ConnectionInfo>& info); + void addNewMultiplexedSocket(const std::weak_ptr<DeviceInfo>& dinfo, const DeviceId& deviceId, const dht::Value::Id& vid, const std::shared_ptr<ConnectionInfo>& info); void onPeerResponse(const PeerConnectionRequest& req); void onDhtConnected(const dht::crypto::PublicKey& devicePk); + const std::shared_future<tls::DhParams> dhParams() const; tls::CertificateStore& certStore() const { return *config_->certStore; } @@ -327,162 +559,31 @@ public: */ bool getUPnPActive() const; - /** - * Triggered when a new TLS socket is ready to use - * @param ok If succeed - * @param deviceId Related device - * @param vid vid of the connection request - * @param name non empty if TLS was created by connectDevice() - */ - void onTlsNegotiationDone(bool ok, - const DeviceId& deviceId, - const dht::Value::Id& vid, - const std::string& name = ""); - std::shared_ptr<ConnectionManager::Config> config_; std::unique_ptr<std::thread> ioContextRunner_; - mutable std::mt19937_64 rand; + mutable std::mutex randMtx_; + mutable std::mt19937_64 rand_; iOSConnectedCallback iOSConnectedCb_ {}; - std::mutex infosMtx_ {}; - // Note: Someone can ask multiple sockets, so to avoid any race condition, - // each device can have multiple multiplexed sockets. - std::map<CallbackId, std::shared_ptr<ConnectionInfo>> infos_ {}; - - std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, const dht::Value::Id& id) - { - std::lock_guard<std::mutex> lk(infosMtx_); - auto it = infos_.find({deviceId, id}); - if (it != infos_.end()) - return it->second; - return {}; - } - - std::shared_ptr<ConnectionInfo> getConnectedInfo(const DeviceId& deviceId) - { - std::lock_guard<std::mutex> lk(infosMtx_); - auto it = std::find_if(infos_.begin(), infos_.end(), [&](const auto& item) { - auto& [key, value] = item; - return key.first == deviceId && value && value->socket_; - }); - if (it != infos_.end()) - return it->second; - return {}; - } + DeviceInfoSet infos_ {}; ChannelRequestCallback channelReqCb_ {}; ConnectionReadyCallback connReadyCb_ {}; onICERequestCallback iceReqCb_ {}; - - /** - * Stores callback from connectDevice - * @note: each device needs a vector because several connectDevice can - * be done in parallel and we only want one socket - */ - std::mutex connectCbsMtx_ {}; - - - struct PendingCb - { - std::string name; - ConnectCallback cb; - }; - struct PendingOperations { - std::map<dht::Value::Id, PendingCb> connecting; - std::map<dht::Value::Id, PendingCb> waiting; - }; - - std::map<DeviceId, PendingOperations> pendingOperations_ {}; - - void executePendingOperations(const DeviceId& deviceId, const dht::Value::Id& vid, const std::shared_ptr<ChannelSocket>& sock, bool accepted = true) - { - std::vector<PendingCb> ret; - std::unique_lock<std::mutex> lk(connectCbsMtx_); - auto it = pendingOperations_.find(deviceId); - if (it == pendingOperations_.end()) - return; - auto& pendingOperations = it->second; - if (vid == 0) { - // Extract all pending callbacks - for (auto& [vid, cb] : pendingOperations.connecting) - ret.emplace_back(std::move(cb)); - pendingOperations.connecting.clear(); - for (auto& [vid, cb] : pendingOperations.waiting) - ret.emplace_back(std::move(cb)); - pendingOperations.waiting.clear(); - } else if (auto n = pendingOperations.waiting.extract(vid)) { - // If it's a waiting operation, just move it - ret.emplace_back(std::move(n.mapped())); - } else if (auto n = pendingOperations.connecting.extract(vid)) { - ret.emplace_back(std::move(n.mapped())); - // If sock is nullptr, execute if it's the last connecting operation - // If accepted is false, it means that underlying socket is ok, but channel is declined - if (!sock && pendingOperations.connecting.empty() && accepted) { - for (auto& [vid, cb] : pendingOperations.waiting) - ret.emplace_back(std::move(cb)); - pendingOperations.waiting.clear(); - for (auto& [vid, cb] : pendingOperations.connecting) - ret.emplace_back(std::move(cb)); - pendingOperations.connecting.clear(); - } - } - if (pendingOperations.waiting.empty() && pendingOperations.connecting.empty()) - pendingOperations_.erase(it); - lk.unlock(); - for (auto& cb : ret) - cb.cb(sock, deviceId); - } - - std::map<dht::Value::Id, std::string> getPendingIds(const DeviceId& deviceId, const dht::Value::Id vid = 0) - { - std::map<dht::Value::Id, std::string> ret; - std::lock_guard<std::mutex> lk(connectCbsMtx_); - auto it = pendingOperations_.find(deviceId); - if (it == pendingOperations_.end()) - return ret; - auto& pendingOp = it->second; - for (const auto& [id, pc]: pendingOp.connecting) { - if (vid == 0 || id == vid) - ret[id] = pc.name; - } - for (const auto& [id, pc]: pendingOp.waiting) { - if (vid == 0 || id == vid) - ret[id] = pc.name; - } - return ret; - } - - std::shared_ptr<ConnectionManager::Impl> shared() - { - return std::static_pointer_cast<ConnectionManager::Impl>(shared_from_this()); - } - std::shared_ptr<ConnectionManager::Impl const> shared() const - { - return std::static_pointer_cast<ConnectionManager::Impl const>(shared_from_this()); - } - std::weak_ptr<ConnectionManager::Impl> weak() - { - return std::static_pointer_cast<ConnectionManager::Impl>(shared_from_this()); - } - std::weak_ptr<ConnectionManager::Impl const> weak() const - { - return std::static_pointer_cast<ConnectionManager::Impl const>(shared_from_this()); - } - std::atomic_bool isDestroying_ {false}; }; void ConnectionManager::Impl::connectDeviceStartIce( + const std::shared_ptr<ConnectionInfo>& info, const std::shared_ptr<dht::crypto::PublicKey>& devicePk, const dht::Value::Id& vid, const std::string& connType, std::function<void(bool)> onConnected) { auto deviceId = devicePk->getLongId(); - auto info = getInfo(deviceId, vid); if (!info) { onConnected(false); return; @@ -542,17 +643,18 @@ ConnectionManager::Impl::connectDeviceStartIce( std::chrono::steady_clock::now() + DHT_MSG_TIMEOUT); info->waitForAnswer_->async_wait( - std::bind(&ConnectionManager::Impl::onResponse, this, std::placeholders::_1, deviceId, vid)); + std::bind(&ConnectionManager::Impl::onResponse, this, std::placeholders::_1, info, deviceId, vid)); } void ConnectionManager::Impl::onResponse(const asio::error_code& ec, + const std::weak_ptr<ConnectionInfo>& winfo, const DeviceId& deviceId, const dht::Value::Id& vid) { if (ec == asio::error::operation_aborted) return; - auto info = getInfo(deviceId, vid); + auto info = winfo.lock(); if (!info) return; @@ -587,12 +689,13 @@ ConnectionManager::Impl::onResponse(const asio::error_code& ec, bool ConnectionManager::Impl::connectDeviceOnNegoDone( + const std::weak_ptr<DeviceInfo>& dinfo, + const std::shared_ptr<ConnectionInfo>& info, const DeviceId& deviceId, const std::string& name, const dht::Value::Id& vid, const std::shared_ptr<dht::crypto::Certificate>& cert) { - auto info = getInfo(deviceId, vid); if (!info) return false; @@ -625,10 +728,10 @@ ConnectionManager::Impl::connectDeviceOnNegoDone( *cert); info->tls_->setOnReady( - [w = weak(), deviceId = std::move(deviceId), vid = std::move(vid), name = std::move(name)]( + [w = weak_from_this(), dinfo, winfo=std::weak_ptr(info), deviceId = std::move(deviceId), vid = std::move(vid), name = std::move(name)]( bool ok) { if (auto shared = w.lock()) - shared->onTlsNegotiationDone(ok, deviceId, vid, name); + shared->onTlsNegotiationDone(dinfo.lock(), winfo.lock(), ok, deviceId, vid, name); }); return true; } @@ -650,7 +753,7 @@ ConnectionManager::Impl::connectDevice(const DeviceId& deviceId, return; } findCertificate(deviceId, - [w = weak(), + [w = weak_from_this(), deviceId, name, cb = std::move(cb), @@ -695,7 +798,7 @@ ConnectionManager::Impl::connectDevice(const dht::InfoHash& deviceId, return; } findCertificate(deviceId, - [w = weak(), + [w = weak_from_this(), deviceId, name, cb = std::move(cb), @@ -734,7 +837,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif const std::string& connType) { // Avoid dht operation in a DHT callback to avoid deadlocks - dht::ThreadPool::computation().run([w = weak(), + dht::ThreadPool::computation().run([w = weak_from_this(), name = std::move(name), cert = std::move(cert), cb = std::move(cb), @@ -748,40 +851,35 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif cb(nullptr, deviceId); return; } + auto di = sthis->infos_.createDeviceInfo(deviceId); + std::unique_lock<std::mutex> lk(di->mtx_); + dht::Value::Id vid; - auto isConnectingToDevice = false; { - std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_); - vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand); - auto pendingsIt = sthis->pendingOperations_.find(deviceId); - if (pendingsIt != sthis->pendingOperations_.end()) { - const auto& pendings = pendingsIt->second; - while (pendings.connecting.find(vid) != pendings.connecting.end() - || pendings.waiting.find(vid) != pendings.waiting.end()) { - vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand); - } - } - // Check if already connecting - isConnectingToDevice = pendingsIt != sthis->pendingOperations_.end(); - // Save current request for sendChannelRequest. - // Note: do not return here, cause we can be in a state where first - // socket is negotiated and first channel is pending - // so return only after we checked the info - if (isConnectingToDevice && !forceNewSocket) - pendingsIt->second.waiting[vid] = PendingCb {name, std::move(cb)}; - else - sthis->pendingOperations_[deviceId].connecting[vid] = PendingCb {name, std::move(cb)}; + std::lock_guard<std::mutex> lkr(sthis->randMtx_); + vid = di->newId(sthis->rand_); } + // Check if already connecting + auto isConnectingToDevice = di->isConnecting(); + // Note: we can be in a state where first + // socket is negotiated and first channel is pending + // so return only after we checked the info + if (isConnectingToDevice && !forceNewSocket) + di->waiting[vid] = PendingCb {name, std::move(cb)}; + else + di->connecting[vid] = PendingCb {name, std::move(cb)}; + // Check if already negotiated - CallbackId cbId(deviceId, vid); - if (auto info = sthis->getConnectedInfo(deviceId)) { - std::lock_guard<std::mutex> lk(info->mutex_); - if (info->socket_) { + if (auto info = di->getConnectedInfo()) { + std::unique_lock<std::mutex> lkc(info->mutex_); + if (auto sock = info->socket_) { + info->cbIds_.emplace(vid); + lkc.unlock(); + lk.unlock(); if (sthis->config_->logger) sthis->config_->logger->debug("[device {}] Peer already connected. Add a new channel", deviceId); - info->cbIds_.emplace(cbId); - sthis->sendChannelRequest(info->socket_, name, deviceId, vid); + sthis->sendChannelRequest(di, sock, name, vid); return; } } @@ -793,18 +891,24 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif } if (noNewSocket) { // If no new socket is specified, we don't try to generate a new socket - sthis->executePendingOperations(deviceId, vid, nullptr); + di->executePendingOperations(lk, vid, nullptr); return; } // Note: used when the ice negotiation fails to erase // all stored structures. - auto eraseInfo = [w, cbId] { - if (auto shared = w.lock()) { - // If no new socket is specified, we don't try to generate a new socket - shared->executePendingOperations(cbId.first, cbId.second, nullptr); - std::lock_guard<std::mutex> lk(shared->infosMtx_); - shared->infos_.erase(cbId); + auto eraseInfo = [w, diw=std::weak_ptr(di), vid] { + if (auto di = diw.lock()) { + std::unique_lock<std::mutex> lk(di->mtx_); + di->info.erase(vid); + auto ops = di->extractPendingOperations(vid, nullptr); + if (di->empty()) { + if (auto shared = w.lock()) + shared->infos_.removeDeviceInfo(di->deviceId); + } + lk.unlock(); + for (const auto& op: ops) + op.cb(nullptr, di->deviceId); } }; @@ -812,6 +916,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif sthis->getIceOptions([w, deviceId = std::move(deviceId), devicePk = std::move(devicePk), + diw=std::weak_ptr(di), name = std::move(name), cert = std::move(cert), vid, @@ -822,17 +927,22 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); }); return; } + auto info = std::make_shared<ConnectionInfo>(); + auto winfo = std::weak_ptr(info); ice_config.tcpEnable = true; ice_config.onInitDone = [w, devicePk = std::move(devicePk), name = std::move(name), cert = std::move(cert), + diw, + winfo = std::weak_ptr(info), vid, connType, eraseInfo](bool ok) { dht::ThreadPool::io().run([w = std::move(w), devicePk = std::move(devicePk), - vid = std::move(vid), + vid, + winfo, eraseInfo, connType, ok] { auto sthis = w.lock(); @@ -842,7 +952,7 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif eraseInfo(); return; } - sthis->connectDeviceStartIce(devicePk, vid, connType, [=](bool ok) { + sthis->connectDeviceStartIce(winfo.lock(), devicePk, vid, connType, [=](bool ok) { if (!ok) { dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); }); } @@ -853,27 +963,30 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif deviceId, name, cert = std::move(cert), + diw, + winfo = std::weak_ptr(info), vid, eraseInfo](bool ok) { dht::ThreadPool::io().run([w = std::move(w), deviceId = std::move(deviceId), name = std::move(name), cert = std::move(cert), + diw = std::move(diw), + winfo = std::move(winfo), vid = std::move(vid), eraseInfo = std::move(eraseInfo), ok] { auto sthis = w.lock(); if (!ok && sthis && sthis->config_->logger) sthis->config_->logger->error("[device {}] ICE negotiation failed.", deviceId); - if (!sthis || !ok || !sthis->connectDeviceOnNegoDone(deviceId, name, vid, cert)) + if (!sthis || !ok || !sthis->connectDeviceOnNegoDone(diw, winfo.lock(), deviceId, name, vid, cert)) eraseInfo(); }); }; - auto info = std::make_shared<ConnectionInfo>(); - { - std::lock_guard<std::mutex> lk(sthis->infosMtx_); - sthis->infos_[{deviceId, vid}] = info; + if (auto di = diw.lock()) { + std::lock_guard<std::mutex> lk(di->mtx_); + di->info[vid] = info; } std::unique_lock<std::mutex> lk {info->mutex_}; ice_config.master = false; @@ -903,23 +1016,20 @@ ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certif } void -ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>& sock, +ConnectionManager::Impl::sendChannelRequest(const std::weak_ptr<DeviceInfo>& dinfo, + const std::shared_ptr<MultiplexedSocket>& sock, const std::string& name, - const DeviceId& deviceId, const dht::Value::Id& vid) { auto channelSock = sock->addChannel(name); - channelSock->onShutdown([name, deviceId, vid, w = weak()] { - auto shared = w.lock(); - if (auto shared = w.lock()) - shared->executePendingOperations(deviceId, vid, nullptr); + channelSock->onShutdown([dinfo, name, vid] { + if (auto info = dinfo.lock()) + info->executePendingOperations(vid, nullptr); }); channelSock->onReady( - [wSock = std::weak_ptr<ChannelSocket>(channelSock), name, deviceId, vid, w = weak()](bool accepted) { - auto shared = w.lock(); - auto channelSock = wSock.lock(); - if (shared) - shared->executePendingOperations(deviceId, vid, accepted ? channelSock : nullptr, accepted); + [dinfo, wSock = std::weak_ptr(channelSock), name, vid](bool accepted) { + if (auto info = dinfo.lock()) + info->executePendingOperations(vid, accepted ? wSock.lock() : nullptr, accepted); }); ChannelRequest val; @@ -937,7 +1047,7 @@ ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>& if (res < 0) { // TODO check if we should handle errors here if (config_->logger) - config_->logger->error("[device {}] sendChannelRequest failed - error: {}", deviceId, ec.message()); + config_->logger->error("sendChannelRequest failed - error: {}", ec.message()); } } @@ -945,7 +1055,7 @@ void ConnectionManager::Impl::onPeerResponse(const PeerConnectionRequest& req) { auto device = req.owner->getLongId(); - if (auto info = getInfo(device, req.id)) { + if (auto info = infos_.getInfo(device, req.id)) { if (config_->logger) config_->logger->debug("[device {}] New response received", device); std::lock_guard<std::mutex> lk {info->mutex_}; @@ -955,6 +1065,7 @@ ConnectionManager::Impl::onPeerResponse(const PeerConnectionRequest& req) info->waitForAnswer_->async_wait(std::bind(&ConnectionManager::Impl::onResponse, this, std::placeholders::_1, + std::weak_ptr(info), device, req.id)); } else { @@ -970,7 +1081,7 @@ ConnectionManager::Impl::onDhtConnected(const dht::crypto::PublicKey& devicePk) return; dht()->listen<PeerConnectionRequest>( dht::InfoHash::get(PeerConnectionRequest::key_prefix + devicePk.getId().toString()), - [w = weak()](PeerConnectionRequest&& req) { + [w = weak_from_this()](PeerConnectionRequest&& req) { auto shared = w.lock(); if (!shared) return false; @@ -1018,7 +1129,9 @@ ConnectionManager::Impl::onDhtConnected(const dht::crypto::PublicKey& devicePk) } void -ConnectionManager::Impl::onTlsNegotiationDone(bool ok, +ConnectionManager::Impl::onTlsNegotiationDone(const std::shared_ptr<DeviceInfo>& dinfo, + const std::shared_ptr<ConnectionInfo>& info, + bool ok, const DeviceId& deviceId, const dht::Value::Id& vid, const std::string& name) @@ -1044,7 +1157,7 @@ ConnectionManager::Impl::onTlsNegotiationDone(bool ok, deviceId, name, vid); - executePendingOperations(deviceId, vid, nullptr); + dinfo->executePendingOperations(vid, nullptr); } } else { // The socket is ready, store it @@ -1061,17 +1174,19 @@ ConnectionManager::Impl::onTlsNegotiationDone(bool ok, vid); } - auto info = getInfo(deviceId, vid); - addNewMultiplexedSocket({deviceId, vid}, info); + // Note: do not remove pending there it's done in sendChannelRequest + std::unique_lock<std::mutex> lk2 {dinfo->mtx_}; + auto pendingIds = dinfo->getPendingIds(); + lk2.unlock(); + std::unique_lock<std::mutex> lk {info->mutex_}; + addNewMultiplexedSocket(dinfo, deviceId, vid, info); // Finally, open the channel and launch pending callbacks - if (info->socket_) { - // Note: do not remove pending there it's done in sendChannelRequest - for (const auto& [id, name] : getPendingIds(deviceId)) { - if (config_->logger) - config_->logger->debug("[device {}] Send request on TLS socket for channel {}", - deviceId, name); - sendChannelRequest(info->socket_, name, deviceId, id); - } + lk.unlock(); + for (const auto& [id, name]: pendingIds) { + if (config_->logger) + config_->logger->debug("[device {}] Send request on TLS socket for channel {}", + deviceId, name); + sendChannelRequest(dinfo, info->socket_, name, id); } } } @@ -1113,13 +1228,12 @@ ConnectionManager::Impl::answerTo(IceTransport& ice, } bool -ConnectionManager::Impl::onRequestStartIce(const PeerConnectionRequest& req) +ConnectionManager::Impl::onRequestStartIce(const std::shared_ptr<ConnectionInfo>& info, const PeerConnectionRequest& req) { - auto deviceId = req.owner->getLongId(); - auto info = getInfo(deviceId, req.id); if (!info) return false; + auto deviceId = req.owner->getLongId(); std::unique_lock<std::mutex> lk {info->mutex_}; auto& ice = info->ice_; if (!ice) { @@ -1144,13 +1258,12 @@ ConnectionManager::Impl::onRequestStartIce(const PeerConnectionRequest& req) } bool -ConnectionManager::Impl::onRequestOnNegoDone(const PeerConnectionRequest& req) +ConnectionManager::Impl::onRequestOnNegoDone(const std::weak_ptr<DeviceInfo>& dinfo, const std::shared_ptr<ConnectionInfo>& info, const PeerConnectionRequest& req) { - auto deviceId = req.owner->getLongId(); - auto info = getInfo(deviceId, req.id); if (!info) return false; + auto deviceId = req.owner->getLongId(); std::unique_lock<std::mutex> lk {info->mutex_}; auto& ice = info->ice_; if (!ice) { @@ -1176,7 +1289,7 @@ ConnectionManager::Impl::onRequestOnNegoDone(const PeerConnectionRequest& req) config_->ioContext, identity(), dhParams(), - [ph, deviceId, w=weak(), l=config_->logger](const dht::crypto::Certificate& cert) { + [ph, deviceId, w=weak_from_this(), l=config_->logger](const dht::crypto::Certificate& cert) { auto shared = w.lock(); if (!shared) return false; @@ -1194,9 +1307,9 @@ ConnectionManager::Impl::onRequestOnNegoDone(const PeerConnectionRequest& req) }); info->tls_->setOnReady( - [w = weak(), deviceId = std::move(deviceId), vid = std::move(req.id)](bool ok) { + [w = weak_from_this(), dinfo, winfo=std::weak_ptr(info), deviceId = std::move(deviceId), vid = std::move(req.id)](bool ok) { if (auto shared = w.lock()) - shared->onTlsNegotiationDone(ok, deviceId, vid); + shared->onTlsNegotiationDone(dinfo.lock(), winfo.lock(), ok, deviceId, vid); }); return true; } @@ -1215,25 +1328,41 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req, } // Because the connection is accepted, create an ICE socket. - getIceOptions([w = weak(), req, deviceId](auto&& ice_config) { + getIceOptions([w = weak_from_this(), req, deviceId](auto&& ice_config) { auto shared = w.lock(); if (!shared) return; + + auto di = shared->infos_.createDeviceInfo(deviceId); + auto info = std::make_shared<ConnectionInfo>(); + auto wdi = std::weak_ptr(di); + auto winfo = std::weak_ptr(info); + // Note: used when the ice negotiation fails to erase // all stored structures. - auto eraseInfo = [w, id = req.id, deviceId] { - if (auto shared = w.lock()) { - // If no new socket is specified, we don't try to generate a new socket - shared->executePendingOperations(deviceId, id, nullptr); - if (shared->connReadyCb_) - shared->connReadyCb_(deviceId, "", nullptr); - std::lock_guard<std::mutex> lk(shared->infosMtx_); - shared->infos_.erase({deviceId, id}); + auto eraseInfo = [w, wdi, id = req.id] { + auto shared = w.lock(); + if (auto di = wdi.lock()) { + std::unique_lock<std::mutex> lk(di->mtx_); + di->info.erase(id); + auto ops = di->extractPendingOperations(id, nullptr); + if (di->empty()) { + if (shared) + shared->infos_.removeDeviceInfo(di->deviceId); + } + lk.unlock(); + for (const auto& op: ops) + op.cb(nullptr, di->deviceId); + if (shared && shared->connReadyCb_) + shared->connReadyCb_(di->deviceId, "", nullptr); } }; + ice_config.master = true; + ice_config.streamsCount = 1; + ice_config.compCountPerStream = 1; // TCP ice_config.tcpEnable = true; - ice_config.onInitDone = [w, req, eraseInfo](bool ok) { + ice_config.onInitDone = [w, winfo, req, eraseInfo](bool ok) { auto shared = w.lock(); if (!shared) return; @@ -1245,16 +1374,15 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req, } dht::ThreadPool::io().run( - [w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] { - auto shared = w.lock(); - if (!shared) - return; - if (!shared->onRequestStartIce(req)) - eraseInfo(); + [w = std::move(w), winfo = std::move(winfo), req = std::move(req), eraseInfo = std::move(eraseInfo)] { + if (auto shared = w.lock()) { + if (!shared->onRequestStartIce(winfo.lock(), req)) + eraseInfo(); + } }); }; - ice_config.onNegoDone = [w, req, eraseInfo](bool ok) { + ice_config.onNegoDone = [w, wdi, winfo, req, eraseInfo](bool ok) { auto shared = w.lock(); if (!shared) return; @@ -1266,25 +1394,22 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req, } dht::ThreadPool::io().run( - [w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] { + [w = std::move(w), wdi = std::move(wdi), winfo = std::move(winfo), req = std::move(req), eraseInfo = std::move(eraseInfo)] { if (auto shared = w.lock()) - if (!shared->onRequestOnNegoDone(req)) + if (!shared->onRequestOnNegoDone(wdi.lock(), winfo.lock(), req)) eraseInfo(); }); }; // Negotiate a new ICE socket - auto info = std::make_shared<ConnectionInfo>(); { - std::lock_guard<std::mutex> lk(shared->infosMtx_); - shared->infos_[{deviceId, req.id}] = info; + std::lock_guard<std::mutex> lk(di->mtx_); + di->info[req.id] = info; } + if (shared->config_->logger) shared->config_->logger->debug("[device {}] Accepting connection", deviceId); std::unique_lock<std::mutex> lk {info->mutex_}; - ice_config.streamsCount = 1; - ice_config.compCountPerStream = 1; // TCP - ice_config.master = true; info->ice_ = shared->config_->factory->createUTransport(""); if (not info->ice_) { if (shared->config_->logger) @@ -1307,16 +1432,16 @@ ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req, } void -ConnectionManager::Impl::addNewMultiplexedSocket(const CallbackId& id, const std::shared_ptr<ConnectionInfo>& info) +ConnectionManager::Impl::addNewMultiplexedSocket(const std::weak_ptr<DeviceInfo>& dinfo, const DeviceId& deviceId, const dht::Value::Id& vid, const std::shared_ptr<ConnectionInfo>& info) { - info->socket_ = std::make_shared<MultiplexedSocket>(config_->ioContext, id.first, std::move(info->tls_), config_->logger); + info->socket_ = std::make_shared<MultiplexedSocket>(config_->ioContext, deviceId, std::move(info->tls_), config_->logger); info->socket_->setOnReady( - [w = weak()](const DeviceId& deviceId, const std::shared_ptr<ChannelSocket>& socket) { + [w = weak_from_this()](const DeviceId& deviceId, const std::shared_ptr<ChannelSocket>& socket) { if (auto sthis = w.lock()) if (sthis->connReadyCb_) sthis->connReadyCb_(deviceId, socket->name(), socket); }); - info->socket_->setOnRequest([w = weak()](const std::shared_ptr<dht::crypto::Certificate>& peer, + info->socket_->setOnRequest([w = weak_from_this()](const std::shared_ptr<dht::crypto::Certificate>& peer, const uint16_t&, const std::string& name) { if (auto sthis = w.lock()) @@ -1324,26 +1449,34 @@ ConnectionManager::Impl::addNewMultiplexedSocket(const CallbackId& id, const std return sthis->channelReqCb_(peer, name); return false; }); - info->socket_->onShutdown([w = weak(), deviceId=id.first, vid=id.second]() { + info->socket_->onShutdown([dinfo, wi=std::weak_ptr(info), vid]() { // Cancel current outgoing connections - dht::ThreadPool::io().run([w, deviceId, vid] { - auto sthis = w.lock(); - if (!sthis) - return; - - std::set<CallbackId> ids; - if (auto info = sthis->getInfo(deviceId, vid)) { + dht::ThreadPool::io().run([dinfo, wi, vid] { + std::set<dht::Value::Id> ids; + if (auto info = wi.lock()) { std::lock_guard<std::mutex> lk(info->mutex_); if (info->socket_) { ids = std::move(info->cbIds_); info->socket_->shutdown(); } } - for (const auto& cbId : ids) - sthis->executePendingOperations(cbId.first, cbId.second, nullptr); - - std::lock_guard<std::mutex> lk(sthis->infosMtx_); - sthis->infos_.erase({deviceId, vid}); + if (auto deviceInfo = dinfo.lock()) { + std::shared_ptr<ConnectionInfo> info; + std::vector<PendingCb> ops; + std::unique_lock<std::mutex> lk(deviceInfo->mtx_); + auto it = deviceInfo->info.find(vid); + if (it != deviceInfo->info.end()) { + info = std::move(it->second); + deviceInfo->info.erase(it); + } + for (const auto& cbId : ids) { + auto po = deviceInfo->extractPendingOperations(cbId, nullptr); + ops.insert(ops.end(), po.begin(), po.end()); + } + lk.unlock(); + for (auto& op : ops) + op.cb(nullptr, deviceInfo->deviceId); + } }); }); } @@ -1409,7 +1542,7 @@ ConnectionManager::Impl::loadTreatedMessages() void ConnectionManager::Impl::saveTreatedMessages() const { - dht::ThreadPool::io().run([w = weak()]() { + dht::ThreadPool::io().run([w = weak_from_this()]() { if (auto sthis = w.lock()) { auto& this_ = *sthis; std::lock_guard<std::mutex> lock(this_.messageMutex_); @@ -1474,7 +1607,7 @@ ConnectionManager::Impl::setPublishedAddress(const IpAddr& ip_addr) void ConnectionManager::Impl::storeActiveIpAddress(std::function<void()>&& cb) { - dht()->getPublicAddress([w=weak(), cb = std::move(cb)](std::vector<dht::SockAddr>&& results) { + dht()->getPublicAddress([w=weak_from_this(), cb = std::move(cb)](std::vector<dht::SockAddr>&& results) { auto shared = w.lock(); if (!shared) return; @@ -1703,52 +1836,50 @@ ConnectionManager::connectDevice(const std::shared_ptr<dht::crypto::Certificate> bool ConnectionManager::isConnecting(const DeviceId& deviceId, const std::string& name) const { - auto pending = pimpl_->getPendingIds(deviceId); - return std::find_if(pending.begin(), pending.end(), [&](auto p) { return p.second == name; }) - != pending.end(); + if (auto dinfo = pimpl_->infos_.getDeviceInfo(deviceId)) { + std::unique_lock<std::mutex> lk {dinfo->mtx_}; + auto pending = dinfo->getPendingIds(); + lk.unlock(); + return std::find_if(pending.begin(), pending.end(), [&](const auto& p) { return p.second == name; }) + != pending.end(); + } + return false; } void ConnectionManager::closeConnectionsWith(const std::string& peerUri) { - std::vector<std::shared_ptr<ConnectionInfo>> connInfos; - std::set<DeviceId> peersDevices; - { - std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); - for (auto iter = pimpl_->infos_.begin(); iter != pimpl_->infos_.end();) { - auto const& [key, value] = *iter; - std::unique_lock<std::mutex> lkv {value->mutex_}; - auto deviceId = key.first; - auto tls = value->tls_ ? value->tls_.get() : (value->socket_ ? value->socket_->endpoint() : nullptr); + std::vector<std::shared_ptr<DeviceInfo>> dInfos; + for (const auto& dinfo: pimpl_->infos_.getDeviceInfos()) { + std::unique_lock<std::mutex> lk(dinfo->mtx_); + bool isPeer = false; + for (auto const& [id, cinfo]: dinfo->info) { + std::lock_guard<std::mutex> lkv {cinfo->mutex_}; + auto tls = cinfo->tls_ ? cinfo->tls_.get() : (cinfo->socket_ ? cinfo->socket_->endpoint() : nullptr); auto cert = tls ? tls->peerCertificate() : nullptr; if (not cert) - cert = pimpl_->certStore().getCertificate(deviceId.toString()); + cert = pimpl_->certStore().getCertificate(dinfo->deviceId.toString()); if (cert && cert->issuer && peerUri == cert->issuer->getId().toString()) { - connInfos.emplace_back(value); - peersDevices.emplace(deviceId); - lkv.unlock(); - iter = pimpl_->infos_.erase(iter); - } else { - iter++; + isPeer = true; + break; } } + lk.unlock(); + if (isPeer) { + dInfos.emplace_back(std::move(dinfo)); + } } // Stop connections to all peers devices - for (const auto& deviceId : peersDevices) { - pimpl_->executePendingOperations(deviceId, 0, nullptr); - // This will close the TLS Session - pimpl_->removeUnusedConnections(deviceId); - } - for (auto& info : connInfos) { - if (info->socket_) - info->socket_->shutdown(); - if (info->waitForAnswer_) - info->waitForAnswer_->cancel(); - if (info->ice_) { - std::unique_lock<std::mutex> lk {info->mutex_}; - dht::ThreadPool::io().run( - [ice = std::shared_ptr<IceTransport>(std::move(info->ice_))] {}); - } + for (const auto& dinfo : dInfos) { + std::unique_lock<std::mutex> lk {dinfo->mtx_}; + auto unused = dinfo->extractUnusedConnections(); + auto pending = dinfo->extractPendingOperations(0, nullptr); + pimpl_->infos_.removeDeviceInfo(dinfo->deviceId); + lk.unlock(); + for (auto& op : unused) + op->shutdown(); + for (auto& op : pending) + op.cb(nullptr, dinfo->deviceId); } } @@ -1785,19 +1916,18 @@ ConnectionManager::oniOSConnected(iOSConnectedCallback&& cb) std::size_t ConnectionManager::activeSockets() const { - std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); - return pimpl_->infos_.size(); + return pimpl_->infos_.getConnectedInfos().size(); } void ConnectionManager::monitor() const { - std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); auto logger = pimpl_->config_->logger; if (!logger) return; logger->debug("ConnectionManager current status:"); - for (const auto& [_, ci] : pimpl_->infos_) { + for (const auto& ci : pimpl_->infos_.getConnectedInfos()) { + std::lock_guard<std::mutex> lk(ci->mutex_); if (ci->socket_) ci->socket_->monitor(); } @@ -1807,8 +1937,8 @@ ConnectionManager::monitor() const void ConnectionManager::connectivityChanged() { - std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); - for (const auto& [_, ci] : pimpl_->infos_) { + for (const auto& ci : pimpl_->infos_.getConnectedInfos()) { + std::lock_guard<std::mutex> lk(ci->mutex_); if (ci->socket_) ci->socket_->sendBeacon(); } @@ -1854,71 +1984,14 @@ std::vector<std::map<std::string, std::string>> ConnectionManager::getConnectionList(const DeviceId& device) const { std::vector<std::map<std::string, std::string>> connectionsList; - std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); - - for (const auto& [key, ci] : pimpl_->infos_) { - if (device && key.first != device) - continue; - std::map<std::string, std::string> connectionInfo; - connectionInfo["id"] = callbackIdToString(key.first, key.second); - connectionInfo["device"] = key.first.toString(); - if (ci->tls_) { - if (auto cert = ci->tls_->peerCertificate()) { - connectionInfo["peer"] = cert->issuer->getId().toString(); - } - } - if (ci->socket_) { - connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Connected)); - } else if (ci->tls_) { - connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::TLS)); - } else if(ci->ice_) - { - connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::ICE)); - } - if (ci->tls_) { - std::string remoteAddress = ci->tls_->getRemoteAddress(); - std::string remoteAddressIp = remoteAddress.substr(0, remoteAddress.find(':')); - std::string remoteAddressPort = remoteAddress.substr(remoteAddress.find(':') + 1); - connectionInfo["remoteAdress"] = remoteAddressIp; - connectionInfo["remotePort"] = remoteAddressPort; - } - connectionsList.emplace_back(std::move(connectionInfo)); - } - if (device) { - auto it = pimpl_->pendingOperations_.find(device); - if (it != pimpl_->pendingOperations_.end()) { - const auto& po = it->second; - for (const auto& [vid, ci] : po.connecting) { - std::map<std::string, std::string> connectionInfo; - connectionInfo["id"] = callbackIdToString(device, vid); - connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Connecting)); - connectionsList.emplace_back(std::move(connectionInfo)); - } - - for (const auto& [vid, ci] : po.waiting) { - std::map<std::string, std::string> connectionInfo; - connectionInfo["id"] = callbackIdToString(device, vid); - connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Waiting)); - connectionsList.emplace_back(std::move(connectionInfo)); - } + if (auto deviceInfo = pimpl_->infos_.getDeviceInfo(device)) { + connectionsList = deviceInfo->getConnectionList(pimpl_->certStore()); } - } - else { - for (const auto& [key, po] : pimpl_->pendingOperations_) { - for (const auto& [vid, ci] : po.connecting) { - std::map<std::string, std::string> connectionInfo; - connectionInfo["id"] = callbackIdToString(device, vid); - connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Connecting)); - connectionsList.emplace_back(std::move(connectionInfo)); - } - - for (const auto& [vid, ci] : po.waiting) { - std::map<std::string, std::string> connectionInfo; - connectionInfo["id"] = callbackIdToString(device, vid); - connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Waiting)); - connectionsList.emplace_back(std::move(connectionInfo)); - } + } else { + for (const auto& deviceInfo : pimpl_->infos_.getDeviceInfos()) { + auto cl = deviceInfo->getConnectionList(pimpl_->certStore()); + connectionsList.insert(connectionsList.end(), std::make_move_iterator(cl.begin()), std::make_move_iterator(cl.end())); } } return connectionsList; @@ -1927,13 +2000,13 @@ ConnectionManager::getConnectionList(const DeviceId& device) const std::vector<std::map<std::string, std::string>> ConnectionManager::getChannelList(const std::string& connectionId) const { - std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); - CallbackId cbid = parseCallbackId(connectionId); - if (pimpl_->infos_.count(cbid) > 0) { - return pimpl_->infos_[cbid]->socket_->getChannelList(); - } else { - return {}; + auto [deviceId, valueId] = parseCallbackId(connectionId); + if (auto info = pimpl_->infos_.getInfo(deviceId, valueId)) { + std::lock_guard<std::mutex> lk(info->mutex_); + if (info->socket_) + return info->socket_->getChannelList(); } + return {}; } } // namespace dhtnet diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp index c1bf14e96b6d4073db44f7e5ea73a6dea99ef839..1c8ae3e888fc1426dcc0b411bcde018a865df29c 100644 --- a/src/peer_connection.cpp +++ b/src/peer_connection.cpp @@ -42,8 +42,6 @@ #include <sys/time.h> #endif -static constexpr int ICE_COMP_ID_SIP_TRANSPORT {1}; - namespace dhtnet { int diff --git a/src/peer_connection.h b/src/peer_connection.h index e92e24986f096a111c0064f9744ec4b5d86dde85..c6f01973232dce7abdabdf0b828cebb64ee3b5fb 100644 --- a/src/peer_connection.h +++ b/src/peer_connection.h @@ -47,6 +47,8 @@ using OnStateChangeCb = std::function<bool(tls::TlsSessionState state)>; using OnReadyCb = std::function<void(bool ok)>; using onShutdownCb = std::function<void(void)>; +static constexpr int ICE_COMP_ID_SIP_TRANSPORT {1}; + //============================================================================== class IceSocketEndpoint : public GenericSocket<uint8_t> diff --git a/tests/connectionManager.cpp b/tests/connectionManager.cpp index 946aa2d3df55d04c8971e00b69c166ff53b4b3e2..83914425824be5b2b31ddc0785b008761ffcfdfa 100644 --- a/tests/connectionManager.cpp +++ b/tests/connectionManager.cpp @@ -23,6 +23,7 @@ #include <opendht/log.h> #include <asio/executor_work_guard.hpp> #include <asio/io_context.hpp> +#include <fmt/compile.h> #include <cppunit/TestAssert.h> #include <cppunit/TestFixture.h> @@ -1342,6 +1343,7 @@ ConnectionManagerTest::testShutdownWhileNegotiating() CPPUNIT_ASSERT(cv.wait_for(lk, 30s, [&] { return notConnected; })); } + void ConnectionManagerTest::testGetChannelList() { @@ -1353,11 +1355,10 @@ ConnectionManagerTest::testGetChannelList() bob->connectionManager->onChannelRequest( [](const std::shared_ptr<dht::crypto::Certificate>&, const std::string&) { return true; }); bob->connectionManager->onConnectionReady( - [&receiverConnected, - &cv](const DeviceId&, const std::string&, std::shared_ptr<ChannelSocket> socket) { + [&](const DeviceId&, const std::string&, std::shared_ptr<ChannelSocket> socket) { + std::lock_guard<std::mutex> lk {mtx}; if (socket) receiverConnected += 1; - cv.notify_one(); }); std::string channelId; @@ -1365,36 +1366,27 @@ ConnectionManagerTest::testGetChannelList() "git://*", [&](std::shared_ptr<ChannelSocket> socket, const DeviceId&) { + std::lock_guard<std::mutex> lk {mtx}; if (socket) { - channelId = std::to_string(socket->channel()); + channelId = fmt::format(FMT_COMPILE("{:x}"), socket->channel()); successfullyConnected = true; } - cv.notify_one(); }); CPPUNIT_ASSERT( cv.wait_for(lk, 60s, [&] { return successfullyConnected && receiverConnected == 1; })); std::vector<std::map<std::string, std::string>> expectedList = { - {{"channel", channelId}, {"channelName", "git://*"}}}; + {{"id", channelId}, {"name", "git://*"}}}; auto connectionList = alice->connectionManager->getConnectionList(); CPPUNIT_ASSERT(!connectionList.empty()); const auto& connectionInfo = connectionList[0]; auto it = connectionInfo.find("id"); CPPUNIT_ASSERT(it != connectionInfo.end()); - std::string connectionId = it->second; - auto actualList = alice->connectionManager->getChannelList(connectionId); + auto actualList = alice->connectionManager->getChannelList(it->second); CPPUNIT_ASSERT(expectedList.size() == actualList.size()); - CPPUNIT_ASSERT(std::equal(expectedList.begin(), expectedList.end(), actualList.begin())); for (const auto& expectedMap : expectedList) { - auto it = std::find_if(actualList.begin(), - actualList.end(), - [&](const std::map<std::string, std::string>& actualMap) { - return expectedMap.size() == actualMap.size() - && std::equal(expectedMap.begin(), - expectedMap.end(), - actualMap.begin()); - }); - CPPUNIT_ASSERT(it != actualList.end()); + CPPUNIT_ASSERT(std::find(actualList.begin(), actualList.end(), expectedMap) + != actualList.end()); } }