Skip to content
Snippets Groups Projects
Commit b44d24e8 authored by Guillaume Roguez's avatar Guillaume Roguez Committed by Anthony Léonard
Browse files

dtls: refactoring and fix of PMTUD/Established code


There are various issues in the PMTUD code:
- OOO handler wasn't applied to the first packet
  due to unseen code duplication in PMTU code.
- first packet sequence has to be known in case of OOO on it
- bug in losts detection.
- decrease the lost threshold time.
- temporary packet allocation is not efficient.
- code duplication and functional flow not well designed.
- comments needed

This patch fixes all of that.

Change-Id: I93ec71e22f6cb7a66ad9ab0f927d31044966f1e3
Reviewed-by: default avatarAnthony Léonard <anthony.leonard@savoirfairelinux.com>
parent d3eff48f
No related branches found
No related tags found
No related merge requests found
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "compiler_intrinsics.h" #include "compiler_intrinsics.h"
#include "manager.h" #include "manager.h"
#include "certstore.h" #include "certstore.h"
#include "array_size.h"
#include <gnutls/gnutls.h> #include <gnutls/gnutls.h>
#include <gnutls/dtls.h> #include <gnutls/dtls.h>
...@@ -57,7 +58,7 @@ static constexpr uint8_t HEARTBEAT_RETRIES = 1; // Number of tries at each heart ...@@ -57,7 +58,7 @@ static constexpr uint8_t HEARTBEAT_RETRIES = 1; // Number of tries at each heart
static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds) static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds)
static constexpr auto HEARTBEAT_TOTAL_TIMEOUT = HEARTBEAT_RETRANS_TIMEOUT * HEARTBEAT_RETRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds) static constexpr auto HEARTBEAT_TOTAL_TIMEOUT = HEARTBEAT_RETRANS_TIMEOUT * HEARTBEAT_RETRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds)
static constexpr int MISS_ORDERING_LIMIT = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type) static constexpr int MISS_ORDERING_LIMIT = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type)
static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(1000); static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(500);
// mtus array to test, do not add mtu over the interface MTU, this will result in false result due to packet fragmentation. // mtus array to test, do not add mtu over the interface MTU, this will result in false result due to packet fragmentation.
// also do not set over 16000 this will result in a gnutls error (unexpected packet size) // also do not set over 16000 this will result in a gnutls error (unexpected packet size)
...@@ -73,6 +74,15 @@ duration2ms(std::chrono::duration<Rep, Period> d) ...@@ -73,6 +74,15 @@ duration2ms(std::chrono::duration<Rep, Period> d)
return std::chrono::duration_cast<std::chrono::milliseconds>(d).count(); return std::chrono::duration_cast<std::chrono::milliseconds>(d).count();
} }
static inline uint64_t
array2uint(const std::array<uint8_t, 8>& a)
{
uint64_t res = 0;
for (int i=0; i < 8; ++i)
res = (res << 8) + a[i];
return res;
}
DhParams::DhParams(const std::vector<uint8_t>& data) DhParams::DhParams(const std::vector<uint8_t>& data)
{ {
gnutls_dh_params_t new_params_; gnutls_dh_params_t new_params_;
...@@ -196,6 +206,7 @@ TlsSession::TlsSession(const std::shared_ptr<IceTransport>& ice, int ice_comp_id ...@@ -196,6 +206,7 @@ TlsSession::TlsSession(const std::shared_ptr<IceTransport>& ice, int ice_comp_id
, params_(params) , params_(params)
, callbacks_(cbs) , callbacks_(cbs)
, anonymous_(anonymous) , anonymous_(anonymous)
, maxPayload_(INPUT_BUFFER_SIZE)
, cacred_(nullptr) , cacred_(nullptr)
, sacred_(nullptr) , sacred_(nullptr)
, xcred_(nullptr) , xcred_(nullptr)
...@@ -775,6 +786,22 @@ TlsSession::handleStateHandshake(TlsSessionState state) ...@@ -775,6 +786,22 @@ TlsSession::handleStateHandshake(TlsSessionState state)
return TlsSessionState::MTU_DISCOVERY; return TlsSessionState::MTU_DISCOVERY;
} }
bool
TlsSession::initFromRecordState(int offset)
{
std::array<uint8_t, 8> seq;
if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0]) != GNUTLS_E_SUCCESS) {
RING_ERR("[TLS] Fatal-error Unable to read initial state");
return false;
}
baseSeq_ = array2uint(seq) + offset;
gapOffset_ = baseSeq_;
lastRxSeq_ = baseSeq_ - 1;
RING_DBG("[TLS] Initial sequence number: %lx", baseSeq_);
return true;
}
TlsSessionState TlsSessionState
TlsSession::handleStateMtuDiscovery(UNUSED TlsSessionState state) TlsSession::handleStateMtuDiscovery(UNUSED TlsSessionState state)
{ {
...@@ -801,6 +828,11 @@ TlsSession::handleStateMtuDiscovery(UNUSED TlsSessionState state) ...@@ -801,6 +828,11 @@ TlsSession::handleStateMtuDiscovery(UNUSED TlsSessionState state)
if (pmtudOver_) if (pmtudOver_)
RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload()); RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload());
if (pmtudOver_) {
if (!initFromRecordState())
return TlsSessionState::SHUTDOWN;
}
return TlsSessionState::ESTABLISHED; return TlsSessionState::ESTABLISHED;
} }
...@@ -835,8 +867,7 @@ TlsSession::pathMtuHeartbeat() ...@@ -835,8 +867,7 @@ TlsSession::pathMtuHeartbeat()
errno_send = gnutls_heartbeat_ping(session_, bytesToSend, HEARTBEAT_RETRIES, GNUTLS_HEARTBEAT_WAIT); errno_send = gnutls_heartbeat_ping(session_, bytesToSend, HEARTBEAT_RETRIES, GNUTLS_HEARTBEAT_WAIT);
RING_DBG("[TLS] Heartbeat PMTUD : ping sequence over with errno %d: %s", errno_send, RING_DBG("[TLS] Heartbeat PMTUD : ping sequence over with errno %d: %s", errno_send,
gnutls_strerror(errno_send)); gnutls_strerror(errno_send));
} } while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED);
while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED);
if (errno_send == GNUTLS_E_SUCCESS) { if (errno_send == GNUTLS_E_SUCCESS) {
++mtuProbe_; ++mtuProbe_;
...@@ -872,21 +903,8 @@ TlsSession::pathMtuHeartbeat() ...@@ -872,21 +903,8 @@ TlsSession::pathMtuHeartbeat()
} }
void void
TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, const uint8_t* seq_bytes) TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, uint64_t pkt_seq)
{ {
uint64_t pkt_seq;
for (int i=0; i < 8; ++i)
pkt_seq = (pkt_seq << 8) + seq_bytes[i];
// Init/offset sequence number trackers
if (baseSeq_) {
pkt_seq -= baseSeq_;
} else {
baseSeq_ = pkt_seq - 1;
pkt_seq = 1; // start at 1 to have a positive seq_delta on first packet
gapOffset_ = 1;
}
// Check for a valid seq. num. delta // Check for a valid seq. num. delta
int64_t seq_delta = pkt_seq - lastRxSeq_; int64_t seq_delta = pkt_seq - lastRxSeq_;
if (seq_delta > 0) { if (seq_delta > 0) {
...@@ -894,14 +912,14 @@ TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, const uint8_t* seq_byte ...@@ -894,14 +912,14 @@ TlsSession::handleDataPacket(std::vector<uint8_t>&& buf, const uint8_t* seq_byte
} else { } else {
// too old? // too old?
if (seq_delta <= -MISS_ORDERING_LIMIT) { if (seq_delta <= -MISS_ORDERING_LIMIT) {
RING_WARN("[dtls] drop old pkt: %lu", pkt_seq); RING_WARN("[dtls] drop old pkt: 0x%lx", pkt_seq);
return; return;
} }
// No duplicate check as DTLS prevents that for us (replay protection) // No duplicate check as DTLS prevents that for us (replay protection)
// accept Out-Of-Order pkt - will be reordered by queue flush operation // accept Out-Of-Order pkt - will be reordered by queue flush operation
RING_WARN("[dtls] OOO pkt: %lu", pkt_seq); RING_WARN("[dtls] OOO pkt: 0x%lx", pkt_seq);
} }
std::lock_guard<std::mutex> lk {reorderBufMutex_}; std::lock_guard<std::mutex> lk {reorderBufMutex_};
...@@ -924,10 +942,15 @@ TlsSession::flushRxQueue() ...@@ -924,10 +942,15 @@ TlsSession::flushRxQueue()
auto item = std::begin(reorderBuffer_); auto item = std::begin(reorderBuffer_);
auto next_offset = item->first; auto next_offset = item->first;
auto first_offset = next_offset;
// Wait for next continous packet until timeou // Wait for next continous packet until timeout
if ((lastReadTime_ - clock::now()) >= RX_OOO_TIMEOUT) { if ((clock::now() - lastReadTime_) >= RX_OOO_TIMEOUT) {
// OOO packet timeout - consider waited packets as lost // OOO packet timeout - consider waited packets as lost
if (auto lost = next_offset - gapOffset_)
RING_WARN("[dtls] %lu lost since 0x%lx", lost, gapOffset_);
else
RING_WARN("[dtls] slow flush");
} else if (next_offset != gapOffset_) } else if (next_offset != gapOffset_)
return; return;
...@@ -948,42 +971,30 @@ TlsSession::flushRxQueue() ...@@ -948,42 +971,30 @@ TlsSession::flushRxQueue()
gapOffset_ = std::max(gapOffset_, next_offset); gapOffset_ = std::max(gapOffset_, next_offset);
lastReadTime_ = clock::now(); lastReadTime_ = clock::now();
RING_DBG("[dtls] %lu pushed since 0x%lx", gapOffset_ - first_offset, first_offset);
} }
TlsSessionState TlsSessionState
TlsSession::handleStateEstablished(TlsSessionState state) TlsSession::handleStateEstablished(TlsSessionState state)
{ {
// block until rx/tx packet or state change // block until rx packet or state change
{
std::unique_lock<std::mutex> lk {rxMutex_}; std::unique_lock<std::mutex> lk {rxMutex_};
rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ != TlsSessionState::ESTABLISHED; }); rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ != TlsSessionState::ESTABLISHED; });
state = state_.load(); state = state_.load();
if (state != TlsSessionState::ESTABLISHED) if (state != TlsSessionState::ESTABLISHED)
return state; return state;
// Handle RX data from network
if (!rxQueue_.empty()) {
std::vector<uint8_t> buf(INPUT_BUFFER_SIZE);
uint8_t seq[8];
lk.unlock();
auto ret = gnutls_record_recv_seq(session_, buf.data(), buf.size(), seq);
if (ret > 0 && pmtudOver_) {
buf.resize(ret);
handleDataPacket(std::move(buf), seq);
return state;
} else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) {
RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong");
auto errno_send = gnutls_heartbeat_pong(session_, 0);
if (errno_send != GNUTLS_E_SUCCESS){
RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send,
gnutls_strerror(errno_send));
} else {
++hbPingRecved_;
} }
} else if (ret > 0 && pmtudOver_ == false){ std::array<uint8_t, 8> seq;
rawPktBuf_.resize(maxPayload_);
auto ret = gnutls_record_recv_seq(session_, rawPktBuf_.data(), rawPktBuf_.size(), &seq[0]);
if (ret > 0) {
if (!pmtudOver_) {
// This is the first application packet recieved after PMTUD
// This packet gives the final MTU.
if (hbPingRecved_ > 0) { if (hbPingRecved_ > 0) {
gnutls_dtls_set_mtu(session_, mtus[hbPingRecved_ - 1] - UDP_HEADER_SIZE - transportOverhead_); gnutls_dtls_set_mtu(session_, mtus[hbPingRecved_ - 1] - UDP_HEADER_SIZE - transportOverhead_);
maxPayload_ = gnutls_dtls_get_data_mtu(session_); maxPayload_ = gnutls_dtls_get_data_mtu(session_);
...@@ -991,33 +1002,37 @@ TlsSession::handleStateEstablished(TlsSessionState state) ...@@ -991,33 +1002,37 @@ TlsSession::handleStateEstablished(TlsSessionState state)
gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_); gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_);
maxPayload_ = gnutls_dtls_get_data_mtu(session_); maxPayload_ = gnutls_dtls_get_data_mtu(session_);
} }
RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload());
pmtudOver_ = true; pmtudOver_ = true;
buf.resize(ret); RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload());
// TODO: handle sequence re-order
if (callbacks_.onRxData)
callbacks_.onRxData(std::move(buf));
return state;
}
if (ret == 0) { if (!initFromRecordState(-1))
RING_DBG("[TLS] eof");
return TlsSessionState::SHUTDOWN; return TlsSessionState::SHUTDOWN;
} }
if (ret == GNUTLS_E_REHANDSHAKE) { rawPktBuf_.resize(ret);
RING_DBG("[TLS] re-handshake"); handleDataPacket(std::move(rawPktBuf_), array2uint(seq));
return TlsSessionState::HANDSHAKE; // no state change
} } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) {
RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong");
if (gnutls_error_is_fatal(ret)) { auto errno_send = gnutls_heartbeat_pong(session_, 0);
RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
return TlsSessionState::SHUTDOWN;
}
// non-fatal error... let's continue if (errno_send != GNUTLS_E_SUCCESS){
lk.lock(); RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send,
gnutls_strerror(errno_send));
} else {
++hbPingRecved_;
} }
// no state change
} else if (ret == 0) {
RING_DBG("[TLS] eof");
state = TlsSessionState::SHUTDOWN;
} else if (ret == GNUTLS_E_REHANDSHAKE) {
RING_DBG("[TLS] re-handshake");
state = TlsSessionState::HANDSHAKE;
} else if (gnutls_error_is_fatal(ret)) {
RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
state = TlsSessionState::SHUTDOWN;
} // else non-fatal error... let's continue
return state; return state;
} }
......
...@@ -210,7 +210,7 @@ private: ...@@ -210,7 +210,7 @@ private:
TlsSessionState handleStateShutdown(TlsSessionState state); TlsSessionState handleStateShutdown(TlsSessionState state);
std::map<TlsSessionState, StateHandler> fsmHandlers_ {}; std::map<TlsSessionState, StateHandler> fsmHandlers_ {};
std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP}; std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP};
std::atomic<unsigned int> maxPayload_ {0}; std::atomic<unsigned int> maxPayload_;
// IO GnuTLS <-> ICE // IO GnuTLS <-> ICE
std::mutex txMutex_ {}; std::mutex txMutex_ {};
...@@ -219,9 +219,10 @@ private: ...@@ -219,9 +219,10 @@ private:
std::list<std::vector<uint8_t>> rxQueue_ {}; std::list<std::vector<uint8_t>> rxQueue_ {};
std::mutex reorderBufMutex_; std::mutex reorderBufMutex_;
uint64_t baseSeq_ {0}; // sequence number of first application data packet received std::vector<uint8_t> rawPktBuf_; ///< gnutls incoming packet buffer
uint64_t lastRxSeq_ {0}; // last received and valid packet sequence number uint64_t baseSeq_ {0}; ///< sequence number of first application data packet received
uint64_t gapOffset_ {1}; // offset of first byte not received yet (start at 1) uint64_t lastRxSeq_ {0}; ///< last received and valid packet sequence number
uint64_t gapOffset_ {0}; ///< offset of first byte not received yet
clock::time_point lastReadTime_; clock::time_point lastReadTime_;
std::map<uint64_t, std::vector<uint8_t>> reorderBuffer_ {}; std::map<uint64_t, std::vector<uint8_t>> reorderBuffer_ {};
...@@ -231,7 +232,8 @@ private: ...@@ -231,7 +232,8 @@ private:
ssize_t recvRaw(void*, size_t); ssize_t recvRaw(void*, size_t);
int waitForRawData(unsigned); int waitForRawData(unsigned);
void handleDataPacket(std::vector<uint8_t>&&, const uint8_t*); bool initFromRecordState(int offset=0);
void handleDataPacket(std::vector<uint8_t>&&, uint64_t);
void flushRxQueue(); void flushRxQueue();
// Statistics // Statistics
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment