From c7d85bb92668623382d868ea4efb7c7ba5c37191 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Blin?=
 <sebastien.blin@savoirfairelinux.com>
Date: Wed, 27 May 2020 14:06:10 -0400
Subject: [PATCH] tlsturntransport: avoid to call unnecessary callbacks

onStateChanged can be used to just clean some structure, for example
with TlsTurnEndpoint. Also, if called after move, the data can be
scrapped and incorrect.
Also, fix a cast in TlsSocketEndpoint when using ICE.

Change-Id: I8104bc8a0fd8e9cd3dae92e06eee45c22feced45
---
 src/jamidht/multiplexed_socket.cpp |  2 +
 src/jamidht/p2p.cpp                | 32 ++++++++------
 src/peer_connection.cpp            | 68 ++++++++++++++++--------------
 src/peer_connection.h              |  8 ++--
 4 files changed, 62 insertions(+), 48 deletions(-)

diff --git a/src/jamidht/multiplexed_socket.cpp b/src/jamidht/multiplexed_socket.cpp
index 946b39d5dc..4af6ba90ca 100644
--- a/src/jamidht/multiplexed_socket.cpp
+++ b/src/jamidht/multiplexed_socket.cpp
@@ -132,7 +132,9 @@ MultiplexedSocket::Impl::eventLoop()
         if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
             JAMI_INFO("Tls endpoint is down, shutdown multiplexed socket");
             shutdown();
+            return false;
         }
+        return true;
     });
     std::error_code ec;
     while (!stop) {
diff --git a/src/jamidht/p2p.cpp b/src/jamidht/p2p.cpp
index a4982717aa..124875182b 100644
--- a/src/jamidht/p2p.cpp
+++ b/src/jamidht/p2p.cpp
@@ -457,7 +457,7 @@ private:
         JAMI_DBG() << parent_.account << "[CNX] start TLS session";
         tls_ep_ = std::make_unique<TlsSocketEndpoint>(
             std::move(peer_ep), parent_.account.identity(), parent_.account.dhParams(),
-            *peerCertificate_);
+            *peerCertificate_, ice->isRunning());
         tls_ep_->setOnStateChange([this, ice=std::move(ice)] (tls::TlsSessionState state) {
             if (state == tls::TlsSessionState::SHUTDOWN) {
                 if (!connected_)
@@ -474,6 +474,7 @@ private:
                     cb(connection_.get());
                 }
             }
+            return true;
         });
     }
 
@@ -627,24 +628,24 @@ DhtPeerConnector::Impl::onTurnPeerConnection(const IpAddr& peer_addr)
 
     tls_turn_ep_[peer_addr]->setOnStateChange([this, peer_addr, peer_h] (tls::TlsSessionState state)
     {
+        auto it = tls_turn_ep_.find(peer_addr);
+        if (it == tls_turn_ep_.end()) return false;
         if (state == tls::TlsSessionState::SHUTDOWN) {
             JAMI_WARN() << "[CNX] TLS connection failure from peer " << peer_addr.toString(true, true);
-            tls_turn_ep_.erase(peer_addr);
+            tls_turn_ep_.erase(it);
+            return false;
         } else if (state == tls::TlsSessionState::ESTABLISHED) {
             if (peer_h) {
-                 JAMI_DBG() << account << "[CNX] Accepted TLS-TURN connection from RingID " << *peer_h;
-                 connectedPeers_
-                .emplace(peer_addr, tls_turn_ep_[peer_addr]->peerCertificate().getId());
-                 auto connection =
-                 std::make_unique<PeerConnection>([] {},
-                                                 peer_addr.toString(),
-                                                 std::move(tls_turn_ep_[peer_addr]));
-                 connection->attachOutputStream(std::make_shared<FtpServer>(account.getAccountID(),
-                                                                            peer_h->toString()));
-                 servers_.emplace(std::make_pair(*peer_h, peer_addr), std::move(connection));
+                JAMI_DBG() << account << "[CNX] Accepted TLS-TURN connection from " << *peer_h;
+                connectedPeers_.emplace(peer_addr, it->second->peerCertificate().getId());
+                auto connection = std::make_unique<PeerConnection>([] {}, peer_addr.toString(), std::move(it->second));
+                connection->attachOutputStream(std::make_shared<FtpServer>(account.getAccountID(), peer_h->toString()));
+                servers_.emplace(std::make_pair(*peer_h, peer_addr), std::move(connection));
             }
-            tls_turn_ep_.erase(peer_addr);
+            tls_turn_ep_.erase(it);
+            return false;
         }
+        return true;
     });
 }
 
@@ -855,11 +856,12 @@ DhtPeerConnector::Impl::answerToRequest(PeerConnectionMsg&& request,
 
         it.first->second->setOnStateChange([this, idx=std::move(idx)] (tls::TlsSessionState state) {
             if (waitForReadyEndpoints_.find(idx) == waitForReadyEndpoints_.end()) {
-                return;
+                return false;
             }
             if (state == tls::TlsSessionState::SHUTDOWN) {
                 JAMI_WARN() << "TLS connection failure";
                 waitForReadyEndpoints_.erase(idx);
+                return false;
             } else if (state == tls::TlsSessionState::ESTABLISHED) {
                 // Connected!
                 auto peer_h = idx.first.toString();
@@ -869,7 +871,9 @@ DhtPeerConnector::Impl::answerToRequest(PeerConnectionMsg&& request,
                 connection->attachOutputStream(std::make_shared<FtpServer>(account.getAccountID(), peer_h));
                 servers_.emplace(idx, std::move(connection));
                 waitForReadyEndpoints_.erase(idx);
+                return false;
             }
+            return true;
         });
     }
     // Now wait for a TURN connection from peer (see onTurnPeerConnection) if fallbacking
diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp
index 906c2c931a..27ea10cc11 100644
--- a/src/peer_connection.cpp
+++ b/src/peer_connection.cpp
@@ -100,7 +100,6 @@ public:
          const Identity& local_identity,
          const std::shared_future<tls::DhParams>& dh_params)
         : peerCertificateCheckFunc {std::move(cert_check)} {
-
         // Add TLS over TURN
         tls::TlsSession::TlsSessionCallbacks tls_cbs = {
             /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); },
@@ -160,7 +159,8 @@ void
 TlsTurnEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state)
 {
     if (onStateChangeCb_)
-        onStateChangeCb_(state);
+        if (!onStateChangeCb_(state))
+            onStateChangeCb_ = {};
 }
 
 void
@@ -233,7 +233,7 @@ TlsTurnEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code&
 }
 
 void
-TlsTurnEndpoint::setOnStateChange(std::function<void(tls::TlsSessionState state)>&& cb)
+TlsTurnEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb)
 {
     pimpl_->onStateChangeCb_ = std::move(cb);
 }
@@ -395,8 +395,9 @@ public:
     Impl(std::unique_ptr<AbstractSocketEndpoint>&& ep,
          const dht::crypto::Certificate& peer_cert,
          const Identity& local_identity,
-         const std::shared_future<tls::DhParams>& dh_params)
-        : peerCertificate {peer_cert}, ep_ {ep.get()} {
+         const std::shared_future<tls::DhParams>& dh_params,
+         bool isIceTransport = true)
+        : peerCertificate {peer_cert}, ep_ {ep.get()}, isIce_{isIceTransport} {
         tls::TlsSession::TlsSessionCallbacks tls_cbs = {
             /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); },
             /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onTlsRxData(std::move(buf)); },
@@ -415,19 +416,21 @@ public:
         };
         tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
 
-        const IceSocketEndpoint* iceSocket = (const IceSocketEndpoint*)(ep_);
-        if (iceSocket) {
-            iceSocket->underlyingICE()->setOnShutdown([this]() {
-                tls->shutdown();
-            });
+        if (isIce_) {
+            if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_)) {
+                iceSocket->underlyingICE()->setOnShutdown([this]() {
+                    tls->shutdown();
+                });
+            }
         }
     }
 
     Impl(std::unique_ptr<AbstractSocketEndpoint>&& ep,
          std::function<bool(const dht::crypto::Certificate &)>&& cert_check,
          const Identity& local_identity,
-         const std::shared_future<tls::DhParams>& dh_params)
-        : peerCertificateCheckFunc{std::move(cert_check)}, peerCertificate {null_cert}, ep_ {ep.get()} {
+         const std::shared_future<tls::DhParams>& dh_params,
+         bool isIce = true)
+        : peerCertificateCheckFunc{std::move(cert_check)}, peerCertificate {null_cert}, ep_ {ep.get()}, isIce_{isIce} {
         tls::TlsSession::TlsSessionCallbacks tls_cbs = {
             /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); },
             /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onTlsRxData(std::move(buf)); },
@@ -446,11 +449,12 @@ public:
         };
         tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
 
-        const IceSocketEndpoint* iceSocket = (const IceSocketEndpoint*)(ep_);
-        if (iceSocket) {
-            iceSocket->underlyingICE()->setOnShutdown([this]() {
-                if (tls) tls->shutdown();
-            });
+        if (isIce_) {
+            if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_)) {
+                iceSocket->underlyingICE()->setOnShutdown([this]() {
+                    if (tls) tls->shutdown();
+                });
+            }
         }
     }
 
@@ -474,6 +478,7 @@ public:
     OnReadyCb onReadyCb_;
     std::unique_ptr<tls::TlsSession> tls;
     const AbstractSocketEndpoint* ep_;
+    bool isIce_ {true};
 };
 
 // Declaration at namespace scope is necessary (until C++17)
@@ -512,7 +517,8 @@ TlsSocketEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state)
             onReadyCb_(state == tls::TlsSessionState::ESTABLISHED);
     }
     if (onStateChangeCb_)
-        onStateChangeCb_(state);
+        if (!onStateChangeCb_(state))
+            onStateChangeCb_ = {};
 }
 
 void
@@ -528,15 +534,17 @@ TlsSocketEndpoint::Impl::onTlsCertificatesUpdate(UNUSED const gnutls_datum_t* lo
 TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr,
                                      const Identity& local_identity,
                                      const std::shared_future<tls::DhParams>& dh_params,
-                                     const dht::crypto::Certificate& peer_cert)
-    : pimpl_ { std::make_unique<Impl>(std::move(tr), peer_cert, local_identity, dh_params) }
+                                     const dht::crypto::Certificate& peer_cert,
+                                     bool isIce)
+    : pimpl_ { std::make_unique<Impl>(std::move(tr), peer_cert, local_identity, dh_params, isIce) }
 {}
 
 TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr,
                                     const Identity& local_identity,
                                     const std::shared_future<tls::DhParams>& dh_params,
-                                    std::function<bool(const dht::crypto::Certificate&)>&& cert_check)
-    : pimpl_ { std::make_unique<Impl>(std::move(tr), std::move(cert_check), local_identity, dh_params) }
+                                    std::function<bool(const dht::crypto::Certificate&)>&& cert_check,
+                                    bool isIce)
+    : pimpl_ { std::make_unique<Impl>(std::move(tr), std::move(cert_check), local_identity, dh_params, isIce) }
 {
 }
 
@@ -601,7 +609,7 @@ TlsSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_cod
 }
 
 void
-TlsSocketEndpoint::setOnStateChange(std::function<void(tls::TlsSessionState state)>&& cb)
+TlsSocketEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb)
 {
     pimpl_->onStateChangeCb_ = std::move(cb);
 }
@@ -615,9 +623,10 @@ TlsSocketEndpoint::setOnReady(std::function<void(bool ok)>&& cb)
 void
 TlsSocketEndpoint::shutdown()
 {
-    const IceSocketEndpoint* iceSocket = (const IceSocketEndpoint*)(pimpl_->ep_);
-    if (iceSocket && iceSocket->underlyingICE()) {
-        iceSocket->underlyingICE()->cancelOperations();
+    if (pimpl_->ep_ && pimpl_->isIce_) {
+        const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_);
+        if (iceSocket && iceSocket->underlyingICE())
+            iceSocket->underlyingICE()->cancelOperations();
     }
     pimpl_->tls->shutdown();
 }
@@ -625,12 +634,9 @@ TlsSocketEndpoint::shutdown()
 std::shared_ptr<IceTransport>
 TlsSocketEndpoint::underlyingICE() const
 {
-    if (pimpl_->ep_) {
-        const IceSocketEndpoint* iceSocket = (const IceSocketEndpoint*)(pimpl_->ep_);
-        if (iceSocket) {
+    if (pimpl_->ep_ && pimpl_->isIce_)
+        if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_))
             return iceSocket->underlyingICE();
-        }
-    }
     return {};
 }
 
diff --git a/src/peer_connection.h b/src/peer_connection.h
index ce5d364e52..29cc183395 100644
--- a/src/peer_connection.h
+++ b/src/peer_connection.h
@@ -45,7 +45,7 @@ struct Certificate;
 
 namespace jami {
 
-using OnStateChangeCb = std::function<void(tls::TlsSessionState state)>;
+using OnStateChangeCb = std::function<bool(tls::TlsSessionState state)>;
 using OnReadyCb = std::function<void(bool ok)>;
 using onShutdownCb = std::function<void(void)>;
 
@@ -199,11 +199,13 @@ public:
     TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr,
                       const Identity& local_identity,
                       const std::shared_future<tls::DhParams>& dh_params,
-                      const dht::crypto::Certificate& peer_cert);
+                      const dht::crypto::Certificate& peer_cert,
+                      bool isIceTransport = true);
     TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr,
                     const Identity& local_identity,
                     const std::shared_future<tls::DhParams>& dh_params,
-                    std::function<bool(const dht::crypto::Certificate&)>&& cert_check);
+                    std::function<bool(const dht::crypto::Certificate&)>&& cert_check,
+                    bool isIceTransport = true);
     ~TlsSocketEndpoint();
 
     bool isReliable() const override { return true; }
-- 
GitLab