-
When shutdown called from callback before it called from destructor sockets could be destroyed. Change-Id: I0fd92646726969c0d9b947a51fecbed495439fb8
When shutdown called from callback before it called from destructor sockets could be destroyed. Change-Id: I0fd92646726969c0d9b947a51fecbed495439fb8
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
multiplexed_socket.cpp 18.42 KiB
/*
* Copyright (C) 2019 Savoir-faire Linux Inc.
* Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include "logger.h"
#include "manager.h"
#include "multiplexed_socket.h"
#include "peer_connection.h"
#include "ice_transport.h"
#include <deque>
#include <opendht/thread_pool.h>
namespace jami
{
static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
struct ChannelInfo {
std::deque<uint8_t> buf {};
std::mutex mutex {};
std::condition_variable cv {};
};
class MultiplexedSocket::Impl
{
public:
Impl(MultiplexedSocket& parent, const std::string& deviceId, std::unique_ptr<TlsSocketEndpoint> endpoint)
: parent_(parent), deviceId(deviceId), endpoint(std::move(endpoint))
, eventLoopFut_ { std::async(std::launch::async, [this] {
try {
stop.store(false);
eventLoop();
} catch (const std::exception& e) {
JAMI_ERR() << "[CNX] peer connection event loop failure: " << e.what();
}
}) } { }
~Impl() {
if (!isShutdown_) {
shutdown();
} else {
std::lock_guard<std::mutex> lkSockets(socketsMutex);
for (auto& socket : sockets) {
if (socket.second) socket.second->stop();
}
sockets.clear();
}
}
void shutdown() {
if (isShutdown_) return;
stop.store(true);
isShutdown_ = true;
if (onShutdown_) onShutdown_();
endpoint->setOnStateChange({});
endpoint->shutdown();
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
for (auto& socket : sockets) {
// Just trigger onShutdown() to make client know
// No need to write the EOF for the channel, the write will fail because endpoint is already shutdown
if (socket.second) socket.second->stop();
}
sockets.clear();
}
}
/**
* Handle packets on the TLS endpoint and parse RTP
*/
void eventLoop();
/**
* Triggered when a new control packet is received
*/
void handleControlPacket(const std::vector<uint8_t>&& pkt);
/**
* Triggered when a new packet on a channel is received
*/
void handleChannelPacket(uint16_t channel, const std::vector<uint8_t>&& pkt);
void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
msgpack::unpacker pac_;
MultiplexedSocket& parent_;
OnConnectionReadyCb onChannelReady_;
OnConnectionRequestCb onRequest_;
OnShutdownCb onShutdown_;
std::string deviceId;
// Main socket
std::unique_ptr<TlsSocketEndpoint> endpoint;
std::mutex socketsMutex;
std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets;
// Contains callback triggered when a channel is ready
std::mutex channelCbsMutex;
std::map<uint16_t, onChannelReadyCb> channelCbs;
// Main loop to parse incoming packets
std::future<void> eventLoopFut_;
std::atomic_bool stop;
// Multiplexed available datas
std::map<uint16_t, std::unique_ptr<ChannelInfo>> channelDatas_ {};
std::mutex channelCbsMtx_ {};
std::map<uint16_t, GenericSocket<uint8_t>::RecvCb> channelCbs_ {};
std::atomic_bool isShutdown_ {false};
};
void
MultiplexedSocket::Impl::eventLoop()
{
endpoint->setOnStateChange([this](tls::TlsSessionState state) {
if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
JAMI_INFO("Tls endpoint is down, shutdown multiplexed socket");
shutdown();
}
});
std::error_code ec;
while (!stop) {
if (!endpoint) {
shutdown();
return;
}
pac_.reserve_buffer(IO_BUFFER_SIZE);
int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
if (size < 0) {
if (ec) JAMI_ERR("Read error detected: %s", ec.message().c_str());
break;
}
if (size == 0) {
// We can close the socket
shutdown();
break;
}
pac_.buffer_consumed(size);
msgpack::object_handle oh;
while (pac_.next(oh) && !stop) {
try {
auto msg = oh.get().as<ChanneledMessage>();
if (msg.channel == 0) {
handleControlPacket(std::move(msg.data));
} else {
handleChannelPacket(msg.channel, std::move(msg.data));
}
} catch (const msgpack::unpack_error &e) {
JAMI_WARN("Error when decoding msgpack message: %s", e.what());
}
}
}
}
void
MultiplexedSocket::Impl::handleControlPacket(const std::vector<uint8_t>&& pkt)
{
// Run this on dedicated thread because some callbacks can take time
dht::ThreadPool::io().run([this, pkt = std::move(pkt)]() {
msgpack::unpacker pac;
pac.reserve_buffer(pkt.size());
memcpy(pac.buffer(), pkt.data(), pkt.size());
pac.buffer_consumed(pkt.size());
msgpack::object_handle oh;
while (pac.next(oh)) {
try {
auto req = oh.get().as<ChannelRequest>();
if (req.state == ChannelRequestState::ACCEPT) {
std::lock_guard<std::mutex> lkSockets(socketsMutex);
auto& channel = channelDatas_[req.channel];
if (not channel)
channel = std::make_unique<ChannelInfo>();
auto& channelSocket = sockets[req.channel];
if (not channelSocket)
channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), req.name, req.channel);
onChannelReady_(deviceId, channelSocket);
std::lock_guard<std::mutex> lk(channelCbsMutex);
auto channelCbsIt = channelCbs.find(req.channel);
if (channelCbsIt != channelCbs.end()) {
(channelCbsIt->second)();
}
} else if (req.state == ChannelRequestState::DECLINE) {
std::lock_guard<std::mutex> lkSockets(socketsMutex);
channelDatas_.erase(req.channel);
sockets.erase(req.channel);
} else if (onRequest_) {
auto accept = onRequest_(deviceId, req.channel, req.name);
std::shared_ptr<ChannelSocket> channelSocket;
if (accept) {
channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), req.name, req.channel);
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
auto sockIt = sockets.find(req.channel);
if (sockIt != sockets.end()) {
JAMI_WARN("A channel is already present on that socket, accepting the request will close the previous one");
sockets.erase(sockIt);
}
channelDatas_.emplace(req.channel, std::make_unique<ChannelInfo>());
sockets.emplace(req.channel, channelSocket);
}
}
// Answer to ChannelRequest if accepted
ChannelRequest val;
val.channel = req.channel;
val.name = req.name;
val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
std::stringstream ss;
msgpack::pack(ss, val);
std::error_code ec;
auto toSend = ss.str();
int wr = parent_.write(CONTROL_CHANNEL, reinterpret_cast<const uint8_t*>(&toSend[0]), toSend.size(), ec);
if (wr < 0) {
if (ec)
JAMI_ERR("The write operation failed with error: %s", ec.message().c_str());
stop.store(true);
return;
}
if (accept) {
onChannelReady_(deviceId, channelSocket);
std::lock_guard<std::mutex> lk(channelCbsMutex);
auto channelCbsIt = channelCbs.find(req.channel);
if (channelCbsIt != channelCbs.end()) {
(channelCbsIt->second)();
}
}
}
} catch (const std::exception& e) {
JAMI_ERR("Error on the control channel: %s", e.what());
stop.store(true);
}
}
});
}
void
MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, const std::vector<uint8_t>&& pkt)
{
auto sockIt = sockets.find(channel);
auto dataIt = channelDatas_.find(channel);
if (channel > 0 && sockIt->second && dataIt->second) {
if (pkt.size() == 0) {
sockIt->second->shutdown();
dataIt->second->cv.notify_all();
channelDatas_.erase(dataIt);
std::lock_guard<std::mutex> lkSockets(socketsMutex);
sockets.erase(sockIt);
} else {
std::unique_lock<std::mutex> lk(channelCbsMtx_);
auto cb = channelCbs_.find(channel);
if (cb != channelCbs_.end()) {
lk.unlock();
cb->second(&pkt[0], pkt.size());
return;
}
lk.unlock();
dataIt->second->buf.insert(
dataIt->second->buf.end(),
std::make_move_iterator(pkt.begin()),
std::make_move_iterator(pkt.end()));
dataIt->second->cv.notify_all();
}
} else {
JAMI_WARN("Non existing channel: %u", channel);
}
}
MultiplexedSocket::MultiplexedSocket(const std::string& deviceId, std::unique_ptr<TlsSocketEndpoint> endpoint)
: pimpl_(std::make_unique<Impl>(*this, deviceId, std::move(endpoint)))
{
}
MultiplexedSocket::~MultiplexedSocket()
{ }
std::shared_ptr<ChannelSocket>
MultiplexedSocket::addChannel(const std::string& name)
{
std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
for (int i = 1; i < UINT16_MAX; ++i) {
auto& socket = pimpl_->sockets[i];
if (!socket) {
auto& channel = pimpl_->channelDatas_[i];
if (!channel)
channel = std::make_unique<ChannelInfo>();
socket = std::make_shared<ChannelSocket>(weak(), name, i);
return socket;
}
}
return {};
}
std::string
MultiplexedSocket::deviceId() const
{
return pimpl_->deviceId;
}
void
MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
{
pimpl_->onChannelReady_ = std::move(cb);
}
void
MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
{
pimpl_->onRequest_ = std::move(cb);
}
void
MultiplexedSocket::setOnChannelReady(uint16_t channel, onChannelReadyCb&& cb)
{
pimpl_->channelCbs[channel] = std::move(cb);
}
bool
MultiplexedSocket::isReliable() const
{
return true;
}
bool
MultiplexedSocket::isInitiator() const
{
if (!pimpl_->endpoint) {
JAMI_WARN("No endpoint found for socket");
return false;
}
return pimpl_->endpoint->isInitiator();
}
int
MultiplexedSocket::maxPayload() const
{
if (!pimpl_->endpoint) {
JAMI_WARN("No endpoint found for socket");
return 0;
}
return pimpl_->endpoint->maxPayload();
}
std::size_t
MultiplexedSocket::read(const uint16_t& channel, uint8_t* buf, std::size_t len, std::error_code& ec)
{
if (pimpl_->isShutdown_) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
auto dataIt = pimpl_->channelDatas_.find(channel);
if (dataIt == pimpl_->channelDatas_.end()
|| !dataIt->second) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
auto& chanBuf = dataIt->second->buf;
auto size = std::min(len, chanBuf.size());
for (std::size_t i = 0; i < size; ++i) {
buf[i] = chanBuf[i];
}
chanBuf.erase(chanBuf.begin(), chanBuf.begin() + size);
return size;
}
std::size_t
MultiplexedSocket::write(const uint16_t& channel, const uint8_t* buf, std::size_t len, std::error_code& ec)
{
if (pimpl_->isShutdown_) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
if (len > UINT16_MAX) {
ec = std::make_error_code(std::errc::message_size);
return -1;
}
if (!pimpl_->endpoint) {
JAMI_WARN("No endpoint found for socket");
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
msgpack::sbuffer buffer;
msgpack::packer<msgpack::sbuffer> pk(&buffer);
pk.pack_array(2);
pk.pack(channel);
pk.pack_bin(len);
pk.pack_bin_body((const char*)buf, len);
int res = pimpl_->endpoint->write((const unsigned char*)buffer.data(), buffer.size(), ec);
if (res < 0) {
if (ec)
JAMI_ERR("Error when writing on socket: %s", ec.message().c_str());
shutdown();
}
return res;
}
int
MultiplexedSocket::waitForData(const uint16_t& channel, std::chrono::milliseconds timeout, std::error_code& ec) const
{
if (pimpl_->isShutdown_) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
auto dataIt = pimpl_->channelDatas_.find(channel);
if (dataIt == pimpl_->channelDatas_.end()) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
auto& channelData = dataIt->second;
if (!channelData) {
return -1;
}
std::unique_lock<std::mutex> lk {channelData->mutex};
channelData->cv.wait_for(lk, timeout, [&]{
return !channelData->buf.empty();
});
return channelData->buf.size();
}
void
MultiplexedSocket::setOnRecv(const uint16_t& channel, GenericSocket<uint8_t>::RecvCb&& cb)
{
std::lock_guard<std::mutex> lk(pimpl_->channelCbsMtx_);
pimpl_->channelCbs_[channel] = cb;
}
void
MultiplexedSocket::shutdown()
{
pimpl_->shutdown();
}
void
MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
{
pimpl_->onShutdown_ = std::move(cb);
if (pimpl_->isShutdown_) {
pimpl_->onShutdown_();
}
}
std::shared_ptr<IceTransport>
MultiplexedSocket::underlyingICE() const
{
return pimpl_->endpoint->underlyingICE();
}
////////////////////////////////////////////////////////////////
class ChannelSocket::Impl
{
public:
Impl(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel)
: name(name), channel(channel), endpoint(std::move(endpoint)) {}
~Impl() {}
OnShutdownCb shutdownCb_;
std::atomic_bool isShutdown_ {false};
std::string name;
uint16_t channel;
std::weak_ptr<MultiplexedSocket> endpoint;
};
ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, const std::string& name, const uint16_t& channel)
: pimpl_ { std::make_unique<Impl>(endpoint, name, channel) }
{ }
ChannelSocket::~ChannelSocket()
{ }
std::string
ChannelSocket::deviceId() const
{
if (auto ep = pimpl_->endpoint.lock()) {
return ep->deviceId();
}
return {};
}
std::string
ChannelSocket::name() const
{
return pimpl_->name;
}
uint16_t
ChannelSocket::channel() const
{
return pimpl_->channel;
}
bool
ChannelSocket::isReliable() const
{
if (auto ep = pimpl_->endpoint.lock()) {
return ep->isReliable();
}
return false;
}
bool
ChannelSocket::isInitiator() const
{
if (auto ep = pimpl_->endpoint.lock()) {
return ep->isInitiator();
}
return false;
}
int
ChannelSocket::maxPayload() const
{
if (auto ep = pimpl_->endpoint.lock()) {
return ep->maxPayload();
}
return -1;
}
void
ChannelSocket::setOnRecv(RecvCb&& cb)
{
if (auto ep = pimpl_->endpoint.lock())
ep->setOnRecv(pimpl_->channel, std::move(cb));
}
std::shared_ptr<IceTransport>
ChannelSocket::underlyingICE() const
{
if (auto mtx = pimpl_->endpoint.lock())
return mtx->underlyingICE();
return {};
}
void
ChannelSocket::stop()
{
pimpl_->isShutdown_ = true;
if (pimpl_->shutdownCb_) pimpl_->shutdownCb_();
}
void
ChannelSocket::shutdown()
{
stop();
if (auto ep = pimpl_->endpoint.lock()) {
std::error_code ec;
ep->write(pimpl_->channel, nullptr, 0, ec);
}
}
std::size_t
ChannelSocket::read(ValueType* buf, std::size_t len, std::error_code& ec)
{
if (auto ep = pimpl_->endpoint.lock()) {
int res = ep->read(pimpl_->channel, buf, len, ec);
if (res < 0) {
if (ec)
JAMI_ERR("Error when reading on channel: %s", ec.message().c_str());
shutdown();
}
return res;
}
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
std::size_t
ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
{
if (auto ep = pimpl_->endpoint.lock()) {
int res = ep->write(pimpl_->channel, buf, len, ec);
if (res < 0) {
if (ec)
JAMI_ERR("Error when writing on channel: %s", ec.message().c_str());
shutdown();
}
return res;
}
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
int
ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
{
if (auto ep = pimpl_->endpoint.lock()) {
auto res = ep->waitForData(pimpl_->channel, timeout, ec);
if (res < 0) {
if (ec)
JAMI_ERR("Error when waiting on channel: %s", ec.message().c_str());
}
return res;
}
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
void
ChannelSocket::onShutdown(OnShutdownCb&& cb)
{
pimpl_->shutdownCb_ = std::move(cb);
if (pimpl_->isShutdown_) {
pimpl_->shutdownCb_();
}
}
}