From 13902b0f19f702c43bc34e24a8d42db6a8d81a31 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Tue, 24 Feb 2015 11:43:52 -0500
Subject: [PATCH] run filter after decryption & security verifications

---
 include/opendht/securedht.h |  6 +++---
 src/dhtrunner.cpp           |  8 ++++----
 src/securedht.cpp           | 25 +++++++++++++++----------
 3 files changed, 22 insertions(+), 17 deletions(-)

diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h
index b672ebdb..3311200f 100644
--- a/include/opendht/securedht.h
+++ b/include/opendht/securedht.h
@@ -86,9 +86,9 @@ public:
      * If the signature can't be checked, or if the data can't be decrypted, it is not returned.
      * Public, non-signed & non-encrypted data is retransmitted as-is.
      */
-    void get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter = Value::AllFilter());
+    void get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter&& = {});
 
-    size_t listen(const InfoHash& id, GetCallback cb, Value::Filter = Value::AllFilter());
+    size_t listen(const InfoHash& id, GetCallback cb, Value::Filter&& = {});
 
     /**
      * Will take ownership of the value, sign it using our private key and put it in the DHT.
@@ -127,7 +127,7 @@ private:
     SecureDht(const SecureDht&) = delete;
     SecureDht& operator=(const SecureDht&) = delete;
 
-    GetCallback getCallbackFilter(GetCallback);
+    GetCallback getCallbackFilter(GetCallback, Value::Filter&&);
 
     std::shared_ptr<crypto::PrivateKey> key_ {};
     std::shared_ptr<crypto::Certificate> certificate_ {};
diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp
index a74ccde9..0f325e53 100644
--- a/src/dhtrunner.cpp
+++ b/src/dhtrunner.cpp
@@ -221,8 +221,8 @@ void
 DhtRunner::get(InfoHash hash, Dht::GetCallback vcb, Dht::DoneCallback dcb, Value::Filter f)
 {
     std::lock_guard<std::mutex> lck(storage_mtx);
-    pending_ops.emplace([=](SecureDht& dht) {
-        dht.get(hash, vcb, dcb, f);
+    pending_ops.emplace([=](SecureDht& dht) mutable {
+        dht.get(hash, vcb, dcb, std::move(f));
     });
     cv.notify_all();
 }
@@ -238,8 +238,8 @@ DhtRunner::listen(InfoHash hash, Dht::GetCallback vcb, Value::Filter f)
 {
     std::lock_guard<std::mutex> lck(storage_mtx);
     auto ret_token = std::make_shared<std::promise<size_t>>();
-    pending_ops.emplace([=](SecureDht& dht) {
-        ret_token->set_value(dht.listen(hash, vcb, f));
+    pending_ops.emplace([=](SecureDht& dht) mutable {
+        ret_token->set_value(dht.listen(hash, vcb, std::move(f)));
     });
     cv.notify_all();
     return ret_token->get_future();
diff --git a/src/securedht.cpp b/src/securedht.cpp
index cc94e98d..a82355b6 100644
--- a/src/securedht.cpp
+++ b/src/securedht.cpp
@@ -183,7 +183,7 @@ SecureDht::findCertificate(const InfoHash& node, std::function<void(const std::s
 
 
 Dht::GetCallback
-SecureDht::getCallbackFilter(GetCallback cb)
+SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter)
 {
     return [=](const std::vector<std::shared_ptr<Value>>& values) {
         std::vector<std::shared_ptr<Value>> tmpvals {};
@@ -193,8 +193,10 @@ SecureDht::getCallbackFilter(GetCallback cb)
                 try {
                     Value decrypted_val (decrypt(*v));
                     if (decrypted_val.recipient == getId()) {
-                        if (decrypted_val.owner.checkSignature(decrypted_val.getToSign(), decrypted_val.signature))
-                            tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val)));
+                        if (decrypted_val.owner.checkSignature(decrypted_val.getToSign(), decrypted_val.signature)) {
+                            if (not filter or filter(decrypted_val))
+                                tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val)));
+                        }
                         else
                             DHT_WARN("Signature verification failed for %s", v->toString().c_str());
                     }
@@ -205,14 +207,17 @@ SecureDht::getCallbackFilter(GetCallback cb)
             }
             // Check signed values
             else if (v->isSigned()) {
-                if (v->owner.checkSignature(v->getToSign(), v->signature))
-                    tmpvals.push_back(v);
+                if (v->owner.checkSignature(v->getToSign(), v->signature)) {
+                    if (not filter or filter(*v))
+                        tmpvals.push_back(v);
+                }
                 else
                     DHT_WARN("Signature verification failed for %s", v->toString().c_str());
             }
             // Forward normal values
             else {
-                tmpvals.push_back(v);
+                if (not filter or filter(*v))
+                    tmpvals.push_back(v);
             }
         }
         if (cb && not tmpvals.empty())
@@ -222,15 +227,15 @@ SecureDht::getCallbackFilter(GetCallback cb)
 }
 
 void
-SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter filter)
+SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter&& f)
 {
-    Dht::get(id, getCallbackFilter(cb), donecb, filter);
+    Dht::get(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), donecb);
 }
 
 size_t
-SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter filter)
+SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f)
 {
-    return Dht::listen(id, getCallbackFilter(cb), filter);
+    return Dht::listen(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)));
 }
 
 void
-- 
GitLab