From b44d24e8c2d7b4ffc234de8cb6cbdd3517c5218d Mon Sep 17 00:00:00 2001 From: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> Date: Fri, 12 May 2017 11:14:59 -0400 Subject: [PATCH] dtls: refactoring and fix of PMTUD/Established code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There are various issues in the PMTUD code: - OOO handler wasn't applied to the first packet due to unseen code duplication in PMTU code. - first packet sequence has to be known in case of OOO on it - bug in losts detection. - decrease the lost threshold time. - temporary packet allocation is not efficient. - code duplication and functional flow not well designed. - comments needed This patch fixes all of that. Change-Id: I93ec71e22f6cb7a66ad9ab0f927d31044966f1e3 Reviewed-by: Anthony Léonard <anthony.leonard@savoirfairelinux.com> --- src/security/tls_session.cpp | 157 +++++++++++++++++++---------------- src/security/tls_session.h | 12 +-- 2 files changed, 93 insertions(+), 76 deletions(-) diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index 58e8e83429..8cd954a743 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -31,6 +31,7 @@ #include "compiler_intrinsics.h" #include "manager.h" #include "certstore.h" +#include "array_size.h" #include <gnutls/gnutls.h> #include <gnutls/dtls.h> @@ -57,7 +58,7 @@ static constexpr uint8_t HEARTBEAT_RETRIES = 1; // Number of tries at each heart static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds) static constexpr auto HEARTBEAT_TOTAL_TIMEOUT = HEARTBEAT_RETRANS_TIMEOUT * HEARTBEAT_RETRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds) static constexpr int MISS_ORDERING_LIMIT = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type) -static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(1000); +static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(500); // mtus array to test, do not add mtu over the interface MTU, this will result in false result due to packet fragmentation. // also do not set over 16000 this will result in a gnutls error (unexpected packet size) @@ -73,6 +74,15 @@ duration2ms(std::chrono::duration<Rep, Period> d) return std::chrono::duration_cast<std::chrono::milliseconds>(d).count(); } +static inline uint64_t +array2uint(const std::array<uint8_t, 8>& a) +{ + uint64_t res = 0; + for (int i=0; i < 8; ++i) + res = (res << 8) + a[i]; + return res; +} + DhParams::DhParams(const std::vector<uint8_t>& data) { gnutls_dh_params_t new_params_; @@ -196,6 +206,7 @@ TlsSession::TlsSession(const std::shared_ptr<IceTransport>& ice, int ice_comp_id , params_(params) , callbacks_(cbs) , anonymous_(anonymous) + , maxPayload_(INPUT_BUFFER_SIZE) , cacred_(nullptr) , sacred_(nullptr) , xcred_(nullptr) @@ -775,6 +786,22 @@ TlsSession::handleStateHandshake(TlsSessionState state) return TlsSessionState::MTU_DISCOVERY; } +bool +TlsSession::initFromRecordState(int offset) +{ + std::array<uint8_t, 8> seq; + if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0]) != GNUTLS_E_SUCCESS) { + RING_ERR("[TLS] Fatal-error Unable to read initial state"); + return false; + } + + baseSeq_ = array2uint(seq) + offset; + gapOffset_ = baseSeq_; + lastRxSeq_ = baseSeq_ - 1; + RING_DBG("[TLS] Initial sequence number: %lx", baseSeq_); + return true; +} + TlsSessionState TlsSession::handleStateMtuDiscovery(UNUSED TlsSessionState state) { @@ -801,6 +828,11 @@ TlsSession::handleStateMtuDiscovery(UNUSED TlsSessionState state) if (pmtudOver_) RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload()); + if (pmtudOver_) { + if (!initFromRecordState()) + return TlsSessionState::SHUTDOWN; + } + return TlsSessionState::ESTABLISHED; } @@ -835,8 +867,7 @@ TlsSession::pathMtuHeartbeat() errno_send = gnutls_heartbeat_ping(session_, bytesToSend, HEARTBEAT_RETRIES, GNUTLS_HEARTBEAT_WAIT); RING_DBG("[TLS] Heartbeat PMTUD : ping sequence over with errno %d: %s", errno_send, gnutls_strerror(errno_send)); - } - while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED); + } while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED); if (errno_send == GNUTLS_E_SUCCESS) { ++mtuProbe_; @@ -872,21 +903,8 @@ TlsSession::pathMtuHeartbeat() } void -TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, const uint8_t* seq_bytes) +TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, uint64_t pkt_seq) { - uint64_t pkt_seq; - for (int i=0; i < 8; ++i) - pkt_seq = (pkt_seq << 8) + seq_bytes[i]; - - // Init/offset sequence number trackers - if (baseSeq_) { - pkt_seq -= baseSeq_; - } else { - baseSeq_ = pkt_seq - 1; - pkt_seq = 1; // start at 1 to have a positive seq_delta on first packet - gapOffset_ = 1; - } - // Check for a valid seq. num. delta int64_t seq_delta = pkt_seq - lastRxSeq_; if (seq_delta > 0) { @@ -894,14 +912,14 @@ TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, const uint8_t* seq_byte } else { // too old? if (seq_delta <= -MISS_ORDERING_LIMIT) { - RING_WARN("[dtls] drop old pkt: %lu", pkt_seq); + RING_WARN("[dtls] drop old pkt: 0x%lx", pkt_seq); return; } // No duplicate check as DTLS prevents that for us (replay protection) // accept Out-Of-Order pkt - will be reordered by queue flush operation - RING_WARN("[dtls] OOO pkt: %lu", pkt_seq); + RING_WARN("[dtls] OOO pkt: 0x%lx", pkt_seq); } std::lock_guard<std::mutex> lk {reorderBufMutex_}; @@ -924,10 +942,15 @@ TlsSession::flushRxQueue() auto item = std::begin(reorderBuffer_); auto next_offset = item->first; + auto first_offset = next_offset; - // Wait for next continous packet until timeou - if ((lastReadTime_ - clock::now()) >= RX_OOO_TIMEOUT) { + // Wait for next continous packet until timeout + if ((clock::now() - lastReadTime_) >= RX_OOO_TIMEOUT) { // OOO packet timeout - consider waited packets as lost + if (auto lost = next_offset - gapOffset_) + RING_WARN("[dtls] %lu lost since 0x%lx", lost, gapOffset_); + else + RING_WARN("[dtls] slow flush"); } else if (next_offset != gapOffset_) return; @@ -948,76 +971,68 @@ TlsSession::flushRxQueue() gapOffset_ = std::max(gapOffset_, next_offset); lastReadTime_ = clock::now(); + + RING_DBG("[dtls] %lu pushed since 0x%lx", gapOffset_ - first_offset, first_offset); } TlsSessionState TlsSession::handleStateEstablished(TlsSessionState state) { - // block until rx/tx packet or state change - std::unique_lock<std::mutex> lk {rxMutex_}; - rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ != TlsSessionState::ESTABLISHED; }); - state = state_.load(); - if (state != TlsSessionState::ESTABLISHED) - return state; - - // Handle RX data from network - if (!rxQueue_.empty()) { - std::vector<uint8_t> buf(INPUT_BUFFER_SIZE); - uint8_t seq[8]; - - lk.unlock(); - auto ret = gnutls_record_recv_seq(session_, buf.data(), buf.size(), seq); - if (ret > 0 && pmtudOver_) { - buf.resize(ret); - handleDataPacket(std::move(buf), seq); + // block until rx packet or state change + { + std::unique_lock<std::mutex> lk {rxMutex_}; + rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ != TlsSessionState::ESTABLISHED; }); + state = state_.load(); + if (state != TlsSessionState::ESTABLISHED) return state; - } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) { + } - RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong"); - auto errno_send = gnutls_heartbeat_pong(session_, 0); + std::array<uint8_t, 8> seq; + rawPktBuf_.resize(maxPayload_); + auto ret = gnutls_record_recv_seq(session_, rawPktBuf_.data(), rawPktBuf_.size(), &seq[0]); - if (errno_send != GNUTLS_E_SUCCESS){ - RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send, - gnutls_strerror(errno_send)); - } else { - ++hbPingRecved_; - } - - } else if (ret > 0 && pmtudOver_ == false){ - if (hbPingRecved_ > 0){ + if (ret > 0) { + if (!pmtudOver_) { + // This is the first application packet recieved after PMTUD + // This packet gives the final MTU. + if (hbPingRecved_ > 0) { gnutls_dtls_set_mtu(session_, mtus[hbPingRecved_ - 1] - UDP_HEADER_SIZE - transportOverhead_); maxPayload_ = gnutls_dtls_get_data_mtu(session_); } else { gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_); maxPayload_ = gnutls_dtls_get_data_mtu(session_); } - RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload()); pmtudOver_ = true; - buf.resize(ret); - // TODO: handle sequence re-order - if (callbacks_.onRxData) - callbacks_.onRxData(std::move(buf)); - return state; - } + RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload()); - if (ret == 0) { - RING_DBG("[TLS] eof"); - return TlsSessionState::SHUTDOWN; + if (!initFromRecordState(-1)) + return TlsSessionState::SHUTDOWN; } - if (ret == GNUTLS_E_REHANDSHAKE) { - RING_DBG("[TLS] re-handshake"); - return TlsSessionState::HANDSHAKE; - } + rawPktBuf_.resize(ret); + handleDataPacket(std::move(rawPktBuf_), array2uint(seq)); + // no state change + } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) { + RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong"); + auto errno_send = gnutls_heartbeat_pong(session_, 0); - if (gnutls_error_is_fatal(ret)) { - RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret)); - return TlsSessionState::SHUTDOWN; + if (errno_send != GNUTLS_E_SUCCESS){ + RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send, + gnutls_strerror(errno_send)); + } else { + ++hbPingRecved_; } - - // non-fatal error... let's continue - lk.lock(); - } + // no state change + } else if (ret == 0) { + RING_DBG("[TLS] eof"); + state = TlsSessionState::SHUTDOWN; + } else if (ret == GNUTLS_E_REHANDSHAKE) { + RING_DBG("[TLS] re-handshake"); + state = TlsSessionState::HANDSHAKE; + } else if (gnutls_error_is_fatal(ret)) { + RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret)); + state = TlsSessionState::SHUTDOWN; + } // else non-fatal error... let's continue return state; } diff --git a/src/security/tls_session.h b/src/security/tls_session.h index 093bc5e8ef..de04573148 100644 --- a/src/security/tls_session.h +++ b/src/security/tls_session.h @@ -210,7 +210,7 @@ private: TlsSessionState handleStateShutdown(TlsSessionState state); std::map<TlsSessionState, StateHandler> fsmHandlers_ {}; std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP}; - std::atomic<unsigned int> maxPayload_ {0}; + std::atomic<unsigned int> maxPayload_; // IO GnuTLS <-> ICE std::mutex txMutex_ {}; @@ -219,9 +219,10 @@ private: std::list<std::vector<uint8_t>> rxQueue_ {}; std::mutex reorderBufMutex_; - uint64_t baseSeq_ {0}; // sequence number of first application data packet received - uint64_t lastRxSeq_ {0}; // last received and valid packet sequence number - uint64_t gapOffset_ {1}; // offset of first byte not received yet (start at 1) + std::vector<uint8_t> rawPktBuf_; ///< gnutls incoming packet buffer + uint64_t baseSeq_ {0}; ///< sequence number of first application data packet received + uint64_t lastRxSeq_ {0}; ///< last received and valid packet sequence number + uint64_t gapOffset_ {0}; ///< offset of first byte not received yet clock::time_point lastReadTime_; std::map<uint64_t, std::vector<uint8_t>> reorderBuffer_ {}; @@ -231,7 +232,8 @@ private: ssize_t recvRaw(void*, size_t); int waitForRawData(unsigned); - void handleDataPacket(std::vector<uint8_t>&&, const uint8_t*); + bool initFromRecordState(int offset=0); + void handleDataPacket(std::vector<uint8_t>&&, uint64_t); void flushRxQueue(); // Statistics -- GitLab