diff --git a/src/generic_io.h b/src/generic_io.h index 8e0ae80eb9d2a5c16040f0b7b6224215d72f012e..097e71501ab99cd22e0fa4bcb753d03108f162ba 100644 --- a/src/generic_io.h +++ b/src/generic_io.h @@ -38,10 +38,14 @@ class GenericSocket public: using ValueType = T; - virtual ~GenericSocket() = default; + virtual ~GenericSocket() { shutdown(); } using RecvCb = std::function<ssize_t(const ValueType* buf, std::size_t len)>; + /// Close established connection + /// \note Terminate oustanding blocking read operations with an empty error code, but a 0 read size. + virtual void shutdown() {} + /// Set Rx callback /// \warning This method is here for backward compatibility /// and because async IO are not implemented yet. diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp index 4b0fa7b81888bbff5e28ff86df271c11d131935f..bd62da9a0e03a3bb7dc1bfe0f8983469ff8a9824 100644 --- a/src/peer_connection.cpp +++ b/src/peer_connection.cpp @@ -174,6 +174,12 @@ TlsTurnEndpoint::TlsTurnEndpoint(ConnectedTurnTransport& turn_ep, TlsTurnEndpoint::~TlsTurnEndpoint() = default; +void +TlsTurnEndpoint::shutdown() +{ + pimpl_->tls->shutdown(); +} + bool TlsTurnEndpoint::isInitiator() const { @@ -444,6 +450,8 @@ public: ~PeerConnectionImpl() { ctrlChannel << std::make_unique<StopCtrlMsg>(); + endpoint_->shutdown(); + } const Account& account; @@ -514,9 +522,6 @@ PeerConnection::PeerConnectionImpl::eventLoop() break; case CtrlMsgType::STOP: - endpoint_.reset(); - inputs_.clear(); - outputs_.clear(); return; default: RING_ERR("BUG: got unhandled control msg!"); break; @@ -526,7 +531,7 @@ PeerConnection::PeerConnectionImpl::eventLoop() // Then handles IO streams std::vector<uint8_t> buf(IO_BUFFER_SIZE); std::error_code ec; - handle_stream_list(inputs_, [&](auto& stream){ + handle_stream_list(inputs_, [&](auto& stream) { if (!stream->read(buf)) return false; auto size = endpoint_->write(buf, ec); @@ -536,9 +541,11 @@ PeerConnection::PeerConnectionImpl::eventLoop() return false; throw std::system_error(ec); }); - handle_stream_list(outputs_, [&](auto& stream){ - endpoint_->read(buf, ec); - return buf.size() != 0 and stream->write(buf); + handle_stream_list(outputs_, [&](auto& stream) { + auto size = endpoint_->read(buf, ec); + if (!ec) + return size > 0 and stream->write(buf); + throw std::system_error(ec); }); } } diff --git a/src/peer_connection.h b/src/peer_connection.h index 289ff0bbc7bc84cebddd73643e1d4c7d47986521..bdf4c517041ebc14c8897ccce2edbb38ea2e2913 100644 --- a/src/peer_connection.h +++ b/src/peer_connection.h @@ -79,6 +79,7 @@ public: dht::crypto::TrustList& trust_list); ~TlsTurnEndpoint(); + void shutdown() override; bool isReliable() const override { return true; } bool isInitiator() const override; int maxPayload() const override; diff --git a/src/ringdht/p2p.cpp b/src/ringdht/p2p.cpp index bd652699a76d40ff3f3a89f1db85a533bfce8a8d..95ab147cabeb57eada4afcb8cb47ab3b493200b6 100644 --- a/src/ringdht/p2p.cpp +++ b/src/ringdht/p2p.cpp @@ -190,7 +190,12 @@ public: : account {account} , loopFut_ {std::async(std::launch::async, [this]{ eventLoop(); })} {} - ~Impl() { ctrl << makeMsg<CtrlMsgType::STOP>(); } + ~Impl() { + servers_.clear(); + clients_.clear(); + turn_.reset(); + ctrl << makeMsg<CtrlMsgType::STOP>(); + } RingAccount& account; Channel<std::unique_ptr<CtrlMsgBase>> ctrl; @@ -208,8 +213,7 @@ private: void onTurnPeerConnection(const IpAddr&); void onTurnPeerDisconnection(const IpAddr&); void onRequestMsg(PeerConnectionMsg&&); - void onTrustedRequestMsg(PeerConnectionMsg&&, - const std::shared_ptr<dht::crypto::Certificate>&); + void onTrustedRequestMsg(PeerConnectionMsg&&, const std::shared_ptr<dht::crypto::Certificate>&); void onResponseMsg(PeerConnectionMsg&&); void onAddDevice(const dht::InfoHash&, const std::shared_ptr<dht::crypto::Certificate>&, @@ -486,11 +490,10 @@ DhtPeerConnector::Impl::eventLoop() { // Loop until STOP msg while (true) { - decltype(ctrl)::value_type msg; + std::unique_ptr<CtrlMsgBase> msg; ctrl >> msg; switch (msg->type()) { case CtrlMsgType::STOP: - turn_.reset(); return; case CtrlMsgType::TURN_PEER_CONNECT: diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp index e026b73012680a971183e4131da04af46e915b52..e1faa6e2498c391305858d09eed7e09c1712fd70 100644 --- a/src/security/tls_session.cpp +++ b/src/security/tls_session.cpp @@ -1195,6 +1195,7 @@ TlsSession::shutdown() { pimpl_->state_ = TlsSessionState::SHUTDOWN; pimpl_->rxCv_.notify_one(); // unblock waiting FSM + pimpl_->transport_.shutdown(); } std::size_t diff --git a/src/turn_transport.cpp b/src/turn_transport.cpp index 4d70fc3777f666c7f0ee3c2cb13c4ac537ac9fdb..16b26b39d4afa0bdc3443545c382509bbd10d7bb 100644 --- a/src/turn_transport.cpp +++ b/src/turn_transport.cpp @@ -60,15 +60,14 @@ class PeerChannel public: PeerChannel() {} ~PeerChannel() { - MutexGuard lk {mutex_}; - stop_ = true; - cv_.notify_all(); + stop(); } PeerChannel(PeerChannel&&o) { MutexGuard lk {o.mutex_}; stream_ = std::move(o.stream_); } + PeerChannel& operator =(PeerChannel&& o) { std::lock(mutex_, o.mutex_); MutexGuard lk1 {mutex_, std::adopt_lock}; @@ -86,22 +85,42 @@ public: template <typename Duration> bool wait(Duration timeout) { - MutexLock lk {mutex_}; - return cv_.wait_for(lk, timeout, [this]{ return !stream_.eof(); }); + std::lock(apiMutex_, mutex_); + MutexGuard lk_api {apiMutex_, std::adopt_lock}; + MutexLock lk {mutex_, std::adopt_lock}; + return cv_.wait_for(lk, timeout, [this]{ return stop_ or !stream_.eof(); }); } std::size_t read(char* output, std::size_t size) { - MutexLock lk {mutex_}; + std::lock(apiMutex_, mutex_); + MutexGuard lk_api {apiMutex_, std::adopt_lock}; + MutexLock lk {mutex_, std::adopt_lock}; cv_.wait(lk, [&, this]{ + if (stop_) + return true; stream_.read(&output[0], size); - return stream_.gcount() > 0 or stop_; + return stream_.gcount() > 0; }); return stop_ ? 0 : stream_.gcount(); } + void stop() noexcept { + { + MutexGuard lk {mutex_}; + if (stop_) + return; + stop_ = true; + } + cv_.notify_all(); + + // Make sure that no thread is blocked into read() or wait() methods + MutexGuard lk_api {apiMutex_}; + } + private: PeerChannel(const PeerChannel&o) = delete; PeerChannel& operator =(const PeerChannel& o) = delete; + std::mutex apiMutex_ {}; std::mutex mutex_ {}; std::condition_variable cv_ {}; std::stringstream stream_ {}; @@ -167,14 +186,18 @@ public: TurnTransportPimpl::~TurnTransportPimpl() { - if (relay) - pj_turn_sock_destroy(relay); + if (relay) { + try { + pj_turn_sock_destroy(relay); + } catch (...) { + RING_ERR() << "exception during pj_turn_sock_destroy() call (ignored)"; + } + } ioJobQuit = true; if (ioWorker.joinable()) ioWorker.join(); - if (pool) - pj_pool_release(pool); pj_caching_pool_destroy(&poolCache); + } void @@ -286,19 +309,19 @@ TurnTransport::TurnTransport(const TurnTransportParams& params) pj_bzero(&relay_cb, sizeof(relay_cb)); relay_cb.on_rx_data = [](pj_turn_sock* relay, void* pkt, unsigned pkt_len, const pj_sockaddr_t* peer_addr, unsigned addr_len) { - auto tr = static_cast<TurnTransport*>(pj_turn_sock_get_user_data(relay)); - tr->pimpl_->onRxData(reinterpret_cast<uint8_t*>(pkt), pkt_len, peer_addr, addr_len); + auto pimpl = static_cast<TurnTransportPimpl*>(pj_turn_sock_get_user_data(relay)); + pimpl->onRxData(reinterpret_cast<uint8_t*>(pkt), pkt_len, peer_addr, addr_len); }; relay_cb.on_state = [](pj_turn_sock* relay, pj_turn_state_t old_state, pj_turn_state_t new_state) { - auto tr = static_cast<TurnTransport*>(pj_turn_sock_get_user_data(relay)); - tr->pimpl_->onTurnState(old_state, new_state); + auto pimpl = static_cast<TurnTransportPimpl*>(pj_turn_sock_get_user_data(relay)); + pimpl->onTurnState(old_state, new_state); }; relay_cb.on_peer_connection = [](pj_turn_sock* relay, pj_uint32_t conn_id, const pj_sockaddr_t* peer_addr, unsigned addr_len, pj_status_t status) { - auto tr = static_cast<TurnTransport*>(pj_turn_sock_get_user_data(relay)); - tr->pimpl_->onPeerConnection(conn_id, peer_addr, addr_len, status); + auto pimpl = static_cast<TurnTransportPimpl*>(pj_turn_sock_get_user_data(relay)); + pimpl->onPeerConnection(conn_id, peer_addr, addr_len, status); }; // TURN socket config @@ -309,7 +332,7 @@ TurnTransport::TurnTransport(const TurnTransportParams& params) // TURN socket creation PjsipCall(pj_turn_sock_create, &pimpl_->stunConfig, server.getFamily(), PJ_TURN_TP_TCP, - &relay_cb, &turn_sock_cfg, this, &pimpl_->relay); + &relay_cb, &turn_sock_cfg, &*this->pimpl_, &pimpl_->relay); // TURN allocation setup pj_turn_alloc_param turn_alloc_param; @@ -335,8 +358,16 @@ TurnTransport::TurnTransport(const TurnTransportParams& params) nullptr, &cred, &turn_alloc_param); } -TurnTransport::~TurnTransport() -{} +TurnTransport::~TurnTransport() = default; + +void +TurnTransport::shutdown(const IpAddr& addr) +{ + MutexLock lk {pimpl_->apiMutex_}; + auto& channel = pimpl_->peerChannels_.at(addr); + lk.unlock(); + channel.stop(); +} bool TurnTransport::isInitiator() const @@ -441,6 +472,12 @@ ConnectedTurnTransport::ConnectedTurnTransport(TurnTransport& turn, const IpAddr , peer_ {peer} {} +void +ConnectedTurnTransport::shutdown() +{ + turn_.shutdown(peer_); +} + bool ConnectedTurnTransport::waitForData(unsigned ms_timeout) const { diff --git a/src/turn_transport.h b/src/turn_transport.h index 6f2d87984cd6fe16db6d50867bab6cee98c37136..c655f6fc32d40d2c67e46eec2a8b5e641ca0e6c8 100644 --- a/src/turn_transport.h +++ b/src/turn_transport.h @@ -64,6 +64,8 @@ public: ~TurnTransport(); + void shutdown(const IpAddr& addr); + bool isInitiator() const; /// Wait for successful connection on the TURN server. @@ -150,6 +152,7 @@ public: ConnectedTurnTransport(TurnTransport& turn, const IpAddr& peer); + void shutdown() override; bool isReliable() const override { return true; } bool isInitiator() const override { return turn_.isInitiator(); } int maxPayload() const override { return 3000; }