diff --git a/tests/connectionManager.cpp b/tests/connectionManager.cpp index 070ed83fa418ee397a204a21a321825f8cfde2a5..58aecfba7a21ad50ee87f85a8f6cc916316613b6 100644 --- a/tests/connectionManager.cpp +++ b/tests/connectionManager.cpp @@ -95,6 +95,7 @@ private: void testDeclineICERequest(); void testChannelRcvShutdown(); void testChannelSenderShutdown(); + void testMultiChannelShutdown(); void testCloseConnectionWith(); void testShutdownCallbacks(); void testFloodSocket(); @@ -124,6 +125,7 @@ private: CPPUNIT_TEST(testAcceptsICERequest); CPPUNIT_TEST(testChannelRcvShutdown); CPPUNIT_TEST(testChannelSenderShutdown); + CPPUNIT_TEST(testMultiChannelShutdown); CPPUNIT_TEST(testCloseConnectionWith); CPPUNIT_TEST(testShutdownCallbacks); CPPUNIT_TEST(testFloodSocket); @@ -887,6 +889,111 @@ ConnectionManagerTest::testChannelSenderShutdown() CPPUNIT_ASSERT(scv.wait_for(lk, 30s, [&] { return shutdownReceived; })); } +void +ConnectionManagerTest::testMultiChannelShutdown() +{ + std::condition_variable cv; + size_t connectedCbCount = 0; + size_t successfullyConnected = 0; + size_t accepted = 0; + size_t shutdownCount = 0; + std::atomic_bool connected = false; + std::set<std::shared_ptr<MultiplexedSocket>> sockets; + bool shut = true; + + bob->connectionManager->onICERequest([](const DeviceId&) { return true; }); + + bob->connectionManager->onChannelRequest([&](const std::shared_ptr<dht::crypto::Certificate>&, const std::string& name) { + if (name.empty()) return false; + std::lock_guard lk {mtx}; + accepted++; + cv.notify_one(); + return true; + }); + + bob->connectionManager->onConnectionReady([&](const DeviceId&, const std::string& name, std::shared_ptr<ChannelSocket> socket) { + if (not socket or name.empty()) return; + socket->setOnRecv([rxbuf = std::make_shared<std::vector<uint8_t>>(), w = std::weak_ptr(socket)](const uint8_t* data, size_t size) { + rxbuf->insert(rxbuf->end(), data, data + size); + if (rxbuf->size() == 32) { + if (auto socket = w.lock()) { + std::error_code ec; + socket->write(rxbuf->data(), rxbuf->size(), ec); + CPPUNIT_ASSERT(!ec); + socket->shutdown(); + } + } + return size; + }); + std::lock_guard lk {mtx}; + sockets.emplace(socket->underlyingSocket()); + }); + + auto onConnect = [&](std::shared_ptr<ChannelSocket> socket, const DeviceId&) { + { + std::lock_guard lk {mtx}; + connectedCbCount++; + if (socket) + successfullyConnected++; + cv.notify_one(); + } + if (socket) { + auto data_sent = dht::PkId::get(socket->name()); + socket->setOnRecv([&, data_sent, rxbuf = std::make_shared<std::vector<uint8_t>>()](const uint8_t* data, size_t size) { + rxbuf->insert(rxbuf->end(), data, data + size); + if (rxbuf->size() == 32) { + CPPUNIT_ASSERT(!std::memcmp(data_sent.data(), rxbuf->data(), data_sent.size())); + } + return size; + }); + socket->onShutdown([&]() { + std::lock_guard lk {mtx}; + shutdownCount++; + cv.notify_one(); + }); + connected = true; + std::error_code ec; + socket->write(data_sent.data(), data_sent.size(), ec); + CPPUNIT_ASSERT(!ec); + } + }; + + // max supported number of channels per side (64k - 2 reserved channels) + static constexpr size_t N = 1024 * 48 - 1; + + for (size_t i = 1; i <= N; ++i) { + alice->connectionManager->connectDevice(bob->id.second, + fmt::format("git://{}", i), + onConnect); + + if (i % 128 == 0) + std::this_thread::sleep_for(15ms); + if (i % 1000 == 0) { + if (shut && connected.exchange(false)) { + shut = false; + decltype(sockets)::node_type toClose; + { + std::lock_guard lk {mtx}; + toClose = sockets.extract(sockets.begin()); + sockets.clear(); + } + fmt::print("Closing connections {} - {}\n", i, fmt::ptr(toClose.value())); + toClose.value()->shutdown(); + } + } + } + + std::unique_lock lk {mtx}; + cv.wait_for(lk, 30s, [&] { return connectedCbCount == N; }); + CPPUNIT_ASSERT_EQUAL(N, connectedCbCount); + CPPUNIT_ASSERT_EQUAL(N, successfullyConnected); + CPPUNIT_ASSERT(successfullyConnected <= accepted); + CPPUNIT_ASSERT(accepted < 2* successfullyConnected); + cv.wait_for(lk, 60s, [&] { return shutdownCount == successfullyConnected; }); + CPPUNIT_ASSERT_EQUAL(successfullyConnected, shutdownCount); + lk.unlock(); +} + void ConnectionManagerTest::testCloseConnectionWith() {