diff --git a/src/jamidht/p2p.cpp b/src/jamidht/p2p.cpp index d4d53342faff3357b95b3b57775a9046bc2b7069..fefb6b744798cce341ff10f9dc92b62ce2d03478 100644 --- a/src/jamidht/p2p.cpp +++ b/src/jamidht/p2p.cpp @@ -257,8 +257,6 @@ public: } private: - std::map<IpAddr, std::unique_ptr<ConnectedTurnTransport>> turnEndpoints_; - std::map<std::pair<dht::InfoHash, IpAddr>, std::unique_ptr<AbstractSocketEndpoint>> p2pEndpoints_; std::map<std::pair<dht::InfoHash, IpAddr>, std::unique_ptr<TlsSocketEndpoint>> waitForReadyEndpoints_; std::unique_ptr<TurnTransport> turnAuthv4_; std::unique_ptr<TurnTransport> turnAuthv6_; @@ -421,7 +419,7 @@ private: } // Check response validity - std::shared_ptr<AbstractSocketEndpoint> peer_ep; + std::unique_ptr<AbstractSocketEndpoint> peer_ep; if (response_.from != peer_ or response_.id != request.id or response_.addresses.empty()) @@ -446,7 +444,7 @@ private: ice->waitForNegotiation(ICE_NEGOTIATION_TIMEOUT); if (ice->isRunning()) { - peer_ep = std::make_shared<IceSocketEndpoint>(ice, true); + peer_ep = std::make_unique<IceSocketEndpoint>(ice, true); JAMI_DBG("[Account:%s] ICE negotiation succeed. Starting file transfer", parent_.account.getAccountID().c_str()); if (hasPubIp) ice->setInitiatorSession(); @@ -459,7 +457,7 @@ private: // Connect to TURN peer using a raw socket JAMI_DBG() << parent_.account << "[CNX] connecting to TURN relay " << relay_addr.toString(true, true); - peer_ep = std::make_shared<TcpSocketEndpoint>(relay_addr); + peer_ep = std::make_unique<TcpSocketEndpoint>(relay_addr); try { peer_ep->connect(SOCK_TIMEOUT); } catch (const std::logic_error& e) { @@ -488,9 +486,9 @@ private: // Negotiate a TLS session JAMI_DBG() << parent_.account << "[CNX] start TLS session"; tls_ep_ = std::make_unique<TlsSocketEndpoint>( - *peer_ep, parent_.account.identity(), parent_.account.dhParams(), + std::move(peer_ep), parent_.account.identity(), parent_.account.dhParams(), *peerCertificate_); - tls_ep_->setOnStateChange([this, ice=std::move(ice), peer_ep] (tls::TlsSessionState state) { + tls_ep_->setOnStateChange([this, ice=std::move(ice)] (tls::TlsSessionState state) { if (state == tls::TlsSessionState::SHUTDOWN) { JAMI_WARN() << "TLS connection failure from peer " << peer_.toString(); @@ -502,7 +500,6 @@ private: connection_ = std::make_unique<PeerConnection>( [this] { cancel(); }, peer_.toString(), std::move(tls_ep_)); - peer_ep_ = std::move(peer_ep); for (auto &cb : listeners_) { cb(connection_.get()); } @@ -520,7 +517,6 @@ private: PeerConnectionMsg response_; uint64_t waitId_ {0}; std::shared_ptr<dht::crypto::Certificate> peerCertificate_; - std::shared_ptr<AbstractSocketEndpoint> peer_ep_; std::unique_ptr<PeerConnection> connection_; std::unique_ptr<TlsSocketEndpoint> tls_ep_; @@ -641,7 +637,7 @@ DhtPeerConnector::Impl::onTurnPeerConnection(const IpAddr& peer_addr) JAMI_DBG() << account << "[CNX] start TLS session over TURN socket"; auto peer_h = std::make_shared<dht::InfoHash>(); tls_turn_ep_[peer_addr] = - std::make_unique<TlsTurnEndpoint>(*turn_ep, + std::make_unique<TlsTurnEndpoint>(std::move(turn_ep), account.identity(), account.dhParams(), [peer_h, this] (const dht::crypto::Certificate& cert) { @@ -669,12 +665,6 @@ DhtPeerConnector::Impl::onTurnPeerConnection(const IpAddr& peer_addr) tls_turn_ep_.erase(peer_addr); } }); - - // note: operating this way let endpoint to be deleted safely in case of exceptions - { - std::lock_guard<std::mutex> lock(turnMutex_); - turnEndpoints_.emplace(std::make_pair(peer_addr, std::move(turn_ep))); - } } void @@ -687,10 +677,6 @@ DhtPeerConnector::Impl::onTurnPeerDisconnection(const IpAddr& peer_addr) JAMI_WARN() << account << "[CNX] disconnection from peer " << peer_addr.toString(true, true); servers_.erase(it); connectedPeers_.erase(peer_addr); - { - std::lock_guard<std::mutex> lock(turnMutex_); - turnEndpoints_.erase(peer_addr); - } } void @@ -877,7 +863,7 @@ DhtPeerConnector::Impl::answerToRequest(PeerConnectionMsg&& request, auto idx = std::make_pair(peer_h, ice->getRemoteAddress(0)); auto it = waitForReadyEndpoints_.emplace( idx, - std::make_unique<TlsSocketEndpoint>(*peer_ep, account.identity(), account.dhParams(), + std::make_unique<TlsSocketEndpoint>(std::move(peer_ep), account.identity(), account.dhParams(), [peer_h, this](const dht::crypto::Certificate &cert) { dht::InfoHash peer_h_found; return validatePeerCertificate(cert, peer_h_found) @@ -885,7 +871,6 @@ DhtPeerConnector::Impl::answerToRequest(PeerConnectionMsg&& request, } ) ); - p2pEndpoints_.emplace(idx, std::move(peer_ep)); it.first->second->setOnStateChange([this, idx=std::move(idx)] (tls::TlsSessionState state) { if (waitForReadyEndpoints_.find(idx) == waitForReadyEndpoints_.end()) { @@ -989,19 +974,7 @@ DhtPeerConnector::Impl::eventLoop() auto peer = it->first.second; // tmp copy to prevent use-after-free below servers_.erase(it); // Remove the file transfer if p2p - auto p2p_it = std::find_if(p2pEndpoints_.begin(), p2pEndpoints_.end(), - [&dev_h, &peer](const auto &element) { - return (element.first.first == dev_h && - element.first.second == peer); - }); - if (p2p_it != p2pEndpoints_.end()) - p2pEndpoints_.erase(p2p_it); connectedPeers_.erase(peer); - // Else it's via TURN! - { - std::lock_guard<std::mutex> lock(turnMutex_); - turnEndpoints_.erase(peer); - } Manager::instance().dataTransfers->close(id); } break; diff --git a/src/jamidht/sips_transport_ice.cpp b/src/jamidht/sips_transport_ice.cpp index 1ca89dd3b95f31341c2ae870a437741c997d8dc1..e413ffc071004130063f38f824282675c94246f6 100644 --- a/src/jamidht/sips_transport_ice.cpp +++ b/src/jamidht/sips_transport_ice.cpp @@ -236,7 +236,7 @@ SipsIceTransport::SipsIceTransport(pjsip_endpoint* endpt, std::memset(&localCertInfo_, 0, sizeof(pj_ssl_cert_info)); std::memset(&remoteCertInfo_, 0, sizeof(pj_ssl_cert_info)); - iceSocket_ = std::make_unique<IceSocketTransport>(ice_, comp_id, PJSIP_TRANSPORT_IS_RELIABLE(&trData_.base)); + auto iceSocket = std::make_unique<IceSocketTransport>(ice_, comp_id, PJSIP_TRANSPORT_IS_RELIABLE(&trData_.base)); TlsSession::TlsSessionCallbacks cbs = { /*.onStateChange = */[this](TlsSessionState state){ onTlsStateChange(state); }, @@ -245,7 +245,7 @@ SipsIceTransport::SipsIceTransport(pjsip_endpoint* endpt, unsigned int n){ onCertificatesUpdate(l, r, n); }, /*.verifyCertificate = */[this](gnutls_session_t session){ return verifyCertificate(session); } }; - tls_ = std::make_unique<TlsSession>(*iceSocket_, param, cbs); + tls_ = std::make_unique<TlsSession>(std::move(iceSocket), param, cbs); if (pjsip_transport_register(base.tpmgr, &base) != PJ_SUCCESS) throw std::runtime_error("Can't register PJSIP transport."); diff --git a/src/jamidht/sips_transport_ice.h b/src/jamidht/sips_transport_ice.h index c5290c8068049ffc8738516d902da1acf902533b..0d0bfe00f390f64e8c24f19ffa3899ce86fbfb6c 100644 --- a/src/jamidht/sips_transport_ice.h +++ b/src/jamidht/sips_transport_ice.h @@ -111,7 +111,6 @@ private: decltype(PJSIP_TP_STATE_DISCONNECTED) state; }; - std::unique_ptr<IceSocketTransport> iceSocket_; std::unique_ptr<TlsSession> tls_; std::mutex txMutex_ {}; diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp index 4aef29e137a92718dd1bc3c086657609208d82a8..8ccc3414a733ea0aef61a1667b0fdef4e8314b57 100644 --- a/src/peer_connection.cpp +++ b/src/peer_connection.cpp @@ -95,9 +95,31 @@ class TlsTurnEndpoint::Impl public: static constexpr auto TLS_TIMEOUT = std::chrono::seconds(20); - Impl(ConnectedTurnTransport& tr, - std::function<bool(const dht::crypto::Certificate&)>&& cert_check) - : turn {tr}, peerCertificateCheckFunc {std::move(cert_check)} {} + Impl(std::unique_ptr<ConnectedTurnTransport>&& turn_ep, + std::function<bool(const dht::crypto::Certificate&)>&& cert_check, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params) + : peerCertificateCheckFunc {std::move(cert_check)} { + + // Add TLS over TURN + tls::TlsSession::TlsSessionCallbacks tls_cbs = { + /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); }, + /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onTlsRxData(std::move(buf)); }, + /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r, + unsigned int n){ onTlsCertificatesUpdate(l, r, n); }, + /*.verifyCertificate = */[this](gnutls_session_t session){ return verifyCertificate(session); } + }; + tls::TlsParams tls_param = { + /*.ca_list = */ "", + /*.peer_ca = */ nullptr, + /*.cert = */ local_identity.second, + /*.cert_key = */ local_identity.first, + /*.dh_params = */ dh_params, + /*.timeout = */ Impl::TLS_TIMEOUT, + /*.cert_check = */ nullptr, + }; + tls = std::make_unique<tls::TlsSession>(std::move(turn_ep), tls_param, tls_cbs); + } ~Impl(); @@ -108,7 +130,6 @@ public: void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int); std::unique_ptr<tls::TlsSession> tls; - ConnectedTurnTransport& turn; std::function<bool(const dht::crypto::Certificate&)> peerCertificateCheckFunc; dht::crypto::Certificate peerCertificate; OnStateChangeCb onStateChangeCb_; @@ -154,31 +175,12 @@ TlsTurnEndpoint::Impl::onTlsCertificatesUpdate(UNUSED const gnutls_datum_t* loca UNUSED unsigned int remote_count) {} -TlsTurnEndpoint::TlsTurnEndpoint(ConnectedTurnTransport& turn_ep, +TlsTurnEndpoint::TlsTurnEndpoint(std::unique_ptr<ConnectedTurnTransport>&& turn_ep, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, std::function<bool(const dht::crypto::Certificate&)>&& cert_check) - : pimpl_ { std::make_unique<Impl>(turn_ep, std::move(cert_check)) } -{ - // Add TLS over TURN - tls::TlsSession::TlsSessionCallbacks tls_cbs = { - /*.onStateChange = */[this](tls::TlsSessionState state){ pimpl_->onTlsStateChange(state); }, - /*.onRxData = */[this](std::vector<uint8_t>&& buf){ pimpl_->onTlsRxData(std::move(buf)); }, - /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r, - unsigned int n){ pimpl_->onTlsCertificatesUpdate(l, r, n); }, - /*.verifyCertificate = */[this](gnutls_session_t session){ return pimpl_->verifyCertificate(session); } - }; - tls::TlsParams tls_param = { - /*.ca_list = */ "", - /*.peer_ca = */ nullptr, - /*.cert = */ local_identity.second, - /*.cert_key = */ local_identity.first, - /*.dh_params = */ dh_params, - /*.timeout = */ Impl::TLS_TIMEOUT, - /*.cert_check = */ nullptr, - }; - pimpl_->tls = std::make_unique<tls::TlsSession>(turn_ep, tls_param, tls_cbs); -} + : pimpl_ { std::make_unique<Impl>(std::move(turn_ep), std::move(cert_check), local_identity, dh_params) } +{} TlsTurnEndpoint::~TlsTurnEndpoint() = default; @@ -390,11 +392,53 @@ class TlsSocketEndpoint::Impl public: static constexpr auto TLS_TIMEOUT = std::chrono::seconds(20); - Impl(AbstractSocketEndpoint& ep, const dht::crypto::Certificate& peer_cert) - : tr {ep}, peerCertificate {peer_cert} {} + Impl(std::unique_ptr<AbstractSocketEndpoint>&& ep, + const dht::crypto::Certificate& peer_cert, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params) + : peerCertificate {peer_cert} { + tls::TlsSession::TlsSessionCallbacks tls_cbs = { + /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); }, + /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onTlsRxData(std::move(buf)); }, + /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r, + unsigned int n){ onTlsCertificatesUpdate(l, r, n); }, + /*.verifyCertificate = */[this](gnutls_session_t session){ return verifyCertificate(session); } + }; + tls::TlsParams tls_param = { + /*.ca_list = */ "", + /*.peer_ca = */ nullptr, + /*.cert = */ local_identity.second, + /*.cert_key = */ local_identity.first, + /*.dh_params = */ dh_params, + /*.timeout = */ Impl::TLS_TIMEOUT, + /*.cert_check = */ nullptr, + }; + tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs); + } - Impl(AbstractSocketEndpoint &ep, std::function<bool(const dht::crypto::Certificate &)>&& cert_check) - : tr{ep}, peerCertificateCheckFunc{std::move(cert_check)}, peerCertificate {null_cert} {} + Impl(std::unique_ptr<AbstractSocketEndpoint>&& ep, + std::function<bool(const dht::crypto::Certificate &)>&& cert_check, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params) + : peerCertificateCheckFunc{std::move(cert_check)}, peerCertificate {null_cert} { + tls::TlsSession::TlsSessionCallbacks tls_cbs = { + /*.onStateChange = */[this](tls::TlsSessionState state){ onTlsStateChange(state); }, + /*.onRxData = */[this](std::vector<uint8_t>&& buf){ onTlsRxData(std::move(buf)); }, + /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r, + unsigned int n){ onTlsCertificatesUpdate(l, r, n); }, + /*.verifyCertificate = */[this](gnutls_session_t session){ return verifyCertificate(session); } + }; + tls::TlsParams tls_param = { + /*.ca_list = */ "", + /*.peer_ca = */ nullptr, + /*.cert = */ local_identity.second, + /*.cert_key = */ local_identity.first, + /*.dh_params = */ dh_params, + /*.timeout = */ Impl::TLS_TIMEOUT, + /*.cert_check = */ nullptr, + }; + tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs); + } // TLS callbacks int verifyCertificate(gnutls_session_t); @@ -403,7 +447,6 @@ public: void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int); std::unique_ptr<tls::TlsSession> tls; - AbstractSocketEndpoint& tr; const dht::crypto::Certificate& peerCertificate; dht::crypto::Certificate null_cert; std::function<bool(const dht::crypto::Certificate &)> peerCertificateCheckFunc; @@ -453,56 +496,20 @@ TlsSocketEndpoint::Impl::onTlsCertificatesUpdate(UNUSED const gnutls_datum_t* lo UNUSED unsigned int remote_count) {} -TlsSocketEndpoint::TlsSocketEndpoint(AbstractSocketEndpoint& tr, +TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, const dht::crypto::Certificate& peer_cert) - : pimpl_ { std::make_unique<Impl>(tr, peer_cert) } -{ - // Add TLS over TURN - tls::TlsSession::TlsSessionCallbacks tls_cbs = { - /*.onStateChange = */[this](tls::TlsSessionState state){ pimpl_->onTlsStateChange(state); }, - /*.onRxData = */[this](std::vector<uint8_t>&& buf){ pimpl_->onTlsRxData(std::move(buf)); }, - /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r, - unsigned int n){ pimpl_->onTlsCertificatesUpdate(l, r, n); }, - /*.verifyCertificate = */[this](gnutls_session_t session){ return pimpl_->verifyCertificate(session); } - }; - tls::TlsParams tls_param = { - /*.ca_list = */ "", - /*.peer_ca = */ nullptr, - /*.cert = */ local_identity.second, - /*.cert_key = */ local_identity.first, - /*.dh_params = */ dh_params, - /*.timeout = */ Impl::TLS_TIMEOUT, - /*.cert_check = */ nullptr, - }; - pimpl_->tls = std::make_unique<tls::TlsSession>(tr, tls_param, tls_cbs); -} - -TlsSocketEndpoint::TlsSocketEndpoint(AbstractSocketEndpoint& tr, + : pimpl_ { std::make_unique<Impl>(std::move(tr), peer_cert, local_identity, dh_params) } +{} + +TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, std::function<bool(const dht::crypto::Certificate&)>&& cert_check) - : pimpl_ { std::make_unique<Impl>(tr, std::move(cert_check)) } -{ - // Add TLS over TURN - tls::TlsSession::TlsSessionCallbacks tls_cbs = { - /*.onStateChange = */[this](tls::TlsSessionState state){ pimpl_->onTlsStateChange(state); }, - /*.onRxData = */[this](std::vector<uint8_t>&& buf){ pimpl_->onTlsRxData(std::move(buf)); }, - /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r, - unsigned int n){ pimpl_->onTlsCertificatesUpdate(l, r, n); }, - /*.verifyCertificate = */[this](gnutls_session_t session){ return pimpl_->verifyCertificate(session); } - }; - tls::TlsParams tls_param = { - /*.ca_list = */ "", - /*.peer_ca = */ nullptr, - /*.cert = */ local_identity.second, - /*.cert_key = */ local_identity.first, - /*.dh_params = */ dh_params, - /*.timeout = */ Impl::TLS_TIMEOUT, - /*.cert_check = */ nullptr, - }; - pimpl_->tls = std::make_unique<tls::TlsSession>(tr, tls_param, tls_cbs); + : pimpl_ { std::make_unique<Impl>(std::move(tr), std::move(cert_check), local_identity, dh_params) } +{ + } diff --git a/src/peer_connection.h b/src/peer_connection.h index 0f15665b3c66d1bab0ce308101a08bc1331f218d..9ddcbb700ad167bbc5bc11d2b52e1aff467d010e 100644 --- a/src/peer_connection.h +++ b/src/peer_connection.h @@ -83,7 +83,7 @@ public: using Identity = std::pair<std::shared_ptr<dht::crypto::PrivateKey>, std::shared_ptr<dht::crypto::Certificate>>; - TlsTurnEndpoint(ConnectedTurnTransport& turn, + TlsTurnEndpoint(std::unique_ptr<ConnectedTurnTransport>&& turn, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, std::function<bool(const dht::crypto::Certificate&)>&& cert_check); @@ -181,11 +181,11 @@ public: using Identity = std::pair<std::shared_ptr<dht::crypto::PrivateKey>, std::shared_ptr<dht::crypto::Certificate>>; - TlsSocketEndpoint(AbstractSocketEndpoint& tr, + TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, const dht::crypto::Certificate& peer_cert); - TlsSocketEndpoint(AbstractSocketEndpoint& tr, + TlsSocketEndpoint(std::unique_ptr<AbstractSocketEndpoint>&& tr, const Identity& local_identity, const std::shared_future<tls::DhParams>& dh_params, std::function<bool(const dht::crypto::Certificate&)>&& cert_check); diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index 9279ab06e5be7a203ea55849c196a7b8c5232658..2522add27fc775e33622031d4493070d0373478e 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -179,14 +179,14 @@ public: const TlsSessionCallbacks callbacks_; const bool anonymous_; - TlsSessionImpl(SocketType& transport, const TlsParams& params, + TlsSessionImpl(std::unique_ptr<SocketType>&& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, bool anonymous); ~TlsSessionImpl(); const char* typeName() const; - SocketType& transport_; + std::unique_ptr<SocketType> transport_; // State protectors std::mutex stateMutex_; @@ -267,15 +267,15 @@ public: void pathMtuHeartbeat(); }; -TlsSession::TlsSessionImpl::TlsSessionImpl(SocketType& transport, +TlsSession::TlsSessionImpl::TlsSessionImpl(std::unique_ptr<SocketType>&& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, bool anonymous) - : isServer_(not transport.isInitiator()) + : isServer_(not transport->isInitiator()) , params_(params) , callbacks_(cbs) , anonymous_(anonymous) - , transport_ { transport } + , transport_ { std::move(transport) } , cacred_(nullptr) , sacred_(nullptr) , xcred_(nullptr) @@ -283,8 +283,8 @@ TlsSession::TlsSessionImpl::TlsSessionImpl(SocketType& transport, [this] { process(); }, [this] { cleanup(); }) { - if (not transport_.isReliable()) { - transport_.setOnRecv([this](const ValueType* buf, size_t len) { + if (not transport_->isReliable()) { + transport_->setOnRecv([this](const ValueType* buf, size_t len) { std::lock_guard<std::mutex> lk {rxMutex_}; if (rxQueue_.size() == INPUT_MAX_SIZE) { rxQueue_.pop_front(); // drop oldest packet if input buffer is full @@ -308,8 +308,8 @@ TlsSession::TlsSessionImpl::~TlsSessionImpl() stateCondition_.notify_all(); rxCv_.notify_all(); thread_.join(); - if (not transport_.isReliable()) - transport_.setOnRecv(nullptr); + if (not transport_->isReliable()) + transport_->setOnRecv(nullptr); } const char* @@ -331,7 +331,7 @@ TlsSession::TlsSessionImpl::setupClient() { int ret; - if (not transport_.isReliable()) { + if (not transport_->isReliable()) { ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM); // uncoment to reactivate PMTUD // JAMI_DBG("[TLS] set heartbeat reception for retrocompatibility check on server"); @@ -357,7 +357,7 @@ TlsSession::TlsSessionImpl::setupServer() { int ret; - if (not transport_.isReliable()) { + if (not transport_->isReliable()) { ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM); // uncoment to reactivate PMTUD @@ -473,7 +473,7 @@ TlsSession::TlsSessionImpl::commonSessionInit() if (anonymous_) { // Force anonymous connection, see handleStateHandshake how we handle failures ret = gnutls_priority_set_direct(session_, - transport_.isReliable() ? TLS_FULL_PRIORITY_STRING : DTLS_FULL_PRIORITY_STRING, + transport_->isReliable() ? TLS_FULL_PRIORITY_STRING : DTLS_FULL_PRIORITY_STRING, nullptr); if (ret != GNUTLS_E_SUCCESS) { JAMI_ERR("[TLS] TLS priority set failed: %s", gnutls_strerror(ret)); @@ -493,7 +493,7 @@ TlsSession::TlsSessionImpl::commonSessionInit() } else { // Use a classic non-encrypted CERTIFICATE exchange method (less anonymous) ret = gnutls_priority_set_direct(session_, - transport_.isReliable() ? TLS_CERT_PRIORITY_STRING : DTLS_CERT_PRIORITY_STRING, + transport_->isReliable() ? TLS_CERT_PRIORITY_STRING : DTLS_CERT_PRIORITY_STRING, nullptr); if (ret != GNUTLS_E_SUCCESS) { JAMI_ERR("[TLS] TLS priority set failed: %s", gnutls_strerror(ret)); @@ -509,14 +509,14 @@ TlsSession::TlsSessionImpl::commonSessionInit() } gnutls_certificate_send_x509_rdn_sequence(session_, 0); - if (not transport_.isReliable()) { + if (not transport_->isReliable()) { // DTLS hanshake timeouts auto re_tx_timeout = duration2ms(DTLS_RETRANSMIT_TIMEOUT); gnutls_dtls_set_timeouts(session_, re_tx_timeout, std::max(duration2ms(params_.timeout), re_tx_timeout)); // gnutls DTLS mtu = maximum payload size given by transport - gnutls_dtls_set_mtu(session_, transport_.maxPayload()); + gnutls_dtls_set_mtu(session_, transport_->maxPayload()); } // Stuff for transport callbacks @@ -553,7 +553,7 @@ TlsSession::TlsSessionImpl::send(const ValueType* tx_data, std::size_t tx_size, std::size_t total_written = 0; std::size_t max_tx_sz; - if (transport_.isReliable()) + if (transport_->isReliable()) max_tx_sz = tx_size; else max_tx_sz = gnutls_dtls_get_data_mtu(session_); @@ -592,7 +592,7 @@ TlsSession::TlsSessionImpl::sendRaw(const void* buf, size_t size) std::error_code ec; unsigned retry_count = 0; do { - auto n = transport_.write(reinterpret_cast<const ValueType*>(buf), size, ec); + auto n = transport_->write(reinterpret_cast<const ValueType*>(buf), size, ec); if (!ec) { // log only on success ++stTxRawPacketCnt_; @@ -639,9 +639,9 @@ TlsSession::TlsSessionImpl::sendRawVec(const giovec_t* iov, int iovcnt) ssize_t TlsSession::TlsSessionImpl::recvRaw(void* buf, size_t size) { - if (transport_.isReliable()) { + if (transport_->isReliable()) { std::error_code ec; - auto count = transport_.read(reinterpret_cast<ValueType*>(buf), size, ec); + auto count = transport_->read(reinterpret_cast<ValueType*>(buf), size, ec); if (!ec) return count; gnutls_transport_set_errno(session_, ec.value()); @@ -667,9 +667,9 @@ TlsSession::TlsSessionImpl::recvRaw(void* buf, size_t size) int TlsSession::TlsSessionImpl::waitForRawData(std::chrono::milliseconds timeout) { - if (transport_.isReliable()) { + if (transport_->isReliable()) { std::error_code ec; - auto err = transport_.waitForData(timeout, ec); + auto err = transport_->waitForData(timeout, ec); if (err <= 0) { // shutdown? if (state_ == TlsSessionState::SHUTDOWN) { @@ -738,7 +738,7 @@ TlsSession::TlsSessionImpl::cleanup() { std::lock_guard<std::mutex> lk(sessionMutex_); if (session_) { - if (transport_.isReliable()) + if (transport_->isReliable()) gnutls_bye(session_, GNUTLS_SHUT_RDWR); else gnutls_bye(session_, GNUTLS_SHUT_WR); // not wait for a peer answer @@ -750,7 +750,7 @@ TlsSession::TlsSessionImpl::cleanup() if (cookie_key_.data) gnutls_free(cookie_key_.data); - transport_.shutdown(); + transport_->shutdown(); } TlsSessionState @@ -771,7 +771,7 @@ TlsSession::TlsSessionImpl::handleStateSetup(UNUSED TlsSessionState state) return setupClient(); // Extra step for DTLS-like transports - if (not transport_.isReliable()) { + if (transport_ and not transport_->isReliable()) { gnutls_key_generate(&cookie_key_, GNUTLS_COOKIE_KEY_SIZE); return TlsSessionState::COOKIE; } @@ -898,7 +898,7 @@ TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) // Re-setup TLS algorithms priority list with only certificate based cipher suites ret = gnutls_priority_set_direct(session_, - transport_.isReliable() ? TLS_CERT_PRIORITY_STRING : DTLS_CERT_PRIORITY_STRING, + transport_ and transport_->isReliable() ? TLS_CERT_PRIORITY_STRING : DTLS_CERT_PRIORITY_STRING, nullptr); if (ret != GNUTLS_E_SUCCESS) { JAMI_ERR("[TLS] session TLS cert-only priority set failed: %s", gnutls_strerror(ret)); @@ -928,13 +928,13 @@ TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) callbacks_.onCertificatesUpdate(local, remote, remote_count); } - return transport_.isReliable() ? TlsSessionState::ESTABLISHED : TlsSessionState::MTU_DISCOVERY; + return transport_ and transport_->isReliable() ? TlsSessionState::ESTABLISHED : TlsSessionState::MTU_DISCOVERY; } TlsSessionState TlsSession::TlsSessionImpl::handleStateMtuDiscovery(UNUSED TlsSessionState state) { - mtuProbe_ = transport_.maxPayload(); + mtuProbe_ = transport_ and transport_->maxPayload(); assert(mtuProbe_ >= MIN_MTU); MTUS_ = {MIN_MTU, std::max((mtuProbe_ + MIN_MTU)/2, MIN_MTU), mtuProbe_}; @@ -990,7 +990,7 @@ TlsSession::TlsSessionImpl::pathMtuHeartbeat() // when the remote (server) has a IPV6 interface selected by ICE, and local (client) has a IPV4 selected, // the path MTU discovery triggers errors for packets too big on server side because of different IP headers overhead. // Hence we have to signal to the TLS session to reduce the MTU on client size accordingly. - if (transport_.localAddr().isIpv4() and transport_.remoteAddr().isIpv6()) { + if (transport_ and transport_->localAddr().isIpv4() and transport_->remoteAddr().isIpv6()) { mtuOffset = ASYMETRIC_TRANSPORT_MTU_OFFSET; JAMI_WARN() << "[TLS] local/remote IP protocol version not alike, use an MTU offset of " << ASYMETRIC_TRANSPORT_MTU_OFFSET << " bytes to compensate"; @@ -1132,7 +1132,7 @@ TlsSessionState TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) { // Nothing to do in reliable mode, so just wait for state change - if (transport_.isReliable()) { + if (transport_ and transport_->isReliable()) { auto disconnected = [this]() -> bool { return state_.load() != TlsSessionState::ESTABLISHED or newState_.load() != TlsSessionState::NONE; @@ -1233,10 +1233,10 @@ TlsSession::TlsSessionImpl::process() //============================================================================== -TlsSession::TlsSession(SocketType& transport, const TlsParams& params, +TlsSession::TlsSession(std::unique_ptr<SocketType>&& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, bool anonymous) - : pimpl_ { std::make_unique<TlsSessionImpl>(transport, params, cbs, anonymous) } + : pimpl_ { std::make_unique<TlsSessionImpl>(std::move(transport), params, cbs, anonymous) } {} TlsSession::~TlsSession() @@ -1251,7 +1251,9 @@ TlsSession::isInitiator() const bool TlsSession::isReliable() const { - return pimpl_->transport_.isReliable(); + if (!pimpl_->transport_) + return false; + return pimpl_->transport_->isReliable(); } int @@ -1259,7 +1261,9 @@ TlsSession::maxPayload() const { if (pimpl_->state_ == TlsSessionState::SHUTDOWN) throw std::runtime_error("Getting maxPayload from non-valid TLS session"); - return pimpl_->transport_.maxPayload(); + if (!pimpl_->transport_) + return 0; + return pimpl_->transport_->maxPayload(); } const char* @@ -1378,7 +1382,11 @@ TlsSession::waitForReady(const duration& timeout) int TlsSession::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const { - if (!pimpl_->transport_.waitForData(timeout, ec)) + if (!pimpl_->transport_) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + if (!pimpl_->transport_->waitForData(timeout, ec)) return 0; return 1; } diff --git a/src/security/tls_session.h b/src/security/tls_session.h index e61e1491156c24a2093b9cd4b6eaab50e8cc7a01..284dfb6f334a42cdb4360816633d0e19e88c976f 100644 --- a/src/security/tls_session.h +++ b/src/security/tls_session.h @@ -108,7 +108,7 @@ public: VerifyCertificate verifyCertificate; }; - TlsSession(SocketType& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, + TlsSession(std::unique_ptr<SocketType>&& transport, const TlsParams& params, const TlsSessionCallbacks& cbs, bool anonymous=true); ~TlsSession();