diff --git a/src/multiplexed_socket.cpp b/src/multiplexed_socket.cpp index b34bf07f17bb9346a4715d7c76ecb2e53a295534..81a777cdf02e7ee564a437b8a6f0ce9fc80dde48 100644 --- a/src/multiplexed_socket.cpp +++ b/src/multiplexed_socket.cpp @@ -60,12 +60,13 @@ public: Impl(MultiplexedSocket& parent, std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, - std::unique_ptr<TlsSocketEndpoint> endpoint, + std::unique_ptr<TlsSocketEndpoint> ep, std::shared_ptr<dht::log::Logger> logger) : parent_(parent) , ctx_(std::move(ctx)) , deviceId(deviceId) - , endpoint(std::move(endpoint)) + , endpoint(std::move(ep)) + , nextChannel_(endpoint->isInitiator() ? 0x0001u : 0x8000u) , eventLoopThread_ {[this] { try { eventLoop(); @@ -193,6 +194,7 @@ public: std::mutex socketsMutex {}; std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {}; + uint16_t nextChannel_; // Main loop to parse incoming packets std::atomic_bool stop {false}; @@ -551,20 +553,16 @@ MultiplexedSocket::~MultiplexedSocket() {} std::shared_ptr<ChannelSocket> MultiplexedSocket::addChannel(const std::string& name) { - // Note: because both sides can request the same channel number at the same time - // it's better to use a random channel number instead of just incrementing the request. - thread_local dht::crypto::random_device rd; - std::uniform_int_distribution<uint16_t> dist; - auto offset = dist(rd); std::lock_guard<std::mutex> lk(pimpl_->socketsMutex); - for (int i = 1; i < UINT16_MAX; ++i) { - auto c = (offset + i) % UINT16_MAX; - if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL - || pimpl_->sockets.find(c) != pimpl_->sockets.end()) - continue; - auto channel = pimpl_->makeSocket(name, c, true); - return channel; - } + if (pimpl_->sockets.size() < UINT16_MAX) + for (unsigned i = 0; i < UINT16_MAX; ++i) { + auto c = pimpl_->nextChannel_++; + if (c == CONTROL_CHANNEL + || c == PROTOCOL_CHANNEL + || pimpl_->sockets.find(c) != pimpl_->sockets.end()) + continue; + return pimpl_->makeSocket(name, c, true); + } return {}; }