From b44d24e8c2d7b4ffc234de8cb6cbdd3517c5218d Mon Sep 17 00:00:00 2001
From: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
Date: Fri, 12 May 2017 11:14:59 -0400
Subject: [PATCH] dtls: refactoring and fix of PMTUD/Established code
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

There are various issues in the PMTUD code:
- OOO handler wasn't applied to the first packet
  due to unseen code duplication in PMTU code.
- first packet sequence has to be known in case of OOO on it
- bug in losts detection.
- decrease the lost threshold time.
- temporary packet allocation is not efficient.
- code duplication and functional flow not well designed.
- comments needed

This patch fixes all of that.

Change-Id: I93ec71e22f6cb7a66ad9ab0f927d31044966f1e3
Reviewed-by: Anthony Léonard <anthony.leonard@savoirfairelinux.com>
---
 src/security/tls_session.cpp | 157 +++++++++++++++++++----------------
 src/security/tls_session.h   |  12 +--
 2 files changed, 93 insertions(+), 76 deletions(-)

diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp
index 58e8e83429..8cd954a743 100644
--- a/src/security/tls_session.cpp
+++ b/src/security/tls_session.cpp
@@ -31,6 +31,7 @@
 #include "compiler_intrinsics.h"
 #include "manager.h"
 #include "certstore.h"
+#include "array_size.h"
 
 #include <gnutls/gnutls.h>
 #include <gnutls/dtls.h>
@@ -57,7 +58,7 @@ static constexpr uint8_t HEARTBEAT_RETRIES = 1; // Number of tries at each heart
 static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds)
 static constexpr auto HEARTBEAT_TOTAL_TIMEOUT = HEARTBEAT_RETRANS_TIMEOUT * HEARTBEAT_RETRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds)
 static constexpr int MISS_ORDERING_LIMIT = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type)
-static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(1000);
+static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(500);
 
 // mtus array to test, do not add mtu over the interface MTU, this will result in false result due to packet fragmentation.
 // also do not set over 16000 this will result in a gnutls error (unexpected packet size)
@@ -73,6 +74,15 @@ duration2ms(std::chrono::duration<Rep, Period> d)
     return std::chrono::duration_cast<std::chrono::milliseconds>(d).count();
 }
 
+static inline uint64_t
+array2uint(const std::array<uint8_t, 8>& a)
+{
+    uint64_t res = 0;
+    for (int i=0; i < 8; ++i)
+        res = (res << 8) + a[i];
+    return res;
+}
+
 DhParams::DhParams(const std::vector<uint8_t>& data)
 {
     gnutls_dh_params_t new_params_;
@@ -196,6 +206,7 @@ TlsSession::TlsSession(const std::shared_ptr<IceTransport>& ice, int ice_comp_id
     , params_(params)
     , callbacks_(cbs)
     , anonymous_(anonymous)
+    , maxPayload_(INPUT_BUFFER_SIZE)
     , cacred_(nullptr)
     , sacred_(nullptr)
     , xcred_(nullptr)
@@ -775,6 +786,22 @@ TlsSession::handleStateHandshake(TlsSessionState state)
     return TlsSessionState::MTU_DISCOVERY;
 }
 
+bool
+TlsSession::initFromRecordState(int offset)
+{
+    std::array<uint8_t, 8> seq;
+    if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0]) != GNUTLS_E_SUCCESS) {
+        RING_ERR("[TLS] Fatal-error Unable to read initial state");
+        return false;
+    }
+
+    baseSeq_ = array2uint(seq) + offset;
+    gapOffset_ = baseSeq_;
+    lastRxSeq_ = baseSeq_ - 1;
+    RING_DBG("[TLS] Initial sequence number: %lx", baseSeq_);
+    return true;
+}
+
 TlsSessionState
 TlsSession::handleStateMtuDiscovery(UNUSED TlsSessionState state)
 {
@@ -801,6 +828,11 @@ TlsSession::handleStateMtuDiscovery(UNUSED TlsSessionState state)
     if (pmtudOver_)
         RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload());
 
+    if (pmtudOver_) {
+        if (!initFromRecordState())
+            return TlsSessionState::SHUTDOWN;
+    }
+
     return TlsSessionState::ESTABLISHED;
 }
 
@@ -835,8 +867,7 @@ TlsSession::pathMtuHeartbeat()
             errno_send = gnutls_heartbeat_ping(session_, bytesToSend, HEARTBEAT_RETRIES, GNUTLS_HEARTBEAT_WAIT);
             RING_DBG("[TLS] Heartbeat PMTUD : ping sequence over with errno %d: %s", errno_send,
                      gnutls_strerror(errno_send));
-        }
-        while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED);
+        } while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED);
 
         if (errno_send == GNUTLS_E_SUCCESS) {
             ++mtuProbe_;
@@ -872,21 +903,8 @@ TlsSession::pathMtuHeartbeat()
 }
 
 void
-TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, const uint8_t* seq_bytes)
+TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, uint64_t pkt_seq)
 {
-    uint64_t pkt_seq;
-    for (int i=0; i < 8; ++i)
-        pkt_seq = (pkt_seq << 8) + seq_bytes[i];
-
-    // Init/offset sequence number trackers
-    if (baseSeq_) {
-        pkt_seq -= baseSeq_;
-    } else {
-        baseSeq_ = pkt_seq - 1;
-        pkt_seq = 1; // start at 1 to have a positive seq_delta on first packet
-        gapOffset_ = 1;
-    }
-
     // Check for a valid seq. num. delta
     int64_t seq_delta = pkt_seq - lastRxSeq_;
     if (seq_delta > 0) {
@@ -894,14 +912,14 @@ TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, const uint8_t* seq_byte
     } else {
         // too old?
         if (seq_delta <= -MISS_ORDERING_LIMIT) {
-            RING_WARN("[dtls] drop old pkt: %lu", pkt_seq);
+            RING_WARN("[dtls] drop old pkt: 0x%lx", pkt_seq);
             return;
         }
 
         // No duplicate check as DTLS prevents that for us (replay protection)
 
         // accept Out-Of-Order pkt - will be reordered by queue flush operation
-        RING_WARN("[dtls] OOO pkt: %lu", pkt_seq);
+        RING_WARN("[dtls] OOO pkt: 0x%lx", pkt_seq);
     }
 
     std::lock_guard<std::mutex> lk {reorderBufMutex_};
@@ -924,10 +942,15 @@ TlsSession::flushRxQueue()
 
     auto item = std::begin(reorderBuffer_);
     auto next_offset = item->first;
+    auto first_offset = next_offset;
 
-    // Wait for next continous packet until timeou
-    if ((lastReadTime_ - clock::now()) >= RX_OOO_TIMEOUT) {
+    // Wait for next continous packet until timeout
+    if ((clock::now() - lastReadTime_) >= RX_OOO_TIMEOUT) {
         // OOO packet timeout - consider waited packets as lost
+        if (auto lost = next_offset - gapOffset_)
+            RING_WARN("[dtls] %lu lost since 0x%lx", lost, gapOffset_);
+        else
+            RING_WARN("[dtls] slow flush");
     } else if (next_offset != gapOffset_)
         return;
 
@@ -948,76 +971,68 @@ TlsSession::flushRxQueue()
 
     gapOffset_ = std::max(gapOffset_, next_offset);
     lastReadTime_ = clock::now();
+
+    RING_DBG("[dtls] %lu pushed since 0x%lx", gapOffset_ - first_offset, first_offset);
 }
 
 TlsSessionState
 TlsSession::handleStateEstablished(TlsSessionState state)
 {
-    // block until rx/tx packet or state change
-    std::unique_lock<std::mutex> lk {rxMutex_};
-    rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ != TlsSessionState::ESTABLISHED; });
-    state = state_.load();
-    if (state != TlsSessionState::ESTABLISHED)
-        return state;
-
-    // Handle RX data from network
-    if (!rxQueue_.empty()) {
-        std::vector<uint8_t> buf(INPUT_BUFFER_SIZE);
-        uint8_t seq[8];
-
-        lk.unlock();
-        auto ret = gnutls_record_recv_seq(session_, buf.data(), buf.size(), seq);
-        if (ret > 0 && pmtudOver_) {
-            buf.resize(ret);
-            handleDataPacket(std::move(buf), seq);
+    // block until rx packet or state change
+    {
+        std::unique_lock<std::mutex> lk {rxMutex_};
+        rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ != TlsSessionState::ESTABLISHED; });
+        state = state_.load();
+        if (state != TlsSessionState::ESTABLISHED)
             return state;
-        } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) {
+    }
 
-            RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong");
-            auto errno_send = gnutls_heartbeat_pong(session_, 0);
+    std::array<uint8_t, 8> seq;
+    rawPktBuf_.resize(maxPayload_);
+    auto ret = gnutls_record_recv_seq(session_, rawPktBuf_.data(), rawPktBuf_.size(), &seq[0]);
 
-            if (errno_send != GNUTLS_E_SUCCESS){
-                RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send,
-                          gnutls_strerror(errno_send));
-            } else {
-                ++hbPingRecved_;
-            }
-
-        } else if (ret > 0 && pmtudOver_ == false){
-            if (hbPingRecved_ > 0){
+    if (ret > 0) {
+        if (!pmtudOver_) {
+            // This is the first application packet recieved after PMTUD
+            // This packet gives the final MTU.
+            if (hbPingRecved_ > 0) {
                 gnutls_dtls_set_mtu(session_, mtus[hbPingRecved_ - 1] - UDP_HEADER_SIZE - transportOverhead_);
                 maxPayload_ = gnutls_dtls_get_data_mtu(session_);
             } else {
                 gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_);
                 maxPayload_ = gnutls_dtls_get_data_mtu(session_);
             }
-            RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload());
             pmtudOver_ = true;
-            buf.resize(ret);
-            // TODO: handle sequence re-order
-            if (callbacks_.onRxData)
-                callbacks_.onRxData(std::move(buf));
-            return state;
-        }
+            RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload());
 
-        if (ret == 0) {
-            RING_DBG("[TLS] eof");
-            return TlsSessionState::SHUTDOWN;
+            if (!initFromRecordState(-1))
+                return TlsSessionState::SHUTDOWN;
         }
 
-        if (ret == GNUTLS_E_REHANDSHAKE) {
-            RING_DBG("[TLS] re-handshake");
-            return TlsSessionState::HANDSHAKE;
-        }
+        rawPktBuf_.resize(ret);
+        handleDataPacket(std::move(rawPktBuf_), array2uint(seq));
+        // no state change
+    } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) {
+        RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong");
+        auto errno_send = gnutls_heartbeat_pong(session_, 0);
 
-        if (gnutls_error_is_fatal(ret)) {
-            RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
-            return TlsSessionState::SHUTDOWN;
+        if (errno_send != GNUTLS_E_SUCCESS){
+            RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send,
+                      gnutls_strerror(errno_send));
+        } else {
+            ++hbPingRecved_;
         }
-
-        // non-fatal error... let's continue
-        lk.lock();
-    }
+        // no state change
+    } else if (ret == 0) {
+        RING_DBG("[TLS] eof");
+        state = TlsSessionState::SHUTDOWN;
+    } else if (ret == GNUTLS_E_REHANDSHAKE) {
+        RING_DBG("[TLS] re-handshake");
+        state = TlsSessionState::HANDSHAKE;
+    } else if (gnutls_error_is_fatal(ret)) {
+        RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
+        state = TlsSessionState::SHUTDOWN;
+    } // else non-fatal error... let's continue
 
     return state;
 }
diff --git a/src/security/tls_session.h b/src/security/tls_session.h
index 093bc5e8ef..de04573148 100644
--- a/src/security/tls_session.h
+++ b/src/security/tls_session.h
@@ -210,7 +210,7 @@ private:
     TlsSessionState handleStateShutdown(TlsSessionState state);
     std::map<TlsSessionState, StateHandler> fsmHandlers_ {};
     std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP};
-    std::atomic<unsigned int> maxPayload_ {0};
+    std::atomic<unsigned int> maxPayload_;
 
     // IO GnuTLS <-> ICE
     std::mutex txMutex_ {};
@@ -219,9 +219,10 @@ private:
     std::list<std::vector<uint8_t>> rxQueue_ {};
 
     std::mutex reorderBufMutex_;
-    uint64_t baseSeq_ {0}; // sequence number of first application data packet received
-    uint64_t lastRxSeq_ {0}; // last received and valid packet sequence number
-    uint64_t gapOffset_ {1}; // offset of first byte not received yet (start at 1)
+    std::vector<uint8_t> rawPktBuf_; ///< gnutls incoming packet buffer
+    uint64_t baseSeq_ {0}; ///< sequence number of first application data packet received
+    uint64_t lastRxSeq_ {0}; ///< last received and valid packet sequence number
+    uint64_t gapOffset_ {0}; ///< offset of first byte not received yet
     clock::time_point lastReadTime_;
     std::map<uint64_t, std::vector<uint8_t>> reorderBuffer_ {};
 
@@ -231,7 +232,8 @@ private:
     ssize_t recvRaw(void*, size_t);
     int waitForRawData(unsigned);
 
-    void handleDataPacket(std::vector<uint8_t>&&, const uint8_t*);
+    bool initFromRecordState(int offset=0);
+    void handleDataPacket(std::vector<uint8_t>&&, uint64_t);
     void flushRxQueue();
 
     // Statistics
-- 
GitLab