diff --git a/src/jamidht/multiplexed_socket.cpp b/src/jamidht/multiplexed_socket.cpp index fdf217b53c7afbf92017db3999c0aa49b9dcadcc..80dada485f2f0df8b17759ad4d7e31789e90130d 100644 --- a/src/jamidht/multiplexed_socket.cpp +++ b/src/jamidht/multiplexed_socket.cpp @@ -53,14 +53,6 @@ namespace jami { using clock = std::chrono::steady_clock; using time_point = clock::time_point; -struct ChannelInfo -{ - std::atomic_bool isDestroying {false}; - std::deque<uint8_t> buf {}; - std::mutex mutex {}; - std::condition_variable cv {}; -}; - class MultiplexedSocket::Impl { public: @@ -99,12 +91,6 @@ public: { std::lock_guard<std::mutex> lkSockets(socketsMutex); socks = std::move(sockets); - for (auto& [key, channelData] : channelDatas_) { - if (channelData) { - channelData->isDestroying = true; - channelData->cv.notify_all(); - } - } } for (auto& socket : socks) { // Just trigger onShutdown() to make client know @@ -132,6 +118,13 @@ public: clearSockets(); } + std::shared_ptr<ChannelSocket> makeSocket(const std::string& name, uint16_t channel) { + auto& channelSocket = sockets[channel]; + if (not channelSocket) + channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel); + return channelSocket; + } + /** * Handle packets on the TLS endpoint and parse RTP */ @@ -182,10 +175,6 @@ public: std::atomic_bool stop {false}; std::thread eventLoopThread_ {}; - // Multiplexed available datas - std::map<uint16_t, std::shared_ptr<ChannelInfo>> channelDatas_ {}; - std::mutex channelCbsMtx_ {}; - std::map<uint16_t, GenericSocket<uint8_t>::RecvCb> channelCbs_ {}; std::atomic_bool isShutdown_ {false}; std::mutex writeMtx {}; @@ -258,13 +247,7 @@ void MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel) { std::lock_guard<std::mutex> lkSockets(socketsMutex); - auto& channelData = channelDatas_[channel]; - if (not channelData) - channelData = std::make_shared<ChannelInfo>(); - auto& channelSocket = sockets[channel]; - if (not channelSocket) - channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel); - onChannelReady_(deviceId, channelSocket); + onChannelReady_(deviceId, makeSocket(name, channel)); std::lock_guard<std::mutex> lk(channelCbsMutex); auto channelCbsIt = channelCbs.find(channel); if (channelCbsIt != channelCbs.end()) { @@ -367,18 +350,8 @@ MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel) auto accept = onRequest_(deviceId, channel, name); std::shared_ptr<ChannelSocket> channelSocket; if (accept) { - channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel); - { - std::lock_guard<std::mutex> lkSockets(socketsMutex); - auto sockIt = sockets.find(channel); - if (sockIt != sockets.end()) { - JAMI_WARN("A channel is already present on that socket, accepting " - "the request will close the previous one"); - sockets.erase(sockIt); - } - channelDatas_.emplace(channel, std::make_shared<ChannelInfo>()); - sockets.emplace(channel, channelSocket); - } + std::lock_guard<std::mutex> lkSockets(socketsMutex); + channelSocket = makeSocket(name, channel); } // Answer to ChannelRequest if accepted @@ -432,14 +405,11 @@ MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt) pimpl.onAccept(req.name, req.channel); } else if (req.state == ChannelRequestState::DECLINE) { std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex); - auto& channelDatas = pimpl.channelDatas_; - auto dataIt = channelDatas.find(req.channel); - if (dataIt != channelDatas.end() && dataIt->second) { - dataIt->second->isDestroying = true; - dataIt->second->cv.notify_all(); - channelDatas.erase(dataIt); + auto channel = pimpl.sockets.find(req.channel); + if (channel != pimpl.sockets.end()) { + channel->second->stop(); + pimpl.sockets.erase(channel); } - pimpl.sockets.erase(req.channel); } else if (pimpl.onRequest_) { pimpl.onRequest(req.name, req.channel); } @@ -455,37 +425,14 @@ MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8 { std::lock_guard<std::mutex> lkSockets(socketsMutex); auto sockIt = sockets.find(channel); - auto dataIt = channelDatas_.find(channel); - if (channel > 0 && sockIt->second && dataIt->second) { + if (channel > 0 && sockIt->second) { if (pkt.size() == 0) { sockIt->second->shutdown(); - dataIt->second->isDestroying = true; - dataIt->second->cv.notify_all(); - channelDatas_.erase(dataIt); sockets.erase(sockIt); } else { - GenericSocket<uint8_t>::RecvCb cb; - { - std::lock_guard<std::mutex> lk(channelCbsMtx_); - auto cbIt = channelCbs_.find(channel); - if (cbIt != channelCbs_.end()) { - cb = cbIt->second; - } - } - if (cb) { - cb(&pkt[0], pkt.size()); - return; - } - { - std::lock_guard<std::mutex> lkSockets(dataIt->second->mutex); - dataIt->second->buf.insert(dataIt->second->buf.end(), - std::make_move_iterator(pkt.begin()), - std::make_move_iterator(pkt.end())); - dataIt->second->cv.notify_all(); - } + sockIt->second->onRecv(std::move(pkt)); } } else if (pkt.size() != 0) { - std::string p = std::string(pkt.begin(), pkt.end()); JAMI_WARN("Non existing channel: %u", channel); } } @@ -562,16 +509,9 @@ MultiplexedSocket::addChannel(const std::string& name) std::lock_guard<std::mutex> lk(pimpl_->socketsMutex); for (int i = 1; i < UINT16_MAX; ++i) { auto c = (offset + i) % UINT16_MAX; - if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL) + if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL || pimpl_->sockets.find(c) != pimpl_->sockets.end()) continue; - auto& socket = pimpl_->sockets[c]; - if (!socket) { - auto& channel = pimpl_->channelDatas_[c]; - if (!channel) - channel = std::make_shared<ChannelInfo>(); - socket = std::make_shared<ChannelSocket>(weak(), name, c); - return socket; - } + return pimpl_->makeSocket(name, c); } return {}; } @@ -626,34 +566,6 @@ MultiplexedSocket::maxPayload() const return pimpl_->endpoint->maxPayload(); } -std::size_t -MultiplexedSocket::read(const uint16_t& channel, uint8_t* buf, std::size_t len, std::error_code& ec) -{ - if (pimpl_->isShutdown_) { - ec = std::make_error_code(std::errc::broken_pipe); - return -1; - } - std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex); - auto dataIt = pimpl_->channelDatas_.find(channel); - if (dataIt == pimpl_->channelDatas_.end() || !dataIt->second) { - ec = std::make_error_code(std::errc::broken_pipe); - return -1; - } - std::size_t size; - { - std::lock_guard<std::mutex> lkSockets(dataIt->second->mutex); - auto& chanBuf = dataIt->second->buf; - size = std::min(len, chanBuf.size()); - - for (std::size_t i = 0; i < size; ++i) - buf[i] = chanBuf[i]; - - chanBuf.erase(chanBuf.begin(), chanBuf.begin() + size); - } - - return size; -} - std::size_t MultiplexedSocket::write(const uint16_t& channel, const uint8_t* buf, @@ -697,56 +609,6 @@ MultiplexedSocket::write(const uint16_t& channel, return res; } -int -MultiplexedSocket::waitForData(const uint16_t& channel, - std::chrono::milliseconds timeout, - std::error_code& ec) const -{ - if (pimpl_->isShutdown_) { - ec = std::make_error_code(std::errc::broken_pipe); - return -1; - } - std::unique_lock lkSockets {pimpl_->socketsMutex}; - auto dataIt = pimpl_->channelDatas_.find(channel); - if (dataIt == pimpl_->channelDatas_.end() or not dataIt->second) { - ec = std::make_error_code(std::errc::broken_pipe); - return -1; - } - auto channelData = dataIt->second; - lkSockets.unlock(); - std::unique_lock<std::mutex> lk {channelData->mutex}; - channelData->cv.wait_for(lk, timeout, [&] { - return channelData->isDestroying or !channelData->buf.empty() or pimpl_->isShutdown_; - }); - return channelData->buf.size(); -} - -void -MultiplexedSocket::setOnRecv(const uint16_t& channel, GenericSocket<uint8_t>::RecvCb&& cb) -{ - // Re run on ioPool, socketsMtx can be locked here (via onAccept), so retrigger - // to avoid double lock - dht::ThreadPool::io().run([w = weak(), channel, cb = std::move(cb)]() { - if (auto shared = w.lock()) { - std::lock_guard<std::mutex> lkSockets(shared->pimpl_->socketsMutex); - std::deque<uint8_t> recv; - { - std::lock_guard<std::mutex> lk(shared->pimpl_->channelCbsMtx_); - shared->pimpl_->channelCbs_[channel] = cb; - - auto dataIt = shared->pimpl_->channelDatas_.find(channel); - if (dataIt != shared->pimpl_->channelDatas_.end() && dataIt->second) { - std::lock_guard<std::mutex> lk(dataIt->second->mutex); - recv = std::move(dataIt->second->buf); - } - } - if (!recv.empty() && cb) { - cb(&recv[0], recv.size()); - } - } - }); -} - void MultiplexedSocket::shutdown() { @@ -851,6 +713,11 @@ public: std::string name {}; uint16_t channel {}; std::weak_ptr<MultiplexedSocket> endpoint {}; + + std::deque<uint8_t> buf {}; + std::mutex mutex {}; + std::condition_variable cv {}; + GenericSocket<uint8_t>::RecvCb cb {}; }; ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, @@ -912,10 +779,28 @@ ChannelSocket::maxPayload() const void ChannelSocket::setOnRecv(RecvCb&& cb) { - if (auto ep = pimpl_->endpoint.lock()) - ep->setOnRecv(pimpl_->channel, std::move(cb)); + std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); + pimpl_->cb = std::move(cb); + if (!pimpl_->buf.empty() && pimpl_->cb) { + pimpl_->cb(&pimpl_->buf[0], pimpl_->buf.size()); + pimpl_->buf = {}; + } } +void +ChannelSocket::onRecv(std::vector<uint8_t>&& pkt) { + std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); + if (pimpl_->cb) { + pimpl_->cb(&pkt[0], pkt.size()); + return; + } + pimpl_->buf.insert(pimpl_->buf.end(), + std::make_move_iterator(pkt.begin()), + std::make_move_iterator(pkt.end())); + pimpl_->cv.notify_all(); +} + + std::shared_ptr<IceTransport> ChannelSocket::underlyingICE() const { @@ -942,6 +827,7 @@ ChannelSocket::stop() pimpl_->isShutdown_ = true; if (pimpl_->shutdownCb_) pimpl_->shutdownCb_(); + pimpl_->cv.notify_all(); } void @@ -958,16 +844,16 @@ ChannelSocket::shutdown() } std::size_t -ChannelSocket::read(ValueType* buf, std::size_t len, std::error_code& ec) +ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec) { - if (auto ep = pimpl_->endpoint.lock()) { - int res = ep->read(pimpl_->channel, buf, len, ec); - if (ec) - JAMI_ERR("Error when reading on channel: %s", ec.message().c_str()); - return res; - } - ec = std::make_error_code(std::errc::broken_pipe); - return -1; + std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); + std::size_t size = std::min(len, pimpl_->buf.size()); + + for (std::size_t i = 0; i < size; ++i) + outBuf[i] = pimpl_->buf[i]; + + pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size); + return size; } std::size_t @@ -997,14 +883,11 @@ ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec) int ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const { - if (auto ep = pimpl_->endpoint.lock()) { - auto res = ep->waitForData(pimpl_->channel, timeout, ec); - if (ec) - JAMI_ERR("Error when waiting on channel: %s", ec.message().c_str()); - return res; - } - ec = std::make_error_code(std::errc::broken_pipe); - return -1; + std::unique_lock<std::mutex> lk {pimpl_->mutex}; + pimpl_->cv.wait_for(lk, timeout, [&] { + return !pimpl_->buf.empty() or pimpl_->isShutdown_; + }); + return pimpl_->buf.size(); } void diff --git a/src/jamidht/multiplexed_socket.h b/src/jamidht/multiplexed_socket.h index d75e420d072bc9dc04ea010530625854b8239a01..fd793e477392bec8dc09e4b9e08928d27d288cdf 100644 --- a/src/jamidht/multiplexed_socket.h +++ b/src/jamidht/multiplexed_socket.h @@ -103,15 +103,10 @@ public: */ void setOnChannelReady(uint16_t channel, onChannelReadyCb&& cb); - std::size_t read(const uint16_t& channel, uint8_t* buf, std::size_t len, std::error_code& ec); std::size_t write(const uint16_t& channel, const uint8_t* buf, std::size_t len, std::error_code& ec); - int waitForData(const uint16_t& channel, - std::chrono::milliseconds timeout, - std::error_code&) const; - void setOnRecv(const uint16_t& channel, GenericSocket<uint8_t>::RecvCb&& cb); /** * This will close all channels and send a TLS EOF on the main socket. @@ -221,6 +216,8 @@ public: */ void setOnRecv(RecvCb&&) override; + void onRecv(std::vector<uint8_t>&& pkt); + std::shared_ptr<IceTransport> underlyingICE() const; /**