From b25ecfefc49a63792d6240a9151a6dfdf6c0f9a3 Mon Sep 17 00:00:00 2001 From: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> Date: Mon, 29 Jan 2018 11:44:13 -0500 Subject: [PATCH] =?UTF-8?q?datatransfer:=20detect=20TCP=C2=A0RST=20event?= =?UTF-8?q?=20at=20initiator=20side?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To dectect TCP RST event at initiator side this patch does following actions: * add waitForData() implementation everywhere * forward transport errors by TLS session. * use waitForData()/read() inside PeerImplementation eventloop to detect read() broken pipe error transmitted by TLS. * ignore SIGPIPE signal (detected by read now) to not stop the application. Change-Id: Ia5721e11ce52ba606a5395ecda3122b64f4afa6d Reviewed-by: Olivier Soldano <olivier.soldano@savoirfairelinux.com> --- bin/main.cpp | 1 + src/data_transfer.h | 3 +- src/generic_io.h | 9 ++++- src/ice_socket.h | 2 +- src/ice_transport.cpp | 12 +++--- src/ice_transport.h | 2 +- src/peer_connection.cpp | 76 +++++++++++++++++++++++++++--------- src/peer_connection.h | 12 ++---- src/ringdht/p2p.cpp | 7 ++-- src/security/tls_session.cpp | 22 +++++++++-- src/security/tls_session.h | 4 +- src/turn_transport.cpp | 11 +++--- src/turn_transport.h | 4 +- 13 files changed, 110 insertions(+), 55 deletions(-) diff --git a/bin/main.cpp b/bin/main.cpp index 5cada71531..9229c4f928 100644 --- a/bin/main.cpp +++ b/bin/main.cpp @@ -171,6 +171,7 @@ signal_handler(int code) signal(SIGHUP, SIG_DFL); signal(SIGINT, SIG_DFL); signal(SIGTERM, SIG_DFL); + signal(SIGPIPE, SIG_IGN); // Interrupt the process #if REST_API diff --git a/src/data_transfer.h b/src/data_transfer.h index 2b42ccd351..6a74bff6c6 100644 --- a/src/data_transfer.h +++ b/src/data_transfer.h @@ -70,8 +70,7 @@ public: std::streamsize bytesProgress(const DRing::DataTransferId& id) const; /// Create an IncomingFileTransfer object. - /// \return a filename to open where incoming data will be written or an empty string - /// in case of refusal. + /// \return a shared pointer on created Stream object, or nullptr in case of error std::shared_ptr<Stream> onIncomingFileRequest(const std::string& account_id, const std::string& peer_uri, const std::string& display_name, diff --git a/src/generic_io.h b/src/generic_io.h index 4147959c92..e1e62e7f5a 100644 --- a/src/generic_io.h +++ b/src/generic_io.h @@ -66,8 +66,13 @@ public: /// this value gives the maximal size used to send one packet. virtual int maxPayload() const = 0; - // TODO: make a std::chrono version - virtual bool waitForData(unsigned ms_timeout) const = 0; + /// Wait until data to read available, timeout or io error + /// \param ec error code set in case of error (if return value is < 0) + /// \return positive number if data ready for read, 0 in case of timeout or error. + /// \note error code is not set in case of timeout, but set only in case of io error + /// (i.e. socket deconnection). + /// \todo make a std::chrono version for the timeout + virtual int waitForData(unsigned ms_timeout, std::error_code& ec) const = 0; /// Write a given amount of data. /// \param buf data to write. diff --git a/src/ice_socket.h b/src/ice_socket.h index b05ae3f6a0..b131362d29 100644 --- a/src/ice_socket.h +++ b/src/ice_socket.h @@ -79,7 +79,7 @@ public: int maxPayload() const override; - bool waitForData(unsigned ms_timeout) const override; + int waitForData(unsigned ms_timeout, std::error_code& ec) const override; std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override; diff --git a/src/ice_transport.cpp b/src/ice_transport.cpp index 4e02c26378..d7140adbc5 100644 --- a/src/ice_transport.cpp +++ b/src/ice_transport.cpp @@ -1113,8 +1113,9 @@ IceTransport::waitForNegotiation(unsigned timeout) } ssize_t -IceTransport::waitForData(int comp_id, unsigned int timeout) +IceTransport::waitForData(int comp_id, unsigned int timeout, std::error_code& ec) { + (void)ec; ///< \todo handle errors auto& io = pimpl_->compIO_[comp_id]; std::unique_lock<std::mutex> lk(io.mutex); if (!io.cv.wait_for(lk, std::chrono::milliseconds(timeout), @@ -1196,10 +1197,10 @@ IceSocketTransport::maxPayload() const return STANDARD_MTU_SIZE - ip_header_size - UDP_HEADER_SIZE; } -bool -IceSocketTransport::waitForData(unsigned ms_timeout) const +int +IceSocketTransport::waitForData(unsigned ms_timeout, std::error_code& ec) const { - return ice_->waitForData(compId_, ms_timeout) > 0; + return ice_->waitForData(compId_, ms_timeout, ec); } std::size_t @@ -1268,7 +1269,8 @@ IceSocket::waitForData(unsigned int timeout) if (!ice_transport_.get()) return -1; - return ice_transport_->waitForData(compId_, timeout); + std::error_code ec; + return ice_transport_->waitForData(compId_, timeout, ec); } void diff --git a/src/ice_transport.h b/src/ice_transport.h index 836ca585b3..c795970834 100644 --- a/src/ice_transport.h +++ b/src/ice_transport.h @@ -168,7 +168,7 @@ public: int waitForNegotiation(unsigned timeout); - ssize_t waitForData(int comp_id, unsigned int timeout); + ssize_t waitForData(int comp_id, unsigned int timeout, std::error_code& ec); unsigned getComponentCount() const; diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp index 857aea5bed..4c39eb67ea 100644 --- a/src/peer_connection.cpp +++ b/src/peer_connection.cpp @@ -210,6 +210,12 @@ TlsTurnEndpoint::peerCertificate() const return pimpl_->peerCertificate; } +int +TlsTurnEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const +{ + return pimpl_->tls->waitForData(ms_timeout, ec); +} + //============================================================================== TcpSocketEndpoint::TcpSocketEndpoint(const IpAddr& addr) @@ -236,23 +242,29 @@ TcpSocketEndpoint::connect() throw std::system_error(errno, std::generic_category()); } -bool -TcpSocketEndpoint::waitForData(unsigned ms_timeout) const -{ - struct timeval tv; - tv.tv_sec = ms_timeout / 1000; - tv.tv_usec = (ms_timeout % 1000) * 1000; - - fd_set read_fds; - FD_ZERO(&read_fds); - FD_SET(sock_, &read_fds); - - while (::select(sock_ + 1, &read_fds, nullptr, nullptr, &tv) >= 0) { +int +TcpSocketEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const +{ + for (;;) { + struct timeval tv; + tv.tv_sec = ms_timeout / 1000; + tv.tv_usec = (ms_timeout % 1000) * 1000; + + fd_set read_fds; + FD_ZERO(&read_fds); + FD_SET(sock_, &read_fds); + + auto res = ::select(sock_ + 1, &read_fds, nullptr, nullptr, &tv); + if (res < 0) + break; + if (res == 0) + return 0; // timeout if (FD_ISSET(sock_, &read_fds)) - return true; + return 1; } - return false; + ec.assign(errno, std::generic_category()); + return -1; } std::size_t @@ -392,6 +404,12 @@ TlsSocketEndpoint::connect() pimpl_->tls->connect(); } +int +TlsSocketEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const +{ + return pimpl_->tls->waitForData(ms_timeout, ec); +} + //============================================================================== // following namespace prevents an ODR violation with definitions in p2p.cpp @@ -440,12 +458,20 @@ struct AttachOutputCtrlMsg final : CtrlMsg class PeerConnection::PeerConnectionImpl { public: - PeerConnectionImpl(Account& account, const std::string& peer_uri, + PeerConnectionImpl(std::function<void()>&& done, + Account& account, const std::string& peer_uri, std::unique_ptr<SocketType> endpoint) : account {account} , peer_uri {peer_uri} , endpoint_ {std::move(endpoint)} - , eventLoopFut_ {std::async(std::launch::async, [this]{ eventLoop();})} {} + , eventLoopFut_ {std::async(std::launch::async, [this, done=std::move(done)] { + try { + eventLoop(); + } catch (const std::exception& e) { + RING_ERR() << "[CNX] peer connection event loop failure: " << e.what(); + done(); + } + })} {} ~PeerConnectionImpl() { ctrlChannel << std::make_unique<StopCtrlMsg>(); @@ -497,7 +523,18 @@ PeerConnection::PeerConnectionImpl::eventLoop() while (true) { std::unique_ptr<CtrlMsg> msg; if (outputs_.empty() and inputs_.empty()) { - ctrlChannel >> msg; + if (!ctrlChannel.empty()) { + msg = ctrlChannel.receive(); + } else { + std::error_code ec; + if (endpoint_->waitForData(100, ec) > 0) { + std::vector<uint8_t> buf(IO_BUFFER_SIZE); + endpoint_->read(buf, ec); ///< \todo what to do with data from a good read? + if (ec) + throw std::system_error(ec); + } + break; + } } else if (!ctrlChannel.empty()) { msg = ctrlChannel.receive(); } else @@ -551,9 +588,10 @@ PeerConnection::PeerConnectionImpl::eventLoop() //============================================================================== -PeerConnection::PeerConnection(Account& account, const std::string& peer_uri, +PeerConnection::PeerConnection(std::function<void()>&& done, Account& account, + const std::string& peer_uri, std::unique_ptr<GenericSocket<uint8_t>> endpoint) - : pimpl_(std::make_unique<PeerConnectionImpl>(account, peer_uri, std::move(endpoint))) + : pimpl_(std::make_unique<PeerConnectionImpl>(std::move(done), account, peer_uri, std::move(endpoint))) {} PeerConnection::~PeerConnection() diff --git a/src/peer_connection.h b/src/peer_connection.h index 8901d5837e..af3d6ad2ad 100644 --- a/src/peer_connection.h +++ b/src/peer_connection.h @@ -94,9 +94,7 @@ public: void setOnRecv(RecvCb&&) override { throw std::logic_error("TlsTurnEndpoint::setOnRecv not implemented"); } - bool waitForData(unsigned) const override { - throw std::logic_error("TlsTurnEndpoint::waitForData not implemented"); - } + int waitForData(unsigned, std::error_code&) const override; void connect(); @@ -120,7 +118,7 @@ public: bool isReliable() const override { return true; } bool isInitiator() const override { return true; } int maxPayload() const override { return 1280; } - bool waitForData(unsigned ms_timeout) const override; + int waitForData(unsigned ms_timeout, std::error_code& ec) const override; std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override; std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override; @@ -160,9 +158,7 @@ public: void setOnRecv(RecvCb&&) override { throw std::logic_error("TlsSocketEndpoint::setOnRecv not implemented"); } - bool waitForData(unsigned) const override { - throw std::logic_error("TlsSocketEndpoint::waitForData not implemented"); - } + int waitForData(unsigned, std::error_code&) const override; void connect(); @@ -178,7 +174,7 @@ class PeerConnection public: using SocketType = GenericSocket<uint8_t>; - PeerConnection(Account& account, const std::string& peer_uri, + PeerConnection(std::function<void()>&& done, Account& account, const std::string& peer_uri, std::unique_ptr<SocketType> endpoint); ~PeerConnection(); diff --git a/src/ringdht/p2p.cpp b/src/ringdht/p2p.cpp index b72176d729..496bc9f640 100644 --- a/src/ringdht/p2p.cpp +++ b/src/ringdht/p2p.cpp @@ -332,8 +332,8 @@ private: tls_ep->connect(); // Connected! - connection_ = std::make_unique<PeerConnection>(parent_.account, peer_.toString(), - std::move(tls_ep)); + connection_ = std::make_unique<PeerConnection>([this] { cancel(); }, parent_.account, + peer_.toString(), std::move(tls_ep)); peer_ep_ = std::move(peer_ep); connected_ = true; @@ -435,7 +435,8 @@ DhtPeerConnector::Impl::onTurnPeerConnection(const IpAddr& peer_addr) RING_DBG() << account << "[CNX] Accepted TLS-TURN connection from RingID " << peer_h; connectedPeers_.emplace(peer_addr, tls_ep->peerCertificate().getId()); - auto connection = std::make_unique<PeerConnection>(account, peer_addr.toString(), std::move(tls_ep)); + auto connection = std::make_unique<PeerConnection>([] {}, account, peer_addr.toString(), + std::move(tls_ep)); connection->attachOutputStream(std::make_shared<FtpServer>(account.getAccountID(), peer_h.toString())); servers_.emplace(peer_addr, std::move(connection)); diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index eaee14128b..761f725499 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -642,7 +642,8 @@ int TlsSession::TlsSessionImpl::waitForRawData(unsigned timeout) { if (transport_.isReliable()) { - if (not transport_.waitForData(timeout)) { + std::error_code ec; + if (transport_.waitForData(timeout, ec) <= 0) { // shutdown? if (state_ == TlsSessionState::SHUTDOWN) { gnutls_transport_set_errno(session_, EINTR); @@ -1069,9 +1070,14 @@ TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) { // Nothing to do in reliable mode, so just wait for state change if (transport_.isReliable()) { - std::unique_lock<std::mutex> lk {rxMutex_}; - rxCv_.wait(lk, [this]{ return state_ != TlsSessionState::ESTABLISHED; }); - return state; + std::error_code ec; + do { + transport_.waitForData(100, ec); + state = state_.load(); + if (state != TlsSessionState::ESTABLISHED) + return state; + } while (!ec); + return TlsSessionState::SHUTDOWN; } // block until rx packet or state change @@ -1276,4 +1282,12 @@ TlsSession::connect() } while (state != TlsSessionState::ESTABLISHED and state != TlsSessionState::SHUTDOWN); } +int +TlsSession::waitForData(unsigned ms_timeout, std::error_code& ec) const +{ + if (!pimpl_->transport_.waitForData(ms_timeout, ec)) + return 0; + return 1; +} + }} // namespace ring::tls diff --git a/src/security/tls_session.h b/src/security/tls_session.h index 15ffe5a42f..750ce505f5 100644 --- a/src/security/tls_session.h +++ b/src/security/tls_session.h @@ -138,9 +138,7 @@ public: /// Return a positive number for number of bytes read, or 0 and \a ec set in case of error. std::size_t read(ValueType* data, std::size_t size, std::error_code& ec) override; - bool waitForData(unsigned) const override { - throw std::logic_error("TlsSession::waitForData not implemented"); - } + int waitForData(unsigned, std::error_code&) const override; private: class TlsSessionImpl; diff --git a/src/turn_transport.cpp b/src/turn_transport.cpp index 8c67dfa3a4..7c27052390 100644 --- a/src/turn_transport.cpp +++ b/src/turn_transport.cpp @@ -456,9 +456,10 @@ TurnTransport::peerAddresses() const return map_utils::extractKeys(pimpl_->peerChannels_); } -bool -TurnTransport::waitForData(const IpAddr& peer, unsigned ms_timeout) const +int +TurnTransport::waitForData(const IpAddr& peer, unsigned ms_timeout, std::error_code& ec) const { + (void)ec; ///< \todo handle errors MutexLock lk {pimpl_->apiMutex_}; auto& channel = pimpl_->peerChannels_.at(peer); lk.unlock(); @@ -478,10 +479,10 @@ ConnectedTurnTransport::shutdown() turn_.shutdown(peer_); } -bool -ConnectedTurnTransport::waitForData(unsigned ms_timeout) const +int +ConnectedTurnTransport::waitForData(unsigned ms_timeout, std::error_code& ec) const { - return turn_.waitForData(peer_, ms_timeout); + return turn_.waitForData(peer_, ms_timeout, ec); } std::size_t diff --git a/src/turn_transport.h b/src/turn_transport.h index 3101f94ec6..46424deb88 100644 --- a/src/turn_transport.h +++ b/src/turn_transport.h @@ -133,7 +133,7 @@ public: /// bool sendto(const IpAddr& peer, const char* const buffer, std::size_t size); - bool waitForData(const IpAddr& peer, unsigned ms_timeout) const; + int waitForData(const IpAddr& peer, unsigned ms_timeout, std::error_code& ec) const; public: // Move semantic only, not copiable @@ -157,7 +157,7 @@ public: bool isInitiator() const override { return turn_.isInitiator(); } int maxPayload() const override { return 3000; } - bool waitForData(unsigned ms_timeout) const override; + int waitForData(unsigned ms_timeout, std::error_code& ec) const override; std::size_t read(ValueType* buf, std::size_t length, std::error_code& ec) override; std::size_t write(const ValueType* buf, std::size_t length, std::error_code& ec) override; -- GitLab