From bdafdfb43d4fb5a2f4ef452ecbfde082809cbad5 Mon Sep 17 00:00:00 2001 From: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> Date: Fri, 24 Nov 2017 13:08:27 -0500 Subject: [PATCH] make TlsSession great again Or at least independant of underlaying transport... To make TlsSession able to handle both TLS and DTLS this patch removes the ICE dependency and replace is by the generic network ABC class 'GenericTransport'. As a first step this class is declared in tls_session.h. Side effects of this change are: * refactoring of PMTUD procedure: 'MTU' for gnutls has the meaning on 'payload-for-gnutls' so this information is now drived by the generic transport and not hardcoded anymore. The minimal value of probing remains hardcoded, as is a minimum given by RFC's documentation and it's based on an IPv4 packet associated with UDP protocol. * getMtu() is now maxPayload() and represent correctly what the application must have. * TlsSession implements itself GenericTransport: we can chain GenericTransport instances to construct an overlayed transport protocol. * TlsSession is now considered as non thread-safe for its public API. Caller must bring itself this property. This permit to remove a redundant mutex in send() operation. Note: and it's the case in the only user (SipsIceTransport), that why the mutex is redundant in 100% of cases. Notice the benefit of this genericity refactoring let us write a unit-test for this TlsSession class without having an heavy ICE transport to mock-up. Also ICE transport gained of this by adding a new IceSocketTransport to replace IceSocket in a near future (need async IO in GenericSocket, but not required for the moment). Change-Id: I6f4591ed6c76fa9cb5519c6e9296f8fc3a6798aa Reviewed-by: Olivier Soldano <olivier.soldano@savoirfairelinux.com> --- src/Makefile.am | 3 +- src/generic_io.h | 103 ++++++ src/ice_socket.h | 45 ++- src/ice_transport.cpp | 52 +++ src/ringdht/sips_transport_ice.cpp | 43 ++- src/ringdht/sips_transport_ice.h | 4 +- src/security/tls_session.cpp | 541 ++++++++++++++++------------- src/security/tls_session.h | 95 ++--- src/sip/siptransport.cpp | 7 +- src/turn_transport.cpp | 187 ++++++---- src/turn_transport.h | 44 ++- 11 files changed, 731 insertions(+), 393 deletions(-) create mode 100644 src/generic_io.h diff --git a/src/Makefile.am b/src/Makefile.am index 4d608b9f03..2bdea570f8 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -143,7 +143,8 @@ libring_la_SOURCES = \ base64.cpp \ turn_transport.h \ turn_transport.cpp \ - channel.h + channel.h \ + generic_io.h if HAVE_WIN32 libring_la_SOURCES += \ diff --git a/src/generic_io.h b/src/generic_io.h new file mode 100644 index 0000000000..8e0ae80eb9 --- /dev/null +++ b/src/generic_io.h @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2017 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include <functional> +#include <vector> +#include <system_error> +#include <cstdint> + +#if defined(_MSC_VER) +#include <BaseTsd.h> +using ssize_t = SSIZE_T; +#endif + +namespace ring { + +template <typename T> +class GenericSocket +{ +public: + using ValueType = T; + + virtual ~GenericSocket() = default; + + using RecvCb = std::function<ssize_t(const ValueType* buf, std::size_t len)>; + + /// Set Rx callback + /// \warning This method is here for backward compatibility + /// and because async IO are not implemented yet. + virtual void setOnRecv(RecvCb&& cb) = 0; + + virtual bool isReliable() const = 0; + + virtual bool isInitiator() const = 0; + + /// Return maximum application payload size. + /// This value is negative if the session is not ready to give a valid answer. + /// The value is 0 if such information is irrelevant for the session. + /// If stricly positive, the user must use send() with an input buffer size below or equals + /// to this value if it want to be sure that the transport sent it in an atomic way. + /// Example: in case of non-reliable transport using packet oriented IO, + /// this value gives the maximal size used to send one packet. + virtual int maxPayload() const = 0; + + // TODO: make a std::chrono version + virtual bool waitForData(unsigned ms_timeout) const = 0; + + /// Write a given amount of data. + /// \param buf data to write. + /// \param len number of bytes to write. + /// \param ec error code set in case of error. + /// \return number of bytes written, 0 is valid. + /// \warning error checking consists in checking if \a !ec is true, not if returned size is 0 + /// as a write of 0 could be considered a valid operation. + virtual std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) = 0; + + /// Read a given amount of data. + /// \param buf data to read. + /// \param len number of bytes to read. + /// \param ec error code set in case of error. + /// \return number of bytes read, 0 is valid. + /// \warning error checking consists in checking if \a !ec is true, not if returned size is 0 + /// as a read of 0 could be considered a valid operation (i.e. non-blocking IO). + virtual std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) = 0; + + /// write() adaptor for STL containers + template <typename U> + std::size_t write(const U& obj, std::error_code& ec) { + return write(obj.data(), obj.size() * sizeof(typename U::value_type), ec); + } + + /// read() adaptor for STL containers + template <typename U> + std::size_t read(U& storage, std::error_code& ec) { + auto res = read(storage.data(), storage.size() * sizeof(typename U::value_type), ec); + if (!ec) + storage.resize(res); + return res; + } + +protected: + GenericSocket() = default; +}; + +} // namespace ring diff --git a/src/ice_socket.h b/src/ice_socket.h index ba12a1b04e..0d81a2aa33 100644 --- a/src/ice_socket.h +++ b/src/ice_socket.h @@ -17,8 +17,9 @@ * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ -#ifndef ICE_SOCKET_H -#define ICE_SOCKET_H +#pragma once + +#include "generic_io.h" #include <memory> #include <functional> @@ -51,6 +52,44 @@ class IceSocket uint16_t getTransportOverhead(); }; +/// ICE transport as a GenericSocket. +/// +/// \warning Simplified version where we assume that ICE protocol +/// always use UDP over IP over ETHERNET, and doesn't add more header to the UDP payload. +/// +class IceSocketTransport final : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + + static constexpr uint16_t STANDARD_MTU_SIZE = 1280; // Size in bytes of MTU for IPv6 capable networks + static constexpr uint16_t IPV6_HEADER_SIZE = 40; // Size in bytes of IPv6 packet header + static constexpr uint16_t IPV4_HEADER_SIZE = 20; // Size in bytes of IPv4 packet header + static constexpr uint16_t UDP_HEADER_SIZE = 8; // Size in bytes of UDP header + + IceSocketTransport(std::shared_ptr<IceTransport>& ice, int comp_id) + : compId_ {comp_id} + , ice_ {ice} {} + + bool isReliable() const override { + return false; // we consider that a ICE transport is never reliable (UDP support only) + } + + bool isInitiator() const override; + + int maxPayload() const override; + + bool waitForData(unsigned ms_timeout) const override; + + std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override; + + std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override; + + void setOnRecv(RecvCb&& cb) override; + +private: + const int compId_; + std::shared_ptr<IceTransport> ice_; }; -#endif /* ICE_SOCKET_H */ +}; diff --git a/src/ice_transport.cpp b/src/ice_transport.cpp index 7ab020daa1..82645ac8bc 100644 --- a/src/ice_transport.cpp +++ b/src/ice_transport.cpp @@ -1176,6 +1176,58 @@ IceTransportFactory::createTransport(const char* name, int component_count, //============================================================================== +void +IceSocketTransport::setOnRecv(RecvCb&& cb) +{ + return ice_->setOnRecv(compId_, cb); +} + +bool +IceSocketTransport::isInitiator() const +{ + return ice_->isInitiator(); +} + +int +IceSocketTransport::maxPayload() const +{ + auto ip_header_size = (ice_->getRemoteAddress(compId_).getFamily() == AF_INET) ? + IPV4_HEADER_SIZE : IPV6_HEADER_SIZE; + return STANDARD_MTU_SIZE - ip_header_size - UDP_HEADER_SIZE; +} + +bool +IceSocketTransport::waitForData(unsigned ms_timeout) const +{ + return ice_->waitForData(compId_, ms_timeout) > 0; +} + +std::size_t +IceSocketTransport::write(const ValueType* buf, std::size_t len, std::error_code& ec) +{ + auto res = ice_->send(compId_, buf, len); + if (res < 0) { + ec.assign(errno, std::generic_category()); + return 0; + } + ec.clear(); + return res; +} + +std::size_t +IceSocketTransport::read(ValueType* buf, std::size_t len, std::error_code& ec) +{ + auto res = ice_->recv(compId_, buf, len); + if (res < 0) { + ec.assign(errno, std::generic_category()); + return 0; + } + ec.clear(); + return res; +} + +//============================================================================== + void IceSocket::close() { diff --git a/src/ringdht/sips_transport_ice.cpp b/src/ringdht/sips_transport_ice.cpp index a86b67e822..7142b4471d 100644 --- a/src/ringdht/sips_transport_ice.cpp +++ b/src/ringdht/sips_transport_ice.cpp @@ -21,7 +21,9 @@ #include "sips_transport_ice.h" +#include "ice_socket.h" #include "ice_transport.h" + #include "manager.h" #include "sip/sip_utils.h" #include "logger.h" @@ -38,6 +40,7 @@ #include <pj/lock.h> #include <algorithm> +#include <system_error> #include <cstring> // std::memset namespace ring { namespace tls { @@ -233,14 +236,16 @@ SipsIceTransport::SipsIceTransport(pjsip_endpoint* endpt, std::memset(&localCertInfo_, 0, sizeof(pj_ssl_cert_info)); std::memset(&remoteCertInfo_, 0, sizeof(pj_ssl_cert_info)); + iceSocket_ = std::make_unique<IceSocketTransport>(ice_, comp_id); + TlsSession::TlsSessionCallbacks cbs = { /*.onStateChange = */[this](TlsSessionState state){ onTlsStateChange(state); }, /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onRxData(std::move(buf)); }, /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r, - unsigned int n){ onCertificatesUpdate(l, r, n); }, + unsigned int n){ onCertificatesUpdate(l, r, n); }, /*.verifyCertificate = */[this](gnutls_session_t session){ return verifyCertificate(session); } }; - tls_.reset(new TlsSession(ice, comp_id, param, cbs)); + tls_ = std::make_unique<TlsSession>(*iceSocket_, param, cbs); if (pjsip_transport_register(base.tpmgr, &base) != PJ_SUCCESS) throw std::runtime_error("Can't register PJSIP transport."); @@ -323,18 +328,19 @@ SipsIceTransport::handleEvents() pj_status_t status; if (!fatal) { const std::size_t size = tdata->buf.cur - tdata->buf.start; - auto ret = tls_->send(tdata->buf.start, size); - if (gnutls_error_is_fatal(ret)) { - RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ret)); - tls_->shutdown(); - fatal = true; + std::error_code ec; + status = tls_->write(reinterpret_cast<const uint8_t*>(tdata->buf.start), size, ec); + if (ec) { + status = tls_status_from_err(ec.value()); + if (gnutls_error_is_fatal(ec.value())) { + RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ec.value())); + tls_->shutdown(); + fatal = true; + } } - if (ret < 0) - status = tls_status_from_err(ret); - else - status = ret; - } else + } else { status = -PJ_RETURN_OS_ERROR(OSERR_ENOTCONN); + } tdata->op_key.tdata = nullptr; if (tdata->op_key.callback) @@ -501,7 +507,7 @@ SipsIceTransport::getInfo(pj_ssl_sock_info* info, bool established) if (established) { // Cipher Suite Id std::array<uint8_t, 2> cs_id; - if (auto cipher_name = tls_->getCurrentCipherSuiteId(cs_id)) { + if (auto cipher_name = tls_->currentCipherSuiteId(cs_id)) { info->cipher = static_cast<pj_ssl_cipher>((cs_id[0] << 8) | cs_id[1]); RING_DBG("[TLS] using cipher %s (0x%02X%02X)", cipher_name, cs_id[0], cs_id[1]); } else @@ -673,14 +679,15 @@ SipsIceTransport::send(pjsip_tx_data* tdata, const pj_sockaddr_t* rem_addr, const std::size_t size = tdata->buf.cur - tdata->buf.start; std::unique_lock<std::mutex> lk {txMutex_}; if (syncTx_ and txQueue_.empty()) { - auto ret = tls_->send(tdata->buf.start, size); + std::error_code ec; + tls_->write(reinterpret_cast<const uint8_t*>(tdata->buf.start), size, ec); lk.unlock(); // Shutdown on fatal error, else ignore it - if (gnutls_error_is_fatal(ret)) { - RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ret)); + if (ec and gnutls_error_is_fatal(ec.value())) { + RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ec.value())); tls_->shutdown(); - return tls_status_from_err(ret); + return tls_status_from_err(ec.value()); } return PJ_SUCCESS; @@ -697,7 +704,7 @@ SipsIceTransport::send(pjsip_tx_data* tdata, const pj_sockaddr_t* rem_addr, uint16_t SipsIceTransport::getTlsSessionMtu() { - return tls_->getMtu(); + return tls_->maxPayload(); } }} // namespace ring::tls diff --git a/src/ringdht/sips_transport_ice.h b/src/ringdht/sips_transport_ice.h index 48cbcb3e02..9617ad46d7 100644 --- a/src/ringdht/sips_transport_ice.h +++ b/src/ringdht/sips_transport_ice.h @@ -43,6 +43,7 @@ namespace ring { class IceTransport; +class IceSocketTransport; } // namespace ring namespace ring { namespace tls { @@ -80,7 +81,7 @@ struct SipsIceTransport private: NON_COPYABLE(SipsIceTransport); - const std::shared_ptr<IceTransport> ice_; + std::shared_ptr<IceTransport> ice_; const int comp_id_; const std::function<int(unsigned, const gnutls_datum_t*, unsigned)> certCheck_; IpAddr local_ {}; @@ -109,6 +110,7 @@ private: decltype(PJSIP_TP_STATE_DISCONNECTED) state; }; + std::unique_ptr<IceSocketTransport> iceSocket_; std::unique_ptr<TlsSession> tls_; std::mutex txMutex_ {}; diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index 0ab03f56e4..e026b73012 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -24,8 +24,6 @@ #include "tls_session.h" #include "threadloop.h" -#include "ice_socket.h" -#include "ice_transport.h" #include "logger.h" #include "noncopyable.h" #include "compiler_intrinsics.h" @@ -54,27 +52,23 @@ namespace ring { namespace tls { -static constexpr const char* TLS_CERT_PRIORITY_STRING {"SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; -static constexpr const char* TLS_FULL_PRIORITY_STRING {"SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* DTLS_CERT_PRIORITY_STRING {"SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* DTLS_FULL_PRIORITY_STRING {"SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* TLS_CERT_PRIORITY_STRING {"SECURE192:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* TLS_FULL_PRIORITY_STRING {"SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; static constexpr uint16_t INPUT_BUFFER_SIZE {16*1024}; // to be coherent with the maximum size advised in path mtu discovery static constexpr std::size_t INPUT_MAX_SIZE {1000}; // Maximum number of packets to store before dropping (pkt size = DTLS_MTU) static constexpr ssize_t FLOOD_THRESHOLD {4*1024}; static constexpr auto FLOOD_PAUSE = std::chrono::milliseconds(100); // Time to wait after an invalid cookie packet (anti flood attack) static constexpr auto DTLS_RETRANSMIT_TIMEOUT = std::chrono::milliseconds(1000); // Delay between two handshake request on DTLS static constexpr auto COOKIE_TIMEOUT = std::chrono::seconds(10); // Time to wait for a cookie packet from client -static constexpr uint8_t UDP_HEADER_SIZE = 8; // Size in bytes of UDP packet header +static constexpr int MIN_MTU {512 - 20 - 8}; // minimal payload size of a DTLS packet carried by an IPv4 packet static constexpr uint8_t HEARTBEAT_TRIES = 1; // Number of tries at each heartbeat ping send static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds) static constexpr auto HEARTBEAT_TOTAL_TIMEOUT = HEARTBEAT_RETRANS_TIMEOUT * HEARTBEAT_TRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds) static constexpr int MISS_ORDERING_LIMIT = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type) static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(1500); -// mtus array to test, do not add mtu over the interface MTU, this will result in false result due to packet fragmentation. -// also do not set over 16000 this will result in a gnutls error (unexpected packet size) -// neither under MIN_MTU because it makes no sense and could result in underflow of certain variables. -// Put mtus values in ascending order in the array to avoid sorting -static constexpr std::array<uint16_t, MTUS_TO_TEST> MTUS = {MIN_MTU, 800, 1280}; - // Helper to cast any duration into an integer number of milliseconds template <class Rep, class Period> static std::chrono::milliseconds::rep @@ -176,22 +170,20 @@ public: using StateHandler = std::function<TlsSessionState(TlsSessionState state)>; // Constants (ctor init.) - const std::unique_ptr<IceSocket> socket_; const bool isServer_; const TlsParams params_; const TlsSessionCallbacks callbacks_; const bool anonymous_; - TlsSessionImpl(const std::shared_ptr<IceTransport>& ice, - int ice_comp_id, - const TlsParams& params, - const TlsSessionCallbacks& cbs, - bool anonymous); + TlsSessionImpl(SocketType& transport, const TlsParams& params, + const TlsSessionCallbacks& cbs, bool anonymous); ~TlsSessionImpl(); const char* typeName() const; + SocketType& transport_; + // State machine TlsSessionState handleStateSetup(TlsSessionState state); TlsSessionState handleStateCookie(TlsSessionState state); @@ -201,31 +193,30 @@ public: TlsSessionState handleStateShutdown(TlsSessionState state); std::map<TlsSessionState, StateHandler> fsmHandlers_ {}; std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP}; - std::atomic<unsigned int> maxPayload_; + std::atomic<int> maxPayload_ {-1}; // IO GnuTLS <-> ICE - std::mutex txMutex_ {}; std::mutex rxMutex_ {}; std::condition_variable rxCv_ {}; - std::list<std::vector<uint8_t>> rxQueue_ {}; + std::list<std::vector<ValueType>> rxQueue_ {}; std::mutex reorderBufMutex_; bool flushProcessing_ {false}; ///< protect against recursive call to flushRxQueue - std::vector<uint8_t> rawPktBuf_; ///< gnutls incoming packet buffer + std::vector<ValueType> rawPktBuf_; ///< gnutls incoming packet buffer uint64_t baseSeq_ {0}; ///< sequence number of first application data packet received uint64_t lastRxSeq_ {0}; ///< last received and valid packet sequence number uint64_t gapOffset_ {0}; ///< offset of first byte not received yet clock::time_point lastReadTime_; - std::map<uint64_t, std::vector<uint8_t>> reorderBuffer_ {}; + std::map<uint64_t, std::vector<ValueType>> reorderBuffer_ {}; - ssize_t send_(const uint8_t* tx_data, std::size_t tx_size); + std::size_t send(const ValueType*, std::size_t, std::error_code&); ssize_t sendRaw(const void*, size_t); ssize_t sendRawVec(const giovec_t*, int); ssize_t recvRaw(void*, size_t); int waitForRawData(unsigned); bool initFromRecordState(int offset=0); - void handleDataPacket(std::vector<uint8_t>&&, uint64_t); + void handleDataPacket(std::vector<ValueType>&&, uint64_t); void flushRxQueue(); // Statistics @@ -257,24 +248,22 @@ public: void cleanup(); // Path mtu discovery - std::array<uint16_t, MTUS_TO_TEST>::const_iterator mtuProbe_; - unsigned hbPingRecved_ {0}; + std::array<int, 3> MTUS_; + int mtuProbe_; + int hbPingRecved_ {0}; bool pmtudOver_ {false}; - uint8_t transportOverhead_; void pathMtuHeartbeat(); }; -TlsSession::TlsSessionImpl::TlsSessionImpl(const std::shared_ptr<IceTransport>& ice, - int ice_comp_id, +TlsSession::TlsSessionImpl::TlsSessionImpl(SocketType& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, bool anonymous) - : socket_(new IceSocket(ice, ice_comp_id)) - , isServer_(not ice->isInitiator()) + : isServer_(not transport.isInitiator()) , params_(params) , callbacks_(cbs) , anonymous_(anonymous) - , maxPayload_(INPUT_BUFFER_SIZE) + , transport_ { transport } , cacred_(nullptr) , sacred_(nullptr) , xcred_(nullptr) @@ -282,18 +271,20 @@ TlsSession::TlsSessionImpl::TlsSessionImpl(const std::shared_ptr<IceTransport>& [this] { process(); }, [this] { cleanup(); }) { - socket_->setOnRecv([this](uint8_t* buf, size_t len) { - std::lock_guard<std::mutex> lk {rxMutex_}; - if (rxQueue_.size() == INPUT_MAX_SIZE) { - rxQueue_.pop_front(); // drop oldest packet if input buffer is full - ++stRxRawPacketDropCnt_; - } - rxQueue_.emplace_back(buf, buf+len); - ++stRxRawPacketCnt_; - stRxRawBytesCnt_ += len; - rxCv_.notify_one(); - return len; - }); + if (not transport_.isReliable()) { + transport_.setOnRecv([this](const ValueType* buf, size_t len) { + std::lock_guard<std::mutex> lk {rxMutex_}; + if (rxQueue_.size() == INPUT_MAX_SIZE) { + rxQueue_.pop_front(); // drop oldest packet if input buffer is full + ++stRxRawPacketDropCnt_; + } + rxQueue_.emplace_back(buf, buf+len); + ++stRxRawPacketCnt_; + stRxRawBytesCnt_ += len; + rxCv_.notify_one(); + return len; + }); + } Manager::instance().registerEventHandler((uintptr_t)this, [this]{ flushRxQueue(); }); @@ -304,9 +295,8 @@ TlsSession::TlsSessionImpl::TlsSessionImpl(const std::shared_ptr<IceTransport>& TlsSession::TlsSessionImpl::~TlsSessionImpl() { thread_.join(); - - socket_->setOnRecv(nullptr); - + if (not transport_.isReliable()) + transport_.setOnRecv(nullptr); Manager::instance().unregisterEventHandler((uintptr_t)this); } @@ -327,9 +317,15 @@ TlsSession::TlsSessionImpl::dump_io_stats() const TlsSessionState TlsSession::TlsSessionImpl::setupClient() { - auto ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM); - RING_WARN("[TLS] set heartbeat reception for retrocompatibility check on server"); - gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); + int ret; + + if (not transport_.isReliable()) { + ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM); + RING_DBG("[TLS] set heartbeat reception for retrocompatibility check on server"); + gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); + } else { + ret = gnutls_init(&session_, GNUTLS_CLIENT); + } if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] session init failed: %s", gnutls_strerror(ret)); @@ -346,8 +342,30 @@ TlsSession::TlsSessionImpl::setupClient() TlsSessionState TlsSession::TlsSessionImpl::setupServer() { - gnutls_key_generate(&cookie_key_, GNUTLS_COOKIE_KEY_SIZE); - return TlsSessionState::COOKIE; + int ret; + + if (not transport_.isReliable()) { + ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM); + + RING_DBG("[TLS] set heartbeat reception"); + gnutls_heartbeat_enable(session_, GNUTLS_HB_PEER_ALLOWED_TO_SEND); + + gnutls_dtls_prestate_set(session_, &prestate_); + } else { + ret = gnutls_init(&session_, GNUTLS_SERVER); + } + + if (ret != GNUTLS_E_SUCCESS) { + RING_ERR("[TLS] session init failed: %s", gnutls_strerror(ret)); + return TlsSessionState::SHUTDOWN; + } + + gnutls_certificate_server_set_request(session_, GNUTLS_CERT_REQUIRE); + + if (not commonSessionInit()) + return TlsSessionState::SHUTDOWN; + + return TlsSessionState::HANDSHAKE; } void @@ -440,7 +458,9 @@ TlsSession::TlsSessionImpl::commonSessionInit() if (anonymous_) { // Force anonymous connection, see handleStateHandshake how we handle failures - ret = gnutls_priority_set_direct(session_, TLS_FULL_PRIORITY_STRING, nullptr); + ret = gnutls_priority_set_direct(session_, + transport_.isReliable() ? TLS_FULL_PRIORITY_STRING : DTLS_FULL_PRIORITY_STRING, + nullptr); if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] TLS priority set failed: %s", gnutls_strerror(ret)); return false; @@ -458,7 +478,9 @@ TlsSession::TlsSessionImpl::commonSessionInit() } } else { // Use a classic non-encrypted CERTIFICATE exchange method (less anonymous) - ret = gnutls_priority_set_direct(session_, TLS_CERT_PRIORITY_STRING, nullptr); + ret = gnutls_priority_set_direct(session_, + transport_.isReliable() ? TLS_CERT_PRIORITY_STRING : DTLS_CERT_PRIORITY_STRING, + nullptr); if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] TLS priority set failed: %s", gnutls_strerror(ret)); return false; @@ -473,13 +495,15 @@ TlsSession::TlsSessionImpl::commonSessionInit() } gnutls_certificate_send_x509_rdn_sequence(session_, 0); - // DTLS hanshake timeouts - auto re_tx_timeout = duration2ms(DTLS_RETRANSMIT_TIMEOUT); - gnutls_dtls_set_timeouts(session_, re_tx_timeout, - std::max(duration2ms(params_.timeout), re_tx_timeout)); + if (not transport_.isReliable()) { + // DTLS hanshake timeouts + auto re_tx_timeout = duration2ms(DTLS_RETRANSMIT_TIMEOUT); + gnutls_dtls_set_timeouts(session_, re_tx_timeout, + std::max(duration2ms(params_.timeout), re_tx_timeout)); - // DTLS maximum payload size (raw packet relative) - gnutls_dtls_set_mtu(session_, DTLS_MTU); + // gnutls DTLS mtu = maximum payload size given by transport + gnutls_dtls_set_mtu(session_, transport_.maxPayload()); + } // Stuff for transport callbacks gnutls_session_set_ptr(session_, this); @@ -504,17 +528,27 @@ TlsSession::TlsSessionImpl::commonSessionInit() return true; } -ssize_t -TlsSession::TlsSessionImpl::send_(const uint8_t* tx_data, std::size_t tx_size) +std::size_t +TlsSession::TlsSessionImpl::send(const ValueType* tx_data, std::size_t tx_size, std::error_code& ec) { - std::size_t max_tx_sz = gnutls_dtls_get_data_mtu(session_); + if (state_ != TlsSessionState::ESTABLISHED) { + ec = std::error_code(GNUTLS_E_INVALID_SESSION, std::system_category()); + return 0; + } - // Split user data into MTU-suitable chunck - size_t total_written = 0; + std::size_t total_written = 0; + std::size_t max_tx_sz; + + if (transport_.isReliable()) + max_tx_sz = tx_size; + else + max_tx_sz = gnutls_dtls_get_data_mtu(session_); + + // Split incoming data into chunck suitable for the underlying transport while (total_written < tx_size) { auto chunck_sz = std::min(max_tx_sz, tx_size - total_written); - ssize_t nwritten; auto data_seq = tx_data + total_written; + ssize_t nwritten; do { nwritten = gnutls_record_send(session_, data_seq, chunck_sz); } while (nwritten == GNUTLS_E_INTERRUPTED or nwritten == GNUTLS_E_AGAIN); @@ -523,13 +557,16 @@ TlsSession::TlsSessionImpl::send_(const uint8_t* tx_data, std::size_t tx_size) * state has not changed, so we have to ask for more data first. * We will just try again later, although this should never happen. */ - RING_WARN("[TLS] send failed (only %zu bytes sent): %s", total_written, - gnutls_strerror(nwritten)); - return nwritten; + RING_ERR() << "[TLS] send failed (only " << total_written << " bytes sent): " + << gnutls_strerror(nwritten); + ec = std::error_code(nwritten, std::system_category()); + return 0; } total_written += nwritten; } + + ec.clear(); return total_written; } @@ -538,16 +575,18 @@ TlsSession::TlsSessionImpl::send_(const uint8_t* tx_data, std::size_t tx_size) ssize_t TlsSession::TlsSessionImpl::sendRaw(const void* buf, size_t size) { - auto ret = socket_->send(reinterpret_cast<const unsigned char*>(buf), size); - if (ret > 0) { + std::error_code ec; + auto n = transport_.write(reinterpret_cast<const ValueType*>(buf), size, ec); + if (!ec) { // log only on success ++stTxRawPacketCnt_; - stTxRawBytesCnt_ += size; - return ret; + stTxRawBytesCnt_ += n; + return n; } // Must be called to pass errno value to GnuTLS on Windows (cf. GnuTLS doc) - gnutls_transport_set_errno(session_, errno); + gnutls_transport_set_errno(session_, ec.value()); + RING_ERR() << "[TLS] transport failure on tx: errno = " << ec.value(); return -1; } @@ -574,39 +613,64 @@ TlsSession::TlsSessionImpl::sendRawVec(const giovec_t* iov, int iovcnt) ssize_t TlsSession::TlsSessionImpl::recvRaw(void* buf, size_t size) { - std::lock_guard<std::mutex> lk {rxMutex_}; - if (rxQueue_.empty()) { - gnutls_transport_set_errno(session_, EAGAIN); + if (transport_.isReliable()) { + std::error_code ec; + auto count = transport_.read(reinterpret_cast<ValueType*>(buf), size, ec); + if (!ec) + return count; + gnutls_transport_set_errno(session_, ec.value()); return -1; } const auto& pkt = rxQueue_.front(); const std::size_t count = std::min(pkt.size(), size); - std::copy_n(pkt.begin(), count, reinterpret_cast<uint8_t*>(buf)); + std::copy_n(pkt.begin(), count, reinterpret_cast<ValueType*>(buf)); rxQueue_.pop_front(); return count; } // Called by GNUTLS to wait for encrypted packet from low-level transport. // 'timeout' is in milliseconds. -// Should return 0 on connection termination, -// a positive number indicating the number of bytes received, -// and -1 on error. +// Should return 0 on timeout, a positive number if data are available for read, or -1 on error. int TlsSession::TlsSessionImpl::waitForRawData(unsigned timeout) { - std::unique_lock<std::mutex> lk {rxMutex_}; - if (not rxCv_.wait_for(lk, std::chrono::milliseconds(timeout), - [this]{ return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN; })) - return 0; + if (transport_.isReliable()) { + if (not transport_.waitForData(timeout)) { + // shutdown? + if (state_ == TlsSessionState::SHUTDOWN) { + gnutls_transport_set_errno(session_, EINTR); + return -1; + } + return 0; + } + return 1; + } - // shutdown? + // non-reliable uses callback installed with setOnRecv() + std::unique_lock<std::mutex> lk {rxMutex_}; + rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN; }); if (state_ == TlsSessionState::SHUTDOWN) { gnutls_transport_set_errno(session_, EINTR); return -1; } + return 1; +} + +bool +TlsSession::TlsSessionImpl::initFromRecordState(int offset) +{ + std::array<uint8_t, 8> seq; + if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0]) != GNUTLS_E_SUCCESS) { + RING_ERR("[TLS] Fatal-error Unable to read initial state"); + return false; + } - return rxQueue_.front().size(); + baseSeq_ = array2uint(seq) + offset; + gapOffset_ = baseSeq_; + lastRxSeq_ = baseSeq_ - 1; + RING_DBG("[TLS] Initial sequence number: %lx", baseSeq_); + return true; } bool @@ -629,8 +693,10 @@ TlsSession::TlsSessionImpl::cleanup() state_ = TlsSessionState::SHUTDOWN; // be sure to block any user operations if (session_) { - // DTLS: not use GNUTLS_SHUT_RDWR to not wait for a peer answer - gnutls_bye(session_, GNUTLS_SHUT_WR); + if (transport_.isReliable()) + gnutls_bye(session_, GNUTLS_SHUT_RDWR); + else + gnutls_bye(session_, GNUTLS_SHUT_WR); // not wait for a peer answer gnutls_deinit(session_); session_ = nullptr; } @@ -642,7 +708,7 @@ TlsSession::TlsSessionImpl::cleanup() TlsSessionState TlsSession::TlsSessionImpl::handleStateSetup(UNUSED TlsSessionState state) { - RING_DBG("[TLS] Start %s DTLS session", typeName()); + RING_DBG("[TLS] Start %s session", typeName()); try { if (anonymous_) @@ -653,10 +719,15 @@ TlsSession::TlsSessionImpl::handleStateSetup(UNUSED TlsSessionState state) return TlsSessionState::SHUTDOWN; } - if (isServer_) - return setupServer(); - else + if (not isServer_) return setupClient(); + + // Extra step for DTLS-like transports + if (not transport_.isReliable()) { + gnutls_key_generate(&cookie_key_, GNUTLS_COOKIE_KEY_SIZE); + return TlsSessionState::COOKIE; + } + return setupServer(); } TlsSessionState @@ -723,22 +794,7 @@ TlsSession::TlsSessionImpl::handleStateCookie(TlsSessionState state) RING_DBG("[TLS] cookie ok"); - ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM); - RING_WARN("[TLS] set heartbeat reception"); - gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); - - if (ret != GNUTLS_E_SUCCESS) { - RING_ERR("[TLS] session init failed: %s", gnutls_strerror(ret)); - return TlsSessionState::SHUTDOWN; - } - - gnutls_certificate_server_set_request(session_, GNUTLS_CERT_REQUIRE); - gnutls_dtls_prestate_set(session_, &prestate_); - - if (not commonSessionInit()) - return TlsSessionState::SHUTDOWN; - - return TlsSessionState::HANDSHAKE; + return setupServer(); } TlsSessionState @@ -769,7 +825,7 @@ TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) } auto desc = gnutls_session_get_desc(session_); - RING_WARN("[TLS] session established: %s", desc); + RING_DBG("[TLS] session established: %s", desc); gnutls_free(desc); // Anonymous connection? rehandshake immediatly with certificate authentification forced @@ -778,7 +834,9 @@ TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) RING_DBG("[TLS] renogotiate with certificate authentification"); // Re-setup TLS algorithms priority list with only certificate based cipher suites - ret = gnutls_priority_set_direct(session_, TLS_CERT_PRIORITY_STRING, nullptr); + ret = gnutls_priority_set_direct(session_, + transport_.isReliable() ? TLS_CERT_PRIORITY_STRING : DTLS_CERT_PRIORITY_STRING, + nullptr); if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] session TLS cert-only priority set failed: %s", gnutls_strerror(ret)); return TlsSessionState::SHUTDOWN; @@ -807,52 +865,32 @@ TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) callbacks_.onCertificatesUpdate(local, remote, remote_count); } - return TlsSessionState::MTU_DISCOVERY; -} - -bool -TlsSession::TlsSessionImpl::initFromRecordState(int offset) -{ - std::array<uint8_t, 8> seq; - if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0]) != GNUTLS_E_SUCCESS) { - RING_ERR("[TLS] Fatal-error Unable to read initial state"); - return false; - } - - baseSeq_ = array2uint(seq) + offset; - gapOffset_ = baseSeq_; - lastRxSeq_ = baseSeq_ - 1; - RING_DBG("[TLS] Initial sequence number: %lx", baseSeq_); - return true; + return transport_.isReliable() ? TlsSessionState::ESTABLISHED : TlsSessionState::MTU_DISCOVERY; } TlsSessionState TlsSession::TlsSessionImpl::handleStateMtuDiscovery(UNUSED TlsSessionState state) { - // set dtls mtu to be over each and every mtus tested - gnutls_dtls_set_mtu(session_, MTUS.back()); - - // get transport overhead - transportOverhead_ = socket_->getTransportOverhead(); + mtuProbe_ = transport_.maxPayload(); + assert(mtuProbe_ >= MIN_MTU); + MTUS_ = {MIN_MTU, std::max((mtuProbe_ + MIN_MTU)/2, MIN_MTU), mtuProbe_}; // retrocompatibility check - if(gnutls_heartbeat_allowed(session_, GNUTLS_HB_LOCAL_ALLOWED_TO_SEND) == 1) { - if (!isServer_){ - RING_WARN("[TLS] HEARTBEAT PATH MTU DISCOVERY"); + if (gnutls_heartbeat_allowed(session_, GNUTLS_HB_LOCAL_ALLOWED_TO_SEND) == 1) { + if (!isServer_) { pathMtuHeartbeat(); pmtudOver_ = true; - RING_WARN("[TLS] HEARTBEAT PATH MTU DISCOVERY OVER"); } } else { - RING_ERR("[TLS] PEER HEARTBEAT DISABLED: setting minimal value to MTU @%d for retrocompatibility", DTLS_MTU); - gnutls_dtls_set_mtu(session_, DTLS_MTU); + RING_ERR() << "[TLS] PEER HEARTBEAT DISABLED: using transport MTU value " << mtuProbe_; pmtudOver_ = true; } + + gnutls_dtls_set_mtu(session_, mtuProbe_); maxPayload_ = gnutls_dtls_get_data_mtu(session_); - if (pmtudOver_) - RING_WARN("[TLS] maxPayload for dtls : %d B", maxPayload_.load()); if (pmtudOver_) { + RING_DBG() << "[TLS] maxPayload: " << maxPayload_.load(); if (!initFromRecordState()) return TlsSessionState::SHUTDOWN; } @@ -872,62 +910,52 @@ TlsSession::TlsSessionImpl::handleStateMtuDiscovery(UNUSED TlsSessionState state void TlsSession::TlsSessionImpl::pathMtuHeartbeat() { - int errno_send = 1; // non zero initialisation - auto tls_overhead = gnutls_record_overhead_size(session_); - RING_WARN("[TLS] tls session overhead : %lu", tls_overhead); - gnutls_heartbeat_set_timeouts(session_, HEARTBEAT_RETRANS_TIMEOUT.count(), HEARTBEAT_TOTAL_TIMEOUT.count()); - RING_DBG("[TLS] Heartbeat PMTUD : retransmission timeout set to: %ld ms", HEARTBEAT_RETRANS_TIMEOUT.count()); - - // server side: managing pong in state established - // client side: managing ping on heartbeat - uint16_t bytesToSend; - mtuProbe_ = MTUS.cbegin(); - RING_DBG("[TLS] Heartbeat PMTUD : client side"); - - while (mtuProbe_ != MTUS.cend()){ - bytesToSend = (*mtuProbe_) - 3 - tls_overhead - UDP_HEADER_SIZE - transportOverhead_; + RING_DBG() << "[TLS] PMTUD: starting probing with " << HEARTBEAT_RETRANS_TIMEOUT.count() + << "ms of retransmission timeout"; + + gnutls_heartbeat_set_timeouts(session_, + HEARTBEAT_RETRANS_TIMEOUT.count(), + HEARTBEAT_TOTAL_TIMEOUT.count()); + + int errno_send = GNUTLS_E_SUCCESS; + mtuProbe_ = MTUS_[0]; + for (auto mtu: MTUS_) { + gnutls_dtls_set_mtu(session_, mtu); + auto data_mtu = gnutls_dtls_get_data_mtu(session_); + RING_DBG() << "[TLS] PMTUD: mtu " << mtu + << ", payload " << data_mtu; + auto bytesToSend = data_mtu - 3; // want to know why -3? ask gnutls! do { - RING_DBG("[TLS] Heartbeat PMTUD : ping with mtu %d and effective payload %d", *mtuProbe_, bytesToSend); errno_send = gnutls_heartbeat_ping(session_, bytesToSend, HEARTBEAT_TRIES, GNUTLS_HEARTBEAT_WAIT); - RING_DBG("[TLS] Heartbeat PMTUD : ping sequence over with errno %d: %s", errno_send, - gnutls_strerror(errno_send)); } while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED); - if (errno_send == GNUTLS_E_SUCCESS) { - ++mtuProbe_; - } else if (errno_send == GNUTLS_E_TIMEDOUT){ // timeout is considered as a packet loss, then the good mtu is the precedent. - if (mtuProbe_ == MTUS.cbegin()) { - RING_WARN("[TLS] Heartbeat PMTUD : no response on first ping, setting minimal MTU value @%d", MIN_MTU); - gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_); - - } else { - --mtuProbe_; - RING_DBG("[TLS] Heartbeat PMTUD : timed out: Path MTU found @ %d", *mtuProbe_); - gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); - } - return; - } else { - RING_WARN("[TLS] Heartbeat PMTUD : client ping failed: error %d: %s", errno_send, - gnutls_strerror(errno_send)); - if (mtuProbe_ != MTUS.begin()) - --mtuProbe_; - gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); - return; + if (errno_send != GNUTLS_E_SUCCESS) { + RING_DBG() << "[TLS] PMTUD: mtu " << mtu << " [FAILED]"; + break; } - } - - if (errno_send == GNUTLS_E_SUCCESS) { - RING_WARN("[TLS] Heartbeat PMTUD completed : reached test value %d", MTUS.back()); - --mtuProbe_; // for loop over, setting mtu to last valid mtu + mtuProbe_ = mtu; + RING_DBG() << "[TLS] PMTUD: mtu " << mtu << " [OK]"; } - gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); - RING_WARN("[TLS] Heartbeat PMTUD : new mtu set to %d", *mtuProbe_); + if (errno_send == GNUTLS_E_TIMEDOUT) { // timeout is considered as a packet loss, then the good mtu is the precedent + if (mtuProbe_ == MTUS_[0]) { + RING_WARN() << "[TLS] PMTUD: no response on first ping, using minimal MTU value " + << mtuProbe_; + } else { + RING_WARN() << "[TLS] PMTUD: timed out, using last working mtu " + << mtuProbe_; + } + } else if (errno_send != GNUTLS_E_SUCCESS) { + RING_ERR() << "[TLS] PMTUD: failed with gnutls error '" + << gnutls_strerror(errno_send) << '\''; + } else { + RING_DBG() << "[TLS] PMTUD: reached maximal value"; + } } void -TlsSession::TlsSessionImpl::handleDataPacket(std::vector<uint8_t>&& buf, uint64_t pkt_seq) +TlsSession::TlsSessionImpl::handleDataPacket(std::vector<ValueType>&& buf, uint64_t pkt_seq) { // Check for a valid seq. num. delta int64_t seq_delta = pkt_seq - lastRxSeq_; @@ -936,14 +964,14 @@ TlsSession::TlsSessionImpl::handleDataPacket(std::vector<uint8_t>&& buf, uint64_ } else { // too old? if (seq_delta <= -MISS_ORDERING_LIMIT) { - RING_WARN("[dtls] drop old pkt: 0x%lx", pkt_seq); + RING_WARN("[TLS] drop old pkt: 0x%lx", pkt_seq); return; } // No duplicate check as DTLS prevents that for us (replay protection) // accept Out-Of-Order pkt - will be reordered by queue flush operation - RING_WARN("[dtls] OOO pkt: 0x%lx", pkt_seq); + RING_WARN("[TLS] OOO pkt: 0x%lx", pkt_seq); } { @@ -986,15 +1014,14 @@ TlsSession::TlsSessionImpl::flushRxQueue() auto item = std::begin(reorderBuffer_); auto next_offset = item->first; - auto first_offset = next_offset; // Wait for next continous packet until timeout if ((clock::now() - lastReadTime_) >= RX_OOO_TIMEOUT) { // OOO packet timeout - consider waited packets as lost if (auto lost = next_offset - gapOffset_) - RING_WARN("[dtls] %lu lost since 0x%lx", lost, gapOffset_); + RING_WARN("[TLS] %lu lost since 0x%lx", lost, gapOffset_); else - RING_WARN("[dtls] slow flush"); + RING_WARN("[TLS] slow flush"); } else if (next_offset != gapOffset_) return; @@ -1021,6 +1048,13 @@ TlsSession::TlsSessionImpl::flushRxQueue() TlsSessionState TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) { + // Nothing to do in reliable mode, so just wait for state change + if (transport_.isReliable()) { + std::unique_lock<std::mutex> lk {rxMutex_}; + rxCv_.wait(lk, [this]{ return state_ != TlsSessionState::ESTABLISHED; }); + return state; + } + // block until rx packet or state change { std::unique_lock<std::mutex> lk {rxMutex_}; @@ -1035,18 +1069,13 @@ TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) auto ret = gnutls_record_recv_seq(session_, rawPktBuf_.data(), rawPktBuf_.size(), &seq[0]); if (ret > 0) { + // Are we in PMTUD phase? if (!pmtudOver_) { - // This is the first application packet recieved after PMTUD - // This packet gives the final MTU. - if (hbPingRecved_ > 0) { - gnutls_dtls_set_mtu(session_, MTUS[hbPingRecved_ - 1] - UDP_HEADER_SIZE - transportOverhead_); - maxPayload_ = gnutls_dtls_get_data_mtu(session_); - } else { - gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_); - maxPayload_ = gnutls_dtls_get_data_mtu(session_); - } + mtuProbe_ = MTUS_[std::max(0, hbPingRecved_ - 1)]; + gnutls_dtls_set_mtu(session_, mtuProbe_); + maxPayload_ = gnutls_dtls_get_data_mtu(session_); pmtudOver_ = true; - RING_WARN("[TLS] maxPayload for dtls : %d B", maxPayload_.load()); + RING_DBG() << "[TLS] maxPayload: " << maxPayload_.load(); if (!initFromRecordState(-1)) return TlsSessionState::SHUTDOWN; @@ -1056,11 +1085,11 @@ TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) handleDataPacket(std::move(rawPktBuf_), array2uint(seq)); // no state change } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) { - RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong"); + RING_DBG("[TLS] PMTUD: ping received sending pong"); auto errno_send = gnutls_heartbeat_pong(session_, 0); if (errno_send != GNUTLS_E_SUCCESS){ - RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send, + RING_ERR("[TLS] PMTUD: failed on pong with error %d: %s", errno_send, gnutls_strerror(errno_send)); } else { ++hbPingRecved_; @@ -1106,10 +1135,10 @@ TlsSession::TlsSessionImpl::process() //============================================================================== -TlsSession::TlsSession(const std::shared_ptr<IceTransport>& ice, int ice_comp_id, - const TlsParams& params, const TlsSessionCallbacks& cbs, bool anonymous) +TlsSession::TlsSession(SocketType& transport, const TlsParams& params, + const TlsSessionCallbacks& cbs, bool anonymous) - : pimpl_ { std::make_unique<TlsSessionImpl>(ice, ice_comp_id, params, cbs, anonymous) } + : pimpl_ { std::make_unique<TlsSessionImpl>(transport, params, cbs, anonymous) } {} TlsSession::~TlsSession() @@ -1117,34 +1146,28 @@ TlsSession::~TlsSession() shutdown(); } -const char* -TlsSession::typeName() const -{ - return pimpl_->typeName(); -} - bool -TlsSession::isServer() const +TlsSession::isInitiator() const { - return pimpl_->isServer_; + return !pimpl_->isServer_; } -unsigned int -TlsSession::getMaxPayload() const +bool +TlsSession::isReliable() const { - return pimpl_->maxPayload_; + return pimpl_->transport_.isReliable(); } -// Called by anyone to stop the connection and the FSM thread -void -TlsSession::shutdown() +int +TlsSession::maxPayload() const { - pimpl_->state_ = TlsSessionState::SHUTDOWN; - pimpl_->rxCv_.notify_one(); // unblock waiting FSM + if (pimpl_->state_ == TlsSessionState::SHUTDOWN) + throw std::runtime_error("Getting MTU from non-valid TLS session"); + return gnutls_dtls_get_data_mtu(pimpl_->session_); } const char* -TlsSession::getCurrentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const +TlsSession::currentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const { // get current session cipher suite info gnutls_cipher_algorithm_t cipher, s_cipher = gnutls_cipher_get(pimpl_->session_); @@ -1166,27 +1189,71 @@ TlsSession::getCurrentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const return {}; } -ssize_t -TlsSession::send(const void* data, std::size_t size) +// Called by anyone to stop the connection and the FSM thread +void +TlsSession::shutdown() +{ + pimpl_->state_ = TlsSessionState::SHUTDOWN; + pimpl_->rxCv_.notify_one(); // unblock waiting FSM +} + +std::size_t +TlsSession::write(const ValueType* data, std::size_t size, std::error_code& ec) { - std::lock_guard<std::mutex> lk {pimpl_->txMutex_}; - if (pimpl_->state_ != TlsSessionState::ESTABLISHED) - return GNUTLS_E_INVALID_SESSION; - return pimpl_->send_(static_cast<const uint8_t*>(data), size); + if (pimpl_->state_ != TlsSessionState::ESTABLISHED) { + ec = std::make_error_code(std::errc::broken_pipe); + return 0; + } + + return pimpl_->send(data, size, ec); } -ssize_t -TlsSession::send(const std::vector<uint8_t>& vec) +std::size_t +TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec) { - return send(vec.data(), vec.size()); + std::errc error; + + if (pimpl_->state_ != TlsSessionState::ESTABLISHED) { + ec = std::make_error_code(std::errc::broken_pipe); + return 0; + } + + while (true) { + auto ret = gnutls_record_recv(pimpl_->session_, data, size); + if (ret > 0) { + ec.clear(); + return ret; + } + + if (ret == 0) { + RING_DBG("[TLS] eof"); + shutdown(); + error = std::errc::broken_pipe; + break; + } else if (ret == GNUTLS_E_REHANDSHAKE) { + RING_DBG("[TLS] re-handshake"); + pimpl_->state_ = TlsSessionState::HANDSHAKE; + pimpl_->rxCv_.notify_one(); // unblock waiting FSM + } else if (gnutls_error_is_fatal(ret)) { + RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret)); + shutdown(); + error = std::errc::io_error; + break; + } + } + + ec = std::make_error_code(error); + return 0; } -uint16_t -TlsSession::getMtu() +void +TlsSession::connect() { - if (pimpl_->state_ == TlsSessionState::SHUTDOWN) - throw std::runtime_error("Getting MTU from dead TLS session"); - return gnutls_dtls_get_mtu(pimpl_->session_); + TlsSessionState state; + do { + state = pimpl_->state_.load(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while (state != TlsSessionState::ESTABLISHED and state != TlsSessionState::SHUTDOWN); } }} // namespace ring::tls diff --git a/src/security/tls_session.h b/src/security/tls_session.h index 7b6c48dce9..5821159831 100644 --- a/src/security/tls_session.h +++ b/src/security/tls_session.h @@ -22,6 +22,7 @@ #pragma once #include "noncopyable.h" +#include "generic_io.h" #include <gnutls/gnutls.h> @@ -32,12 +33,6 @@ #include <chrono> #include <vector> #include <array> -#include <cstdint> - -namespace ring { -class IceTransport; -class IceSocket; -} // namespace ring namespace dht { namespace crypto { struct Certificate; @@ -48,20 +43,18 @@ namespace ring { namespace tls { class DhParams; -static constexpr uint8_t MTUS_TO_TEST = 3; //number of mtus to test in path mtu discovery. -static constexpr int DTLS_MTU {1232}; // (1280 from IPv6 minimum MTU - 40 IPv6 header - 8 UDP header) -static constexpr uint16_t MIN_MTU {512}; - -enum class TlsSessionState { +enum class TlsSessionState +{ SETUP, - COOKIE, // server only + COOKIE, // only used with non-initiator and non-reliable transport HANDSHAKE, - MTU_DISCOVERY, + MTU_DISCOVERY, // only used with non-reliable transport ESTABLISHED, SHUTDOWN }; -struct TlsParams { +struct TlsParams +{ // User CA list for session credentials std::string ca_list; @@ -83,20 +76,22 @@ struct TlsParams { unsigned cert_list_size)> cert_check; }; -/** - * TlsSession - * - * Manages a DTLS connection over an ICE transport. - * This implementation uses a Threadloop to manage IO from ICE and TLS states, - * so IO are asynchronous. - */ -class TlsSession { +/// TlsSession +/// +/// Manages a TLS/DTLS data transport overlayed on a given generic socket. +/// +/// \note API is not thread-safe. +/// +class TlsSession : public GenericSocket<uint8_t> +{ public: + using SocketType = GenericSocket<uint8_t>; using OnStateChangeFunc = std::function<void(TlsSessionState)>; using OnRxDataFunc = std::function<void(std::vector<uint8_t>&&)>; - using OnCertificatesUpdate = std::function<void(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int)>; + using OnCertificatesUpdate = std::function<void(const gnutls_datum_t*, + const gnutls_datum_t*, + unsigned int)>; using VerifyCertificate = std::function<int(gnutls_session_t)>; - using TxDataCompleteFunc = std::function<void(std::size_t bytes_sent)>; // ===> WARNINGS <=== // Following callbacks are called into the FSM thread context @@ -108,38 +103,44 @@ public: VerifyCertificate verifyCertificate; }; - TlsSession(const std::shared_ptr<IceTransport>& ice, int ice_comp_id, const TlsParams& params, - const TlsSessionCallbacks& cbs, bool anonymous=true); + TlsSession(SocketType& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, + bool anonymous=true); ~TlsSession(); - // Returns the TLS session type ('server' or 'client') - const char* typeName() const; + /// Return the name of current cipher. + /// Can be called by onStateChange callback when state == ESTABLISHED + /// to obtain the used cypher suite id. + const char* currentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const; - bool isServer() const; - - // Request TLS thread to stop and quit. IO are not possible after that. + /// Request TLS thread to stop and quit. + /// \note IO operations return error after this call. void shutdown(); - // Return maximum application payload size in bytes - // Returned value must be checked and considered valid only if not 0 (session is initialized) - unsigned int getMaxPayload() const; + void setOnRecv(RecvCb&& cb) override { + (void)cb; + throw std::logic_error("TlsSession::setOnRecv not implemented"); + } + + /// Return true if the TLS session type is a server. + bool isInitiator() const override; + + bool isReliable() const override; + + int maxPayload() const override; - // Can be called by onStateChange callback when state == ESTABLISHED - // to obtain the used cypher suite id. - // Return the name of current cipher. - const char* getCurrentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const; + void connect(); - // Asynchronous sending operation. on_send_complete will be called with a positive number - // for number of bytes sent, or negative for errors, or 0 in case of shutdown (end of session). - int async_send(const void* data, std::size_t size, TxDataCompleteFunc on_send_complete); - int async_send(std::vector<uint8_t>&& data, TxDataCompleteFunc on_send_complete); + /// Synchronous writing. + /// Return a positive number for number of bytes write, or 0 and \a ec set in case of error. + std::size_t write(const ValueType* data, std::size_t size, std::error_code& ec) override; - // Synchronous sending operation. Return negative number (gnutls error) or a positive number - // for bytes sent. - ssize_t send(const void* data, std::size_t size); - ssize_t send(const std::vector<uint8_t>& data); + /// Synchronous reading. + /// Return a positive number for number of bytes read, or 0 and \a ec set in case of error. + std::size_t read(ValueType* data, std::size_t size, std::error_code& ec) override; - uint16_t getMtu(); + bool waitForData(unsigned) const override { + throw std::logic_error("TlsSession::waitForData not implemented"); + } private: class TlsSessionImpl; diff --git a/src/sip/siptransport.cpp b/src/sip/siptransport.cpp index b7d8089e71..09b9e941db 100644 --- a/src/sip/siptransport.cpp +++ b/src/sip/siptransport.cpp @@ -182,7 +182,12 @@ SipTransport::getTlsMtu() auto tls_tr = reinterpret_cast<tls::SipsIceTransport::TransportData*>(transport_.get())->self; return tls_tr->getTlsSessionMtu(); } - return ring::tls::DTLS_MTU; + return 1232; /* Hardcoded yes (it's the IPv6 value). + * This method is broken by definition. + * A MTU should not be defined at this layer. + * And a correct value should come from the underlying transport itself, + * not from a constant... + */ } SipTransportBroker::SipTransportBroker(pjsip_endpoint *endpt, diff --git a/src/turn_transport.cpp b/src/turn_transport.cpp index fbe666c4d3..4d70fc3777 100644 --- a/src/turn_transport.cpp +++ b/src/turn_transport.cpp @@ -29,7 +29,6 @@ #include <pjlib-util.h> #include <pjlib.h> -#include <stdexcept> #include <future> #include <atomic> #include <thread> @@ -46,6 +45,9 @@ namespace ring { using MutexGuard = std::lock_guard<std::mutex>; using MutexLock = std::unique_lock<std::mutex>; +inline +namespace { + enum class RelayState { NONE, @@ -82,31 +84,19 @@ public: cv_.notify_one(); } - void read(std::vector<char>& output) { + template <typename Duration> + bool wait(Duration timeout) { MutexLock lk {mutex_}; - cv_.wait(lk, [&, this]{ - stream_.read(&output[0], output.size()); - return stream_.gcount() > 0 or stop_; - }); - output.resize(stop_ ? 0 : stream_.gcount()); + return cv_.wait_for(lk, timeout, [this]{ return !stream_.eof(); }); } - std::vector<char> readline() { + std::size_t read(char* output, std::size_t size) { MutexLock lk {mutex_}; - std::vector<char> result(3000); - cv_.wait(lk, [&, this] { - if (stop_) - return true; - stream_.getline(&result[0], 3000); - if (stream_) { - result.resize(stream_.gcount()); - return result.size() > 0; - } - return false; + cv_.wait(lk, [&, this]{ + stream_.read(&output[0], size); + return stream_.gcount() > 0 or stop_; }); - if (stop_) - return {}; - return result; + return stop_ ? 0 : stream_.gcount(); } private: @@ -120,6 +110,31 @@ private: friend void operator <<(std::vector<char>&, PeerChannel&); }; +} + +//============================================================================== + +template <class Callable, class... Args> +inline void +PjsipCall(Callable& func, Args... args) +{ + auto status = func(args...); + if (status != PJ_SUCCESS) + throw sip_utils::PjsipFailure(status); +} + +template <class Callable, class... Args> +inline auto +PjsipCallReturn(const Callable& func, Args... args) -> decltype(func(args...)) +{ + auto res = func(args...); + if (!res) + throw sip_utils::PjsipFailure(); + return res; +} + +//============================================================================== + class TurnTransportPimpl { public: @@ -135,6 +150,7 @@ public: std::map<IpAddr, PeerChannel> peerChannels_; + GenericSocket<uint8_t>::RecvCb onRxDataCb; TurnTransportParams settings; pj_caching_pool poolCache {}; pj_pool_t* pool {nullptr}; @@ -193,7 +209,10 @@ TurnTransportPimpl::onRxData(const uint8_t* pkt, unsigned pkt_len, return; } - (channel_it->second) << std::string(reinterpret_cast<const char*>(pkt), pkt_len); + if (onRxDataCb) + onRxDataCb(pkt, pkt_len); + else + (channel_it->second) << std::string(reinterpret_cast<const char*>(pkt), pkt_len); } void @@ -226,45 +245,10 @@ TurnTransportPimpl::ioJob() const pj_time_val delay = {0, 10}; pj_ioqueue_poll(stunConfig.ioqueue, &delay); pj_timer_heap_poll(stunConfig.timer_heap, nullptr); - } -} - -class PjsipError final : public std::exception { -public: - PjsipError() = default; - explicit PjsipError(pj_status_t st) : std::exception() { - char err_msg[PJ_ERR_MSG_SIZE]; - pj_strerror(st, err_msg, sizeof(err_msg)); - what_msg_ += ": "; - what_msg_ += err_msg; } - const char* what() const noexcept override { - return what_msg_.c_str(); - }; -private: - std::string what_msg_ {"PJSIP api error"}; -}; - -template <class Callable, class... Args> -inline void -PjsipCall(Callable& func, Args... args) -{ - auto status = func(args...); - if (status != PJ_SUCCESS) - throw PjsipError(status); -} - -template <class Callable, class... Args> -inline auto -PjsipCallReturn(const Callable& func, Args... args) -> decltype(func(args...)) -{ - auto res = func(args...); - if (!res) - throw PjsipError(); - return res; } -//================================================================================================== +//============================================================================== TurnTransport::TurnTransport(const TurnTransportParams& params) : pimpl_ {new TurnTransportPimpl} @@ -354,6 +338,12 @@ TurnTransport::TurnTransport(const TurnTransportParams& params) TurnTransport::~TurnTransport() {} +bool +TurnTransport::isInitiator() const +{ + return !pimpl_->settings.server; +} + void TurnTransport::permitPeer(const IpAddr& addr) { @@ -363,6 +353,7 @@ TurnTransport::permitPeer(const IpAddr& addr) if (addr.getFamily() != pimpl_->peerRelayAddr.getFamily()) throw std::invalid_argument("mismatching peer address family"); + sip_utils::register_thread(); PjsipCall(pj_turn_sock_set_perm, pimpl_->relay, 1, addr.pjPtr(), 1); } @@ -400,7 +391,7 @@ TurnTransport::sendto(const IpAddr& peer, const char* const buffer, std::size_t reinterpret_cast<const pj_uint8_t*>(buffer), length, peer.pjPtr(), peer.getLength()); if (status != PJ_SUCCESS && status != PJ_EPENDING) - throw PjsipError(status); + throw sip_utils::PjsipFailure(status); return status == PJ_SUCCESS; } @@ -411,40 +402,84 @@ TurnTransport::sendto(const IpAddr& peer, const std::vector<char>& buffer) return sendto(peer, &buffer[0], buffer.size()); } -bool -TurnTransport::writelineto(const IpAddr& peer, const char* const buffer, std::size_t length) +std::size_t +TurnTransport::recvfrom(const IpAddr& peer, char* buffer, std::size_t size) { - if (sendto(peer, buffer, length)) - return sendto(peer, "\n", 1); - return false; + MutexLock lk {pimpl_->apiMutex_}; + auto& channel = pimpl_->peerChannels_.at(peer); + lk.unlock(); + return channel.read(buffer, size); } void TurnTransport::recvfrom(const IpAddr& peer, std::vector<char>& result) { - if (result.empty()) - throw std::runtime_error("TurnTransport::recvfrom() called with an empty output buffer"); + auto res = recvfrom(peer, result.data(), result.size()); + result.resize(res); +} +std::vector<IpAddr> +TurnTransport::peerAddresses() const +{ MutexLock lk {pimpl_->apiMutex_}; - auto& channel = pimpl_->peerChannels_.at(peer); - lk.unlock(); - channel.read(result); + return map_utils::extractKeys(pimpl_->peerChannels_); } -void -TurnTransport::readlinefrom(const IpAddr& peer, std::vector<char>& result) +bool +TurnTransport::waitForData(const IpAddr& peer, unsigned ms_timeout) const { MutexLock lk {pimpl_->apiMutex_}; auto& channel = pimpl_->peerChannels_.at(peer); lk.unlock(); - result = channel.readline(); + return channel.wait(std::chrono::milliseconds(ms_timeout)); } -std::vector<IpAddr> -TurnTransport::peerAddresses() const +//============================================================================== + +ConnectedTurnTransport::ConnectedTurnTransport(TurnTransport& turn, const IpAddr& peer) + : turn_ {turn} + , peer_ {peer} +{} + +bool +ConnectedTurnTransport::waitForData(unsigned ms_timeout) const { - MutexLock lk {pimpl_->apiMutex_}; - return map_utils::extractKeys(pimpl_->peerChannels_); + return turn_.waitForData(peer_, ms_timeout); +} + +std::size_t +ConnectedTurnTransport::write(const ValueType* buf, std::size_t size, std::error_code& ec) +{ + try { + turn_.sendto(peer_, reinterpret_cast<const char*>(buf), size); + } catch (const sip_utils::PjsipFailure& ex) { + ec = ex.code(); + return 0; + } + + ec.clear(); + return size; +} + +std::size_t +ConnectedTurnTransport::read(ValueType* buf, std::size_t size, std::error_code& ec) +{ + if (size > 0) { + try { + size = turn_.recvfrom(peer_, reinterpret_cast<char*>(buf), size); + } catch (const sip_utils::PjsipFailure& ex) { + ec = ex.code(); + return 0; + } + + if (size == 0) { + ec = std::make_error_code(std::errc::broken_pipe); + return 0; + } + } + + ec.clear(); + return size; } } // namespace ring diff --git a/src/turn_transport.h b/src/turn_transport.h index 4b59568eb3..f884af7a98 100644 --- a/src/turn_transport.h +++ b/src/turn_transport.h @@ -21,17 +21,20 @@ #pragma once #include "ip_utils.h" +#include "generic_io.h" #include <string> #include <memory> #include <functional> #include <map> +#include <stdexcept> namespace ring { class TurnTransportPimpl; -struct TurnTransportParams { +struct TurnTransportParams +{ IpAddr server; // Plain Credentials @@ -46,7 +49,8 @@ struct TurnTransportParams { std::size_t maxPacketSize {4096}; ///< size of one "logical" packet }; -class TurnTransport { +class TurnTransport +{ public: /// Constructs a TurnTransport connected by TCP to given server. /// @@ -60,6 +64,8 @@ public: ~TurnTransport(); + bool isInitiator() const; + /// Wait for successful connection on the TURN server. /// /// TurnTransport constructor connects asynchronously on the TURN server. @@ -106,10 +112,9 @@ public: /// void recvfrom(const IpAddr& peer, std::vector<char>& data); - /// Work as recvfrom but stop on first '\n' character found. - /// If such character isn't found, stop at /a data vector size. + /// Works as recvfrom() vector version but accept a simple char array. /// - void readlinefrom(const IpAddr& peer, std::vector<char>& data); + std::size_t recvfrom(const IpAddr& peer, char* buffer, std::size_t size); /// Send data to given peer through the TURN tunnel. /// @@ -124,11 +129,9 @@ public: /// Works as sendto() vector version but accept a simple char array. /// - bool sendto(const IpAddr& peer, const char* const buffer, std::size_t length); + bool sendto(const IpAddr& peer, const char* const buffer, std::size_t size); - /// Works as sendto() char array but happend a '\n' character at the end of sent data. - /// - bool writelineto(const IpAddr& peer, const char* const buffer, std::size_t length); + bool waitForData(const IpAddr& peer, unsigned ms_timeout) const; public: // Move semantic only, not copiable @@ -140,4 +143,27 @@ private: std::unique_ptr<TurnTransportPimpl> pimpl_; }; +class ConnectedTurnTransport final : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + + ConnectedTurnTransport(TurnTransport& turn, const IpAddr& peer); + + bool isReliable() const override { return true; } + bool isInitiator() const override { return turn_.isInitiator(); } + int maxPayload() const override { return 3000; } + + bool waitForData(unsigned ms_timeout) const override; + std::size_t read(ValueType* buf, std::size_t length, std::error_code& ec) override; + std::size_t write(const ValueType* buf, std::size_t length, std::error_code& ec) override; + + void setOnRecv(RecvCb&&) override { throw std::logic_error("ConnectedTurnTransport bad call"); } + +private: + TurnTransport& turn_; + const IpAddr peer_; + RecvCb onRxDataCb_; +}; + } // namespace ring -- GitLab