diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp index aed32b8da5231b39ba14d6095339456979775e14..037b36625a55c64b75be9a9d1a067ecc5a036d0a 100644 --- a/src/peer_connection.cpp +++ b/src/peer_connection.cpp @@ -183,9 +183,9 @@ TlsTurnEndpoint::isInitiator() const } void -TlsTurnEndpoint::connect() +TlsTurnEndpoint::waitForReady(const std::chrono::steady_clock::duration& timeout) { - pimpl_->tls->connect(); + pimpl_->tls->waitForReady(timeout); } int @@ -411,9 +411,9 @@ TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& } void -TlsSocketEndpoint::connect() +TlsSocketEndpoint::waitForReady(const std::chrono::steady_clock::duration& timeout) { - pimpl_->tls->connect(); + pimpl_->tls->waitForReady(timeout); } int diff --git a/src/peer_connection.h b/src/peer_connection.h index 973e224ff86f4adeda41b7dfcdf4f69057a55025..a650688f10dbb327420dad4e1606a4a0b0d598af 100644 --- a/src/peer_connection.h +++ b/src/peer_connection.h @@ -96,7 +96,7 @@ public: } int waitForData(unsigned, std::error_code&) const override; - void connect(); + void waitForReady(const std::chrono::steady_clock::duration& timeout = {}); const dht::crypto::Certificate& peerCertificate() const; @@ -160,7 +160,7 @@ public: } int waitForData(unsigned, std::error_code&) const override; - void connect(); + void waitForReady(const std::chrono::steady_clock::duration& timeout = {}); private: class Impl; diff --git a/src/ringdht/p2p.cpp b/src/ringdht/p2p.cpp index 8c81362ad4fe58c5ce697b934a767dd572d1e958..b838cd33c91056cb570e6387587a0744ef1fce4f 100644 --- a/src/ringdht/p2p.cpp +++ b/src/ringdht/p2p.cpp @@ -44,6 +44,7 @@ namespace ring { static constexpr auto DHT_MSG_TIMEOUT = std::chrono::seconds(20); static constexpr auto NET_CONNECTION_TIMEOUT = std::chrono::seconds(10); +static constexpr auto SOCK_TIMEOUT = std::chrono::seconds(3); using Clock = std::chrono::system_clock; using ValueIdDist = std::uniform_int_distribution<dht::Value::Id>; @@ -361,7 +362,19 @@ private: parent_.account.identity(), parent_.account.dhParams(), *peerCertificate_); - tls_ep->connect(); + // block until TLS is negotiated (with 3 secs of timeout) (must throw in case of error) + try { + tls_ep->waitForReady(SOCK_TIMEOUT); + } catch (const std::logic_error& e) { + // In case of a timeout + RING_WARN() << "TLS connection timeout from peer " << peer_.toString() << ": " << e.what(); + cancel(); + return; + } catch (...) { + RING_WARN() << "TLS connection failure from peer " << peer_.toString(); + cancel(); + return; + } // Connected! connection_ = std::make_unique<PeerConnection>([this] { cancel(); }, parent_.account, @@ -488,9 +501,13 @@ DhtPeerConnector::Impl::onTurnPeerConnection(const IpAddr& peer_addr) *turn_ep, account.identity(), account.dhParams(), [&, this] (const dht::crypto::Certificate& cert) { return validatePeerCertificate(cert, peer_h); }); - // block until TLS is negotiated (must throw in case of error) + // block until TLS is negotiated (with 3 secs of timeout) (must throw in case of error) try { - tls_ep->connect(); + tls_ep->waitForReady(SOCK_TIMEOUT); + } catch (const std::logic_error& e) { + // In case of a timeout + RING_WARN() << "TLS connection timeout from peer " << peer_addr.toString(true, true) << ": " << e.what(); + return; } catch (...) { RING_WARN() << "[CNX] TLS connection failure from peer " << peer_addr.toString(true, true); return; diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index fcd79df7eb27a195d07cb3664a8869fc162ea7bc..4b2073f512759e808f9eca0adf3ac89b74059b97 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -249,6 +249,9 @@ public: bool setup(); void process(); void cleanup(); + // State protectors + std::mutex stateMutex_; + std::condition_variable stateCondition_; ScheduledExecutor scheduler_; @@ -718,6 +721,7 @@ void TlsSession::TlsSessionImpl::cleanup() { state_ = TlsSessionState::SHUTDOWN; // be sure to block any user operations + stateCondition_.notify_all(); if (session_) { if (transport_.isReliable()) @@ -1102,13 +1106,12 @@ TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) { // Nothing to do in reliable mode, so just wait for state change if (transport_.isReliable()) { - while (true) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - state = state_.load(); - if (state != TlsSessionState::ESTABLISHED) - return state; - } - return TlsSessionState::SHUTDOWN; + auto disconnected = [this]() -> bool { + return state_.load() != TlsSessionState::ESTABLISHED; + }; + std::unique_lock<std::mutex> lk(stateMutex_); + stateCondition_.wait(lk, disconnected); + return state; } // block until rx packet or state change @@ -1185,6 +1188,9 @@ TlsSession::TlsSessionImpl::process() if (not std::atomic_compare_exchange_strong(&state_, &old_state, new_state)) new_state = old_state; + if (old_state != new_state) + stateCondition_.notify_all(); + if (old_state != new_state and callbacks_.onStateChange) callbacks_.onStateChange(new_state); } @@ -1250,6 +1256,7 @@ void TlsSession::shutdown() { pimpl_->state_ = TlsSessionState::SHUTDOWN; + pimpl_->stateCondition_.notify_all(); pimpl_->rxCv_.notify_one(); // unblock waiting FSM pimpl_->transport_.shutdown(); } @@ -1291,6 +1298,7 @@ TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec) RING_DBG("[TLS] re-handshake"); pimpl_->state_ = TlsSessionState::HANDSHAKE; pimpl_->rxCv_.notify_one(); // unblock waiting FSM + pimpl_->stateCondition_.notify_all(); } else if (gnutls_error_is_fatal(ret)) { RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret)); shutdown(); @@ -1304,13 +1312,20 @@ TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec) } void -TlsSession::connect() +TlsSession::waitForReady(const std::chrono::steady_clock::duration& timeout) { - TlsSessionState state; - do { - state = pimpl_->state_.load(); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } while (state != TlsSessionState::ESTABLISHED and state != TlsSessionState::SHUTDOWN); + auto ready = [this]() -> bool { + auto state = pimpl_->state_.load(); + return state == TlsSessionState::ESTABLISHED or state == TlsSessionState::SHUTDOWN; + }; + std::unique_lock<std::mutex> lk(pimpl_->stateMutex_); + if (timeout == std::chrono::steady_clock::duration::zero()) + pimpl_->stateCondition_.wait(lk, ready); + else + pimpl_->stateCondition_.wait_for(lk, timeout, ready); + + if(!ready()) + throw std::logic_error("Invalid state in TlsSession::waitForReady"); } int diff --git a/src/security/tls_session.h b/src/security/tls_session.h index 750ce505f51b12f1a08a75aef45539e82fb3672d..f52b135eb174180c29b47ca6a8b9aec6f1632f33 100644 --- a/src/security/tls_session.h +++ b/src/security/tls_session.h @@ -128,7 +128,7 @@ public: int maxPayload() const override; - void connect(); + void waitForReady(const std::chrono::steady_clock::duration& timeout = {}); /// Synchronous writing. /// Return a positive number for number of bytes write, or 0 and \a ec set in case of error.