From afa8e282d9d6f08a5e34057fbb4f8c07209792f2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Sun, 24 Sep 2023 12:53:20 -0400
Subject: [PATCH] ConnectionManager: use peer certificate from TLS in
 closeConnectionsWith

Change-Id: I55ea604cc2542fb0d38b465cfa6a090450fe9322
---
 include/multiplexed_socket.h | 2 ++
 src/connectionmanager.cpp    | 7 ++++++-
 src/multiplexed_socket.cpp   | 6 ++++++
 3 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/include/multiplexed_socket.h b/include/multiplexed_socket.h
index e265db9..2079df5 100644
--- a/include/multiplexed_socket.h
+++ b/include/multiplexed_socket.h
@@ -162,6 +162,8 @@ public:
 
     void eraseChannel(uint16_t channel);
 
+    TlsSocketEndpoint* endpoint();
+
 #ifdef DHTNET_TESTABLE
     /**
      * Check if we can send beacon on the socket
diff --git a/src/connectionmanager.cpp b/src/connectionmanager.cpp
index 67623b2..751b798 100644
--- a/src/connectionmanager.cpp
+++ b/src/connectionmanager.cpp
@@ -1713,11 +1713,16 @@ ConnectionManager::closeConnectionsWith(const std::string& peerUri)
         std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
         for (auto iter = pimpl_->infos_.begin(); iter != pimpl_->infos_.end();) {
             auto const& [key, value] = *iter;
+            std::unique_lock<std::mutex> lkv {value->mutex_};
             auto deviceId = key.first;
-            auto cert = pimpl_->certStore().getCertificate(deviceId.toString());
+            auto tls = value->tls_ ? value->tls_.get() : (value->socket_ ? value->socket_->endpoint() : nullptr);
+            auto cert = tls ? tls->peerCertificate() : nullptr;
+            if (not cert)
+                cert = pimpl_->certStore().getCertificate(deviceId.toString());
             if (cert && cert->issuer && peerUri == cert->issuer->getId().toString()) {
                 connInfos.emplace_back(value);
                 peersDevices.emplace(deviceId);
+                lkv.unlock();
                 iter = pimpl_->infos_.erase(iter);
             } else {
                 iter++;
diff --git a/src/multiplexed_socket.cpp b/src/multiplexed_socket.cpp
index 154741b..9c1f952 100644
--- a/src/multiplexed_socket.cpp
+++ b/src/multiplexed_socket.cpp
@@ -770,6 +770,12 @@ MultiplexedSocket::getRemoteAddress() const
     return pimpl_->endpoint->getRemoteAddress();
 }
 
+TlsSocketEndpoint*
+MultiplexedSocket::endpoint()
+{
+    return pimpl_->endpoint.get();
+}
+
 void
 MultiplexedSocket::eraseChannel(uint16_t channel)
 {
-- 
GitLab