diff --git a/src/Makefile.am b/src/Makefile.am index 4d608b9f0343a16417924efbc92779b6ac6274e2..2bdea570f82ea8935d6cf7a56cc625efa36b97fa 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -143,7 +143,8 @@ libring_la_SOURCES = \ base64.cpp \ turn_transport.h \ turn_transport.cpp \ - channel.h + channel.h \ + generic_io.h if HAVE_WIN32 libring_la_SOURCES += \ diff --git a/src/generic_io.h b/src/generic_io.h new file mode 100644 index 0000000000000000000000000000000000000000..8e0ae80eb9d2a5c16040f0b7b6224215d72f012e --- /dev/null +++ b/src/generic_io.h @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2017 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include <functional> +#include <vector> +#include <system_error> +#include <cstdint> + +#if defined(_MSC_VER) +#include <BaseTsd.h> +using ssize_t = SSIZE_T; +#endif + +namespace ring { + +template <typename T> +class GenericSocket +{ +public: + using ValueType = T; + + virtual ~GenericSocket() = default; + + using RecvCb = std::function<ssize_t(const ValueType* buf, std::size_t len)>; + + /// Set Rx callback + /// \warning This method is here for backward compatibility + /// and because async IO are not implemented yet. + virtual void setOnRecv(RecvCb&& cb) = 0; + + virtual bool isReliable() const = 0; + + virtual bool isInitiator() const = 0; + + /// Return maximum application payload size. + /// This value is negative if the session is not ready to give a valid answer. + /// The value is 0 if such information is irrelevant for the session. + /// If stricly positive, the user must use send() with an input buffer size below or equals + /// to this value if it want to be sure that the transport sent it in an atomic way. + /// Example: in case of non-reliable transport using packet oriented IO, + /// this value gives the maximal size used to send one packet. + virtual int maxPayload() const = 0; + + // TODO: make a std::chrono version + virtual bool waitForData(unsigned ms_timeout) const = 0; + + /// Write a given amount of data. + /// \param buf data to write. + /// \param len number of bytes to write. + /// \param ec error code set in case of error. + /// \return number of bytes written, 0 is valid. + /// \warning error checking consists in checking if \a !ec is true, not if returned size is 0 + /// as a write of 0 could be considered a valid operation. + virtual std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) = 0; + + /// Read a given amount of data. + /// \param buf data to read. + /// \param len number of bytes to read. + /// \param ec error code set in case of error. + /// \return number of bytes read, 0 is valid. + /// \warning error checking consists in checking if \a !ec is true, not if returned size is 0 + /// as a read of 0 could be considered a valid operation (i.e. non-blocking IO). + virtual std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) = 0; + + /// write() adaptor for STL containers + template <typename U> + std::size_t write(const U& obj, std::error_code& ec) { + return write(obj.data(), obj.size() * sizeof(typename U::value_type), ec); + } + + /// read() adaptor for STL containers + template <typename U> + std::size_t read(U& storage, std::error_code& ec) { + auto res = read(storage.data(), storage.size() * sizeof(typename U::value_type), ec); + if (!ec) + storage.resize(res); + return res; + } + +protected: + GenericSocket() = default; +}; + +} // namespace ring diff --git a/src/ice_socket.h b/src/ice_socket.h index ba12a1b04e02b414a20137afa9c557fa5a1c1b6c..0d81a2aa33d99ceb6f42079f83387f5a88144971 100644 --- a/src/ice_socket.h +++ b/src/ice_socket.h @@ -17,8 +17,9 @@ * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ -#ifndef ICE_SOCKET_H -#define ICE_SOCKET_H +#pragma once + +#include "generic_io.h" #include <memory> #include <functional> @@ -51,6 +52,44 @@ class IceSocket uint16_t getTransportOverhead(); }; +/// ICE transport as a GenericSocket. +/// +/// \warning Simplified version where we assume that ICE protocol +/// always use UDP over IP over ETHERNET, and doesn't add more header to the UDP payload. +/// +class IceSocketTransport final : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + + static constexpr uint16_t STANDARD_MTU_SIZE = 1280; // Size in bytes of MTU for IPv6 capable networks + static constexpr uint16_t IPV6_HEADER_SIZE = 40; // Size in bytes of IPv6 packet header + static constexpr uint16_t IPV4_HEADER_SIZE = 20; // Size in bytes of IPv4 packet header + static constexpr uint16_t UDP_HEADER_SIZE = 8; // Size in bytes of UDP header + + IceSocketTransport(std::shared_ptr<IceTransport>& ice, int comp_id) + : compId_ {comp_id} + , ice_ {ice} {} + + bool isReliable() const override { + return false; // we consider that a ICE transport is never reliable (UDP support only) + } + + bool isInitiator() const override; + + int maxPayload() const override; + + bool waitForData(unsigned ms_timeout) const override; + + std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override; + + std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override; + + void setOnRecv(RecvCb&& cb) override; + +private: + const int compId_; + std::shared_ptr<IceTransport> ice_; }; -#endif /* ICE_SOCKET_H */ +}; diff --git a/src/ice_transport.cpp b/src/ice_transport.cpp index 7ab020daa131d972775050949f6dbd32a0d4834e..82645ac8bcdc5e6f15e850f05255878838f9b9a8 100644 --- a/src/ice_transport.cpp +++ b/src/ice_transport.cpp @@ -1176,6 +1176,58 @@ IceTransportFactory::createTransport(const char* name, int component_count, //============================================================================== +void +IceSocketTransport::setOnRecv(RecvCb&& cb) +{ + return ice_->setOnRecv(compId_, cb); +} + +bool +IceSocketTransport::isInitiator() const +{ + return ice_->isInitiator(); +} + +int +IceSocketTransport::maxPayload() const +{ + auto ip_header_size = (ice_->getRemoteAddress(compId_).getFamily() == AF_INET) ? + IPV4_HEADER_SIZE : IPV6_HEADER_SIZE; + return STANDARD_MTU_SIZE - ip_header_size - UDP_HEADER_SIZE; +} + +bool +IceSocketTransport::waitForData(unsigned ms_timeout) const +{ + return ice_->waitForData(compId_, ms_timeout) > 0; +} + +std::size_t +IceSocketTransport::write(const ValueType* buf, std::size_t len, std::error_code& ec) +{ + auto res = ice_->send(compId_, buf, len); + if (res < 0) { + ec.assign(errno, std::generic_category()); + return 0; + } + ec.clear(); + return res; +} + +std::size_t +IceSocketTransport::read(ValueType* buf, std::size_t len, std::error_code& ec) +{ + auto res = ice_->recv(compId_, buf, len); + if (res < 0) { + ec.assign(errno, std::generic_category()); + return 0; + } + ec.clear(); + return res; +} + +//============================================================================== + void IceSocket::close() { diff --git a/src/ringdht/sips_transport_ice.cpp b/src/ringdht/sips_transport_ice.cpp index a86b67e82241a2d0192ec90d056bd614938f0cda..7142b4471d3061477db0482171647914fca7ddb8 100644 --- a/src/ringdht/sips_transport_ice.cpp +++ b/src/ringdht/sips_transport_ice.cpp @@ -21,7 +21,9 @@ #include "sips_transport_ice.h" +#include "ice_socket.h" #include "ice_transport.h" + #include "manager.h" #include "sip/sip_utils.h" #include "logger.h" @@ -38,6 +40,7 @@ #include <pj/lock.h> #include <algorithm> +#include <system_error> #include <cstring> // std::memset namespace ring { namespace tls { @@ -233,14 +236,16 @@ SipsIceTransport::SipsIceTransport(pjsip_endpoint* endpt, std::memset(&localCertInfo_, 0, sizeof(pj_ssl_cert_info)); std::memset(&remoteCertInfo_, 0, sizeof(pj_ssl_cert_info)); + iceSocket_ = std::make_unique<IceSocketTransport>(ice_, comp_id); + TlsSession::TlsSessionCallbacks cbs = { /*.onStateChange = */[this](TlsSessionState state){ onTlsStateChange(state); }, /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onRxData(std::move(buf)); }, /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r, - unsigned int n){ onCertificatesUpdate(l, r, n); }, + unsigned int n){ onCertificatesUpdate(l, r, n); }, /*.verifyCertificate = */[this](gnutls_session_t session){ return verifyCertificate(session); } }; - tls_.reset(new TlsSession(ice, comp_id, param, cbs)); + tls_ = std::make_unique<TlsSession>(*iceSocket_, param, cbs); if (pjsip_transport_register(base.tpmgr, &base) != PJ_SUCCESS) throw std::runtime_error("Can't register PJSIP transport."); @@ -323,18 +328,19 @@ SipsIceTransport::handleEvents() pj_status_t status; if (!fatal) { const std::size_t size = tdata->buf.cur - tdata->buf.start; - auto ret = tls_->send(tdata->buf.start, size); - if (gnutls_error_is_fatal(ret)) { - RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ret)); - tls_->shutdown(); - fatal = true; + std::error_code ec; + status = tls_->write(reinterpret_cast<const uint8_t*>(tdata->buf.start), size, ec); + if (ec) { + status = tls_status_from_err(ec.value()); + if (gnutls_error_is_fatal(ec.value())) { + RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ec.value())); + tls_->shutdown(); + fatal = true; + } } - if (ret < 0) - status = tls_status_from_err(ret); - else - status = ret; - } else + } else { status = -PJ_RETURN_OS_ERROR(OSERR_ENOTCONN); + } tdata->op_key.tdata = nullptr; if (tdata->op_key.callback) @@ -501,7 +507,7 @@ SipsIceTransport::getInfo(pj_ssl_sock_info* info, bool established) if (established) { // Cipher Suite Id std::array<uint8_t, 2> cs_id; - if (auto cipher_name = tls_->getCurrentCipherSuiteId(cs_id)) { + if (auto cipher_name = tls_->currentCipherSuiteId(cs_id)) { info->cipher = static_cast<pj_ssl_cipher>((cs_id[0] << 8) | cs_id[1]); RING_DBG("[TLS] using cipher %s (0x%02X%02X)", cipher_name, cs_id[0], cs_id[1]); } else @@ -673,14 +679,15 @@ SipsIceTransport::send(pjsip_tx_data* tdata, const pj_sockaddr_t* rem_addr, const std::size_t size = tdata->buf.cur - tdata->buf.start; std::unique_lock<std::mutex> lk {txMutex_}; if (syncTx_ and txQueue_.empty()) { - auto ret = tls_->send(tdata->buf.start, size); + std::error_code ec; + tls_->write(reinterpret_cast<const uint8_t*>(tdata->buf.start), size, ec); lk.unlock(); // Shutdown on fatal error, else ignore it - if (gnutls_error_is_fatal(ret)) { - RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ret)); + if (ec and gnutls_error_is_fatal(ec.value())) { + RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ec.value())); tls_->shutdown(); - return tls_status_from_err(ret); + return tls_status_from_err(ec.value()); } return PJ_SUCCESS; @@ -697,7 +704,7 @@ SipsIceTransport::send(pjsip_tx_data* tdata, const pj_sockaddr_t* rem_addr, uint16_t SipsIceTransport::getTlsSessionMtu() { - return tls_->getMtu(); + return tls_->maxPayload(); } }} // namespace ring::tls diff --git a/src/ringdht/sips_transport_ice.h b/src/ringdht/sips_transport_ice.h index 48cbcb3e02538eff095f54ed5462ed529784b5e5..9617ad46d77d73e96134e37faaf9b87be61e9af8 100644 --- a/src/ringdht/sips_transport_ice.h +++ b/src/ringdht/sips_transport_ice.h @@ -43,6 +43,7 @@ namespace ring { class IceTransport; +class IceSocketTransport; } // namespace ring namespace ring { namespace tls { @@ -80,7 +81,7 @@ struct SipsIceTransport private: NON_COPYABLE(SipsIceTransport); - const std::shared_ptr<IceTransport> ice_; + std::shared_ptr<IceTransport> ice_; const int comp_id_; const std::function<int(unsigned, const gnutls_datum_t*, unsigned)> certCheck_; IpAddr local_ {}; @@ -109,6 +110,7 @@ private: decltype(PJSIP_TP_STATE_DISCONNECTED) state; }; + std::unique_ptr<IceSocketTransport> iceSocket_; std::unique_ptr<TlsSession> tls_; std::mutex txMutex_ {}; diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index 0ab03f56e40f6771fb3946085d03ee6032493293..e026b73012680a971183e4131da04af46e915b52 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -24,8 +24,6 @@ #include "tls_session.h" #include "threadloop.h" -#include "ice_socket.h" -#include "ice_transport.h" #include "logger.h" #include "noncopyable.h" #include "compiler_intrinsics.h" @@ -54,27 +52,23 @@ namespace ring { namespace tls { -static constexpr const char* TLS_CERT_PRIORITY_STRING {"SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; -static constexpr const char* TLS_FULL_PRIORITY_STRING {"SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* DTLS_CERT_PRIORITY_STRING {"SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* DTLS_FULL_PRIORITY_STRING {"SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* TLS_CERT_PRIORITY_STRING {"SECURE192:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* TLS_FULL_PRIORITY_STRING {"SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; static constexpr uint16_t INPUT_BUFFER_SIZE {16*1024}; // to be coherent with the maximum size advised in path mtu discovery static constexpr std::size_t INPUT_MAX_SIZE {1000}; // Maximum number of packets to store before dropping (pkt size = DTLS_MTU) static constexpr ssize_t FLOOD_THRESHOLD {4*1024}; static constexpr auto FLOOD_PAUSE = std::chrono::milliseconds(100); // Time to wait after an invalid cookie packet (anti flood attack) static constexpr auto DTLS_RETRANSMIT_TIMEOUT = std::chrono::milliseconds(1000); // Delay between two handshake request on DTLS static constexpr auto COOKIE_TIMEOUT = std::chrono::seconds(10); // Time to wait for a cookie packet from client -static constexpr uint8_t UDP_HEADER_SIZE = 8; // Size in bytes of UDP packet header +static constexpr int MIN_MTU {512 - 20 - 8}; // minimal payload size of a DTLS packet carried by an IPv4 packet static constexpr uint8_t HEARTBEAT_TRIES = 1; // Number of tries at each heartbeat ping send static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds) static constexpr auto HEARTBEAT_TOTAL_TIMEOUT = HEARTBEAT_RETRANS_TIMEOUT * HEARTBEAT_TRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds) static constexpr int MISS_ORDERING_LIMIT = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type) static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(1500); -// mtus array to test, do not add mtu over the interface MTU, this will result in false result due to packet fragmentation. -// also do not set over 16000 this will result in a gnutls error (unexpected packet size) -// neither under MIN_MTU because it makes no sense and could result in underflow of certain variables. -// Put mtus values in ascending order in the array to avoid sorting -static constexpr std::array<uint16_t, MTUS_TO_TEST> MTUS = {MIN_MTU, 800, 1280}; - // Helper to cast any duration into an integer number of milliseconds template <class Rep, class Period> static std::chrono::milliseconds::rep @@ -176,22 +170,20 @@ public: using StateHandler = std::function<TlsSessionState(TlsSessionState state)>; // Constants (ctor init.) - const std::unique_ptr<IceSocket> socket_; const bool isServer_; const TlsParams params_; const TlsSessionCallbacks callbacks_; const bool anonymous_; - TlsSessionImpl(const std::shared_ptr<IceTransport>& ice, - int ice_comp_id, - const TlsParams& params, - const TlsSessionCallbacks& cbs, - bool anonymous); + TlsSessionImpl(SocketType& transport, const TlsParams& params, + const TlsSessionCallbacks& cbs, bool anonymous); ~TlsSessionImpl(); const char* typeName() const; + SocketType& transport_; + // State machine TlsSessionState handleStateSetup(TlsSessionState state); TlsSessionState handleStateCookie(TlsSessionState state); @@ -201,31 +193,30 @@ public: TlsSessionState handleStateShutdown(TlsSessionState state); std::map<TlsSessionState, StateHandler> fsmHandlers_ {}; std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP}; - std::atomic<unsigned int> maxPayload_; + std::atomic<int> maxPayload_ {-1}; // IO GnuTLS <-> ICE - std::mutex txMutex_ {}; std::mutex rxMutex_ {}; std::condition_variable rxCv_ {}; - std::list<std::vector<uint8_t>> rxQueue_ {}; + std::list<std::vector<ValueType>> rxQueue_ {}; std::mutex reorderBufMutex_; bool flushProcessing_ {false}; ///< protect against recursive call to flushRxQueue - std::vector<uint8_t> rawPktBuf_; ///< gnutls incoming packet buffer + std::vector<ValueType> rawPktBuf_; ///< gnutls incoming packet buffer uint64_t baseSeq_ {0}; ///< sequence number of first application data packet received uint64_t lastRxSeq_ {0}; ///< last received and valid packet sequence number uint64_t gapOffset_ {0}; ///< offset of first byte not received yet clock::time_point lastReadTime_; - std::map<uint64_t, std::vector<uint8_t>> reorderBuffer_ {}; + std::map<uint64_t, std::vector<ValueType>> reorderBuffer_ {}; - ssize_t send_(const uint8_t* tx_data, std::size_t tx_size); + std::size_t send(const ValueType*, std::size_t, std::error_code&); ssize_t sendRaw(const void*, size_t); ssize_t sendRawVec(const giovec_t*, int); ssize_t recvRaw(void*, size_t); int waitForRawData(unsigned); bool initFromRecordState(int offset=0); - void handleDataPacket(std::vector<uint8_t>&&, uint64_t); + void handleDataPacket(std::vector<ValueType>&&, uint64_t); void flushRxQueue(); // Statistics @@ -257,24 +248,22 @@ public: void cleanup(); // Path mtu discovery - std::array<uint16_t, MTUS_TO_TEST>::const_iterator mtuProbe_; - unsigned hbPingRecved_ {0}; + std::array<int, 3> MTUS_; + int mtuProbe_; + int hbPingRecved_ {0}; bool pmtudOver_ {false}; - uint8_t transportOverhead_; void pathMtuHeartbeat(); }; -TlsSession::TlsSessionImpl::TlsSessionImpl(const std::shared_ptr<IceTransport>& ice, - int ice_comp_id, +TlsSession::TlsSessionImpl::TlsSessionImpl(SocketType& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, bool anonymous) - : socket_(new IceSocket(ice, ice_comp_id)) - , isServer_(not ice->isInitiator()) + : isServer_(not transport.isInitiator()) , params_(params) , callbacks_(cbs) , anonymous_(anonymous) - , maxPayload_(INPUT_BUFFER_SIZE) + , transport_ { transport } , cacred_(nullptr) , sacred_(nullptr) , xcred_(nullptr) @@ -282,18 +271,20 @@ TlsSession::TlsSessionImpl::TlsSessionImpl(const std::shared_ptr<IceTransport>& [this] { process(); }, [this] { cleanup(); }) { - socket_->setOnRecv([this](uint8_t* buf, size_t len) { - std::lock_guard<std::mutex> lk {rxMutex_}; - if (rxQueue_.size() == INPUT_MAX_SIZE) { - rxQueue_.pop_front(); // drop oldest packet if input buffer is full - ++stRxRawPacketDropCnt_; - } - rxQueue_.emplace_back(buf, buf+len); - ++stRxRawPacketCnt_; - stRxRawBytesCnt_ += len; - rxCv_.notify_one(); - return len; - }); + if (not transport_.isReliable()) { + transport_.setOnRecv([this](const ValueType* buf, size_t len) { + std::lock_guard<std::mutex> lk {rxMutex_}; + if (rxQueue_.size() == INPUT_MAX_SIZE) { + rxQueue_.pop_front(); // drop oldest packet if input buffer is full + ++stRxRawPacketDropCnt_; + } + rxQueue_.emplace_back(buf, buf+len); + ++stRxRawPacketCnt_; + stRxRawBytesCnt_ += len; + rxCv_.notify_one(); + return len; + }); + } Manager::instance().registerEventHandler((uintptr_t)this, [this]{ flushRxQueue(); }); @@ -304,9 +295,8 @@ TlsSession::TlsSessionImpl::TlsSessionImpl(const std::shared_ptr<IceTransport>& TlsSession::TlsSessionImpl::~TlsSessionImpl() { thread_.join(); - - socket_->setOnRecv(nullptr); - + if (not transport_.isReliable()) + transport_.setOnRecv(nullptr); Manager::instance().unregisterEventHandler((uintptr_t)this); } @@ -327,9 +317,15 @@ TlsSession::TlsSessionImpl::dump_io_stats() const TlsSessionState TlsSession::TlsSessionImpl::setupClient() { - auto ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM); - RING_WARN("[TLS] set heartbeat reception for retrocompatibility check on server"); - gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); + int ret; + + if (not transport_.isReliable()) { + ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM); + RING_DBG("[TLS] set heartbeat reception for retrocompatibility check on server"); + gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); + } else { + ret = gnutls_init(&session_, GNUTLS_CLIENT); + } if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] session init failed: %s", gnutls_strerror(ret)); @@ -346,8 +342,30 @@ TlsSession::TlsSessionImpl::setupClient() TlsSessionState TlsSession::TlsSessionImpl::setupServer() { - gnutls_key_generate(&cookie_key_, GNUTLS_COOKIE_KEY_SIZE); - return TlsSessionState::COOKIE; + int ret; + + if (not transport_.isReliable()) { + ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM); + + RING_DBG("[TLS] set heartbeat reception"); + gnutls_heartbeat_enable(session_, GNUTLS_HB_PEER_ALLOWED_TO_SEND); + + gnutls_dtls_prestate_set(session_, &prestate_); + } else { + ret = gnutls_init(&session_, GNUTLS_SERVER); + } + + if (ret != GNUTLS_E_SUCCESS) { + RING_ERR("[TLS] session init failed: %s", gnutls_strerror(ret)); + return TlsSessionState::SHUTDOWN; + } + + gnutls_certificate_server_set_request(session_, GNUTLS_CERT_REQUIRE); + + if (not commonSessionInit()) + return TlsSessionState::SHUTDOWN; + + return TlsSessionState::HANDSHAKE; } void @@ -440,7 +458,9 @@ TlsSession::TlsSessionImpl::commonSessionInit() if (anonymous_) { // Force anonymous connection, see handleStateHandshake how we handle failures - ret = gnutls_priority_set_direct(session_, TLS_FULL_PRIORITY_STRING, nullptr); + ret = gnutls_priority_set_direct(session_, + transport_.isReliable() ? TLS_FULL_PRIORITY_STRING : DTLS_FULL_PRIORITY_STRING, + nullptr); if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] TLS priority set failed: %s", gnutls_strerror(ret)); return false; @@ -458,7 +478,9 @@ TlsSession::TlsSessionImpl::commonSessionInit() } } else { // Use a classic non-encrypted CERTIFICATE exchange method (less anonymous) - ret = gnutls_priority_set_direct(session_, TLS_CERT_PRIORITY_STRING, nullptr); + ret = gnutls_priority_set_direct(session_, + transport_.isReliable() ? TLS_CERT_PRIORITY_STRING : DTLS_CERT_PRIORITY_STRING, + nullptr); if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] TLS priority set failed: %s", gnutls_strerror(ret)); return false; @@ -473,13 +495,15 @@ TlsSession::TlsSessionImpl::commonSessionInit() } gnutls_certificate_send_x509_rdn_sequence(session_, 0); - // DTLS hanshake timeouts - auto re_tx_timeout = duration2ms(DTLS_RETRANSMIT_TIMEOUT); - gnutls_dtls_set_timeouts(session_, re_tx_timeout, - std::max(duration2ms(params_.timeout), re_tx_timeout)); + if (not transport_.isReliable()) { + // DTLS hanshake timeouts + auto re_tx_timeout = duration2ms(DTLS_RETRANSMIT_TIMEOUT); + gnutls_dtls_set_timeouts(session_, re_tx_timeout, + std::max(duration2ms(params_.timeout), re_tx_timeout)); - // DTLS maximum payload size (raw packet relative) - gnutls_dtls_set_mtu(session_, DTLS_MTU); + // gnutls DTLS mtu = maximum payload size given by transport + gnutls_dtls_set_mtu(session_, transport_.maxPayload()); + } // Stuff for transport callbacks gnutls_session_set_ptr(session_, this); @@ -504,17 +528,27 @@ TlsSession::TlsSessionImpl::commonSessionInit() return true; } -ssize_t -TlsSession::TlsSessionImpl::send_(const uint8_t* tx_data, std::size_t tx_size) +std::size_t +TlsSession::TlsSessionImpl::send(const ValueType* tx_data, std::size_t tx_size, std::error_code& ec) { - std::size_t max_tx_sz = gnutls_dtls_get_data_mtu(session_); + if (state_ != TlsSessionState::ESTABLISHED) { + ec = std::error_code(GNUTLS_E_INVALID_SESSION, std::system_category()); + return 0; + } - // Split user data into MTU-suitable chunck - size_t total_written = 0; + std::size_t total_written = 0; + std::size_t max_tx_sz; + + if (transport_.isReliable()) + max_tx_sz = tx_size; + else + max_tx_sz = gnutls_dtls_get_data_mtu(session_); + + // Split incoming data into chunck suitable for the underlying transport while (total_written < tx_size) { auto chunck_sz = std::min(max_tx_sz, tx_size - total_written); - ssize_t nwritten; auto data_seq = tx_data + total_written; + ssize_t nwritten; do { nwritten = gnutls_record_send(session_, data_seq, chunck_sz); } while (nwritten == GNUTLS_E_INTERRUPTED or nwritten == GNUTLS_E_AGAIN); @@ -523,13 +557,16 @@ TlsSession::TlsSessionImpl::send_(const uint8_t* tx_data, std::size_t tx_size) * state has not changed, so we have to ask for more data first. * We will just try again later, although this should never happen. */ - RING_WARN("[TLS] send failed (only %zu bytes sent): %s", total_written, - gnutls_strerror(nwritten)); - return nwritten; + RING_ERR() << "[TLS] send failed (only " << total_written << " bytes sent): " + << gnutls_strerror(nwritten); + ec = std::error_code(nwritten, std::system_category()); + return 0; } total_written += nwritten; } + + ec.clear(); return total_written; } @@ -538,16 +575,18 @@ TlsSession::TlsSessionImpl::send_(const uint8_t* tx_data, std::size_t tx_size) ssize_t TlsSession::TlsSessionImpl::sendRaw(const void* buf, size_t size) { - auto ret = socket_->send(reinterpret_cast<const unsigned char*>(buf), size); - if (ret > 0) { + std::error_code ec; + auto n = transport_.write(reinterpret_cast<const ValueType*>(buf), size, ec); + if (!ec) { // log only on success ++stTxRawPacketCnt_; - stTxRawBytesCnt_ += size; - return ret; + stTxRawBytesCnt_ += n; + return n; } // Must be called to pass errno value to GnuTLS on Windows (cf. GnuTLS doc) - gnutls_transport_set_errno(session_, errno); + gnutls_transport_set_errno(session_, ec.value()); + RING_ERR() << "[TLS] transport failure on tx: errno = " << ec.value(); return -1; } @@ -574,39 +613,64 @@ TlsSession::TlsSessionImpl::sendRawVec(const giovec_t* iov, int iovcnt) ssize_t TlsSession::TlsSessionImpl::recvRaw(void* buf, size_t size) { - std::lock_guard<std::mutex> lk {rxMutex_}; - if (rxQueue_.empty()) { - gnutls_transport_set_errno(session_, EAGAIN); + if (transport_.isReliable()) { + std::error_code ec; + auto count = transport_.read(reinterpret_cast<ValueType*>(buf), size, ec); + if (!ec) + return count; + gnutls_transport_set_errno(session_, ec.value()); return -1; } const auto& pkt = rxQueue_.front(); const std::size_t count = std::min(pkt.size(), size); - std::copy_n(pkt.begin(), count, reinterpret_cast<uint8_t*>(buf)); + std::copy_n(pkt.begin(), count, reinterpret_cast<ValueType*>(buf)); rxQueue_.pop_front(); return count; } // Called by GNUTLS to wait for encrypted packet from low-level transport. // 'timeout' is in milliseconds. -// Should return 0 on connection termination, -// a positive number indicating the number of bytes received, -// and -1 on error. +// Should return 0 on timeout, a positive number if data are available for read, or -1 on error. int TlsSession::TlsSessionImpl::waitForRawData(unsigned timeout) { - std::unique_lock<std::mutex> lk {rxMutex_}; - if (not rxCv_.wait_for(lk, std::chrono::milliseconds(timeout), - [this]{ return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN; })) - return 0; + if (transport_.isReliable()) { + if (not transport_.waitForData(timeout)) { + // shutdown? + if (state_ == TlsSessionState::SHUTDOWN) { + gnutls_transport_set_errno(session_, EINTR); + return -1; + } + return 0; + } + return 1; + } - // shutdown? + // non-reliable uses callback installed with setOnRecv() + std::unique_lock<std::mutex> lk {rxMutex_}; + rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN; }); if (state_ == TlsSessionState::SHUTDOWN) { gnutls_transport_set_errno(session_, EINTR); return -1; } + return 1; +} + +bool +TlsSession::TlsSessionImpl::initFromRecordState(int offset) +{ + std::array<uint8_t, 8> seq; + if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0]) != GNUTLS_E_SUCCESS) { + RING_ERR("[TLS] Fatal-error Unable to read initial state"); + return false; + } - return rxQueue_.front().size(); + baseSeq_ = array2uint(seq) + offset; + gapOffset_ = baseSeq_; + lastRxSeq_ = baseSeq_ - 1; + RING_DBG("[TLS] Initial sequence number: %lx", baseSeq_); + return true; } bool @@ -629,8 +693,10 @@ TlsSession::TlsSessionImpl::cleanup() state_ = TlsSessionState::SHUTDOWN; // be sure to block any user operations if (session_) { - // DTLS: not use GNUTLS_SHUT_RDWR to not wait for a peer answer - gnutls_bye(session_, GNUTLS_SHUT_WR); + if (transport_.isReliable()) + gnutls_bye(session_, GNUTLS_SHUT_RDWR); + else + gnutls_bye(session_, GNUTLS_SHUT_WR); // not wait for a peer answer gnutls_deinit(session_); session_ = nullptr; } @@ -642,7 +708,7 @@ TlsSession::TlsSessionImpl::cleanup() TlsSessionState TlsSession::TlsSessionImpl::handleStateSetup(UNUSED TlsSessionState state) { - RING_DBG("[TLS] Start %s DTLS session", typeName()); + RING_DBG("[TLS] Start %s session", typeName()); try { if (anonymous_) @@ -653,10 +719,15 @@ TlsSession::TlsSessionImpl::handleStateSetup(UNUSED TlsSessionState state) return TlsSessionState::SHUTDOWN; } - if (isServer_) - return setupServer(); - else + if (not isServer_) return setupClient(); + + // Extra step for DTLS-like transports + if (not transport_.isReliable()) { + gnutls_key_generate(&cookie_key_, GNUTLS_COOKIE_KEY_SIZE); + return TlsSessionState::COOKIE; + } + return setupServer(); } TlsSessionState @@ -723,22 +794,7 @@ TlsSession::TlsSessionImpl::handleStateCookie(TlsSessionState state) RING_DBG("[TLS] cookie ok"); - ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM); - RING_WARN("[TLS] set heartbeat reception"); - gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); - - if (ret != GNUTLS_E_SUCCESS) { - RING_ERR("[TLS] session init failed: %s", gnutls_strerror(ret)); - return TlsSessionState::SHUTDOWN; - } - - gnutls_certificate_server_set_request(session_, GNUTLS_CERT_REQUIRE); - gnutls_dtls_prestate_set(session_, &prestate_); - - if (not commonSessionInit()) - return TlsSessionState::SHUTDOWN; - - return TlsSessionState::HANDSHAKE; + return setupServer(); } TlsSessionState @@ -769,7 +825,7 @@ TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) } auto desc = gnutls_session_get_desc(session_); - RING_WARN("[TLS] session established: %s", desc); + RING_DBG("[TLS] session established: %s", desc); gnutls_free(desc); // Anonymous connection? rehandshake immediatly with certificate authentification forced @@ -778,7 +834,9 @@ TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) RING_DBG("[TLS] renogotiate with certificate authentification"); // Re-setup TLS algorithms priority list with only certificate based cipher suites - ret = gnutls_priority_set_direct(session_, TLS_CERT_PRIORITY_STRING, nullptr); + ret = gnutls_priority_set_direct(session_, + transport_.isReliable() ? TLS_CERT_PRIORITY_STRING : DTLS_CERT_PRIORITY_STRING, + nullptr); if (ret != GNUTLS_E_SUCCESS) { RING_ERR("[TLS] session TLS cert-only priority set failed: %s", gnutls_strerror(ret)); return TlsSessionState::SHUTDOWN; @@ -807,52 +865,32 @@ TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) callbacks_.onCertificatesUpdate(local, remote, remote_count); } - return TlsSessionState::MTU_DISCOVERY; -} - -bool -TlsSession::TlsSessionImpl::initFromRecordState(int offset) -{ - std::array<uint8_t, 8> seq; - if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0]) != GNUTLS_E_SUCCESS) { - RING_ERR("[TLS] Fatal-error Unable to read initial state"); - return false; - } - - baseSeq_ = array2uint(seq) + offset; - gapOffset_ = baseSeq_; - lastRxSeq_ = baseSeq_ - 1; - RING_DBG("[TLS] Initial sequence number: %lx", baseSeq_); - return true; + return transport_.isReliable() ? TlsSessionState::ESTABLISHED : TlsSessionState::MTU_DISCOVERY; } TlsSessionState TlsSession::TlsSessionImpl::handleStateMtuDiscovery(UNUSED TlsSessionState state) { - // set dtls mtu to be over each and every mtus tested - gnutls_dtls_set_mtu(session_, MTUS.back()); - - // get transport overhead - transportOverhead_ = socket_->getTransportOverhead(); + mtuProbe_ = transport_.maxPayload(); + assert(mtuProbe_ >= MIN_MTU); + MTUS_ = {MIN_MTU, std::max((mtuProbe_ + MIN_MTU)/2, MIN_MTU), mtuProbe_}; // retrocompatibility check - if(gnutls_heartbeat_allowed(session_, GNUTLS_HB_LOCAL_ALLOWED_TO_SEND) == 1) { - if (!isServer_){ - RING_WARN("[TLS] HEARTBEAT PATH MTU DISCOVERY"); + if (gnutls_heartbeat_allowed(session_, GNUTLS_HB_LOCAL_ALLOWED_TO_SEND) == 1) { + if (!isServer_) { pathMtuHeartbeat(); pmtudOver_ = true; - RING_WARN("[TLS] HEARTBEAT PATH MTU DISCOVERY OVER"); } } else { - RING_ERR("[TLS] PEER HEARTBEAT DISABLED: setting minimal value to MTU @%d for retrocompatibility", DTLS_MTU); - gnutls_dtls_set_mtu(session_, DTLS_MTU); + RING_ERR() << "[TLS] PEER HEARTBEAT DISABLED: using transport MTU value " << mtuProbe_; pmtudOver_ = true; } + + gnutls_dtls_set_mtu(session_, mtuProbe_); maxPayload_ = gnutls_dtls_get_data_mtu(session_); - if (pmtudOver_) - RING_WARN("[TLS] maxPayload for dtls : %d B", maxPayload_.load()); if (pmtudOver_) { + RING_DBG() << "[TLS] maxPayload: " << maxPayload_.load(); if (!initFromRecordState()) return TlsSessionState::SHUTDOWN; } @@ -872,62 +910,52 @@ TlsSession::TlsSessionImpl::handleStateMtuDiscovery(UNUSED TlsSessionState state void TlsSession::TlsSessionImpl::pathMtuHeartbeat() { - int errno_send = 1; // non zero initialisation - auto tls_overhead = gnutls_record_overhead_size(session_); - RING_WARN("[TLS] tls session overhead : %lu", tls_overhead); - gnutls_heartbeat_set_timeouts(session_, HEARTBEAT_RETRANS_TIMEOUT.count(), HEARTBEAT_TOTAL_TIMEOUT.count()); - RING_DBG("[TLS] Heartbeat PMTUD : retransmission timeout set to: %ld ms", HEARTBEAT_RETRANS_TIMEOUT.count()); - - // server side: managing pong in state established - // client side: managing ping on heartbeat - uint16_t bytesToSend; - mtuProbe_ = MTUS.cbegin(); - RING_DBG("[TLS] Heartbeat PMTUD : client side"); - - while (mtuProbe_ != MTUS.cend()){ - bytesToSend = (*mtuProbe_) - 3 - tls_overhead - UDP_HEADER_SIZE - transportOverhead_; + RING_DBG() << "[TLS] PMTUD: starting probing with " << HEARTBEAT_RETRANS_TIMEOUT.count() + << "ms of retransmission timeout"; + + gnutls_heartbeat_set_timeouts(session_, + HEARTBEAT_RETRANS_TIMEOUT.count(), + HEARTBEAT_TOTAL_TIMEOUT.count()); + + int errno_send = GNUTLS_E_SUCCESS; + mtuProbe_ = MTUS_[0]; + for (auto mtu: MTUS_) { + gnutls_dtls_set_mtu(session_, mtu); + auto data_mtu = gnutls_dtls_get_data_mtu(session_); + RING_DBG() << "[TLS] PMTUD: mtu " << mtu + << ", payload " << data_mtu; + auto bytesToSend = data_mtu - 3; // want to know why -3? ask gnutls! do { - RING_DBG("[TLS] Heartbeat PMTUD : ping with mtu %d and effective payload %d", *mtuProbe_, bytesToSend); errno_send = gnutls_heartbeat_ping(session_, bytesToSend, HEARTBEAT_TRIES, GNUTLS_HEARTBEAT_WAIT); - RING_DBG("[TLS] Heartbeat PMTUD : ping sequence over with errno %d: %s", errno_send, - gnutls_strerror(errno_send)); } while (errno_send == GNUTLS_E_AGAIN || errno_send == GNUTLS_E_INTERRUPTED); - if (errno_send == GNUTLS_E_SUCCESS) { - ++mtuProbe_; - } else if (errno_send == GNUTLS_E_TIMEDOUT){ // timeout is considered as a packet loss, then the good mtu is the precedent. - if (mtuProbe_ == MTUS.cbegin()) { - RING_WARN("[TLS] Heartbeat PMTUD : no response on first ping, setting minimal MTU value @%d", MIN_MTU); - gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_); - - } else { - --mtuProbe_; - RING_DBG("[TLS] Heartbeat PMTUD : timed out: Path MTU found @ %d", *mtuProbe_); - gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); - } - return; - } else { - RING_WARN("[TLS] Heartbeat PMTUD : client ping failed: error %d: %s", errno_send, - gnutls_strerror(errno_send)); - if (mtuProbe_ != MTUS.begin()) - --mtuProbe_; - gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); - return; + if (errno_send != GNUTLS_E_SUCCESS) { + RING_DBG() << "[TLS] PMTUD: mtu " << mtu << " [FAILED]"; + break; } - } - - if (errno_send == GNUTLS_E_SUCCESS) { - RING_WARN("[TLS] Heartbeat PMTUD completed : reached test value %d", MTUS.back()); - --mtuProbe_; // for loop over, setting mtu to last valid mtu + mtuProbe_ = mtu; + RING_DBG() << "[TLS] PMTUD: mtu " << mtu << " [OK]"; } - gnutls_dtls_set_mtu(session_, *mtuProbe_ - UDP_HEADER_SIZE - transportOverhead_); - RING_WARN("[TLS] Heartbeat PMTUD : new mtu set to %d", *mtuProbe_); + if (errno_send == GNUTLS_E_TIMEDOUT) { // timeout is considered as a packet loss, then the good mtu is the precedent + if (mtuProbe_ == MTUS_[0]) { + RING_WARN() << "[TLS] PMTUD: no response on first ping, using minimal MTU value " + << mtuProbe_; + } else { + RING_WARN() << "[TLS] PMTUD: timed out, using last working mtu " + << mtuProbe_; + } + } else if (errno_send != GNUTLS_E_SUCCESS) { + RING_ERR() << "[TLS] PMTUD: failed with gnutls error '" + << gnutls_strerror(errno_send) << '\''; + } else { + RING_DBG() << "[TLS] PMTUD: reached maximal value"; + } } void -TlsSession::TlsSessionImpl::handleDataPacket(std::vector<uint8_t>&& buf, uint64_t pkt_seq) +TlsSession::TlsSessionImpl::handleDataPacket(std::vector<ValueType>&& buf, uint64_t pkt_seq) { // Check for a valid seq. num. delta int64_t seq_delta = pkt_seq - lastRxSeq_; @@ -936,14 +964,14 @@ TlsSession::TlsSessionImpl::handleDataPacket(std::vector<uint8_t>&& buf, uint64_ } else { // too old? if (seq_delta <= -MISS_ORDERING_LIMIT) { - RING_WARN("[dtls] drop old pkt: 0x%lx", pkt_seq); + RING_WARN("[TLS] drop old pkt: 0x%lx", pkt_seq); return; } // No duplicate check as DTLS prevents that for us (replay protection) // accept Out-Of-Order pkt - will be reordered by queue flush operation - RING_WARN("[dtls] OOO pkt: 0x%lx", pkt_seq); + RING_WARN("[TLS] OOO pkt: 0x%lx", pkt_seq); } { @@ -986,15 +1014,14 @@ TlsSession::TlsSessionImpl::flushRxQueue() auto item = std::begin(reorderBuffer_); auto next_offset = item->first; - auto first_offset = next_offset; // Wait for next continous packet until timeout if ((clock::now() - lastReadTime_) >= RX_OOO_TIMEOUT) { // OOO packet timeout - consider waited packets as lost if (auto lost = next_offset - gapOffset_) - RING_WARN("[dtls] %lu lost since 0x%lx", lost, gapOffset_); + RING_WARN("[TLS] %lu lost since 0x%lx", lost, gapOffset_); else - RING_WARN("[dtls] slow flush"); + RING_WARN("[TLS] slow flush"); } else if (next_offset != gapOffset_) return; @@ -1021,6 +1048,13 @@ TlsSession::TlsSessionImpl::flushRxQueue() TlsSessionState TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) { + // Nothing to do in reliable mode, so just wait for state change + if (transport_.isReliable()) { + std::unique_lock<std::mutex> lk {rxMutex_}; + rxCv_.wait(lk, [this]{ return state_ != TlsSessionState::ESTABLISHED; }); + return state; + } + // block until rx packet or state change { std::unique_lock<std::mutex> lk {rxMutex_}; @@ -1035,18 +1069,13 @@ TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) auto ret = gnutls_record_recv_seq(session_, rawPktBuf_.data(), rawPktBuf_.size(), &seq[0]); if (ret > 0) { + // Are we in PMTUD phase? if (!pmtudOver_) { - // This is the first application packet recieved after PMTUD - // This packet gives the final MTU. - if (hbPingRecved_ > 0) { - gnutls_dtls_set_mtu(session_, MTUS[hbPingRecved_ - 1] - UDP_HEADER_SIZE - transportOverhead_); - maxPayload_ = gnutls_dtls_get_data_mtu(session_); - } else { - gnutls_dtls_set_mtu(session_, MIN_MTU - UDP_HEADER_SIZE - transportOverhead_); - maxPayload_ = gnutls_dtls_get_data_mtu(session_); - } + mtuProbe_ = MTUS_[std::max(0, hbPingRecved_ - 1)]; + gnutls_dtls_set_mtu(session_, mtuProbe_); + maxPayload_ = gnutls_dtls_get_data_mtu(session_); pmtudOver_ = true; - RING_WARN("[TLS] maxPayload for dtls : %d B", maxPayload_.load()); + RING_DBG() << "[TLS] maxPayload: " << maxPayload_.load(); if (!initFromRecordState(-1)) return TlsSessionState::SHUTDOWN; @@ -1056,11 +1085,11 @@ TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) handleDataPacket(std::move(rawPktBuf_), array2uint(seq)); // no state change } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) { - RING_DBG("[TLS] Heartbeat PMTUD : ping received sending pong"); + RING_DBG("[TLS] PMTUD: ping received sending pong"); auto errno_send = gnutls_heartbeat_pong(session_, 0); if (errno_send != GNUTLS_E_SUCCESS){ - RING_WARN("[TLS] Heartbeat PMTUD : failed on pong with error %d: %s", errno_send, + RING_ERR("[TLS] PMTUD: failed on pong with error %d: %s", errno_send, gnutls_strerror(errno_send)); } else { ++hbPingRecved_; @@ -1106,10 +1135,10 @@ TlsSession::TlsSessionImpl::process() //============================================================================== -TlsSession::TlsSession(const std::shared_ptr<IceTransport>& ice, int ice_comp_id, - const TlsParams& params, const TlsSessionCallbacks& cbs, bool anonymous) +TlsSession::TlsSession(SocketType& transport, const TlsParams& params, + const TlsSessionCallbacks& cbs, bool anonymous) - : pimpl_ { std::make_unique<TlsSessionImpl>(ice, ice_comp_id, params, cbs, anonymous) } + : pimpl_ { std::make_unique<TlsSessionImpl>(transport, params, cbs, anonymous) } {} TlsSession::~TlsSession() @@ -1117,34 +1146,28 @@ TlsSession::~TlsSession() shutdown(); } -const char* -TlsSession::typeName() const -{ - return pimpl_->typeName(); -} - bool -TlsSession::isServer() const +TlsSession::isInitiator() const { - return pimpl_->isServer_; + return !pimpl_->isServer_; } -unsigned int -TlsSession::getMaxPayload() const +bool +TlsSession::isReliable() const { - return pimpl_->maxPayload_; + return pimpl_->transport_.isReliable(); } -// Called by anyone to stop the connection and the FSM thread -void -TlsSession::shutdown() +int +TlsSession::maxPayload() const { - pimpl_->state_ = TlsSessionState::SHUTDOWN; - pimpl_->rxCv_.notify_one(); // unblock waiting FSM + if (pimpl_->state_ == TlsSessionState::SHUTDOWN) + throw std::runtime_error("Getting MTU from non-valid TLS session"); + return gnutls_dtls_get_data_mtu(pimpl_->session_); } const char* -TlsSession::getCurrentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const +TlsSession::currentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const { // get current session cipher suite info gnutls_cipher_algorithm_t cipher, s_cipher = gnutls_cipher_get(pimpl_->session_); @@ -1166,27 +1189,71 @@ TlsSession::getCurrentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const return {}; } -ssize_t -TlsSession::send(const void* data, std::size_t size) +// Called by anyone to stop the connection and the FSM thread +void +TlsSession::shutdown() +{ + pimpl_->state_ = TlsSessionState::SHUTDOWN; + pimpl_->rxCv_.notify_one(); // unblock waiting FSM +} + +std::size_t +TlsSession::write(const ValueType* data, std::size_t size, std::error_code& ec) { - std::lock_guard<std::mutex> lk {pimpl_->txMutex_}; - if (pimpl_->state_ != TlsSessionState::ESTABLISHED) - return GNUTLS_E_INVALID_SESSION; - return pimpl_->send_(static_cast<const uint8_t*>(data), size); + if (pimpl_->state_ != TlsSessionState::ESTABLISHED) { + ec = std::make_error_code(std::errc::broken_pipe); + return 0; + } + + return pimpl_->send(data, size, ec); } -ssize_t -TlsSession::send(const std::vector<uint8_t>& vec) +std::size_t +TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec) { - return send(vec.data(), vec.size()); + std::errc error; + + if (pimpl_->state_ != TlsSessionState::ESTABLISHED) { + ec = std::make_error_code(std::errc::broken_pipe); + return 0; + } + + while (true) { + auto ret = gnutls_record_recv(pimpl_->session_, data, size); + if (ret > 0) { + ec.clear(); + return ret; + } + + if (ret == 0) { + RING_DBG("[TLS] eof"); + shutdown(); + error = std::errc::broken_pipe; + break; + } else if (ret == GNUTLS_E_REHANDSHAKE) { + RING_DBG("[TLS] re-handshake"); + pimpl_->state_ = TlsSessionState::HANDSHAKE; + pimpl_->rxCv_.notify_one(); // unblock waiting FSM + } else if (gnutls_error_is_fatal(ret)) { + RING_ERR("[TLS] fatal error in recv: %s", gnutls_strerror(ret)); + shutdown(); + error = std::errc::io_error; + break; + } + } + + ec = std::make_error_code(error); + return 0; } -uint16_t -TlsSession::getMtu() +void +TlsSession::connect() { - if (pimpl_->state_ == TlsSessionState::SHUTDOWN) - throw std::runtime_error("Getting MTU from dead TLS session"); - return gnutls_dtls_get_mtu(pimpl_->session_); + TlsSessionState state; + do { + state = pimpl_->state_.load(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while (state != TlsSessionState::ESTABLISHED and state != TlsSessionState::SHUTDOWN); } }} // namespace ring::tls diff --git a/src/security/tls_session.h b/src/security/tls_session.h index 7b6c48dce91eac1f60a3877e013bbd69ab7c20a9..58211598315f95bc5527aef880ea50189f5220e3 100644 --- a/src/security/tls_session.h +++ b/src/security/tls_session.h @@ -22,6 +22,7 @@ #pragma once #include "noncopyable.h" +#include "generic_io.h" #include <gnutls/gnutls.h> @@ -32,12 +33,6 @@ #include <chrono> #include <vector> #include <array> -#include <cstdint> - -namespace ring { -class IceTransport; -class IceSocket; -} // namespace ring namespace dht { namespace crypto { struct Certificate; @@ -48,20 +43,18 @@ namespace ring { namespace tls { class DhParams; -static constexpr uint8_t MTUS_TO_TEST = 3; //number of mtus to test in path mtu discovery. -static constexpr int DTLS_MTU {1232}; // (1280 from IPv6 minimum MTU - 40 IPv6 header - 8 UDP header) -static constexpr uint16_t MIN_MTU {512}; - -enum class TlsSessionState { +enum class TlsSessionState +{ SETUP, - COOKIE, // server only + COOKIE, // only used with non-initiator and non-reliable transport HANDSHAKE, - MTU_DISCOVERY, + MTU_DISCOVERY, // only used with non-reliable transport ESTABLISHED, SHUTDOWN }; -struct TlsParams { +struct TlsParams +{ // User CA list for session credentials std::string ca_list; @@ -83,20 +76,22 @@ struct TlsParams { unsigned cert_list_size)> cert_check; }; -/** - * TlsSession - * - * Manages a DTLS connection over an ICE transport. - * This implementation uses a Threadloop to manage IO from ICE and TLS states, - * so IO are asynchronous. - */ -class TlsSession { +/// TlsSession +/// +/// Manages a TLS/DTLS data transport overlayed on a given generic socket. +/// +/// \note API is not thread-safe. +/// +class TlsSession : public GenericSocket<uint8_t> +{ public: + using SocketType = GenericSocket<uint8_t>; using OnStateChangeFunc = std::function<void(TlsSessionState)>; using OnRxDataFunc = std::function<void(std::vector<uint8_t>&&)>; - using OnCertificatesUpdate = std::function<void(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int)>; + using OnCertificatesUpdate = std::function<void(const gnutls_datum_t*, + const gnutls_datum_t*, + unsigned int)>; using VerifyCertificate = std::function<int(gnutls_session_t)>; - using TxDataCompleteFunc = std::function<void(std::size_t bytes_sent)>; // ===> WARNINGS <=== // Following callbacks are called into the FSM thread context @@ -108,38 +103,44 @@ public: VerifyCertificate verifyCertificate; }; - TlsSession(const std::shared_ptr<IceTransport>& ice, int ice_comp_id, const TlsParams& params, - const TlsSessionCallbacks& cbs, bool anonymous=true); + TlsSession(SocketType& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, + bool anonymous=true); ~TlsSession(); - // Returns the TLS session type ('server' or 'client') - const char* typeName() const; + /// Return the name of current cipher. + /// Can be called by onStateChange callback when state == ESTABLISHED + /// to obtain the used cypher suite id. + const char* currentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const; - bool isServer() const; - - // Request TLS thread to stop and quit. IO are not possible after that. + /// Request TLS thread to stop and quit. + /// \note IO operations return error after this call. void shutdown(); - // Return maximum application payload size in bytes - // Returned value must be checked and considered valid only if not 0 (session is initialized) - unsigned int getMaxPayload() const; + void setOnRecv(RecvCb&& cb) override { + (void)cb; + throw std::logic_error("TlsSession::setOnRecv not implemented"); + } + + /// Return true if the TLS session type is a server. + bool isInitiator() const override; + + bool isReliable() const override; + + int maxPayload() const override; - // Can be called by onStateChange callback when state == ESTABLISHED - // to obtain the used cypher suite id. - // Return the name of current cipher. - const char* getCurrentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const; + void connect(); - // Asynchronous sending operation. on_send_complete will be called with a positive number - // for number of bytes sent, or negative for errors, or 0 in case of shutdown (end of session). - int async_send(const void* data, std::size_t size, TxDataCompleteFunc on_send_complete); - int async_send(std::vector<uint8_t>&& data, TxDataCompleteFunc on_send_complete); + /// Synchronous writing. + /// Return a positive number for number of bytes write, or 0 and \a ec set in case of error. + std::size_t write(const ValueType* data, std::size_t size, std::error_code& ec) override; - // Synchronous sending operation. Return negative number (gnutls error) or a positive number - // for bytes sent. - ssize_t send(const void* data, std::size_t size); - ssize_t send(const std::vector<uint8_t>& data); + /// Synchronous reading. + /// Return a positive number for number of bytes read, or 0 and \a ec set in case of error. + std::size_t read(ValueType* data, std::size_t size, std::error_code& ec) override; - uint16_t getMtu(); + bool waitForData(unsigned) const override { + throw std::logic_error("TlsSession::waitForData not implemented"); + } private: class TlsSessionImpl; diff --git a/src/sip/siptransport.cpp b/src/sip/siptransport.cpp index b7d8089e7179680b72ad93f8b9a2ec4b05d7d27f..09b9e941db36658deaed9c2b95fa2db2f8f5d848 100644 --- a/src/sip/siptransport.cpp +++ b/src/sip/siptransport.cpp @@ -182,7 +182,12 @@ SipTransport::getTlsMtu() auto tls_tr = reinterpret_cast<tls::SipsIceTransport::TransportData*>(transport_.get())->self; return tls_tr->getTlsSessionMtu(); } - return ring::tls::DTLS_MTU; + return 1232; /* Hardcoded yes (it's the IPv6 value). + * This method is broken by definition. + * A MTU should not be defined at this layer. + * And a correct value should come from the underlying transport itself, + * not from a constant... + */ } SipTransportBroker::SipTransportBroker(pjsip_endpoint *endpt, diff --git a/src/turn_transport.cpp b/src/turn_transport.cpp index fbe666c4d3f5b92f77605e24b80a694371b1e790..4d70fc3777f666c7f0ee3c2cb13c4ac537ac9fdb 100644 --- a/src/turn_transport.cpp +++ b/src/turn_transport.cpp @@ -29,7 +29,6 @@ #include <pjlib-util.h> #include <pjlib.h> -#include <stdexcept> #include <future> #include <atomic> #include <thread> @@ -46,6 +45,9 @@ namespace ring { using MutexGuard = std::lock_guard<std::mutex>; using MutexLock = std::unique_lock<std::mutex>; +inline +namespace { + enum class RelayState { NONE, @@ -82,31 +84,19 @@ public: cv_.notify_one(); } - void read(std::vector<char>& output) { + template <typename Duration> + bool wait(Duration timeout) { MutexLock lk {mutex_}; - cv_.wait(lk, [&, this]{ - stream_.read(&output[0], output.size()); - return stream_.gcount() > 0 or stop_; - }); - output.resize(stop_ ? 0 : stream_.gcount()); + return cv_.wait_for(lk, timeout, [this]{ return !stream_.eof(); }); } - std::vector<char> readline() { + std::size_t read(char* output, std::size_t size) { MutexLock lk {mutex_}; - std::vector<char> result(3000); - cv_.wait(lk, [&, this] { - if (stop_) - return true; - stream_.getline(&result[0], 3000); - if (stream_) { - result.resize(stream_.gcount()); - return result.size() > 0; - } - return false; + cv_.wait(lk, [&, this]{ + stream_.read(&output[0], size); + return stream_.gcount() > 0 or stop_; }); - if (stop_) - return {}; - return result; + return stop_ ? 0 : stream_.gcount(); } private: @@ -120,6 +110,31 @@ private: friend void operator <<(std::vector<char>&, PeerChannel&); }; +} + +//============================================================================== + +template <class Callable, class... Args> +inline void +PjsipCall(Callable& func, Args... args) +{ + auto status = func(args...); + if (status != PJ_SUCCESS) + throw sip_utils::PjsipFailure(status); +} + +template <class Callable, class... Args> +inline auto +PjsipCallReturn(const Callable& func, Args... args) -> decltype(func(args...)) +{ + auto res = func(args...); + if (!res) + throw sip_utils::PjsipFailure(); + return res; +} + +//============================================================================== + class TurnTransportPimpl { public: @@ -135,6 +150,7 @@ public: std::map<IpAddr, PeerChannel> peerChannels_; + GenericSocket<uint8_t>::RecvCb onRxDataCb; TurnTransportParams settings; pj_caching_pool poolCache {}; pj_pool_t* pool {nullptr}; @@ -193,7 +209,10 @@ TurnTransportPimpl::onRxData(const uint8_t* pkt, unsigned pkt_len, return; } - (channel_it->second) << std::string(reinterpret_cast<const char*>(pkt), pkt_len); + if (onRxDataCb) + onRxDataCb(pkt, pkt_len); + else + (channel_it->second) << std::string(reinterpret_cast<const char*>(pkt), pkt_len); } void @@ -226,45 +245,10 @@ TurnTransportPimpl::ioJob() const pj_time_val delay = {0, 10}; pj_ioqueue_poll(stunConfig.ioqueue, &delay); pj_timer_heap_poll(stunConfig.timer_heap, nullptr); - } -} - -class PjsipError final : public std::exception { -public: - PjsipError() = default; - explicit PjsipError(pj_status_t st) : std::exception() { - char err_msg[PJ_ERR_MSG_SIZE]; - pj_strerror(st, err_msg, sizeof(err_msg)); - what_msg_ += ": "; - what_msg_ += err_msg; } - const char* what() const noexcept override { - return what_msg_.c_str(); - }; -private: - std::string what_msg_ {"PJSIP api error"}; -}; - -template <class Callable, class... Args> -inline void -PjsipCall(Callable& func, Args... args) -{ - auto status = func(args...); - if (status != PJ_SUCCESS) - throw PjsipError(status); -} - -template <class Callable, class... Args> -inline auto -PjsipCallReturn(const Callable& func, Args... args) -> decltype(func(args...)) -{ - auto res = func(args...); - if (!res) - throw PjsipError(); - return res; } -//================================================================================================== +//============================================================================== TurnTransport::TurnTransport(const TurnTransportParams& params) : pimpl_ {new TurnTransportPimpl} @@ -354,6 +338,12 @@ TurnTransport::TurnTransport(const TurnTransportParams& params) TurnTransport::~TurnTransport() {} +bool +TurnTransport::isInitiator() const +{ + return !pimpl_->settings.server; +} + void TurnTransport::permitPeer(const IpAddr& addr) { @@ -363,6 +353,7 @@ TurnTransport::permitPeer(const IpAddr& addr) if (addr.getFamily() != pimpl_->peerRelayAddr.getFamily()) throw std::invalid_argument("mismatching peer address family"); + sip_utils::register_thread(); PjsipCall(pj_turn_sock_set_perm, pimpl_->relay, 1, addr.pjPtr(), 1); } @@ -400,7 +391,7 @@ TurnTransport::sendto(const IpAddr& peer, const char* const buffer, std::size_t reinterpret_cast<const pj_uint8_t*>(buffer), length, peer.pjPtr(), peer.getLength()); if (status != PJ_SUCCESS && status != PJ_EPENDING) - throw PjsipError(status); + throw sip_utils::PjsipFailure(status); return status == PJ_SUCCESS; } @@ -411,40 +402,84 @@ TurnTransport::sendto(const IpAddr& peer, const std::vector<char>& buffer) return sendto(peer, &buffer[0], buffer.size()); } -bool -TurnTransport::writelineto(const IpAddr& peer, const char* const buffer, std::size_t length) +std::size_t +TurnTransport::recvfrom(const IpAddr& peer, char* buffer, std::size_t size) { - if (sendto(peer, buffer, length)) - return sendto(peer, "\n", 1); - return false; + MutexLock lk {pimpl_->apiMutex_}; + auto& channel = pimpl_->peerChannels_.at(peer); + lk.unlock(); + return channel.read(buffer, size); } void TurnTransport::recvfrom(const IpAddr& peer, std::vector<char>& result) { - if (result.empty()) - throw std::runtime_error("TurnTransport::recvfrom() called with an empty output buffer"); + auto res = recvfrom(peer, result.data(), result.size()); + result.resize(res); +} +std::vector<IpAddr> +TurnTransport::peerAddresses() const +{ MutexLock lk {pimpl_->apiMutex_}; - auto& channel = pimpl_->peerChannels_.at(peer); - lk.unlock(); - channel.read(result); + return map_utils::extractKeys(pimpl_->peerChannels_); } -void -TurnTransport::readlinefrom(const IpAddr& peer, std::vector<char>& result) +bool +TurnTransport::waitForData(const IpAddr& peer, unsigned ms_timeout) const { MutexLock lk {pimpl_->apiMutex_}; auto& channel = pimpl_->peerChannels_.at(peer); lk.unlock(); - result = channel.readline(); + return channel.wait(std::chrono::milliseconds(ms_timeout)); } -std::vector<IpAddr> -TurnTransport::peerAddresses() const +//============================================================================== + +ConnectedTurnTransport::ConnectedTurnTransport(TurnTransport& turn, const IpAddr& peer) + : turn_ {turn} + , peer_ {peer} +{} + +bool +ConnectedTurnTransport::waitForData(unsigned ms_timeout) const { - MutexLock lk {pimpl_->apiMutex_}; - return map_utils::extractKeys(pimpl_->peerChannels_); + return turn_.waitForData(peer_, ms_timeout); +} + +std::size_t +ConnectedTurnTransport::write(const ValueType* buf, std::size_t size, std::error_code& ec) +{ + try { + turn_.sendto(peer_, reinterpret_cast<const char*>(buf), size); + } catch (const sip_utils::PjsipFailure& ex) { + ec = ex.code(); + return 0; + } + + ec.clear(); + return size; +} + +std::size_t +ConnectedTurnTransport::read(ValueType* buf, std::size_t size, std::error_code& ec) +{ + if (size > 0) { + try { + size = turn_.recvfrom(peer_, reinterpret_cast<char*>(buf), size); + } catch (const sip_utils::PjsipFailure& ex) { + ec = ex.code(); + return 0; + } + + if (size == 0) { + ec = std::make_error_code(std::errc::broken_pipe); + return 0; + } + } + + ec.clear(); + return size; } } // namespace ring diff --git a/src/turn_transport.h b/src/turn_transport.h index 4b59568eb3202a8f1b7aa95060ac75b506db0eae..f884af7a982ff893ffe02ddb1d6723cd6ed64263 100644 --- a/src/turn_transport.h +++ b/src/turn_transport.h @@ -21,17 +21,20 @@ #pragma once #include "ip_utils.h" +#include "generic_io.h" #include <string> #include <memory> #include <functional> #include <map> +#include <stdexcept> namespace ring { class TurnTransportPimpl; -struct TurnTransportParams { +struct TurnTransportParams +{ IpAddr server; // Plain Credentials @@ -46,7 +49,8 @@ struct TurnTransportParams { std::size_t maxPacketSize {4096}; ///< size of one "logical" packet }; -class TurnTransport { +class TurnTransport +{ public: /// Constructs a TurnTransport connected by TCP to given server. /// @@ -60,6 +64,8 @@ public: ~TurnTransport(); + bool isInitiator() const; + /// Wait for successful connection on the TURN server. /// /// TurnTransport constructor connects asynchronously on the TURN server. @@ -106,10 +112,9 @@ public: /// void recvfrom(const IpAddr& peer, std::vector<char>& data); - /// Work as recvfrom but stop on first '\n' character found. - /// If such character isn't found, stop at /a data vector size. + /// Works as recvfrom() vector version but accept a simple char array. /// - void readlinefrom(const IpAddr& peer, std::vector<char>& data); + std::size_t recvfrom(const IpAddr& peer, char* buffer, std::size_t size); /// Send data to given peer through the TURN tunnel. /// @@ -124,11 +129,9 @@ public: /// Works as sendto() vector version but accept a simple char array. /// - bool sendto(const IpAddr& peer, const char* const buffer, std::size_t length); + bool sendto(const IpAddr& peer, const char* const buffer, std::size_t size); - /// Works as sendto() char array but happend a '\n' character at the end of sent data. - /// - bool writelineto(const IpAddr& peer, const char* const buffer, std::size_t length); + bool waitForData(const IpAddr& peer, unsigned ms_timeout) const; public: // Move semantic only, not copiable @@ -140,4 +143,27 @@ private: std::unique_ptr<TurnTransportPimpl> pimpl_; }; +class ConnectedTurnTransport final : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + + ConnectedTurnTransport(TurnTransport& turn, const IpAddr& peer); + + bool isReliable() const override { return true; } + bool isInitiator() const override { return turn_.isInitiator(); } + int maxPayload() const override { return 3000; } + + bool waitForData(unsigned ms_timeout) const override; + std::size_t read(ValueType* buf, std::size_t length, std::error_code& ec) override; + std::size_t write(const ValueType* buf, std::size_t length, std::error_code& ec) override; + + void setOnRecv(RecvCb&&) override { throw std::logic_error("ConnectedTurnTransport bad call"); } + +private: + TurnTransport& turn_; + const IpAddr peer_; + RecvCb onRxDataCb_; +}; + } // namespace ring