Skip to content
Snippets Groups Projects
Commit 5aec4100 authored by Adrien Béraud's avatar Adrien Béraud
Browse files

maintain treated message list to avoid leak

this piece of code could run on a spacecraft

Change-Id: I7f47ed35b5d03dc449bd8d9459682f2101c25518
parent 738aedb0
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,7 @@
#include <cstdio>
#include <ios>
#include <filesystem>
#include <map>
#ifndef _WIN32
#include <sys/stat.h> // mode_t
......@@ -91,5 +92,21 @@ int removeAll(const std::filesystem::path& path, bool erase = false);
*/
int accessFile(const std::string& file, int mode);
class IdList
{
public:
IdList() = default;
IdList(std::filesystem::path p): path(std::move(p)) {
load();
}
bool add(uint64_t id);
private:
void load();
std::filesystem::path path;
std::map<uint64_t, std::chrono::system_clock::time_point> ids;
std::chrono::system_clock::time_point last_maintain;
};
} // namespace fileutils
} // namespace dhtnet
......@@ -382,8 +382,8 @@ public:
explicit Impl(std::shared_ptr<ConnectionManager::Config> config_)
: config_ {std::move(createConfig(config_))}
, rand_ {config_->rng ? *config_->rng : dht::crypto::getSeededRandomEngine<std::mt19937_64>()}
, treatedMessages_ {config_->cachePath / "treatedMessages"}
{
loadTreatedMessages();
if(!config_->ioContext) {
config_->ioContext = std::make_shared<asio::io_context>();
ioContextRunner_ = std::make_unique<std::thread>([context = config_->ioContext, l=config_->logger]() {
......@@ -509,15 +509,12 @@ public:
tls::CertificateStore& certStore() const { return *config_->certStore; }
mutable std::mutex messageMutex_ {};
std::set<std::string, std::less<>> treatedMessages_ {};
void loadTreatedMessages();
void saveTreatedMessages() const;
fileutils::IdList treatedMessages_;
/// \return true if the given DHT message identifier has been treated
/// \note if message has not been treated yet this method st/ore this id and returns true at
/// further calls
bool isMessageTreated(std::string_view id);
bool isMessageTreated(dht::Value::Id id);
const std::shared_ptr<dht::log::Logger>& logger() const { return config_->logger; }
......@@ -1125,7 +1122,7 @@ ConnectionManager::Impl::onDhtConnected(const dht::crypto::PublicKey& devicePk)
auto shared = w.lock();
if (!shared)
return false;
if (shared->isMessageTreated(to_hex_string(req.id))) {
if (shared->isMessageTreated(req.id)) {
// Message already treated. Just ignore
return true;
}
......@@ -1539,81 +1536,11 @@ ConnectionManager::Impl::dhParams() const
std::bind(tls::DhParams::loadDhParams, config_->cachePath / "dhParams"));
}
template<typename ID = dht::Value::Id>
std::set<ID, std::less<>>
loadIdList(const std::filesystem::path& path)
{
std::set<ID, std::less<>> ids;
std::ifstream file(path);
if (!file.is_open()) {
//JAMI_DBG("Could not load %s", path.c_str());
return ids;
}
std::string line;
while (std::getline(file, line)) {
if constexpr (std::is_same<ID, std::string>::value) {
ids.emplace(std::move(line));
} else if constexpr (std::is_integral<ID>::value) {
ID vid;
if (auto [p, ec] = std::from_chars(line.data(), line.data() + line.size(), vid, 16);
ec == std::errc()) {
ids.emplace(vid);
}
}
}
return ids;
}
template<typename List = std::set<dht::Value::Id>>
void
saveIdList(const std::filesystem::path& path, const List& ids)
{
std::ofstream file(path, std::ios::trunc | std::ios::binary);
if (!file.is_open()) {
//JAMI_ERR("Could not save to %s", path.c_str());
return;
}
for (auto& c : ids)
file << std::hex << c << "\n";
}
void
ConnectionManager::Impl::loadTreatedMessages()
{
std::lock_guard<std::mutex> lock(messageMutex_);
auto path = config_->cachePath / "treatedMessages";
treatedMessages_ = loadIdList<std::string>(path.string());
if (treatedMessages_.empty()) {
auto messages = loadIdList(path.string());
for (const auto& m : messages)
treatedMessages_.emplace(to_hex_string(m));
}
}
void
ConnectionManager::Impl::saveTreatedMessages() const
{
dht::ThreadPool::io().run([w = weak_from_this()]() {
if (auto sthis = w.lock()) {
auto& this_ = *sthis;
std::lock_guard<std::mutex> lock(this_.messageMutex_);
fileutils::check_dir(this_.config_->cachePath.c_str());
saveIdList<decltype(this_.treatedMessages_)>(this_.config_->cachePath / "treatedMessages",
this_.treatedMessages_);
}
});
}
bool
ConnectionManager::Impl::isMessageTreated(std::string_view id)
ConnectionManager::Impl::isMessageTreated(dht::Value::Id id)
{
std::lock_guard<std::mutex> lock(messageMutex_);
auto res = treatedMessages_.emplace(id);
if (res.second) {
saveTreatedMessages();
return false;
}
return true;
return treatedMessages_.add(id);
}
/**
......
......@@ -408,5 +408,73 @@ accessFile(const std::filesystem::path& file, int mode)
#endif
}
constexpr auto ID_TIMEOUT = std::chrono::hours(24);
void
IdList::load()
{
size_t pruned = 0;
auto now = std::chrono::system_clock::now();
try {
std::ifstream file(path, std::ios::binary);
msgpack::unpacker unp;
auto timeout = now - ID_TIMEOUT;
while (file.is_open() && !file.eof()) {
unp.reserve_buffer(8 * 1024);
file.read(unp.buffer(), unp.buffer_capacity());
unp.buffer_consumed(file.gcount());
msgpack::unpacked result;
while (unp.next(result)) {
auto kv = result.get().as<std::pair<uint64_t, std::chrono::system_clock::time_point>>();
if (kv.second > timeout)
ids.insert(std::move(kv));
else
pruned++;
}
}
} catch (const std::exception& e) {
// discard corrupted files
std::error_code ec;
std::filesystem::remove(path, ec);
}
last_maintain = now;
if (pruned) {
std::ofstream file(path, std::ios::trunc | std::ios::binary);
for (auto& kv : ids)
msgpack::pack(file, kv);
}
}
bool
IdList::add(uint64_t id)
{
auto now = std::chrono::system_clock::now();
auto r = ids.emplace(id, now);
if (r.second) {
auto timeout = now - ID_TIMEOUT;
if (last_maintain > timeout) {
// append
std::ofstream file(path, std::ios::app | std::ios::binary);
if (file.is_open()) {
msgpack::pack(file, *r.first);
}
} else {
// maintain and save
std::ofstream file(path, std::ios::trunc | std::ios::binary);
for (auto it = ids.begin(); it != ids.end();) {
if (it->second < timeout) {
it = ids.erase(it);
} else {
msgpack::pack(file, *it);
++it;
}
}
last_maintain = now;
}
return true;
}
return false;
}
} // namespace fileutils
} // namespace dhtnet
......@@ -40,12 +40,14 @@ private:
void testPath();
void testReadDirectory();
void testLoadFile();
void testIdList();
CPPUNIT_TEST_SUITE(FileutilsTest);
CPPUNIT_TEST(testCheckDir);
CPPUNIT_TEST(testPath);
CPPUNIT_TEST(testReadDirectory);
CPPUNIT_TEST(testLoadFile);
CPPUNIT_TEST(testIdList);
CPPUNIT_TEST_SUITE_END();
static constexpr auto tmpFileName = "temp_file";
......@@ -61,7 +63,7 @@ CPPUNIT_TEST_SUITE_NAMED_REGISTRATION(FileutilsTest, FileutilsTest::name());
void
FileutilsTest::setUp()
{
char template_name[] = {"ring_unit_tests_XXXXXX"};
char template_name[] = {"unit_tests_XXXXXX"};
// Generate a temporary directory with a file inside
auto directory = mkdtemp(template_name);
......@@ -133,7 +135,31 @@ FileutilsTest::testLoadFile()
CPPUNIT_ASSERT(file.at(3) == 'G');
}
void
FileutilsTest::testIdList()
{
auto path = TEST_PATH / "idList";
IdList list(path);
list.add(1);
list.add(2);
IdList list2(path);
CPPUNIT_ASSERT(!list.add(1));
CPPUNIT_ASSERT(!list.add(2));
CPPUNIT_ASSERT(!list2.add(1));
CPPUNIT_ASSERT(!list2.add(2));
CPPUNIT_ASSERT(list2.add(10));
CPPUNIT_ASSERT(list2.add(11));
list = {path};
CPPUNIT_ASSERT(list.add(5));
CPPUNIT_ASSERT(list.add(6));
CPPUNIT_ASSERT(!list.add(1));
CPPUNIT_ASSERT(!list.add(2));
CPPUNIT_ASSERT(!list.add(10));
CPPUNIT_ASSERT(!list.add(11));
CPPUNIT_ASSERT(removeAll(path) == 0);
}
}}} // namespace dhtnet::test::fileutils
JAMI_TEST_RUNNER(dhtnet::fileutils::test::FileutilsTest::name());
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment