Skip to content
Snippets Groups Projects
Commit ef23ce0c authored by Sébastien Blin's avatar Sébastien Blin Committed by Adrien Béraud
Browse files

multiplexed_socket: use shared_ptr for ChannelInfo

This fix a potential crash into waitForData where the channel is deleted
after the check and before locking. Also notify when deleting. Also
add isDestroying on channel to avoid to wait indefinitely on destroyed channel

GitLab: #549
Change-Id: Ib37c4bab31def4ca6f594fc1338a018138217765
parent 602f0a68
No related branches found
No related tags found
No related merge requests found
......@@ -55,6 +55,7 @@ using time_point = clock::time_point;
struct ChannelInfo
{
std::atomic_bool isDestroying {false};
std::deque<uint8_t> buf {};
std::mutex mutex {};
std::condition_variable cv {};
......@@ -99,10 +100,12 @@ public:
std::lock_guard<std::mutex> lkSockets(socketsMutex);
socks = std::move(sockets);
for (auto& [key, channelData] : channelDatas_) {
if (channelData)
if (channelData) {
channelData->isDestroying = true;
channelData->cv.notify_all();
}
}
}
for (auto& socket : socks) {
// Just trigger onShutdown() to make client know
// No need to write the EOF for the channel, the write will fail because endpoint is
......@@ -180,7 +183,7 @@ public:
std::thread eventLoopThread_ {};
// Multiplexed available datas
std::map<uint16_t, std::unique_ptr<ChannelInfo>> channelDatas_ {};
std::map<uint16_t, std::shared_ptr<ChannelInfo>> channelDatas_ {};
std::mutex channelCbsMtx_ {};
std::map<uint16_t, GenericSocket<uint8_t>::RecvCb> channelCbs_ {};
std::atomic_bool isShutdown_ {false};
......@@ -255,7 +258,7 @@ MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
std::lock_guard<std::mutex> lkSockets(socketsMutex);
auto& channelData = channelDatas_[channel];
if (not channelData)
channelData = std::make_unique<ChannelInfo>();
channelData = std::make_shared<ChannelInfo>();
auto& channelSocket = sockets[channel];
if (not channelSocket)
channelSocket = std::make_shared<ChannelSocket>(parent_.weak(), name, channel);
......@@ -371,7 +374,7 @@ MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
"the request will close the previous one");
sockets.erase(sockIt);
}
channelDatas_.emplace(channel, std::make_unique<ChannelInfo>());
channelDatas_.emplace(channel, std::make_shared<ChannelInfo>());
sockets.emplace(channel, channelSocket);
}
}
......@@ -419,17 +422,24 @@ MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
msgpack::unpacked result;
msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
auto object = result.get();
if (shared->pimpl_->handleProtocolMsg(object))
auto& pimpl = *shared->pimpl_;
if (pimpl.handleProtocolMsg(object))
continue;
auto req = object.as<ChannelRequest>();
if (req.state == ChannelRequestState::ACCEPT) {
shared->pimpl_->onAccept(req.name, req.channel);
pimpl.onAccept(req.name, req.channel);
} else if (req.state == ChannelRequestState::DECLINE) {
std::lock_guard<std::mutex> lkSockets(shared->pimpl_->socketsMutex);
shared->pimpl_->channelDatas_.erase(req.channel);
shared->pimpl_->sockets.erase(req.channel);
} else if (shared->pimpl_->onRequest_) {
shared->pimpl_->onRequest(req.name, req.channel);
std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
auto& channelDatas = pimpl.channelDatas_;
auto dataIt = channelDatas.find(req.channel);
if (dataIt != channelDatas.end() && dataIt->second) {
dataIt->second->isDestroying = true;
dataIt->second->cv.notify_all();
channelDatas.erase(dataIt);
}
pimpl.sockets.erase(req.channel);
} else if (pimpl.onRequest_) {
pimpl.onRequest(req.name, req.channel);
}
}
} catch (const std::exception& e) {
......@@ -447,6 +457,7 @@ MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8
if (channel > 0 && sockIt->second && dataIt->second) {
if (pkt.size() == 0) {
sockIt->second->shutdown();
dataIt->second->isDestroying = true;
dataIt->second->cv.notify_all();
channelDatas_.erase(dataIt);
sockets.erase(sockIt);
......@@ -555,7 +566,7 @@ MultiplexedSocket::addChannel(const std::string& name)
if (!socket) {
auto& channel = pimpl_->channelDatas_[c];
if (!channel)
channel = std::make_unique<ChannelInfo>();
channel = std::make_shared<ChannelInfo>();
socket = std::make_shared<ChannelSocket>(weak(), name, c);
return socket;
}
......@@ -691,18 +702,17 @@ MultiplexedSocket::waitForData(const uint16_t& channel,
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
std::unique_lock lkSockets {pimpl_->socketsMutex};
auto dataIt = pimpl_->channelDatas_.find(channel);
if (dataIt == pimpl_->channelDatas_.end()) {
if (dataIt == pimpl_->channelDatas_.end() or not dataIt->second) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
auto& channelData = dataIt->second;
if (!channelData) {
return -1;
}
auto channelData = dataIt->second;
lkSockets.unlock();
std::unique_lock<std::mutex> lk {channelData->mutex};
channelData->cv.wait_for(lk, timeout, [&] {
return !channelData->buf.empty() or pimpl_->isShutdown_;
return channelData->isDestroying or !channelData->buf.empty() or pimpl_->isShutdown_;
});
return channelData->buf.size();
}
......@@ -712,6 +722,7 @@ MultiplexedSocket::setOnRecv(const uint16_t& channel, GenericSocket<uint8_t>::Re
{
std::deque<uint8_t> recv;
{
// NOTE: here socketsMtx is already locked via onAccept
std::lock_guard<std::mutex> lk(pimpl_->channelCbsMtx_);
pimpl_->channelCbs_[channel] = cb;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment