diff --git a/src/jamidht/multiplexed_socket.cpp b/src/jamidht/multiplexed_socket.cpp index 929243e0924f49db240b103359acb7f1d9063847..4e7062a3e0a230c0c4bccc437318867fafacefe7 100644 --- a/src/jamidht/multiplexed_socket.cpp +++ b/src/jamidht/multiplexed_socket.cpp @@ -104,11 +104,13 @@ public: /** * Triggered when a new control packet is received */ - void handleControlPacket(const std::vector<uint8_t>&& pkt); + void handleControlPacket(std::vector<uint8_t>&& pkt); /** * Triggered when a new packet on a channel is received */ - void handleChannelPacket(uint16_t channel, const std::vector<uint8_t>&& pkt); + void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt); + void onRequest(const std::string& name, uint16_t channel); + void onAccept(const std::string& name, uint16_t channel); void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); } void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); } @@ -179,7 +181,6 @@ MultiplexedSocket::Impl::eventLoop() while (pac_.next(oh) && !stop) { try { auto msg = oh.get().as<ChanneledMessage>(); - if (msg.channel == 0) handleControlPacket(std::move(msg.data)); else @@ -192,96 +193,106 @@ MultiplexedSocket::Impl::eventLoop() } void -MultiplexedSocket::Impl::handleControlPacket(const std::vector<uint8_t>&& pkt) +MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel) +{ + std::lock_guard<std::mutex> lkSockets(socketsMutex); + auto& channelData = channelDatas_[channel]; + if (not channelData) + channelData = std::make_unique<ChannelInfo>(); + auto& channelSocket = sockets[channel]; + if (not channelSocket) + channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel); + onChannelReady_(deviceId, channelSocket); + std::lock_guard<std::mutex> lk(channelCbsMutex); + auto channelCbsIt = channelCbs.find(channel); + if (channelCbsIt != channelCbs.end()) { + (channelCbsIt->second)(); + } +} + +void +MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel) +{ + auto accept = onRequest_(deviceId, channel, name); + std::shared_ptr<ChannelSocket> channelSocket; + if (accept) { + channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel); + { + std::lock_guard<std::mutex> lkSockets(socketsMutex); + auto sockIt = sockets.find(channel); + if (sockIt != sockets.end()) { + JAMI_WARN("A channel is already present on that socket, accepting " + "the request will close the previous one"); + sockets.erase(sockIt); + } + channelDatas_.emplace(channel, std::make_unique<ChannelInfo>()); + sockets.emplace(channel, channelSocket); + } + } + + // Answer to ChannelRequest if accepted + ChannelRequest val; + val.channel = channel; + val.name = name; + val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE; + std::stringstream ss; + msgpack::pack(ss, val); + std::error_code ec; + auto toSend = ss.str(); + int wr = parent_.write(CONTROL_CHANNEL, + reinterpret_cast<const uint8_t*>(&toSend[0]), + toSend.size(), + ec); + if (wr < 0) { + if (ec) + JAMI_ERR("The write operation failed with error: %s", ec.message().c_str()); + stop.store(true); + return; + } + + if (accept) { + onChannelReady_(deviceId, channelSocket); + std::lock_guard<std::mutex> lk(channelCbsMutex); + auto channelCbsIt = channelCbs.find(channel); + if (channelCbsIt != channelCbs.end()) { + channelCbsIt->second(); + } + } +} + +void +MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt) { // Run this on dedicated thread because some callbacks can take time - dht::ThreadPool::io().run([this, pkt = std::move(pkt)]() { + dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() { try { size_t off = 0; while (off != pkt.size()) { msgpack::unpacked result; msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off); auto req = result.get().as<ChannelRequest>(); - if (req.state == ChannelRequestState::ACCEPT) { - std::lock_guard<std::mutex> lkSockets(socketsMutex); - auto& channel = channelDatas_[req.channel]; - if (not channel) - channel = std::make_unique<ChannelInfo>(); - auto& channelSocket = sockets[req.channel]; - if (not channelSocket) - channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), - req.name, - req.channel); - onChannelReady_(deviceId, channelSocket); - std::lock_guard<std::mutex> lk(channelCbsMutex); - auto channelCbsIt = channelCbs.find(req.channel); - if (channelCbsIt != channelCbs.end()) { - (channelCbsIt->second)(); - } - } else if (req.state == ChannelRequestState::DECLINE) { - std::lock_guard<std::mutex> lkSockets(socketsMutex); - channelDatas_.erase(req.channel); - sockets.erase(req.channel); - } else if (onRequest_) { - auto accept = onRequest_(deviceId, req.channel, req.name); - std::shared_ptr<ChannelSocket> channelSocket; - if (accept) { - channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), - req.name, - req.channel); - { - std::lock_guard<std::mutex> lkSockets(socketsMutex); - auto sockIt = sockets.find(req.channel); - if (sockIt != sockets.end()) { - JAMI_WARN("A channel is already present on that socket, accepting " - "the request will close the previous one"); - sockets.erase(sockIt); - } - channelDatas_.emplace(req.channel, std::make_unique<ChannelInfo>()); - sockets.emplace(req.channel, channelSocket); - } - } - - // Answer to ChannelRequest if accepted - ChannelRequest val; - val.channel = req.channel; - val.name = req.name; - val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE; - std::stringstream ss; - msgpack::pack(ss, val); - std::error_code ec; - auto toSend = ss.str(); - int wr = parent_.write(CONTROL_CHANNEL, - reinterpret_cast<const uint8_t*>(&toSend[0]), - toSend.size(), - ec); - if (wr < 0) { - if (ec) - JAMI_ERR("The write operation failed with error: %s", - ec.message().c_str()); - stop.store(true); - return; - } - - if (accept) { - onChannelReady_(deviceId, channelSocket); - std::lock_guard<std::mutex> lk(channelCbsMutex); - auto channelCbsIt = channelCbs.find(req.channel); - if (channelCbsIt != channelCbs.end()) { - (channelCbsIt->second)(); - } + if (auto shared = w.lock()) { + if (req.state == ChannelRequestState::ACCEPT) { + shared->pimpl_->onAccept(req.name, req.channel); + } else if (req.state == ChannelRequestState::DECLINE) { + std::lock_guard<std::mutex> lkSockets(shared->pimpl_->socketsMutex); + shared->pimpl_->channelDatas_.erase(req.channel); + shared->pimpl_->sockets.erase(req.channel); + } else if (shared->pimpl_->onRequest_) { + shared->pimpl_->onRequest(req.name, req.channel); } } } } catch (const std::exception& e) { JAMI_ERR("Error on the control channel: %s", e.what()); - stop.store(true); + if (auto shared = w.lock()) + shared->pimpl_->stop.store(true); } }); } void -MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, const std::vector<uint8_t>&& pkt) +MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt) { auto sockIt = sockets.find(channel); auto dataIt = channelDatas_.find(channel);