diff --git a/src/ringdht/sips_transport_ice.cpp b/src/ringdht/sips_transport_ice.cpp index 5146e3ab331be91ccf2c7e2a9c4e96290b060046..a15236e458e1a4eb05da0bd1df16f7fbeb19a566 100644 --- a/src/ringdht/sips_transport_ice.cpp +++ b/src/ringdht/sips_transport_ice.cpp @@ -260,6 +260,14 @@ SipsIceTransport::~SipsIceTransport() { RING_DBG("~SipIceTransport@%p {tr=%p}", this, &trData_.base); + // Flush send queue with ENOTCONN error + for (auto tdata : txQueue_) { + tdata->op_key.tdata = nullptr; + if (tdata->op_key.callback) + tdata->op_key.callback(&trData_.base, tdata->op_key.token, + -PJ_RETURN_OS_ERROR(OSERR_ENOTCONN)); + } + auto base = getTransportBase(); Manager::instance().unregisterEventHandler((uintptr_t)this); @@ -307,6 +315,39 @@ SipsIceTransport::handleEvents() } } + // Handle SIP transport -> TLS + decltype(txQueue_) tx_queue; + { + std::lock_guard<std::mutex> l(txMutex_); + if (syncTx_) { + tx_queue = std::move(txQueue_); + txQueue_.clear(); + } + } + + bool fatal = false; + for (auto tdata : tx_queue) { + pj_status_t status; + if (!fatal) { + const std::size_t size = tdata->buf.cur - tdata->buf.start; + auto ret = tls_->send(tdata->buf.start, size); + if (gnutls_error_is_fatal(ret)) { + RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ret)); + tls_->shutdown(); + fatal = true; + } + if (ret < 0) + status = tls_status_from_err(ret); + else + status = ret; + } else + status = -PJ_RETURN_OS_ERROR(OSERR_ENOTCONN); + + tdata->op_key.tdata = nullptr; + if (tdata->op_key.callback) + tdata->op_key.callback(&trData_.base, tdata->op_key.token, status); + } + // Handle TLS -> SIP transport decltype(rxPending_) rx; { @@ -438,9 +479,13 @@ SipsIceTransport::updateTransportState(pjsip_transport_state state) std::memset(&ev.tls_info, 0, sizeof(ev.tls_info)); ev.state = state; - tlsConnected_ = state == PJSIP_TP_STATE_CONNECTED; - getInfo(&ev.ssl_info, tlsConnected_); - if (tlsConnected_) + bool connected = state == PJSIP_TP_STATE_CONNECTED; + { + std::lock_guard<std::mutex> lk {txMutex_}; + syncTx_ = true; + } + getInfo(&ev.ssl_info, connected); + if (connected) ev.state_info.status = ev.ssl_info.verify_status ? PJSIP_TLS_ECERTVERIF : PJ_SUCCESS; else ev.state_info.status = PJ_SUCCESS; // TODO: use last gnu error @@ -613,8 +658,8 @@ SipsIceTransport::certGetCn(const pj_str_t* gen_name, pj_str_t* cn) } pj_status_t -SipsIceTransport::send(pjsip_tx_data *tdata, const pj_sockaddr_t *rem_addr, - int addr_len, void *token, +SipsIceTransport::send(pjsip_tx_data* tdata, const pj_sockaddr_t* rem_addr, + int addr_len, void* token, pjsip_transport_callback callback) { // Sanity check @@ -629,29 +674,28 @@ SipsIceTransport::send(pjsip_tx_data *tdata, const pj_sockaddr_t *rem_addr, addr_len==sizeof(pj_sockaddr_in6)), PJ_EINVAL); - tdata->op_key.tdata = tdata; - tdata->op_key.token = token; - tdata->op_key.callback = callback; - - // Asynchronous send + // Check in we are able to send it in synchronous way first const std::size_t size = tdata->buf.cur - tdata->buf.start; - auto ret = tls_->async_send(tdata->buf.start, size, [=](std::size_t bytes_sent) { - // WARN: This code is called in the context of the TlsSession thread - if (bytes_sent == 0) - bytes_sent = -PJ_RETURN_OS_ERROR(OSERR_ENOTCONN); - tdata->op_key.tdata = nullptr; - if (tdata->op_key.callback) - tdata->op_key.callback(&trData_.base, tdata->op_key.token, bytes_sent); - }); - - // Shutdown on fatal errors - if (gnutls_error_is_fatal(ret)) { - tdata->op_key.tdata = nullptr; - RING_ERR("[TLS] send failed: %s", gnutls_strerror(ret)); - tls_->shutdown(); - return tls_status_from_err(ret); + std::unique_lock<std::mutex> lk {txMutex_}; + if (syncTx_ and txQueue_.empty()) { + auto ret = tls_->send(tdata->buf.start, size); + lk.unlock(); + + // Shutdown on fatal error, else ignore it + if (gnutls_error_is_fatal(ret)) { + RING_ERR("[TLS] fatal error during sending: %s", gnutls_strerror(ret)); + tls_->shutdown(); + return tls_status_from_err(ret); + } + + return PJ_SUCCESS; } + // Asynchronous sending + tdata->op_key.tdata = tdata; + tdata->op_key.token = token; + tdata->op_key.callback = callback; + txQueue_.push_back(tdata); return PJ_EPENDING; } diff --git a/src/ringdht/sips_transport_ice.h b/src/ringdht/sips_transport_ice.h index 6433dd62e3cbdf691490a43b7a58ceff1efb5f89..1061a8ff5017f9f93f87b3db24bb7e6b0951dcbb 100644 --- a/src/ringdht/sips_transport_ice.h +++ b/src/ringdht/sips_transport_ice.h @@ -107,7 +107,11 @@ private: }; std::unique_ptr<TlsSession> tls_; - std::atomic_bool tlsConnected_ {false}; // set by updateTransportState + + std::mutex txMutex_ {}; + std::condition_variable txCv_ {}; + std::list<pjsip_tx_data*> txQueue_ {}; + bool syncTx_ {false}; // true if we can send data synchronously (cnx established) std::mutex stateChangeEventsMutex_ {}; std::list<ChangeStateEventData> stateChangeEvents_ {}; diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index 74fa615f5cd31286d93a307d09099c5bb87cbb5c..fc65fa483398a41c54fc04cfd3007a071e302858 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -163,7 +163,7 @@ TlsSession::TlsSession(std::shared_ptr<IceTransport> ice, int ice_comp_id, [this] { cleanup(); }) { socket_->setOnRecv([this](uint8_t* buf, size_t len) { - std::lock_guard<std::mutex> lk {ioMutex_}; + std::lock_guard<std::mutex> lk {rxMutex_}; if (rxQueue_.size() == INPUT_MAX_SIZE) { rxQueue_.pop_front(); // drop oldest packet if input buffer is full ++stRxRawPacketDropCnt_; @@ -171,7 +171,7 @@ TlsSession::TlsSession(std::shared_ptr<IceTransport> ice, int ice_comp_id, rxQueue_.emplace_back(buf, buf+len); ++stRxRawPacketCnt_; stRxRawBytesCnt_ += len; - ioCv_.notify_one(); + rxCv_.notify_one(); return len; }); @@ -196,7 +196,9 @@ TlsSession::typeName() const void TlsSession::dump_io_stats() const { - RING_WARN("[TLS] RxRawPckt=%zu (%zu bytes)", stRxRawPacketCnt_, stRxRawBytesCnt_); + RING_DBG("[TLS] RxRawPkt=%zu (%zu bytes) - TxRawPkt=%zu (%zu bytes)", + stRxRawPacketCnt_.load(), stRxRawBytesCnt_.load(), + stTxRawPacketCnt_.load(), stTxRawBytesCnt_.load()); } TlsSessionState @@ -366,7 +368,7 @@ void TlsSession::shutdown() { state_ = TlsSessionState::SHUTDOWN; - ioCv_.notify_one(); // unblock waiting FSM + rxCv_.notify_one(); // unblock waiting FSM } const char* @@ -392,28 +394,31 @@ TlsSession::getCurrentCipherSuiteId(std::array<uint8_t, 2>& cs_id) const return {}; } -// Called by application to send data to encrypt. ssize_t -TlsSession::async_send(void* data, std::size_t size, TxDataCompleteFunc on_send_complete) +TlsSession::send(const void* data, std::size_t size) { - std::lock_guard<std::mutex> lk {ioMutex_}; - txQueue_.emplace_back(TxData {data, size, on_send_complete}); - ioCv_.notify_one(); - return GNUTLS_E_SUCCESS; + std::lock_guard<std::mutex> lk {txMutex_}; + if (state_ != TlsSessionState::ESTABLISHED) + return GNUTLS_E_INVALID_SESSION; + return send_(static_cast<const uint8_t*>(data), size); } ssize_t -TlsSession::send(const TxData& tx_data) +TlsSession::send(const std::vector<uint8_t>& vec) +{ + return send(vec.data(), vec.size()); +} + +ssize_t +TlsSession::send_(const uint8_t* tx_data, std::size_t tx_size) { std::size_t max_tx_sz = gnutls_dtls_get_data_mtu(session_); - std::size_t tx_size = tx_data.size; - auto ptr = static_cast<uint8_t*>(tx_data.ptr); // Split user data into MTU-suitable chunck size_t total_written = 0; while (total_written < tx_size) { auto chunck_sz = std::min(max_tx_sz, tx_size - total_written); - auto nwritten = gnutls_record_send(session_, ptr + total_written, chunck_sz); + auto nwritten = gnutls_record_send(session_, tx_data + total_written, chunck_sz); if (nwritten <= 0) { /* Normally we would have to retry record_send but our internal * state has not changed, so we have to ask for more data first. @@ -466,7 +471,7 @@ TlsSession::sendRawVec(const giovec_t* iov, int iovcnt) ssize_t TlsSession::recvRaw(void* buf, size_t size) { - std::lock_guard<std::mutex> lk {ioMutex_}; + std::lock_guard<std::mutex> lk {rxMutex_}; if (rxQueue_.empty()) { gnutls_transport_set_errno(session_, EAGAIN); return -1; @@ -487,8 +492,8 @@ TlsSession::recvRaw(void* buf, size_t size) int TlsSession::waitForRawData(unsigned timeout) { - std::unique_lock<std::mutex> lk {ioMutex_}; - if (not ioCv_.wait_for(lk, std::chrono::milliseconds(timeout), + std::unique_lock<std::mutex> lk {rxMutex_}; + if (not rxCv_.wait_for(lk, std::chrono::milliseconds(timeout), [this]{ return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN; })) return 0; @@ -519,12 +524,6 @@ TlsSession::cleanup() { state_ = TlsSessionState::SHUTDOWN; // be sure to block any user operations - // Flush pending application send requests with a 0 bytes-sent result - for (auto& txdata : txQueue_) { - if (txdata.onComplete) - txdata.onComplete(0); - } - if (session_) { // DTLS: not use GNUTLS_SHUT_RDWR to not wait for a peer answer gnutls_bye(session_, GNUTLS_SHUT_WR); @@ -564,8 +563,8 @@ TlsSession::handleStateCookie(TlsSessionState state) std::size_t count; { // block until rx packet or shutdown - std::unique_lock<std::mutex> lk {ioMutex_}; - if (!ioCv_.wait_for(lk, COOKIE_TIMEOUT, + std::unique_lock<std::mutex> lk {rxMutex_}; + if (!rxCv_.wait_for(lk, COOKIE_TIMEOUT, [this]{ return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN; })) { RING_ERR("[TLS] SYN cookie failed: timeout"); @@ -584,7 +583,7 @@ TlsSession::handleStateCookie(TlsSessionState state) // Peek and verify front packet { - std::lock_guard<std::mutex> lk {ioMutex_}; + std::lock_guard<std::mutex> lk {rxMutex_}; auto& pkt = rxQueue_.front(); std::memset(&prestate_, 0, sizeof(prestate_)); ret = gnutls_dtls_cookie_verify(&cookie_key_, nullptr, 0, @@ -602,7 +601,7 @@ TlsSession::handleStateCookie(TlsSessionState state) // Drop front packet { - std::lock_guard<std::mutex> lk {ioMutex_}; + std::lock_guard<std::mutex> lk {rxMutex_}; rxQueue_.pop_front(); } @@ -709,33 +708,12 @@ TlsSessionState TlsSession::handleStateEstablished(TlsSessionState state) { // block until rx/tx packet or state change - std::unique_lock<std::mutex> lk {ioMutex_}; - ioCv_.wait(lk, [this]{ return !txQueue_.empty() or !rxQueue_.empty() or state_ != TlsSessionState::ESTABLISHED; }); + std::unique_lock<std::mutex> lk {rxMutex_}; + rxCv_.wait(lk, [this]{ return !rxQueue_.empty() or state_ != TlsSessionState::ESTABLISHED; }); state = state_.load(); if (state != TlsSessionState::ESTABLISHED) return state; - // Handle TX data from application - if (not txQueue_.empty()) { - decltype(txQueue_) tx_queue = std::move(txQueue_); - txQueue_.clear(); - lk.unlock(); - for (const auto& txdata : tx_queue) { - while (state_ == TlsSessionState::ESTABLISHED) { - auto bytes_sent = send(txdata); - auto fatal = gnutls_error_is_fatal(bytes_sent); - if (bytes_sent < 0 and !fatal) - continue; - if (txdata.onComplete) - txdata.onComplete(bytes_sent); - if (fatal) - return TlsSessionState::SHUTDOWN; - break; - } - } - lk.lock(); - } - // Handle RX data from network if (!rxQueue_.empty()) { std::vector<uint8_t> buf(8*1024); diff --git a/src/security/tls_session.h b/src/security/tls_session.h index 9cc221ffadf31cd8179157d31956df947b256183..cc6989192db08c003a8a630b75f1d8918e17cde2 100644 --- a/src/security/tls_session.h +++ b/src/security/tls_session.h @@ -38,6 +38,7 @@ #include <utility> #include <vector> #include <map> +#include <atomic> namespace ring { class IceTransport; @@ -154,7 +155,13 @@ public: // Asynchronous sending operation. on_send_complete will be called with a positive number // for number of bytes sent, or negative for errors, or 0 in case of shutdown (end of session). - ssize_t async_send(void* data, std::size_t size, TxDataCompleteFunc on_send_complete); + int async_send(const void* data, std::size_t size, TxDataCompleteFunc on_send_complete); + int async_send(std::vector<uint8_t>&& data, TxDataCompleteFunc on_send_complete); + + // Synchronous sending operation. Return negative number (gnutls error) or a positive number + // for bytes sent. + ssize_t send(const void* data, std::size_t size); + ssize_t send(const std::vector<uint8_t>& data); private: using clock = std::chrono::steady_clock; @@ -178,29 +185,23 @@ private: std::atomic<unsigned int> maxPayload_ {0}; // IO GnuTLS <-> ICE - struct TxData { - void* const ptr; - std::size_t size; - TxDataCompleteFunc onComplete; - }; - - std::mutex ioMutex_ {}; - std::condition_variable ioCv_ {}; - std::list<TxData> txQueue_ {}; + std::mutex txMutex_ {}; + std::mutex rxMutex_ {}; + std::condition_variable rxCv_ {}; std::list<std::vector<uint8_t>> rxQueue_ {}; - ssize_t send(const TxData&); + ssize_t send_(const uint8_t* tx_data, std::size_t tx_size); ssize_t sendRaw(const void*, size_t); ssize_t sendRawVec(const giovec_t*, int); ssize_t recvRaw(void*, size_t); int waitForRawData(unsigned); - // Statistics (also protected by mutex ioMutex_) - std::size_t stRxRawPacketCnt_ {0}; - std::size_t stRxRawBytesCnt_ {0}; - std::size_t stRxRawPacketDropCnt_ {0}; - std::size_t stTxRawPacketCnt_ {0}; - std::size_t stTxRawBytesCnt_ {0}; + // Statistics + std::atomic<std::size_t> stRxRawPacketCnt_ {0}; + std::atomic<std::size_t> stRxRawBytesCnt_ {0}; + std::atomic<std::size_t> stRxRawPacketDropCnt_ {0}; + std::atomic<std::size_t> stTxRawPacketCnt_ {0}; + std::atomic<std::size_t> stTxRawBytesCnt_ {0}; void dump_io_stats() const; // GnuTLS backend and connection state