Commit 4b2386c8 authored by Adrien Béraud's avatar Adrien Béraud Committed by Sébastien Blin
Browse files

MultiplexedSocket: cleanup, improve concurrency

Refactor structures to improve memory consistency
and reduce concurrency on the main mutex.

Change-Id: I3dcf1b94d96d51bddff4446b7011605821c4f1b6
parent f6a41b7e
......@@ -53,14 +53,6 @@ namespace jami {
using clock = std::chrono::steady_clock;
using time_point = clock::time_point;
struct ChannelInfo
{
std::atomic_bool isDestroying {false};
std::deque<uint8_t> buf {};
std::mutex mutex {};
std::condition_variable cv {};
};
class MultiplexedSocket::Impl
{
public:
......@@ -99,12 +91,6 @@ public:
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
socks = std::move(sockets);
for (auto& [key, channelData] : channelDatas_) {
if (channelData) {
channelData->isDestroying = true;
channelData->cv.notify_all();
}
}
}
for (auto& socket : socks) {
// Just trigger onShutdown() to make client know
......@@ -132,6 +118,13 @@ public:
clearSockets();
}
std::shared_ptr<ChannelSocket> makeSocket(const std::string& name, uint16_t channel) {
auto& channelSocket = sockets[channel];
if (not channelSocket)
channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel);
return channelSocket;
}
/**
* Handle packets on the TLS endpoint and parse RTP
*/
......@@ -182,10 +175,6 @@ public:
std::atomic_bool stop {false};
std::thread eventLoopThread_ {};
// Multiplexed available datas
std::map<uint16_t, std::shared_ptr<ChannelInfo>> channelDatas_ {};
std::mutex channelCbsMtx_ {};
std::map<uint16_t, GenericSocket<uint8_t>::RecvCb> channelCbs_ {};
std::atomic_bool isShutdown_ {false};
std::mutex writeMtx {};
......@@ -258,13 +247,7 @@ void
MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
auto& channelData = channelDatas_[channel];
if (not channelData)
channelData = std::make_shared<ChannelInfo>();
auto& channelSocket = sockets[channel];
if (not channelSocket)
channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel);
onChannelReady_(deviceId, channelSocket);
onChannelReady_(deviceId, makeSocket(name, channel));
std::lock_guard<std::mutex> lk(channelCbsMutex);
auto channelCbsIt = channelCbs.find(channel);
if (channelCbsIt != channelCbs.end()) {
......@@ -367,18 +350,8 @@ MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
auto accept = onRequest_(deviceId, channel, name);
std::shared_ptr<ChannelSocket> channelSocket;
if (accept) {
channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel);
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
auto sockIt = sockets.find(channel);
if (sockIt != sockets.end()) {
JAMI_WARN("A channel is already present on that socket, accepting "
"the request will close the previous one");
sockets.erase(sockIt);
}
channelDatas_.emplace(channel, std::make_shared<ChannelInfo>());
sockets.emplace(channel, channelSocket);
}
std::lock_guard<std::mutex> lkSockets(socketsMutex);
channelSocket = makeSocket(name, channel);
}
// Answer to ChannelRequest if accepted
......@@ -432,14 +405,11 @@ MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
pimpl.onAccept(req.name, req.channel);
} else if (req.state == ChannelRequestState::DECLINE) {
std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
auto& channelDatas = pimpl.channelDatas_;
auto dataIt = channelDatas.find(req.channel);
if (dataIt != channelDatas.end() && dataIt->second) {
dataIt->second->isDestroying = true;
dataIt->second->cv.notify_all();
channelDatas.erase(dataIt);
auto channel = pimpl.sockets.find(req.channel);
if (channel != pimpl.sockets.end()) {
channel->second->stop();
pimpl.sockets.erase(channel);
}
pimpl.sockets.erase(req.channel);
} else if (pimpl.onRequest_) {
pimpl.onRequest(req.name, req.channel);
}
......@@ -455,37 +425,14 @@ MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
auto sockIt = sockets.find(channel);
auto dataIt = channelDatas_.find(channel);
if (channel > 0 && sockIt->second && dataIt->second) {
if (channel > 0 && sockIt->second) {
if (pkt.size() == 0) {
sockIt->second->shutdown();
dataIt->second->isDestroying = true;
dataIt->second->cv.notify_all();
channelDatas_.erase(dataIt);
sockets.erase(sockIt);
} else {
GenericSocket<uint8_t>::RecvCb cb;
{
std::lock_guard<std::mutex> lk(channelCbsMtx_);
auto cbIt = channelCbs_.find(channel);
if (cbIt != channelCbs_.end()) {
cb = cbIt->second;
}
}
if (cb) {
cb(&pkt[0], pkt.size());
return;
}
{
std::lock_guard<std::mutex> lkSockets(dataIt->second->mutex);
dataIt->second->buf.insert(dataIt->second->buf.end(),
std::make_move_iterator(pkt.begin()),
std::make_move_iterator(pkt.end()));
dataIt->second->cv.notify_all();
}
sockIt->second->onRecv(std::move(pkt));
}
} else if (pkt.size() != 0) {
std::string p = std::string(pkt.begin(), pkt.end());
JAMI_WARN("Non existing channel: %u", channel);
}
}
......@@ -562,16 +509,9 @@ MultiplexedSocket::addChannel(const std::string& name)
std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
for (int i = 1; i < UINT16_MAX; ++i) {
auto c = (offset + i) % UINT16_MAX;
if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL)
if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL || pimpl_->sockets.find(c) != pimpl_->sockets.end())
continue;
auto& socket = pimpl_->sockets[c];
if (!socket) {
auto& channel = pimpl_->channelDatas_[c];
if (!channel)
channel = std::make_shared<ChannelInfo>();
socket = std::make_shared<ChannelSocket>(weak(), name, c);
return socket;
}
return pimpl_->makeSocket(name, c);
}
return {};
}
......@@ -626,34 +566,6 @@ MultiplexedSocket::maxPayload() const
return pimpl_->endpoint->maxPayload();
}
std::size_t
MultiplexedSocket::read(const uint16_t& channel, uint8_t* buf, std::size_t len, std::error_code& ec)
{
if (pimpl_->isShutdown_) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
auto dataIt = pimpl_->channelDatas_.find(channel);
if (dataIt == pimpl_->channelDatas_.end() || !dataIt->second) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
std::size_t size;
{
std::lock_guard<std::mutex> lkSockets(dataIt->second->mutex);
auto& chanBuf = dataIt->second->buf;
size = std::min(len, chanBuf.size());
for (std::size_t i = 0; i < size; ++i)
buf[i] = chanBuf[i];
chanBuf.erase(chanBuf.begin(), chanBuf.begin() + size);
}
return size;
}
std::size_t
MultiplexedSocket::write(const uint16_t& channel,
const uint8_t* buf,
......@@ -697,56 +609,6 @@ MultiplexedSocket::write(const uint16_t& channel,
return res;
}
int
MultiplexedSocket::waitForData(const uint16_t& channel,
std::chrono::milliseconds timeout,
std::error_code& ec) const
{
if (pimpl_->isShutdown_) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
std::unique_lock lkSockets {pimpl_->socketsMutex};
auto dataIt = pimpl_->channelDatas_.find(channel);
if (dataIt == pimpl_->channelDatas_.end() or not dataIt->second) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
auto channelData = dataIt->second;
lkSockets.unlock();
std::unique_lock<std::mutex> lk {channelData->mutex};
channelData->cv.wait_for(lk, timeout, [&] {
return channelData->isDestroying or !channelData->buf.empty() or pimpl_->isShutdown_;
});
return channelData->buf.size();
}
void
MultiplexedSocket::setOnRecv(const uint16_t& channel, GenericSocket<uint8_t>::RecvCb&& cb)
{
// Re run on ioPool, socketsMtx can be locked here (via onAccept), so retrigger
// to avoid double lock
dht::ThreadPool::io().run([w = weak(), channel, cb = std::move(cb)]() {
if (auto shared = w.lock()) {
std::lock_guard<std::mutex> lkSockets(shared->pimpl_->socketsMutex);
std::deque<uint8_t> recv;
{
std::lock_guard<std::mutex> lk(shared->pimpl_->channelCbsMtx_);
shared->pimpl_->channelCbs_[channel] = cb;
auto dataIt = shared->pimpl_->channelDatas_.find(channel);
if (dataIt != shared->pimpl_->channelDatas_.end() && dataIt->second) {
std::lock_guard<std::mutex> lk(dataIt->second->mutex);
recv = std::move(dataIt->second->buf);
}
}
if (!recv.empty() && cb) {
cb(&recv[0], recv.size());
}
}
});
}
void
MultiplexedSocket::shutdown()
{
......@@ -851,6 +713,11 @@ public:
std::string name {};
uint16_t channel {};
std::weak_ptr<MultiplexedSocket> endpoint {};
std::deque<uint8_t> buf {};
std::mutex mutex {};
std::condition_variable cv {};
GenericSocket<uint8_t>::RecvCb cb {};
};
ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
......@@ -912,10 +779,28 @@ ChannelSocket::maxPayload() const
void
ChannelSocket::setOnRecv(RecvCb&& cb)
{
if (auto ep = pimpl_->endpoint.lock())
ep->setOnRecv(pimpl_->channel, std::move(cb));
std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
pimpl_->cb = std::move(cb);
if (!pimpl_->buf.empty() && pimpl_->cb) {
pimpl_->cb(&pimpl_->buf[0], pimpl_->buf.size());
pimpl_->buf = {};
}
}
void
ChannelSocket::onRecv(std::vector<uint8_t>&& pkt) {
std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
if (pimpl_->cb) {
pimpl_->cb(&pkt[0], pkt.size());
return;
}
pimpl_->buf.insert(pimpl_->buf.end(),
std::make_move_iterator(pkt.begin()),
std::make_move_iterator(pkt.end()));
pimpl_->cv.notify_all();
}
std::shared_ptr<IceTransport>
ChannelSocket::underlyingICE() const
{
......@@ -942,6 +827,7 @@ ChannelSocket::stop()
pimpl_->isShutdown_ = true;
if (pimpl_->shutdownCb_)
pimpl_->shutdownCb_();
pimpl_->cv.notify_all();
}
void
......@@ -958,16 +844,16 @@ ChannelSocket::shutdown()
}
std::size_t
ChannelSocket::read(ValueType* buf, std::size_t len, std::error_code& ec)
ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
{
if (auto ep = pimpl_->endpoint.lock()) {
int res = ep->read(pimpl_->channel, buf, len, ec);
if (ec)
JAMI_ERR("Error when reading on channel: %s", ec.message().c_str());
return res;
}
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
std::size_t size = std::min(len, pimpl_->buf.size());
for (std::size_t i = 0; i < size; ++i)
outBuf[i] = pimpl_->buf[i];
pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
return size;
}
std::size_t
......@@ -997,14 +883,11 @@ ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
int
ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
{
if (auto ep = pimpl_->endpoint.lock()) {
auto res = ep->waitForData(pimpl_->channel, timeout, ec);
if (ec)
JAMI_ERR("Error when waiting on channel: %s", ec.message().c_str());
return res;
}
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
std::unique_lock<std::mutex> lk {pimpl_->mutex};
pimpl_->cv.wait_for(lk, timeout, [&] {
return !pimpl_->buf.empty() or pimpl_->isShutdown_;
});
return pimpl_->buf.size();
}
void
......
......@@ -103,15 +103,10 @@ public:
*/
void setOnChannelReady(uint16_t channel, onChannelReadyCb&& cb);
std::size_t read(const uint16_t& channel, uint8_t* buf, std::size_t len, std::error_code& ec);
std::size_t write(const uint16_t& channel,
const uint8_t* buf,
std::size_t len,
std::error_code& ec);
int waitForData(const uint16_t& channel,
std::chrono::milliseconds timeout,
std::error_code&) const;
void setOnRecv(const uint16_t& channel, GenericSocket<uint8_t>::RecvCb&& cb);
/**
* This will close all channels and send a TLS EOF on the main socket.
......@@ -221,6 +216,8 @@ public:
*/
void setOnRecv(RecvCb&&) override;
void onRecv(std::vector<uint8_t>&& pkt);
std::shared_ptr<IceTransport> underlyingICE() const;
/**
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment