From 29ae5d8abeda344db958936c0c0eab3ad8f5fb02 Mon Sep 17 00:00:00 2001 From: Olivier SOLDANO <olivier.soldano@savoirfairelinux.com> Date: Tue, 27 Sep 2016 10:13:14 -0400 Subject: [PATCH] Path MTU discovery implementation This implementation uses gnutls dtls heartbeat API to test path MTU. heartbeat allowing messages with automated response in a datagram, the application is able to guess the MTU via a timeout in the heartbeat. (timeout on packet sent and no response, implies that the MTU is lower than the lost payload.) To minimize false positives (a response is lost for example), each attempt triggers one retry on the first timeout. This version ensures a minimal MTU of 512 bytes will be returned in case of any failure in the procedure. For retrocompatibility with non heartbeat capable clients, a fallback MTU is set at 1280. Change-Id: Ib9a7f63a70e8bdad239d8fc103779a0f2c387e87 Reviewed-by: Andreas Traczyk <andreas.traczyk@savoirfairelinux.com> --- src/ice_socket.h | 7 +- src/ice_transport.cpp | 7 ++ src/media/audio/audio_rtp_session.cpp | 26 ++-- src/media/rtp_session.h | 4 + src/media/socket_pair.cpp | 12 +- src/media/socket_pair.h | 2 +- src/media/video/video_receive_thread.cpp | 6 +- src/media/video/video_receive_thread.h | 4 +- src/media/video/video_rtp_session.cpp | 5 +- src/media/video/video_sender.cpp | 5 +- src/media/video/video_sender.h | 3 +- src/ringdht/sips_transport_ice.cpp | 6 + src/ringdht/sips_transport_ice.h | 3 + src/security/tls_session.cpp | 153 ++++++++++++++++++++++- src/security/tls_session.h | 15 +++ src/sip/sipcall.cpp | 6 +- src/sip/siptransport.cpp | 6 + src/sip/siptransport.h | 2 + 18 files changed, 243 insertions(+), 29 deletions(-) diff --git a/src/ice_socket.h b/src/ice_socket.h index 69d85a44d0..a35fb64269 100644 --- a/src/ice_socket.h +++ b/src/ice_socket.h @@ -23,9 +23,9 @@ #include <memory> #include <functional> -#if defined(_MSC_VER) -#include <BaseTsd.h> -using ssize_t = SSIZE_T; +#if defined(_MSC_VER) +#include <BaseTsd.h> +using ssize_t = SSIZE_T; #endif namespace ring { @@ -49,6 +49,7 @@ class IceSocket ssize_t getNextPacketSize() const; ssize_t waitForData(unsigned int timeout); void setOnRecv(IceRecvCb cb); + uint16_t getTransportOverhead(); }; }; diff --git a/src/ice_transport.cpp b/src/ice_transport.cpp index dc359963e1..489d47e85f 100644 --- a/src/ice_transport.cpp +++ b/src/ice_transport.cpp @@ -44,6 +44,8 @@ namespace ring { static constexpr unsigned STUN_MAX_PACKET_SIZE {8192}; +static constexpr uint16_t IPV6_HEADER_SIZE = 40; // Size in bytes of IPV6 packet header +static constexpr uint16_t IPV4_HEADER_SIZE = 20; // Size in bytes of IPV4 packet header // TODO: C++14 ? remove me and use std::min template< class T > @@ -1061,4 +1063,9 @@ IceSocket::setOnRecv(IceRecvCb cb) return ice_transport_->setOnRecv(compId_, cb); } +uint16_t +IceSocket::getTransportOverhead(){ + return (ice_transport_->getRemoteAddress(compId_).getFamily() == AF_INET) ? IPV4_HEADER_SIZE : IPV6_HEADER_SIZE; +} + } // namespace ring diff --git a/src/media/audio/audio_rtp_session.cpp b/src/media/audio/audio_rtp_session.cpp index 1bbbccc954..f2238758d5 100644 --- a/src/media/audio/audio_rtp_session.cpp +++ b/src/media/audio/audio_rtp_session.cpp @@ -49,7 +49,8 @@ class AudioSender { const MediaDescription& args, SocketPair& socketPair, const uint16_t seqVal, - bool muteState); + bool muteState, + const uint16_t mtu); ~AudioSender(); void setMuted(bool isMuted); @@ -71,6 +72,7 @@ class AudioSender { AudioBuffer resampledData_; const uint16_t seqVal_; bool muteState_ = false; + uint16_t mtu_; using seconds = std::chrono::duration<double, std::ratio<1>>; const seconds secondsPerPacket_ {0.02}; // 20 ms @@ -85,12 +87,14 @@ AudioSender::AudioSender(const std::string& id, const MediaDescription& args, SocketPair& socketPair, const uint16_t seqVal, - bool muteState) : + bool muteState, + const uint16_t mtu) : id_(id), dest_(dest), args_(args), seqVal_(seqVal), muteState_(muteState), + mtu_(mtu), loop_([&] { return setup(socketPair); }, std::bind(&AudioSender::process, this), std::bind(&AudioSender::cleanup, this)) @@ -107,7 +111,7 @@ bool AudioSender::setup(SocketPair& socketPair) { audioEncoder_.reset(new MediaEncoder); - muxContext_.reset(socketPair.createIOContext()); + muxContext_.reset(socketPair.createIOContext(mtu_)); try { /* Encoder setup */ @@ -199,7 +203,8 @@ class AudioReceiveThread public: AudioReceiveThread(const std::string &id, const AudioFormat& format, - const std::string& sdp); + const std::string& sdp, + const uint16_t mtu); ~AudioReceiveThread(); void addIOContext(SocketPair &socketPair); void startLoop(); @@ -230,6 +235,8 @@ class AudioReceiveThread std::shared_ptr<RingBuffer> ringbuffer_; + uint16_t mtu_; + ThreadLoop loop_; bool setup(); void process(); @@ -238,10 +245,12 @@ class AudioReceiveThread AudioReceiveThread::AudioReceiveThread(const std::string& id, const AudioFormat& format, - const std::string& sdp) + const std::string& sdp, + const uint16_t mtu) : id_(id) , format_(format) , stream_(sdp) + , mtu_(mtu) , sdpContext_(new MediaIOHandle(sdp.size(), false, &readFunction, 0, 0, this)) , loop_(std::bind(&AudioReceiveThread::setup, this), @@ -345,7 +354,7 @@ AudioReceiveThread::interruptCb(void* data) void AudioReceiveThread::addIOContext(SocketPair& socketPair) { - demuxContext_.reset(socketPair.createIOContext()); + demuxContext_.reset(socketPair.createIOContext(mtu_)); } void @@ -391,7 +400,7 @@ AudioRtpSession::startSender() sender_.reset(); socketPair_->stopSendOp(false); sender_.reset(new AudioSender(callID_, getRemoteRtpUri(), send_, - *socketPair_, initSeqVal_, muteState_)); + *socketPair_, initSeqVal_, muteState_, mtu_)); } catch (const MediaEncoderException &e) { RING_ERR("%s", e.what()); send_.enabled = false; @@ -423,7 +432,8 @@ AudioRtpSession::startReceiver() auto accountAudioCodec = std::static_pointer_cast<AccountAudioCodecInfo>(receive_.codec); receiveThread_.reset(new AudioReceiveThread(callID_, accountAudioCodec->audioformat, - receive_.receiving_sdp)); + receive_.receiving_sdp, + mtu_)); receiveThread_->addIOContext(*socketPair_); receiveThread_->startLoop(); } diff --git a/src/media/rtp_session.h b/src/media/rtp_session.h index 0dc4e5ecae..6206181e0a 100644 --- a/src/media/rtp_session.h +++ b/src/media/rtp_session.h @@ -51,6 +51,8 @@ public: bool isSending() const noexcept { return send_.enabled; } bool isReceiving() const noexcept { return receive_.enabled; } + void setMtu(uint16_t mtu) { mtu_ = mtu; } + protected: std::recursive_mutex mutex_; std::unique_ptr<SocketPair> socketPair_; @@ -59,6 +61,8 @@ protected: MediaDescription send_; MediaDescription receive_; + uint16_t mtu_; + std::string getRemoteRtpUri() const { return "rtp://" + send_.addr.toString(true); } diff --git a/src/media/socket_pair.cpp b/src/media/socket_pair.cpp index 398bee84f6..0b7045aca9 100644 --- a/src/media/socket_pair.cpp +++ b/src/media/socket_pair.cpp @@ -60,6 +60,8 @@ namespace ring { static constexpr int NET_POLL_TIMEOUT = 100; /* poll() timeout in ms */ static constexpr int RTP_MAX_PACKET_LENGTH = 2048; +static constexpr auto UDP_HEADER_SIZE = 8; +static constexpr auto SRTP_OVERHEAD = 10; enum class DataType : unsigned { RTP=1<<0, RTCP=1<<1 }; @@ -190,10 +192,6 @@ udp_socket_create(sockaddr_storage* addr, socklen_t* addr_len, int local_port) return udp_fd; } -// Maximal size allowed for a RTP packet, this value of 1232 bytes is an IPv6 minimum (1280 - 40 IPv6 header - 8 UDP header). -static const size_t RTP_BUFFER_SIZE = 1232; -static const size_t SRTP_BUFFER_SIZE = RTP_BUFFER_SIZE - 10; - SocketPair::SocketPair(const char *uri, int localPort) : rtp_sock_() , rtcp_sock_() @@ -334,9 +332,11 @@ SocketPair::openSockets(const char* uri, int local_rtp_port) } MediaIOHandle* -SocketPair::createIOContext() +SocketPair::createIOContext(const uint16_t mtu) { - return new MediaIOHandle(srtpContext_ ? SRTP_BUFFER_SIZE : RTP_BUFFER_SIZE, true, + auto ip_header_size = rtp_sock_->getTransportOverhead(); + return new MediaIOHandle( mtu - (srtpContext_ ? SRTP_OVERHEAD : 0) - UDP_HEADER_SIZE - ip_header_size, + true, [](void* sp, uint8_t* buf, int len){ return static_cast<SocketPair*>(sp)->readCallback(buf, len); }, [](void* sp, uint8_t* buf, int len){ return static_cast<SocketPair*>(sp)->writeCallback(buf, len); }, 0, reinterpret_cast<void*>(this)); diff --git a/src/media/socket_pair.h b/src/media/socket_pair.h index 4fc447be7d..95fc503473 100644 --- a/src/media/socket_pair.h +++ b/src/media/socket_pair.h @@ -80,7 +80,7 @@ class SocketPair { void interrupt(); - MediaIOHandle* createIOContext(); + MediaIOHandle* createIOContext(const uint16_t mtu); void openSockets(const char* uri, int localPort); void closeSockets(); diff --git a/src/media/video/video_receive_thread.cpp b/src/media/video/video_receive_thread.cpp index 3f8c91c9ca..5318c0b8fb 100644 --- a/src/media/video/video_receive_thread.cpp +++ b/src/media/video/video_receive_thread.cpp @@ -38,13 +38,15 @@ using std::string; VideoReceiveThread::VideoReceiveThread(const std::string& id, const std::string &sdp, - const bool isReset) : + const bool isReset, + uint16_t mtu) : VideoGenerator::VideoGenerator() , args_() , dstWidth_(0) , dstHeight_(0) , id_(id) , stream_(sdp) + , mtu_(mtu) , sdpContext_(stream_.str().size(), false, &readFunction, 0, 0, this) , sink_ {Manager::instance().createSinkClient(id)} , restartDecoder_(false) @@ -162,7 +164,7 @@ int VideoReceiveThread::readFunction(void *opaque, uint8_t *buf, int buf_size) void VideoReceiveThread::addIOContext(SocketPair &socketPair) { - demuxContext_.reset(socketPair.createIOContext()); + demuxContext_.reset(socketPair.createIOContext(mtu_)); } bool VideoReceiveThread::decodeFrame() diff --git a/src/media/video/video_receive_thread.h b/src/media/video/video_receive_thread.h index 515c486bd7..5b2253023a 100644 --- a/src/media/video/video_receive_thread.h +++ b/src/media/video/video_receive_thread.h @@ -47,7 +47,7 @@ class SinkClient; class VideoReceiveThread : public VideoGenerator { public: - VideoReceiveThread(const std::string &id, const std::string &sdp, const bool isReset); + VideoReceiveThread(const std::string &id, const std::string &sdp, const bool isReset, uint16_t mtu); ~VideoReceiveThread(); void startLoop(); @@ -80,6 +80,8 @@ private: std::shared_ptr<SinkClient> sink_; std::atomic_bool restartDecoder_; bool isReset_; + uint16_t mtu_; + void (*requestKeyFrameCallback_)(const std::string &); void openDecoder(); bool decodeFrame(); diff --git a/src/media/video/video_rtp_session.cpp b/src/media/video/video_rtp_session.cpp index 265006ed4d..c031540a8f 100644 --- a/src/media/video/video_rtp_session.cpp +++ b/src/media/video/video_rtp_session.cpp @@ -102,7 +102,7 @@ void VideoRtpSession::startSender() sender_.reset(); socketPair_->stopSendOp(false); sender_.reset(new VideoSender(getRemoteRtpUri(), localVideoParams_, - send_, *socketPair_, initSeqVal_)); + send_, *socketPair_, initSeqVal_, mtu_)); } catch (const MediaEncoderException &e) { RING_ERR("%s", e.what()); send_.enabled = false; @@ -138,8 +138,9 @@ void VideoRtpSession::startReceiver() isReset = true; } receiveThread_.reset( - new VideoReceiveThread(callID_, receive_.receiving_sdp, isReset) + new VideoReceiveThread(callID_, receive_.receiving_sdp, isReset, mtu_) ); + /* ebail: keyframe requests can lead to timeout if they are not answered. * we decided so to disable them for the moment receiveThread_->setRequestKeyFrameCallback(&SIPVoIPLink::enqueueKeyframeRequest); diff --git a/src/media/video/video_sender.cpp b/src/media/video/video_sender.cpp index d9531c01fb..c8246c29f4 100644 --- a/src/media/video/video_sender.cpp +++ b/src/media/video/video_sender.cpp @@ -37,8 +37,9 @@ using std::string; VideoSender::VideoSender(const std::string& dest, const DeviceParams& dev, const MediaDescription& args, SocketPair& socketPair, - const uint16_t seqVal) - : muxContext_(socketPair.createIOContext()) + const uint16_t seqVal, + uint16_t mtu) + : muxContext_(socketPair.createIOContext(mtu)) , videoEncoder_(new MediaEncoder) { videoEncoder_->setDeviceOptions(dev); diff --git a/src/media/video/video_sender.h b/src/media/video/video_sender.h index 4bb2285016..6852b3452d 100644 --- a/src/media/video/video_sender.h +++ b/src/media/video/video_sender.h @@ -46,7 +46,8 @@ public: const DeviceParams& dev, const MediaDescription& args, SocketPair& socketPair, - const uint16_t seqVal); + const uint16_t seqVal, + uint16_t mtu); ~VideoSender(); diff --git a/src/ringdht/sips_transport_ice.cpp b/src/ringdht/sips_transport_ice.cpp index f4488de4fe..c56ddda5eb 100644 --- a/src/ringdht/sips_transport_ice.cpp +++ b/src/ringdht/sips_transport_ice.cpp @@ -694,4 +694,10 @@ SipsIceTransport::send(pjsip_tx_data* tdata, const pj_sockaddr_t* rem_addr, return PJ_EPENDING; } +uint16_t +SipsIceTransport::getTlsSessionMtu() +{ + return tls_->getMtu(); +} + }} // namespace ring::tls diff --git a/src/ringdht/sips_transport_ice.h b/src/ringdht/sips_transport_ice.h index d0e243cf95..48cbcb3e02 100644 --- a/src/ringdht/sips_transport_ice.h +++ b/src/ringdht/sips_transport_ice.h @@ -74,6 +74,9 @@ struct SipsIceTransport IpAddr getLocalAddress() const { return local_; } IpAddr getRemoteAddress() const { return remote_; } + // uses the tls_ uniquepointer internal gnutls_session_t, to call its method to get its MTU + uint16_t getTlsSessionMtu(); + private: NON_COPYABLE(SipsIceTransport); diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index b4f1dcf3ce..526638a285 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -30,22 +30,38 @@ #include "noncopyable.h" #include "compiler_intrinsics.h" +#include <gnutls/gnutls.h> #include <gnutls/dtls.h> #include <gnutls/abstract.h> #include <algorithm> #include <cstring> // std::memset +#include <cstdlib> +#include <unistd.h> + namespace ring { namespace tls { static constexpr const char* TLS_CERT_PRIORITY_STRING {"SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; static constexpr const char* TLS_FULL_PRIORITY_STRING {"SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; static constexpr int DTLS_MTU {1232}; // (1280 from IPv6 minimum MTU - 40 IPv6 header - 8 UDP header) +static constexpr uint16_t MIN_MTU {512}; +static constexpr uint16_t INPUT_BUFFER_SIZE {16*1024}; // to be coherent with the maximum size advised in path mtu discovery static constexpr std::size_t INPUT_MAX_SIZE {1000}; // Maximum packet to store before dropping (pkt size = DTLS_MTU) static constexpr ssize_t FLOOD_THRESHOLD {4*1024}; static constexpr auto FLOOD_PAUSE = std::chrono::milliseconds(100); // Time to wait after an invalid cookie packet (anti flood attack) static constexpr auto DTLS_RETRANSMIT_TIMEOUT = std::chrono::milliseconds(1000); // Delay between two handshake request on DTLS static constexpr auto COOKIE_TIMEOUT = std::chrono::seconds(10); // Time to wait for a cookie packet from client +static constexpr uint8_t UDP_HEADER_SIZE = 8; // Size in bytes of UDP packet header +static constexpr uint8_t HEARTBEAT_RETRIES = 1; // Number of tries at each heartbeat ping send (1 + 1 if error) +static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds) +static constexpr auto HEARTBEAT_TOTAL_TIMEOUT = HEARTBEAT_RETRANS_TIMEOUT * HEARTBEAT_RETRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds) + +// mtus array to test, do not add mtu over the interface MTU, this will result in false result due to packet fragmentation. +// also do not set over 16000 this will result in a gnutls error (unexpected packet size) +// neither under MIN_MTU because it makes no sense and could result in underflow of certain variables. +// Put mtus values in ascending order in the array to avoid sorting +static constexpr std::array<uint16_t, MTUS_TO_TEST> mtus = {MIN_MTU, 800, 1280, 1500}; // Helper to cast any duration into an integer number of milliseconds template <class Rep, class Period> @@ -209,6 +225,9 @@ TlsSessionState TlsSession::setupClient() { auto ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM); + RING_WARN("[TLS] set heartbeat reception for retrocompatibility check on server"); + gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); + if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] session init failed: %s", gnutls_strerror(ret)); return TlsSessionState::SHUTDOWN; @@ -529,6 +548,7 @@ TlsSession::setup() fsmHandlers_[TlsSessionState::SETUP] = [this](TlsSessionState s){ return handleStateSetup(s); }; fsmHandlers_[TlsSessionState::COOKIE] = [this](TlsSessionState s){ return handleStateCookie(s); }; fsmHandlers_[TlsSessionState::HANDSHAKE] = [this](TlsSessionState s){ return handleStateHandshake(s); }; + fsmHandlers_[TlsSessionState::MTU_DISCOVERY] = [this](TlsSessionState s){ return handleStateMtuDiscovery(s); }; fsmHandlers_[TlsSessionState::ESTABLISHED] = [this](TlsSessionState s){ return handleStateEstablished(s); }; fsmHandlers_[TlsSessionState::SHUTDOWN] = [this](TlsSessionState s){ return handleStateShutdown(s); }; @@ -636,6 +656,9 @@ TlsSession::handleStateCookie(TlsSessionState state) RING_DBG("[TLS] cookie ok"); ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM); + RING_WARN("[TLS] set heartbeat reception"); + gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); + if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] session init failed: %s", gnutls_strerror(ret)); return TlsSessionState::SHUTDOWN; @@ -716,10 +739,103 @@ TlsSession::handleStateHandshake(TlsSessionState state) callbacks_.onCertificatesUpdate(local, remote, remote_count); } + return TlsSessionState::MTU_DISCOVERY; +} + +TlsSessionState +TlsSession::handleStateMtuDiscovery(TlsSessionState state) +{ + //set dtls mtu to be over each and every mtus tested + gnutls_dtls_set_mtu(session_, mtus.back()); + // retrocompatibility check + if(gnutls_heartbeat_allowed(session_, GNUTLS_HB_LOCAL_ALLOWED_TO_SEND) == 1) { + if (!isServer()){ + RING_WARN("[TLS] HEARTBEAT PATH MTU DISCOVERY"); + pathMtuHeartbeat(); + pmtudOver_ = true; + RING_WARN("[TLS] HEARTBEAT PATH MTU DISCOVERY OVER"); + } + } else { + RING_ERR("[TLS] PEER HEARTBEAT DISABLED: setting minimal value to MTU @%d for retrocompatibility", DTLS_MTU); + gnutls_dtls_set_mtu(session_, DTLS_MTU); + pmtudOver_ = true; + } maxPayload_ = gnutls_dtls_get_data_mtu(session_); + if (pmtudOver_) + RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload()); + return TlsSessionState::ESTABLISHED; } +/* + * Path MTU discovery heuristic + * heuristic description: + * The two members of the current tls connection will exchange dtls heartbeat messages + * of increasing size until the heartbeat times out which will be considered as a packet + * drop from the network due to the size of the packet. (one retry to test for a buffer issue) + * when timeout happens or all the values have been tested, the mtu will be returned. + * In case of unexpected error the first (and minimal) value of the mtu array + */ +void +TlsSession::pathMtuHeartbeat() +{ + int errno_send = 1; // non zero initialisation + auto tls_overhead = gnutls_record_overhead_size(session_); + RING_WARN("[TLS] tls session overhead : %d", tls_overhead); + transportOverhead_ = socket_->getTransportOverhead(); + gnutls_heartbeat_set_timeouts(session_, HEARTBEAT_RETRANS_TIMEOUT.count(), HEARTBEAT_TOTAL_TIMEOUT.count()); + RING_DBG("[TLS] Heartbeat PMTUD : retransmission timeout set to: %d ms", HEARTBEAT_RETRANS_TIMEOUT); + + // server side: managing pong in state established + // client side: managing ping on heartbeat + uint16_t bytesToSend; + mtuProbe_ = mtus.cbegin(); + RING_DBG("[TLS] Heartbeat PMTUD : client side"); + + while (mtuProbe_ != mtus.cend()){ + bytesToSend = (*mtuProbe_) - 3 - tls_overhead - UDP_HEADER_SIZE - transportOverhead_; + do { + RING_DBG("[TLS] Heartbeat PMTUD : ping with mtu %d and effective payload %d", *mtuProbe_, bytesToSend); + errno_send = gnutls_heartbeat_ping(session_, bytesToSend, HEARTBEAT_RETRIES, GNUTLS_HEARTBEAT_WAIT); + RING_DBG("[TLS] Heartbeat PMTUD : ping sequence over with errno %d: %s", errno_send, + gnutls_strerror(errno_send)); + } + while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED); + + if (errno_send == GNUTLS_E_SUCCESS) { + ++mtuProbe_; + } else if (errno_send == GNUTLS_E_TIMEDOUT){ // timeout is considered as a packet loss, then the good mtu is the precedent. + if (mtuProbe_ == mtus.cbegin()) { + RING_WARN("[TLS] Heartbeat PMTUD : no response on first ping, setting minimal MTU value @%d", MIN_MTU); + gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_); + + } else { + --mtuProbe_; + RING_DBG("[TLS] Heartbeat PMTUD : timed out: Path MTU found @ %d", *mtuProbe_); + gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); + } + return; + } else { + RING_WARN("[TLS] Heartbeat PMTUD : client ping failed: error %d: %s", errno_send, + gnutls_strerror(errno_send)); + if (mtuProbe_ != mtus.begin()) + --mtuProbe_; + gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); + return; + } + } + + + if (errno_send == GNUTLS_E_SUCCESS) { + RING_WARN("[TLS] Heartbeat PMTUD completed : reached test value %d", mtus.back()); + --mtuProbe_; // for loop over, setting mtu to last valid mtu + } + + gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); + RING_WARN("[TLS] Heartbeat PMTUD : new mtu set to %d", *mtuProbe_); +} + + TlsSessionState TlsSession::handleStateEstablished(TlsSessionState state) { @@ -732,12 +848,39 @@ TlsSession::handleStateEstablished(TlsSessionState state) // Handle RX data from network if (!rxQueue_.empty()) { - std::vector<uint8_t> buf(8*1024); + std::vector<uint8_t> buf(INPUT_BUFFER_SIZE); unsigned char sequence[8]; lk.unlock(); auto ret = gnutls_record_recv_seq(session_, buf.data(), buf.size(), sequence); - if (ret > 0) { + if (ret > 0 && pmtudOver_) { + buf.resize(ret); + // TODO: handle sequence re-order + if (callbacks_.onRxData) + callbacks_.onRxData(std::move(buf)); + return state; + } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) { + + RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong"); + auto errno_send = gnutls_heartbeat_pong(session_, 0); + + if (errno_send != GNUTLS_E_SUCCESS){ + RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send, + gnutls_strerror(errno_send)); + } else { + ++hbPingRecved_; + } + + } else if (ret > 0 && pmtudOver_ == false){ + if (hbPingRecved_ > 0){ + gnutls_dtls_set_mtu(session_, mtus[hbPingRecved_ - 1] - UDP_HEADER_SIZE - transportOverhead_); + maxPayload_ = gnutls_dtls_get_data_mtu(session_); + } else { + gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_); + maxPayload_ = gnutls_dtls_get_data_mtu(session_); + } + RING_WARN("[TLS] maxPayload for dtls : %d B", getMaxPayload()); + pmtudOver_ = true; buf.resize(ret); // TODO: handle sequence re-order if (callbacks_.onRxData) @@ -819,4 +962,10 @@ DhParams::generate() return params; } +uint16_t +TlsSession::getMtu() +{ + return gnutls_dtls_get_mtu(session_); +} + }} // namespace ring::tls diff --git a/src/security/tls_session.h b/src/security/tls_session.h index 745f029124..c66c903706 100644 --- a/src/security/tls_session.h +++ b/src/security/tls_session.h @@ -39,6 +39,8 @@ #include <vector> #include <map> #include <atomic> +#include <iterator> +#include <array> namespace ring { class IceTransport; @@ -52,10 +54,13 @@ struct PrivateKey; namespace ring { namespace tls { +static constexpr uint8_t MTUS_TO_TEST = 4; //number of mtus to test in path mtu discovery. + enum class TlsSessionState { SETUP, COOKIE, // server only HANDSHAKE, + MTU_DISCOVERY, ESTABLISHED, SHUTDOWN }; @@ -171,6 +176,8 @@ public: ssize_t send(const void* data, std::size_t size); ssize_t send(const std::vector<uint8_t>& data); + uint16_t getMtu(); + private: using clock = std::chrono::steady_clock; using StateHandler = std::function<TlsSessionState(TlsSessionState state)>; @@ -186,6 +193,7 @@ private: TlsSessionState handleStateSetup(TlsSessionState state); TlsSessionState handleStateCookie(TlsSessionState state); TlsSessionState handleStateHandshake(TlsSessionState state); + TlsSessionState handleStateMtuDiscovery(TlsSessionState state); TlsSessionState handleStateEstablished(TlsSessionState state); TlsSessionState handleStateShutdown(TlsSessionState state); std::map<TlsSessionState, StateHandler> fsmHandlers_ {}; @@ -235,6 +243,13 @@ private: bool setup(); void process(); void cleanup(); + + // Path mtu discovery + std::array<uint16_t, MTUS_TO_TEST>::const_iterator mtuProbe_; + unsigned hbPingRecved_ {0}; + bool pmtudOver_ {false}; + uint8_t transportOverhead_; + void pathMtuHeartbeat(); }; }} // namespace ring::tls diff --git a/src/sip/sipcall.cpp b/src/sip/sipcall.cpp index b81b9fb3b5..7bec8a9e92 100644 --- a/src/sip/sipcall.cpp +++ b/src/sip/sipcall.cpp @@ -856,11 +856,15 @@ SIPCall::startAllMedia() continue; } + auto new_mtu = transport_->getTlsMtu(); + avformatrtp_->setMtu(new_mtu); + #ifdef RING_VIDEO if (local.type == MEDIA_VIDEO) videortp_->switchInput(videoInput_); -#endif + videortp_->setMtu(new_mtu); +#endif rtp->updateMedia(remote, local); // Not restarting media loop on hold as it's a huge waste of CPU ressources diff --git a/src/sip/siptransport.cpp b/src/sip/siptransport.cpp index 3e2ca56863..ee54f11077 100644 --- a/src/sip/siptransport.cpp +++ b/src/sip/siptransport.cpp @@ -173,6 +173,12 @@ SipTransport::removeStateListener(uintptr_t lid) return false; } +uint16_t +SipTransport::getTlsMtu(){ + auto tls_tr = reinterpret_cast<tls::SipsIceTransport::TransportData*>(transport_.get())->self; + return tls_tr->getTlsSessionMtu(); +} + SipTransportBroker::SipTransportBroker(pjsip_endpoint *endpt, pj_caching_pool& cp, pj_pool_t& pool) : cp_(cp), pool_(pool), endpt_(endpt) diff --git a/src/sip/siptransport.h b/src/sip/siptransport.h index 0605719a2c..0e4919e16e 100644 --- a/src/sip/siptransport.h +++ b/src/sip/siptransport.h @@ -138,6 +138,8 @@ class SipTransport /** Only makes sense for connection-oriented transports */ bool isConnected() const noexcept { return connected_; } + uint16_t getTlsMtu(); + private: NON_COPYABLE(SipTransport); -- GitLab