Commit 4c32af16 authored by Sébastien Blin's avatar Sébastien Blin Committed by Adrien Béraud

tlssession: add timeout parameter for connection

+ Rename connect->waitForReady (that is the goal of that method)
+ Remove while loops to wait for a new state (use condition_variables)

Change-Id: I9f763245e4bd0300fab2f015704366a2c43ace88
Reviewed-by: Philippe Gorley's avatarPhilippe Gorley <philippe.gorley@savoirfairelinux.com>
parent 232c2696
......@@ -183,9 +183,9 @@ TlsTurnEndpoint::isInitiator() const
}
void
TlsTurnEndpoint::connect()
TlsTurnEndpoint::waitForReady(const std::chrono::steady_clock::duration& timeout)
{
pimpl_->tls->connect();
pimpl_->tls->waitForReady(timeout);
}
int
......@@ -411,9 +411,9 @@ TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code&
}
void
TlsSocketEndpoint::connect()
TlsSocketEndpoint::waitForReady(const std::chrono::steady_clock::duration& timeout)
{
pimpl_->tls->connect();
pimpl_->tls->waitForReady(timeout);
}
int
......
......@@ -96,7 +96,7 @@ public:
}
int waitForData(unsigned, std::error_code&) const override;
void connect();
void waitForReady(const std::chrono::steady_clock::duration& timeout = {});
const dht::crypto::Certificate& peerCertificate() const;
......@@ -160,7 +160,7 @@ public:
}
int waitForData(unsigned, std::error_code&) const override;
void connect();
void waitForReady(const std::chrono::steady_clock::duration& timeout = {});
private:
class Impl;
......
......@@ -44,6 +44,7 @@ namespace ring {
static constexpr auto DHT_MSG_TIMEOUT = std::chrono::seconds(20);
static constexpr auto NET_CONNECTION_TIMEOUT = std::chrono::seconds(10);
static constexpr auto SOCK_TIMEOUT = std::chrono::seconds(3);
using Clock = std::chrono::system_clock;
using ValueIdDist = std::uniform_int_distribution<dht::Value::Id>;
......@@ -361,7 +362,19 @@ private:
parent_.account.identity(),
parent_.account.dhParams(),
*peerCertificate_);
tls_ep->connect();
// block until TLS is negotiated (with 3 secs of timeout) (must throw in case of error)
try {
tls_ep->waitForReady(SOCK_TIMEOUT);
} catch (const std::logic_error& e) {
// In case of a timeout
RING_WARN() << "TLS connection timeout from peer " << peer_.toString() << ": " << e.what();
cancel();
return;
} catch (...) {
RING_WARN() << "TLS connection failure from peer " << peer_.toString();
cancel();
return;
}
// Connected!
connection_ = std::make_unique<PeerConnection>([this] { cancel(); }, parent_.account,
......@@ -488,9 +501,13 @@ DhtPeerConnector::Impl::onTurnPeerConnection(const IpAddr& peer_addr)
*turn_ep, account.identity(), account.dhParams(),
[&, this] (const dht::crypto::Certificate& cert) { return validatePeerCertificate(cert, peer_h); });
// block until TLS is negotiated (must throw in case of error)
// block until TLS is negotiated (with 3 secs of timeout) (must throw in case of error)
try {
tls_ep->connect();
tls_ep->waitForReady(SOCK_TIMEOUT);
} catch (const std::logic_error& e) {
// In case of a timeout
RING_WARN() << "TLS connection timeout from peer " << peer_addr.toString(true, true) << ": " << e.what();
return;
} catch (...) {
RING_WARN() << "[CNX] TLS connection failure from peer " << peer_addr.toString(true, true);
return;
......
......@@ -249,6 +249,9 @@ public:
bool setup();
void process();
void cleanup();
// State protectors
std::mutex stateMutex_;
std::condition_variable stateCondition_;
ScheduledExecutor scheduler_;
......@@ -718,6 +721,7 @@ void
TlsSession::TlsSessionImpl::cleanup()
{
state_ = TlsSessionState::SHUTDOWN; // be sure to block any user operations
stateCondition_.notify_all();
if (session_) {
if (transport_.isReliable())
......@@ -1102,13 +1106,12 @@ TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state)
{
// Nothing to do in reliable mode, so just wait for state change
if (transport_.isReliable()) {
while (true) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
state = state_.load();
if (state != TlsSessionState::ESTABLISHED)
return state;
}
return TlsSessionState::SHUTDOWN;
auto disconnected = [this]() -> bool {
return state_.load() != TlsSessionState::ESTABLISHED;
};
std::unique_lock<std::mutex> lk(stateMutex_);
stateCondition_.wait(lk, disconnected);
return state;
}
// block until rx packet or state change
......@@ -1185,6 +1188,9 @@ TlsSession::TlsSessionImpl::process()
if (not std::atomic_compare_exchange_strong(&state_, &old_state, new_state))
new_state = old_state;
if (old_state != new_state)
stateCondition_.notify_all();
if (old_state != new_state and callbacks_.onStateChange)
callbacks_.onStateChange(new_state);
}
......@@ -1250,6 +1256,7 @@ void
TlsSession::shutdown()
{
pimpl_->state_ = TlsSessionState::SHUTDOWN;
pimpl_->stateCondition_.notify_all();
pimpl_->rxCv_.notify_one(); // unblock waiting FSM
pimpl_->transport_.shutdown();
}
......@@ -1291,6 +1298,7 @@ TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec)
RING_DBG("[TLS] re-handshake");
pimpl_->state_ = TlsSessionState::HANDSHAKE;
pimpl_->rxCv_.notify_one(); // unblock waiting FSM
pimpl_->stateCondition_.notify_all();
} else if (gnutls_error_is_fatal(ret)) {
RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
shutdown();
......@@ -1304,13 +1312,20 @@ TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec)
}
void
TlsSession::connect()
TlsSession::waitForReady(const std::chrono::steady_clock::duration& timeout)
{
TlsSessionState state;
do {
state = pimpl_->state_.load();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
} while (state != TlsSessionState::ESTABLISHED and state != TlsSessionState::SHUTDOWN);
auto ready = [this]() -> bool {
auto state = pimpl_->state_.load();
return state == TlsSessionState::ESTABLISHED or state == TlsSessionState::SHUTDOWN;
};
std::unique_lock<std::mutex> lk(pimpl_->stateMutex_);
if (timeout == std::chrono::steady_clock::duration::zero())
pimpl_->stateCondition_.wait(lk, ready);
else
pimpl_->stateCondition_.wait_for(lk, timeout, ready);
if(!ready())
throw std::logic_error("Invalid state in TlsSession::waitForReady");
}
int
......
......@@ -128,7 +128,7 @@ public:
int maxPayload() const override;
void connect();
void waitForReady(const std::chrono::steady_clock::duration& timeout = {});
/// Synchronous writing.
/// Return a positive number for number of bytes write, or 0 and \a ec set in case of error.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment