diff --git a/src/jamidht/multiplexed_socket.cpp b/src/jamidht/multiplexed_socket.cpp index 946b39d5dc49d739a0d28d88605a5d4d9f36cbe4..4af6ba90cab8eb4f6dd0aa45f2b99bf625e9be50 100644 --- a/src/jamidht/multiplexed_socket.cpp +++ b/src/jamidht/multiplexed_socket.cpp @@ -132,7 +132,9 @@ MultiplexedSocket::Impl::eventLoop() if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) { JAMI_INFO("Tls endpoint is down, shutdown multiplexed socket"); shutdown(); + return false; } + return true; }); std::error_code ec; while (!stop) { diff --git a/src/jamidht/p2p.cpp b/src/jamidht/p2p.cpp index a4982717aa1a5663aaac9e800a3e1b8eedbc0e17..124875182b3245d4518417852309a23d5df0d5f5 100644 --- a/src/jamidht/p2p.cpp +++ b/src/jamidht/p2p.cpp @@ -457,7 +457,7 @@ private: JAMI_DBG() << parent_.account << "[CNX] start TLS session"; tls_ep_ = std::make_unique<TlsSocketEndpoint>( std::move(peer_ep), parent_.account.identity(), parent_.account.dhParams(), - *peerCertificate_); + *peerCertificate_, ice->isRunning()); tls_ep_->setOnStateChange([this, ice=std::move(ice)] (tls::TlsSessionState state) { if (state == tls::TlsSessionState::SHUTDOWN) { if (!connected_) @@ -474,6 +474,7 @@ private: cb(connection_.get()); } } + return true; }); } @@ -627,24 +628,24 @@ DhtPeerConnector::Impl::onTurnPeerConnection(const IpAddr& peer_addr) tls_turn_ep_[peer_addr]->setOnStateChange([this, peer_addr, peer_h] (tls::TlsSessionState state) { + auto it = tls_turn_ep_.find(peer_addr); + if (it == tls_turn_ep_.end()) return false; if (state == tls::TlsSessionState::SHUTDOWN) { JAMI_WARN() << "[CNX] TLS connection failure from peer " << peer_addr.toString(true, true); - tls_turn_ep_.erase(peer_addr); + tls_turn_ep_.erase(it); + return false; } else if (state == tls::TlsSessionState::ESTABLISHED) { if (peer_h) { - JAMI_DBG() << account << "[CNX] Accepted TLS-TURN connection from RingID " << *peer_h; - connectedPeers_ - .emplace(peer_addr, tls_turn_ep_[peer_addr]->peerCertificate().getId()); - auto connection = - std::make_unique<PeerConnection>([] {}, - peer_addr.toString(), - std::move(tls_turn_ep_[peer_addr])); - connection->attachOutputStream(std::make_shared<FtpServer>(account.getAccountID(), - peer_h->toString())); - servers_.emplace(std::make_pair(*peer_h, peer_addr), std::move(connection)); + JAMI_DBG() << account << "[CNX] Accepted TLS-TURN connection from " << *peer_h; + connectedPeers_.emplace(peer_addr, it->second->peerCertificate().getId()); + auto connection = std::make_unique<PeerConnection>([] {}, peer_addr.toString(), std::move(it->second)); + connection->attachOutputStream(std::make_shared<FtpServer>(account.getAccountID(), peer_h->toString())); + servers_.emplace(std::make_pair(*peer_h, peer_addr), std::move(connection)); } - tls_turn_ep_.erase(peer_addr); + tls_turn_ep_.erase(it); + return false; } + return true; }); } @@ -855,11 +856,12 @@ DhtPeerConnector::Impl::answerToRequest(PeerConnectionMsg&& request, it.first->second->setOnStateChange([this, idx=std::move(idx)] (tls::TlsSessionState state) { if (waitForReadyEndpoints_.find(idx) == waitForReadyEndpoints_.end()) { - return; + return false; } if (state == tls::TlsSessionState::SHUTDOWN) { JAMI_WARN() << "TLS connection failure"; waitForReadyEndpoints_.erase(idx); + return false; } else if (state == tls::TlsSessionState::ESTABLISHED) { // Connected! auto peer_h = idx.first.toString(); @@ -869,7 +871,9 @@ DhtPeerConnector::Impl::answerToRequest(PeerConnectionMsg&& request, connection->attachOutputStream(std::make_shared<FtpServer>(account.getAccountID(), peer_h)); servers_.emplace(idx, std::move(connection)); waitForReadyEndpoints_.erase(idx); + return false; } + return true; }); } // Now wait for a TURN connection from peer (see onTurnPeerConnection) if fallbacking diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp index 906c2c931a6cf8a8c63928e5921eef9301f4f7a4..27ea10cc1196fd2c653352061f7c5bbe6b3ababd 100644 --- a/src/peer_connection.cpp +++ b/src/peer_connection.cpp @@ -100,7 +100,6 @@ public: const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params) : peerCertificateCheckFunc {std::move(cert_check)} { - // Add TLS over TURN tls::TlsSession::TlsSessionCallbacks tls_cbs = { /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); }, @@ -160,7 +159,8 @@ void TlsTurnEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state) { if (onStateChangeCb_) - onStateChangeCb_(state); + if (!onStateChangeCb_(state)) + onStateChangeCb_ = {}; } void @@ -233,7 +233,7 @@ TlsTurnEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& } void -TlsTurnEndpoint::setOnStateChange(std::function<void(tls::TlsSessionState state)>&& cb) +TlsTurnEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb) { pimpl_->onStateChangeCb_ = std::move(cb); } @@ -395,8 +395,9 @@ public: Impl(std::unique_ptr<AbstractSocketEndpoint>&& ep, const dht::crypto::Certificate& peer_cert, const Identity& local_identity, - const std::shared_future<tls::DhParams>& dh_params) - : peerCertificate {peer_cert}, ep_ {ep.get()} { + const std::shared_future<tls::DhParams>& dh_params, + bool isIceTransport = true) + : peerCertificate {peer_cert}, ep_ {ep.get()}, isIce_{isIceTransport} { tls::TlsSession::TlsSessionCallbacks tls_cbs = { /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); }, /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onTlsRxData(std::move(buf)); }, @@ -415,19 +416,21 @@ public: }; tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs); - const IceSocketEndpoint* iceSocket = (const IceSocketEndpoint*)(ep_); - if (iceSocket) { - iceSocket->underlyingICE()->setOnShutdown([this]() { - tls->shutdown(); - }); + if (isIce_) { + if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_)) { + iceSocket->underlyingICE()->setOnShutdown([this]() { + tls->shutdown(); + }); + } } } Impl(std::unique_ptr<AbstractSocketEndpoint>&& ep, std::function<bool(const dht::crypto::Certificate &)>&& cert_check, const Identity& local_identity, - const std::shared_future<tls::DhParams>& dh_params) - : peerCertificateCheckFunc{std::move(cert_check)}, peerCertificate {null_cert}, ep_ {ep.get()} { + const std::shared_future<tls::DhParams>& dh_params, + bool isIce = true) + : peerCertificateCheckFunc{std::move(cert_check)}, peerCertificate {null_cert}, ep_ {ep.get()}, isIce_{isIce} { tls::TlsSession::TlsSessionCallbacks tls_cbs = { /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); }, /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onTlsRxData(std::move(buf)); }, @@ -446,11 +449,12 @@ public: }; tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs); - const IceSocketEndpoint* iceSocket = (const IceSocketEndpoint*)(ep_); - if (iceSocket) { - iceSocket->underlyingICE()->setOnShutdown([this]() { - if (tls) tls->shutdown(); - }); + if (isIce_) { + if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_)) { + iceSocket->underlyingICE()->setOnShutdown([this]() { + if (tls) tls->shutdown(); + }); + } } } @@ -474,6 +478,7 @@ public: OnReadyCb onReadyCb_; std::unique_ptr<tls::TlsSession> tls; const AbstractSocketEndpoint* ep_; + bool isIce_ {true}; }; // Declaration at namespace scope is necessary (until C++17) @@ -512,7 +517,8 @@ TlsSocketEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state) onReadyCb_(state == tls::TlsSessionState::ESTABLISHED); } if (onStateChangeCb_) - onStateChangeCb_(state); + if (!onStateChangeCb_(state)) + onStateChangeCb_ = {}; } void @@ -528,15 +534,17 @@ TlsSocketEndpoint::Impl::onTlsCertificatesUpdate(UNUSED const gnutls_datum_t* lo TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, - const dht::crypto::Certificate& peer_cert) - : pimpl_ { std::make_unique<Impl>(std::move(tr), peer_cert, local_identity, dh_params) } + const dht::crypto::Certificate& peer_cert, + bool isIce) + : pimpl_ { std::make_unique<Impl>(std::move(tr), peer_cert, local_identity, dh_params, isIce) } {} TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, - std::function<bool(const dht::crypto::Certificate&)>&& cert_check) - : pimpl_ { std::make_unique<Impl>(std::move(tr), std::move(cert_check), local_identity, dh_params) } + std::function<bool(const dht::crypto::Certificate&)>&& cert_check, + bool isIce) + : pimpl_ { std::make_unique<Impl>(std::move(tr), std::move(cert_check), local_identity, dh_params, isIce) } { } @@ -601,7 +609,7 @@ TlsSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_cod } void -TlsSocketEndpoint::setOnStateChange(std::function<void(tls::TlsSessionState state)>&& cb) +TlsSocketEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb) { pimpl_->onStateChangeCb_ = std::move(cb); } @@ -615,9 +623,10 @@ TlsSocketEndpoint::setOnReady(std::function<void(bool ok)>&& cb) void TlsSocketEndpoint::shutdown() { - const IceSocketEndpoint* iceSocket = (const IceSocketEndpoint*)(pimpl_->ep_); - if (iceSocket && iceSocket->underlyingICE()) { - iceSocket->underlyingICE()->cancelOperations(); + if (pimpl_->ep_ && pimpl_->isIce_) { + const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_); + if (iceSocket && iceSocket->underlyingICE()) + iceSocket->underlyingICE()->cancelOperations(); } pimpl_->tls->shutdown(); } @@ -625,12 +634,9 @@ TlsSocketEndpoint::shutdown() std::shared_ptr<IceTransport> TlsSocketEndpoint::underlyingICE() const { - if (pimpl_->ep_) { - const IceSocketEndpoint* iceSocket = (const IceSocketEndpoint*)(pimpl_->ep_); - if (iceSocket) { + if (pimpl_->ep_ && pimpl_->isIce_) + if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_)) return iceSocket->underlyingICE(); - } - } return {}; } diff --git a/src/peer_connection.h b/src/peer_connection.h index ce5d364e5217003e640806d166c16ddac9bfc334..29cc1833956edb3546e087d3ef13d9cb2ad60386 100644 --- a/src/peer_connection.h +++ b/src/peer_connection.h @@ -45,7 +45,7 @@ struct Certificate; namespace jami { -using OnStateChangeCb = std::function<void(tls::TlsSessionState state)>; +using OnStateChangeCb = std::function<bool(tls::TlsSessionState state)>; using OnReadyCb = std::function<void(bool ok)>; using onShutdownCb = std::function<void(void)>; @@ -199,11 +199,13 @@ public: TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, - const dht::crypto::Certificate& peer_cert); + const dht::crypto::Certificate& peer_cert, + bool isIceTransport = true); TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, - std::function<bool(const dht::crypto::Certificate&)>&& cert_check); + std::function<bool(const dht::crypto::Certificate&)>&& cert_check, + bool isIceTransport = true); ~TlsSocketEndpoint(); bool isReliable() const override { return true; }