From afa8e282d9d6f08a5e34057fbb4f8c07209792f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com> Date: Sun, 24 Sep 2023 12:53:20 -0400 Subject: [PATCH] ConnectionManager: use peer certificate from TLS in closeConnectionsWith Change-Id: I55ea604cc2542fb0d38b465cfa6a090450fe9322 --- include/multiplexed_socket.h | 2 ++ src/connectionmanager.cpp | 7 ++++++- src/multiplexed_socket.cpp | 6 ++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/include/multiplexed_socket.h b/include/multiplexed_socket.h index e265db9..2079df5 100644 --- a/include/multiplexed_socket.h +++ b/include/multiplexed_socket.h @@ -162,6 +162,8 @@ public: void eraseChannel(uint16_t channel); + TlsSocketEndpoint* endpoint(); + #ifdef DHTNET_TESTABLE /** * Check if we can send beacon on the socket diff --git a/src/connectionmanager.cpp b/src/connectionmanager.cpp index 67623b2..751b798 100644 --- a/src/connectionmanager.cpp +++ b/src/connectionmanager.cpp @@ -1713,11 +1713,16 @@ ConnectionManager::closeConnectionsWith(const std::string& peerUri) std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); for (auto iter = pimpl_->infos_.begin(); iter != pimpl_->infos_.end();) { auto const& [key, value] = *iter; + std::unique_lock<std::mutex> lkv {value->mutex_}; auto deviceId = key.first; - auto cert = pimpl_->certStore().getCertificate(deviceId.toString()); + auto tls = value->tls_ ? value->tls_.get() : (value->socket_ ? value->socket_->endpoint() : nullptr); + auto cert = tls ? tls->peerCertificate() : nullptr; + if (not cert) + cert = pimpl_->certStore().getCertificate(deviceId.toString()); if (cert && cert->issuer && peerUri == cert->issuer->getId().toString()) { connInfos.emplace_back(value); peersDevices.emplace(deviceId); + lkv.unlock(); iter = pimpl_->infos_.erase(iter); } else { iter++; diff --git a/src/multiplexed_socket.cpp b/src/multiplexed_socket.cpp index 154741b..9c1f952 100644 --- a/src/multiplexed_socket.cpp +++ b/src/multiplexed_socket.cpp @@ -770,6 +770,12 @@ MultiplexedSocket::getRemoteAddress() const return pimpl_->endpoint->getRemoteAddress(); } +TlsSocketEndpoint* +MultiplexedSocket::endpoint() +{ + return pimpl_->endpoint.get(); +} + void MultiplexedSocket::eraseChannel(uint16_t channel) { -- GitLab