Commit b25ecfef authored by Guillaume Roguez's avatar Guillaume Roguez Committed by Olivier SOLDANO

datatransfer: detect TCP RST event at initiator side

To dectect TCP RST event at initiator side this patch does
following actions:

* add waitForData() implementation everywhere
* forward transport errors by TLS session.
* use waitForData()/read() inside PeerImplementation eventloop
  to detect read() broken pipe error transmitted by TLS.
* ignore SIGPIPE signal (detected by read now) to not stop the application.

Change-Id: Ia5721e11ce52ba606a5395ecda3122b64f4afa6d
Reviewed-by: default avatarOlivier Soldano <olivier.soldano@savoirfairelinux.com>
parent 71771820
......@@ -171,6 +171,7 @@ signal_handler(int code)
signal(SIGHUP, SIG_DFL);
signal(SIGINT, SIG_DFL);
signal(SIGTERM, SIG_DFL);
signal(SIGPIPE, SIG_IGN);
// Interrupt the process
#if REST_API
......
......@@ -70,8 +70,7 @@ public:
std::streamsize bytesProgress(const DRing::DataTransferId& id) const;
/// Create an IncomingFileTransfer object.
/// \return a filename to open where incoming data will be written or an empty string
/// in case of refusal.
/// \return a shared pointer on created Stream object, or nullptr in case of error
std::shared_ptr<Stream> onIncomingFileRequest(const std::string& account_id,
const std::string& peer_uri,
const std::string& display_name,
......
......@@ -66,8 +66,13 @@ public:
/// this value gives the maximal size used to send one packet.
virtual int maxPayload() const = 0;
// TODO: make a std::chrono version
virtual bool waitForData(unsigned ms_timeout) const = 0;
/// Wait until data to read available, timeout or io error
/// \param ec error code set in case of error (if return value is < 0)
/// \return positive number if data ready for read, 0 in case of timeout or error.
/// \note error code is not set in case of timeout, but set only in case of io error
/// (i.e. socket deconnection).
/// \todo make a std::chrono version for the timeout
virtual int waitForData(unsigned ms_timeout, std::error_code& ec) const = 0;
/// Write a given amount of data.
/// \param buf data to write.
......
......@@ -79,7 +79,7 @@ public:
int maxPayload() const override;
bool waitForData(unsigned ms_timeout) const override;
int waitForData(unsigned ms_timeout, std::error_code& ec) const override;
std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
......
......@@ -1113,8 +1113,9 @@ IceTransport::waitForNegotiation(unsigned timeout)
}
ssize_t
IceTransport::waitForData(int comp_id, unsigned int timeout)
IceTransport::waitForData(int comp_id, unsigned int timeout, std::error_code& ec)
{
(void)ec; ///< \todo handle errors
auto& io = pimpl_->compIO_[comp_id];
std::unique_lock<std::mutex> lk(io.mutex);
if (!io.cv.wait_for(lk, std::chrono::milliseconds(timeout),
......@@ -1196,10 +1197,10 @@ IceSocketTransport::maxPayload() const
return STANDARD_MTU_SIZE - ip_header_size - UDP_HEADER_SIZE;
}
bool
IceSocketTransport::waitForData(unsigned ms_timeout) const
int
IceSocketTransport::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
return ice_->waitForData(compId_, ms_timeout) > 0;
return ice_->waitForData(compId_, ms_timeout, ec);
}
std::size_t
......@@ -1268,7 +1269,8 @@ IceSocket::waitForData(unsigned int timeout)
if (!ice_transport_.get())
return -1;
return ice_transport_->waitForData(compId_, timeout);
std::error_code ec;
return ice_transport_->waitForData(compId_, timeout, ec);
}
void
......
......@@ -168,7 +168,7 @@ public:
int waitForNegotiation(unsigned timeout);
ssize_t waitForData(int comp_id, unsigned int timeout);
ssize_t waitForData(int comp_id, unsigned int timeout, std::error_code& ec);
unsigned getComponentCount() const;
......
......@@ -210,6 +210,12 @@ TlsTurnEndpoint::peerCertificate() const
return pimpl_->peerCertificate;
}
int
TlsTurnEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
return pimpl_->tls->waitForData(ms_timeout, ec);
}
//==============================================================================
TcpSocketEndpoint::TcpSocketEndpoint(const IpAddr& addr)
......@@ -236,23 +242,29 @@ TcpSocketEndpoint::connect()
throw std::system_error(errno, std::generic_category());
}
bool
TcpSocketEndpoint::waitForData(unsigned ms_timeout) const
{
struct timeval tv;
tv.tv_sec = ms_timeout / 1000;
tv.tv_usec = (ms_timeout % 1000) * 1000;
fd_set read_fds;
FD_ZERO(&read_fds);
FD_SET(sock_, &read_fds);
while (::select(sock_ + 1, &read_fds, nullptr, nullptr, &tv) >= 0) {
int
TcpSocketEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
for (;;) {
struct timeval tv;
tv.tv_sec = ms_timeout / 1000;
tv.tv_usec = (ms_timeout % 1000) * 1000;
fd_set read_fds;
FD_ZERO(&read_fds);
FD_SET(sock_, &read_fds);
auto res = ::select(sock_ + 1, &read_fds, nullptr, nullptr, &tv);
if (res < 0)
break;
if (res == 0)
return 0; // timeout
if (FD_ISSET(sock_, &read_fds))
return true;
return 1;
}
return false;
ec.assign(errno, std::generic_category());
return -1;
}
std::size_t
......@@ -392,6 +404,12 @@ TlsSocketEndpoint::connect()
pimpl_->tls->connect();
}
int
TlsSocketEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
return pimpl_->tls->waitForData(ms_timeout, ec);
}
//==============================================================================
// following namespace prevents an ODR violation with definitions in p2p.cpp
......@@ -440,12 +458,20 @@ struct AttachOutputCtrlMsg final : CtrlMsg
class PeerConnection::PeerConnectionImpl
{
public:
PeerConnectionImpl(Account& account, const std::string& peer_uri,
PeerConnectionImpl(std::function<void()>&& done,
Account& account, const std::string& peer_uri,
std::unique_ptr<SocketType> endpoint)
: account {account}
, peer_uri {peer_uri}
, endpoint_ {std::move(endpoint)}
, eventLoopFut_ {std::async(std::launch::async, [this]{ eventLoop();})} {}
, eventLoopFut_ {std::async(std::launch::async, [this, done=std::move(done)] {
try {
eventLoop();
} catch (const std::exception& e) {
RING_ERR() << "[CNX] peer connection event loop failure: " << e.what();
done();
}
})} {}
~PeerConnectionImpl() {
ctrlChannel << std::make_unique<StopCtrlMsg>();
......@@ -497,7 +523,18 @@ PeerConnection::PeerConnectionImpl::eventLoop()
while (true) {
std::unique_ptr<CtrlMsg> msg;
if (outputs_.empty() and inputs_.empty()) {
ctrlChannel >> msg;
if (!ctrlChannel.empty()) {
msg = ctrlChannel.receive();
} else {
std::error_code ec;
if (endpoint_->waitForData(100, ec) > 0) {
std::vector<uint8_t> buf(IO_BUFFER_SIZE);
endpoint_->read(buf, ec); ///< \todo what to do with data from a good read?
if (ec)
throw std::system_error(ec);
}
break;
}
} else if (!ctrlChannel.empty()) {
msg = ctrlChannel.receive();
} else
......@@ -551,9 +588,10 @@ PeerConnection::PeerConnectionImpl::eventLoop()
//==============================================================================
PeerConnection::PeerConnection(Account& account, const std::string& peer_uri,
PeerConnection::PeerConnection(std::function<void()>&& done, Account& account,
const std::string& peer_uri,
std::unique_ptr<GenericSocket<uint8_t>> endpoint)
: pimpl_(std::make_unique<PeerConnectionImpl>(account, peer_uri, std::move(endpoint)))
: pimpl_(std::make_unique<PeerConnectionImpl>(std::move(done), account, peer_uri, std::move(endpoint)))
{}
PeerConnection::~PeerConnection()
......
......@@ -94,9 +94,7 @@ public:
void setOnRecv(RecvCb&&) override {
throw std::logic_error("TlsTurnEndpoint::setOnRecv not implemented");
}
bool waitForData(unsigned) const override {
throw std::logic_error("TlsTurnEndpoint::waitForData not implemented");
}
int waitForData(unsigned, std::error_code&) const override;
void connect();
......@@ -120,7 +118,7 @@ public:
bool isReliable() const override { return true; }
bool isInitiator() const override { return true; }
int maxPayload() const override { return 1280; }
bool waitForData(unsigned ms_timeout) const override;
int waitForData(unsigned ms_timeout, std::error_code& ec) const override;
std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override;
std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
......@@ -160,9 +158,7 @@ public:
void setOnRecv(RecvCb&&) override {
throw std::logic_error("TlsSocketEndpoint::setOnRecv not implemented");
}
bool waitForData(unsigned) const override {
throw std::logic_error("TlsSocketEndpoint::waitForData not implemented");
}
int waitForData(unsigned, std::error_code&) const override;
void connect();
......@@ -178,7 +174,7 @@ class PeerConnection
public:
using SocketType = GenericSocket<uint8_t>;
PeerConnection(Account& account, const std::string& peer_uri,
PeerConnection(std::function<void()>&& done, Account& account, const std::string& peer_uri,
std::unique_ptr<SocketType> endpoint);
~PeerConnection();
......
......@@ -332,8 +332,8 @@ private:
tls_ep->connect();
// Connected!
connection_ = std::make_unique<PeerConnection>(parent_.account, peer_.toString(),
std::move(tls_ep));
connection_ = std::make_unique<PeerConnection>([this] { cancel(); }, parent_.account,
peer_.toString(), std::move(tls_ep));
peer_ep_ = std::move(peer_ep);
connected_ = true;
......@@ -435,7 +435,8 @@ DhtPeerConnector::Impl::onTurnPeerConnection(const IpAddr& peer_addr)
RING_DBG() << account << "[CNX] Accepted TLS-TURN connection from RingID " << peer_h;
connectedPeers_.emplace(peer_addr, tls_ep->peerCertificate().getId());
auto connection = std::make_unique<PeerConnection>(account, peer_addr.toString(), std::move(tls_ep));
auto connection = std::make_unique<PeerConnection>([] {}, account, peer_addr.toString(),
std::move(tls_ep));
connection->attachOutputStream(std::make_shared<FtpServer>(account.getAccountID(), peer_h.toString()));
servers_.emplace(peer_addr, std::move(connection));
......
......@@ -642,7 +642,8 @@ int
TlsSession::TlsSessionImpl::waitForRawData(unsigned timeout)
{
if (transport_.isReliable()) {
if (not transport_.waitForData(timeout)) {
std::error_code ec;
if (transport_.waitForData(timeout, ec) <= 0) {
// shutdown?
if (state_ == TlsSessionState::SHUTDOWN) {
gnutls_transport_set_errno(session_, EINTR);
......@@ -1069,9 +1070,14 @@ TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state)
{
// Nothing to do in reliable mode, so just wait for state change
if (transport_.isReliable()) {
std::unique_lock<std::mutex> lk {rxMutex_};
rxCv_.wait(lk, [this]{ return state_ != TlsSessionState::ESTABLISHED; });
return state;
std::error_code ec;
do {
transport_.waitForData(100, ec);
state = state_.load();
if (state != TlsSessionState::ESTABLISHED)
return state;
} while (!ec);
return TlsSessionState::SHUTDOWN;
}
// block until rx packet or state change
......@@ -1276,4 +1282,12 @@ TlsSession::connect()
} while (state != TlsSessionState::ESTABLISHED and state != TlsSessionState::SHUTDOWN);
}
int
TlsSession::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
if (!pimpl_->transport_.waitForData(ms_timeout, ec))
return 0;
return 1;
}
}} // namespace ring::tls
......@@ -138,9 +138,7 @@ public:
/// Return a positive number for number of bytes read, or 0 and \a ec set in case of error.
std::size_t read(ValueType* data, std::size_t size, std::error_code& ec) override;
bool waitForData(unsigned) const override {
throw std::logic_error("TlsSession::waitForData not implemented");
}
int waitForData(unsigned, std::error_code&) const override;
private:
class TlsSessionImpl;
......
......@@ -456,9 +456,10 @@ TurnTransport::peerAddresses() const
return map_utils::extractKeys(pimpl_->peerChannels_);
}
bool
TurnTransport::waitForData(const IpAddr& peer, unsigned ms_timeout) const
int
TurnTransport::waitForData(const IpAddr& peer, unsigned ms_timeout, std::error_code& ec) const
{
(void)ec; ///< \todo handle errors
MutexLock lk {pimpl_->apiMutex_};
auto& channel = pimpl_->peerChannels_.at(peer);
lk.unlock();
......@@ -478,10 +479,10 @@ ConnectedTurnTransport::shutdown()
turn_.shutdown(peer_);
}
bool
ConnectedTurnTransport::waitForData(unsigned ms_timeout) const
int
ConnectedTurnTransport::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
return turn_.waitForData(peer_, ms_timeout);
return turn_.waitForData(peer_, ms_timeout, ec);
}
std::size_t
......
......@@ -133,7 +133,7 @@ public:
///
bool sendto(const IpAddr& peer, const char* const buffer, std::size_t size);
bool waitForData(const IpAddr& peer, unsigned ms_timeout) const;
int waitForData(const IpAddr& peer, unsigned ms_timeout, std::error_code& ec) const;
public:
// Move semantic only, not copiable
......@@ -157,7 +157,7 @@ public:
bool isInitiator() const override { return turn_.isInitiator(); }
int maxPayload() const override { return 3000; }
bool waitForData(unsigned ms_timeout) const override;
int waitForData(unsigned ms_timeout, std::error_code& ec) const override;
std::size_t read(ValueType* buf, std::size_t length, std::error_code& ec) override;
std::size_t write(const ValueType* buf, std::size_t length, std::error_code& ec) override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment