From 5aec410067947f33089004d7ef639015822c4f16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com> Date: Thu, 22 Feb 2024 14:15:56 -0500 Subject: [PATCH] maintain treated message list to avoid leak this piece of code could run on a spacecraft Change-Id: I7f47ed35b5d03dc449bd8d9459682f2101c25518 --- include/fileutils.h | 17 ++++++++ src/connectionmanager.cpp | 85 +++------------------------------------ src/fileutils.cpp | 68 +++++++++++++++++++++++++++++++ tests/testFileutils.cpp | 30 +++++++++++++- 4 files changed, 119 insertions(+), 81 deletions(-) diff --git a/include/fileutils.h b/include/fileutils.h index f3bdb5c..434eac0 100644 --- a/include/fileutils.h +++ b/include/fileutils.h @@ -23,6 +23,7 @@ #include <cstdio> #include <ios> #include <filesystem> +#include <map> #ifndef _WIN32 #include <sys/stat.h> // mode_t @@ -91,5 +92,21 @@ int removeAll(const std::filesystem::path& path, bool erase = false); */ int accessFile(const std::string& file, int mode); + +class IdList +{ +public: + IdList() = default; + IdList(std::filesystem::path p): path(std::move(p)) { + load(); + } + bool add(uint64_t id); +private: + void load(); + std::filesystem::path path; + std::map<uint64_t, std::chrono::system_clock::time_point> ids; + std::chrono::system_clock::time_point last_maintain; +}; + } // namespace fileutils } // namespace dhtnet diff --git a/src/connectionmanager.cpp b/src/connectionmanager.cpp index 523d90d..6e0fc4d 100644 --- a/src/connectionmanager.cpp +++ b/src/connectionmanager.cpp @@ -382,8 +382,8 @@ public: explicit Impl(std::shared_ptr<ConnectionManager::Config> config_) : config_ {std::move(createConfig(config_))} , rand_ {config_->rng ? *config_->rng : dht::crypto::getSeededRandomEngine<std::mt19937_64>()} + , treatedMessages_ {config_->cachePath / "treatedMessages"} { - loadTreatedMessages(); if(!config_->ioContext) { config_->ioContext = std::make_shared<asio::io_context>(); ioContextRunner_ = std::make_unique<std::thread>([context = config_->ioContext, l=config_->logger]() { @@ -509,15 +509,12 @@ public: tls::CertificateStore& certStore() const { return *config_->certStore; } mutable std::mutex messageMutex_ {}; - std::set<std::string, std::less<>> treatedMessages_ {}; - - void loadTreatedMessages(); - void saveTreatedMessages() const; + fileutils::IdList treatedMessages_; /// \return true if the given DHT message identifier has been treated /// \note if message has not been treated yet this method st/ore this id and returns true at /// further calls - bool isMessageTreated(std::string_view id); + bool isMessageTreated(dht::Value::Id id); const std::shared_ptr<dht::log::Logger>& logger() const { return config_->logger; } @@ -1125,7 +1122,7 @@ ConnectionManager::Impl::onDhtConnected(const dht::crypto::PublicKey& devicePk) auto shared = w.lock(); if (!shared) return false; - if (shared->isMessageTreated(to_hex_string(req.id))) { + if (shared->isMessageTreated(req.id)) { // Message already treated. Just ignore return true; } @@ -1539,81 +1536,11 @@ ConnectionManager::Impl::dhParams() const std::bind(tls::DhParams::loadDhParams, config_->cachePath / "dhParams")); } -template<typename ID = dht::Value::Id> -std::set<ID, std::less<>> -loadIdList(const std::filesystem::path& path) -{ - std::set<ID, std::less<>> ids; - std::ifstream file(path); - if (!file.is_open()) { - //JAMI_DBG("Could not load %s", path.c_str()); - return ids; - } - std::string line; - while (std::getline(file, line)) { - if constexpr (std::is_same<ID, std::string>::value) { - ids.emplace(std::move(line)); - } else if constexpr (std::is_integral<ID>::value) { - ID vid; - if (auto [p, ec] = std::from_chars(line.data(), line.data() + line.size(), vid, 16); - ec == std::errc()) { - ids.emplace(vid); - } - } - } - return ids; -} - -template<typename List = std::set<dht::Value::Id>> -void -saveIdList(const std::filesystem::path& path, const List& ids) -{ - std::ofstream file(path, std::ios::trunc | std::ios::binary); - if (!file.is_open()) { - //JAMI_ERR("Could not save to %s", path.c_str()); - return; - } - for (auto& c : ids) - file << std::hex << c << "\n"; -} - -void -ConnectionManager::Impl::loadTreatedMessages() -{ - std::lock_guard<std::mutex> lock(messageMutex_); - auto path = config_->cachePath / "treatedMessages"; - treatedMessages_ = loadIdList<std::string>(path.string()); - if (treatedMessages_.empty()) { - auto messages = loadIdList(path.string()); - for (const auto& m : messages) - treatedMessages_.emplace(to_hex_string(m)); - } -} - -void -ConnectionManager::Impl::saveTreatedMessages() const -{ - dht::ThreadPool::io().run([w = weak_from_this()]() { - if (auto sthis = w.lock()) { - auto& this_ = *sthis; - std::lock_guard<std::mutex> lock(this_.messageMutex_); - fileutils::check_dir(this_.config_->cachePath.c_str()); - saveIdList<decltype(this_.treatedMessages_)>(this_.config_->cachePath / "treatedMessages", - this_.treatedMessages_); - } - }); -} - bool -ConnectionManager::Impl::isMessageTreated(std::string_view id) +ConnectionManager::Impl::isMessageTreated(dht::Value::Id id) { std::lock_guard<std::mutex> lock(messageMutex_); - auto res = treatedMessages_.emplace(id); - if (res.second) { - saveTreatedMessages(); - return false; - } - return true; + return treatedMessages_.add(id); } /** diff --git a/src/fileutils.cpp b/src/fileutils.cpp index f700ddb..72e78b6 100644 --- a/src/fileutils.cpp +++ b/src/fileutils.cpp @@ -408,5 +408,73 @@ accessFile(const std::filesystem::path& file, int mode) #endif } +constexpr auto ID_TIMEOUT = std::chrono::hours(24); + +void +IdList::load() +{ + size_t pruned = 0; + auto now = std::chrono::system_clock::now(); + try { + std::ifstream file(path, std::ios::binary); + msgpack::unpacker unp; + auto timeout = now - ID_TIMEOUT; + while (file.is_open() && !file.eof()) { + unp.reserve_buffer(8 * 1024); + file.read(unp.buffer(), unp.buffer_capacity()); + unp.buffer_consumed(file.gcount()); + msgpack::unpacked result; + while (unp.next(result)) { + auto kv = result.get().as<std::pair<uint64_t, std::chrono::system_clock::time_point>>(); + if (kv.second > timeout) + ids.insert(std::move(kv)); + else + pruned++; + } + } + } catch (const std::exception& e) { + // discard corrupted files + std::error_code ec; + std::filesystem::remove(path, ec); + } + last_maintain = now; + if (pruned) { + std::ofstream file(path, std::ios::trunc | std::ios::binary); + for (auto& kv : ids) + msgpack::pack(file, kv); + } +} + +bool +IdList::add(uint64_t id) +{ + auto now = std::chrono::system_clock::now(); + auto r = ids.emplace(id, now); + if (r.second) { + auto timeout = now - ID_TIMEOUT; + if (last_maintain > timeout) { + // append + std::ofstream file(path, std::ios::app | std::ios::binary); + if (file.is_open()) { + msgpack::pack(file, *r.first); + } + } else { + // maintain and save + std::ofstream file(path, std::ios::trunc | std::ios::binary); + for (auto it = ids.begin(); it != ids.end();) { + if (it->second < timeout) { + it = ids.erase(it); + } else { + msgpack::pack(file, *it); + ++it; + } + } + last_maintain = now; + } + return true; + } + return false; +} + } // namespace fileutils } // namespace dhtnet diff --git a/tests/testFileutils.cpp b/tests/testFileutils.cpp index f6fa77e..72ad2d2 100644 --- a/tests/testFileutils.cpp +++ b/tests/testFileutils.cpp @@ -40,12 +40,14 @@ private: void testPath(); void testReadDirectory(); void testLoadFile(); + void testIdList(); CPPUNIT_TEST_SUITE(FileutilsTest); CPPUNIT_TEST(testCheckDir); CPPUNIT_TEST(testPath); CPPUNIT_TEST(testReadDirectory); CPPUNIT_TEST(testLoadFile); + CPPUNIT_TEST(testIdList); CPPUNIT_TEST_SUITE_END(); static constexpr auto tmpFileName = "temp_file"; @@ -61,7 +63,7 @@ CPPUNIT_TEST_SUITE_NAMED_REGISTRATION(FileutilsTest, FileutilsTest::name()); void FileutilsTest::setUp() { - char template_name[] = {"ring_unit_tests_XXXXXX"}; + char template_name[] = {"unit_tests_XXXXXX"}; // Generate a temporary directory with a file inside auto directory = mkdtemp(template_name); @@ -133,7 +135,31 @@ FileutilsTest::testLoadFile() CPPUNIT_ASSERT(file.at(3) == 'G'); } +void +FileutilsTest::testIdList() +{ + auto path = TEST_PATH / "idList"; + IdList list(path); + list.add(1); + list.add(2); + IdList list2(path); + CPPUNIT_ASSERT(!list.add(1)); + CPPUNIT_ASSERT(!list.add(2)); + CPPUNIT_ASSERT(!list2.add(1)); + CPPUNIT_ASSERT(!list2.add(2)); + CPPUNIT_ASSERT(list2.add(10)); + CPPUNIT_ASSERT(list2.add(11)); + list = {path}; + CPPUNIT_ASSERT(list.add(5)); + CPPUNIT_ASSERT(list.add(6)); + CPPUNIT_ASSERT(!list.add(1)); + CPPUNIT_ASSERT(!list.add(2)); + CPPUNIT_ASSERT(!list.add(10)); + CPPUNIT_ASSERT(!list.add(11)); + CPPUNIT_ASSERT(removeAll(path) == 0); +} + + }}} // namespace dhtnet::test::fileutils JAMI_TEST_RUNNER(dhtnet::fileutils::test::FileutilsTest::name()); - -- GitLab