From 5967c89c23dd6d03a1e200143b38ee744c609bb8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Blin?=
 <sebastien.blin@savoirfairelinux.com>
Date: Tue, 21 Nov 2017 12:57:57 -0500
Subject: [PATCH] proxy: don't reinit dhtproxy multiple times and forward
 encrypted messages

+ Having a null dht_proxy_client can lead to using nullptr.
+ securedht should forward messages if we are using a proxy.
+ Protect callback segfault
---
 include/opendht/dht.h              |  3 ++
 include/opendht/dht_interface.h    |  2 +
 include/opendht/dht_proxy_client.h |  1 +
 include/opendht/dhtrunner.h        |  3 ++
 include/opendht/securedht.h        | 14 +++++
 src/dht_proxy_client.cpp           | 47 +++++++++++------
 src/dht_proxy_server.cpp           |  4 +-
 src/dhtrunner.cpp                  | 83 +++++++++++++++++++-----------
 src/securedht.cpp                  |  5 +-
 9 files changed, 113 insertions(+), 49 deletions(-)

diff --git a/include/opendht/dht.h b/include/opendht/dht.h
index 33aaef39..67d4c9b9 100644
--- a/include/opendht/dht.h
+++ b/include/opendht/dht.h
@@ -70,6 +70,9 @@ public:
     Dht(int s, int s6, Config config);
     virtual ~Dht();
 
+
+    virtual void start(const std::string& ) {};
+
     /**
      * Get the ID of the node.
      */
diff --git a/include/opendht/dht_interface.h b/include/opendht/dht_interface.h
index bfaa7601..b0b683b5 100644
--- a/include/opendht/dht_interface.h
+++ b/include/opendht/dht_interface.h
@@ -28,6 +28,8 @@ public:
     DhtInterface() = default;
     virtual ~DhtInterface() = default;
 
+    virtual void start(const std::string& host) = 0;
+
     // [[deprecated]]
     using Status = NodeStatus;
     // [[deprecated]]
diff --git a/include/opendht/dht_proxy_client.h b/include/opendht/dht_proxy_client.h
index 27531a48..79f63026 100644
--- a/include/opendht/dht_proxy_client.h
+++ b/include/opendht/dht_proxy_client.h
@@ -44,6 +44,7 @@ public:
      * and an ID for the node.
      */
     explicit DhtProxyClient(const std::string& serverHost);
+    void start(const std::string& serverHost);
     virtual ~DhtProxyClient();
 
     /**
diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h
index 1d4a645c..0936bd21 100644
--- a/include/opendht/dhtrunner.h
+++ b/include/opendht/dhtrunner.h
@@ -382,6 +382,9 @@ public:
     }
     void enableProxy(bool proxify);
 #endif // OPENDHT_PROXY_CLIENT
+#if OPENDHT_PROXY_SERVER
+    void forwardAllMessages(bool forward);
+#endif // OPENDHT_PROXY_SERVER
 
     static std::vector<SockAddr> getAddrInfo(const std::string& host, const std::string& service);
 private:
diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h
index e2efed5f..8678874b 100644
--- a/include/opendht/securedht.h
+++ b/include/opendht/securedht.h
@@ -298,6 +298,16 @@ public:
         dht_->connectivityChanged();
     }
 
+    void start(const std::string& host) {
+        dht_->start(host);
+    }
+
+#if OPENDHT_PROXY_SERVER
+    void forwardAllMessages(bool forward) {
+        force_forward_ = forward;
+    }
+#endif //OPENDHT_PROXY_SERVER
+
 private:
     std::unique_ptr<DhtInterface> dht_;
     // prevent copy
@@ -317,6 +327,10 @@ private:
     std::map<InfoHash, Sp<const crypto::PublicKey>> nodesPubKeys_ {};
 
     std::uniform_int_distribution<Value::Id> rand_id {};
+
+#if OPENDHT_PROXY_SERVER
+    std::atomic_bool force_forward_ {false};
+#endif //OPENDHT_PROXY_SERVER
 };
 
 const ValueType CERTIFICATE_TYPE = {
diff --git a/src/dht_proxy_client.cpp b/src/dht_proxy_client.cpp
index fd5b3124..d60b21cf 100644
--- a/src/dht_proxy_client.cpp
+++ b/src/dht_proxy_client.cpp
@@ -35,17 +35,14 @@ namespace dht {
 DhtProxyClient::DhtProxyClient(const std::string& serverHost)
 : serverHost_(serverHost), scheduler(DHT_LOG), currentProxyInfos_(new Json::Value())
 {
-    auto confirm_proxy_time = scheduler.time() + std::chrono::seconds(5);
-    nextProxyConfirmation = scheduler.add(confirm_proxy_time, std::bind(&DhtProxyClient::confirmProxy, this));
-    auto confirm_connectivity = scheduler.time() + std::chrono::seconds(5);
-    nextConnectivityConfirmation = scheduler.add(confirm_connectivity, std::bind(&DhtProxyClient::confirmConnectivity, this));
-
-    getConnectivityStatus();
+    if (!serverHost_.empty())
+        start(serverHost_);
 }
 
 void
 DhtProxyClient::confirmProxy()
 {
+    if (serverHost_.empty()) return;
     // Retrieve the connectivity each hours if connected, else every 5 seconds.
     auto disconnected_old_status =  statusIpv4_ == NodeStatus::Disconnected && statusIpv6_ == NodeStatus::Disconnected;
     getConnectivityStatus();
@@ -58,6 +55,19 @@ DhtProxyClient::confirmProxy()
     scheduler.edit(nextProxyConfirmation, confirm_proxy_time);
 }
 
+void
+DhtProxyClient::start(const std::string& serverHost)
+{
+    serverHost_ = serverHost;
+    if (serverHost_.empty()) return;
+    auto confirm_proxy_time = scheduler.time() + std::chrono::seconds(5);
+    nextProxyConfirmation = scheduler.add(confirm_proxy_time, std::bind(&DhtProxyClient::confirmProxy, this));
+    auto confirm_connectivity = scheduler.time() + std::chrono::seconds(5);
+    nextConnectivityConfirmation = scheduler.add(confirm_connectivity, std::bind(&DhtProxyClient::confirmConnectivity, this));
+
+    getConnectivityStatus();
+}
+
 void
 DhtProxyClient::confirmConnectivity()
 {
@@ -90,7 +100,8 @@ DhtProxyClient::cancelAllListeners()
     for (auto& listener: listeners_) {
         if (listener.thread && listener.thread->joinable()) {
             // Close connection to stop listener?
-            restbed::Http::close(listener.req);
+            if (listener.req)
+                restbed::Http::close(listener.req);
             listener.thread->join();
         }
     }
@@ -101,7 +112,8 @@ DhtProxyClient::shutdown(ShutdownCallback cb)
 {
     cancelAllOperations();
     cancelAllListeners();
-    cb();
+    if (cb)
+        cb();
 }
 
 NodeStatus
@@ -165,7 +177,7 @@ DhtProxyClient::get(const InfoHash& key, GetCallback cb, DoneCallback donecb,
                         Json::Reader reader;
                         if (reader.parse(body, json)) {
                             auto value = std::make_shared<Value>(json);
-                            if (not filterChain or filterChain(*value))
+                            if ((not filterChain or filterChain(*value)) && cb)
                                 cb({value});
                         } else {
                             *ok = false;
@@ -176,7 +188,8 @@ DhtProxyClient::get(const InfoHash& key, GetCallback cb, DoneCallback donecb,
                 *ok = false;
             }
         }).wait();
-        donecb(*ok, {});
+        if (donecb)
+            donecb(*ok, {});
         if (!ok) {
             // Connection failed, update connectivity
             getConnectivityStatus();
@@ -225,7 +238,8 @@ DhtProxyClient::put(const InfoHash& key, Sp<Value> val, DoneCallback cb, time_po
                 *ok = false;
             }
         }).wait();
-        cb(*ok, {});
+        if (cb)
+            cb(*ok, {});
         if (!ok) {
             // Connection failed, update connectivity
             getConnectivityStatus();
@@ -358,7 +372,7 @@ DhtProxyClient::listen(const InfoHash& key, GetCallback cb, Value::Filter&& filt
                             Json::Reader reader;
                             if (reader.parse(body, json)) {
                                 auto value = std::make_shared<Value>(json);
-                                if (not filterChain or filterChain(*value))
+                                if ((not filterChain or filterChain(*value)) && cb)
                                     cb({value});
                             }
                         }
@@ -386,8 +400,10 @@ DhtProxyClient::cancelListen(const InfoHash&, size_t token)
         if (listener.token == token) {
             if (listener.thread->joinable()) {
                 // Close connection to stop listener?
-                restbed::Http::close(listener.req);
-                listener.thread->join();
+                if (listener.req)
+                    restbed::Http::close(listener.req);
+                if (listener.thread->joinable())
+                    listener.thread->join();
                 listeners_.erase(it);
                 return true;
             }
@@ -441,6 +457,7 @@ DhtProxyClient::restartListeners()
         restbed::Uri uri(HTTP_PROTO + serverHost_ + "/" + listener.key);
         auto req = std::make_shared<restbed::Request>(uri);
         req->set_method("LISTEN");
+        listener.req = req;
         listener.thread = std::move(std::unique_ptr<std::thread>(new std::thread([this, filterChain, cb, req]()
             {
                 auto settings = std::make_shared<restbed::Settings>();
@@ -464,7 +481,7 @@ DhtProxyClient::restartListeners()
                                 Json::Reader reader;
                                 if (reader.parse(body, json)) {
                                     auto value = std::make_shared<Value>(json);
-                                    if (not filterChain or filterChain(*value))
+                                    if ((not filterChain or filterChain(*value)) && cb)
                                         cb({value});
                                 }
                             }
diff --git a/src/dht_proxy_server.cpp b/src/dht_proxy_server.cpp
index 6881025f..058f25d9 100644
--- a/src/dht_proxy_server.cpp
+++ b/src/dht_proxy_server.cpp
@@ -28,8 +28,6 @@
 #include <json/json.h>
 #include <limits>
 
-#include <iostream>
-
 using namespace std::placeholders;
 
 namespace dht {
@@ -108,6 +106,8 @@ DhtProxyServer::DhtProxyServer(std::shared_ptr<DhtRunner> dht, in_port_t port)
             }
         }
     });
+
+    dht->forwardAllMessages(true);
 }
 
 DhtProxyServer::~DhtProxyServer()
diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp
index 53f573f4..7c5f879b 100644
--- a/src/dhtrunner.cpp
+++ b/src/dhtrunner.cpp
@@ -123,8 +123,7 @@ DhtRunner::run(const SockAddr& local4, const SockAddr& local6, DhtRunner::Config
 void
 DhtRunner::shutdown(ShutdownCallback cb) {
 #if OPENDHT_PROXY_CLIENT
-    if (dht_via_proxy_)
-        dht_via_proxy_->shutdown(cb);
+    dht_via_proxy_->shutdown(cb);
 #endif
     std::lock_guard<std::mutex> lck(storage_mtx);
     pending_ops_prio.emplace([=](SecureDht& dht) mutable {
@@ -165,7 +164,7 @@ void
 DhtRunner::dumpTables() const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    activeDht()->dumpTables();
+    activeDht()->dumpTables(); // NOTE: NOT USED by RingAccount
 }
 
 InfoHash
@@ -181,7 +180,7 @@ DhtRunner::getNodeId() const
 {
     if (!activeDht())
         return {};
-    return activeDht()->getNodeId();
+    return activeDht()->getNodeId(); // NOTE: This is OK, return the SecureDht id
 }
 
 
@@ -190,7 +189,7 @@ DhtRunner::getStoreSize() const {
     std::lock_guard<std::mutex> lck(dht_mtx);
     if (!activeDht())
         return {};
-    return activeDht()->getStoreSize();
+    return activeDht()->getStoreSize(); // NOTE: NOT USED by RingAccount
 }
 
 void
@@ -198,7 +197,7 @@ DhtRunner::setStorageLimit(size_t limit) {
     std::lock_guard<std::mutex> lck(dht_mtx);
     if (!activeDht())
         throw std::runtime_error("dht is not running");
-    return activeDht()->setStorageLimit(limit);
+    return activeDht()->setStorageLimit(limit); // NOTE: NOT USED by RingAccount
 }
 
 std::vector<NodeExport>
@@ -206,7 +205,7 @@ DhtRunner::exportNodes() const {
     std::lock_guard<std::mutex> lck(dht_mtx);
     if (!dht_)
         return {};
-    return activeDht()->exportNodes();
+    return activeDht()->exportNodes(); // NOTE: TBD Should be OK
 }
 
 std::vector<ValuesExport>
@@ -214,38 +213,38 @@ DhtRunner::exportValues() const {
     std::lock_guard<std::mutex> lck(dht_mtx);
     if (!activeDht())
         return {};
-    return activeDht()->exportValues();
+    return activeDht()->exportValues(); // NOTE: TBD Should be OK
 }
 
 void
 DhtRunner::setLoggers(LogMethod error, LogMethod warn, LogMethod debug) {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    activeDht()->setLoggers(std::forward<LogMethod>(error), std::forward<LogMethod>(warn), std::forward<LogMethod>(debug));
+    activeDht()->setLoggers(std::forward<LogMethod>(error), std::forward<LogMethod>(warn), std::forward<LogMethod>(debug)); // NOTE: TBD Should be OK
 }
 
 void
 DhtRunner::setLogFilter(const InfoHash& f) {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    activeDht()->setLogFilter(f);
+    activeDht()->setLogFilter(f); // NOTE: NOT USED by RingAccount
 }
 
 void
 DhtRunner::registerType(const ValueType& type) {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    activeDht()->registerType(type);
+    activeDht()->registerType(type); // NOTE: NOT USED by RingAccount
 }
 
 void
 DhtRunner::importValues(const std::vector<ValuesExport>& values) {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    activeDht()->importValues(values);
+    activeDht()->importValues(values); // NOTE: TBD Should be OK
 }
 
 unsigned
 DhtRunner::getNodesStats(sa_family_t af, unsigned *good_return, unsigned *dubious_return, unsigned *cached_return, unsigned *incoming_return) const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    const auto stats = activeDht()->getNodesStats(af);
+    const auto stats = activeDht()->getNodesStats(af); // NOTE: TBD Should be OK
     if (good_return)
         *good_return = stats.good_nodes;
     if (dubious_return)
@@ -261,51 +260,51 @@ NodeStats
 DhtRunner::getNodesStats(sa_family_t af) const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    return activeDht()->getNodesStats(af);
+    return activeDht()->getNodesStats(af); // NOTE: TBD Should be OK
 }
 
 std::vector<unsigned>
 DhtRunner::getNodeMessageStats(bool in) const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    return activeDht()->getNodeMessageStats(in);
+    return activeDht()->getNodeMessageStats(in); // NOTE: NOT USED by RingAccount
 }
 
 std::string
 DhtRunner::getStorageLog() const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    return activeDht()->getStorageLog();
+    return activeDht()->getStorageLog(); // NOTE: NOT USED by RingAccount
 }
 std::string
 DhtRunner::getStorageLog(const InfoHash& f) const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    return activeDht()->getStorageLog(f);
+    return activeDht()->getStorageLog(f); // NOTE: NOT USED by RingAccount
 }
 std::string
 DhtRunner::getRoutingTablesLog(sa_family_t af) const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    return activeDht()->getRoutingTablesLog(af);
+    return activeDht()->getRoutingTablesLog(af); // NOTE: NOT USED by RingAccount
 }
 std::string
 DhtRunner::getSearchesLog(sa_family_t af) const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    return activeDht()->getSearchesLog(af);
+    return activeDht()->getSearchesLog(af); // NOTE: NOT USED by RingAccount
 }
 std::string
 DhtRunner::getSearchLog(const InfoHash& f, sa_family_t af) const
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    return activeDht()->getSearchLog(f, af);
+    return activeDht()->getSearchLog(f, af); // NOTE: NOT USED by RingAccount
 }
 std::vector<SockAddr>
 DhtRunner::getPublicAddress(sa_family_t af)
 {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    return activeDht()->getPublicAddress(af);
+    return activeDht()->getPublicAddress(af); // NOTE: TBD Should be OK
 }
 std::vector<std::string>
 DhtRunner::getPublicAddressStr(sa_family_t af)
@@ -319,12 +318,13 @@ DhtRunner::getPublicAddressStr(sa_family_t af)
 void
 DhtRunner::registerCertificate(std::shared_ptr<crypto::Certificate> cert) {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    activeDht()->registerCertificate(cert);
+    activeDht()->registerCertificate(cert); // NOTE: NOT USED by RingAccount
 }
 void
 DhtRunner::setLocalCertificateStore(CertificateStoreQuery&& query_method) {
     std::lock_guard<std::mutex> lck(dht_mtx);
-    activeDht()->setLocalCertificateStore(std::forward<CertificateStoreQuery>(query_method));
+    dht_via_proxy_->setLocalCertificateStore(std::forward<CertificateStoreQuery>(query_method));
+    dht_->setLocalCertificateStore(std::forward<CertificateStoreQuery>(query_method));
 }
 
 time_point
@@ -430,6 +430,15 @@ DhtRunner::doRun(const SockAddr& sin4, const SockAddr& sin6, SecureDht::Config c
     );
     dht_ = std::unique_ptr<SecureDht>(new SecureDht(std::move(dht), config));
 
+#if OPENDHT_PROXY_CLIENT
+    if (!dht_via_proxy_) {
+        auto dht_via_proxy = std::unique_ptr<DhtInterface>(
+            new DhtProxyClient(config_.proxy_server)
+        );
+        dht_via_proxy_ = std::unique_ptr<SecureDht>(new SecureDht(std::move(dht_via_proxy), config_.dht_config));
+    }
+#endif
+
     rcv_thread = std::thread([this,s4,s6]() {
         try {
             while (true) {
@@ -522,7 +531,7 @@ DhtRunner::listen(InfoHash hash, GetCallback vcb, Value::Filter f, Where w)
             auto tokenProxy = 0, tokenClassic = 0;
             if (!use_proxy)
                 tokenClassic = dht_->listen(hash, vcb, std::move(f), std::move(w));
-            else if (dht_via_proxy_)
+            else
                 tokenProxy = dht_via_proxy_->listen(hash, vcb, std::move(f), std::move(w));
 #else
         pending_ops.emplace([=](SecureDht& dht) mutable {
@@ -573,7 +582,7 @@ DhtRunner::cancelListen(InfoHash h, size_t token)
             if (listener->tokenClassicDht != 0) {
                 dht_->cancelListen(h, listener->tokenClassicDht);
             }
-            if (dht_via_proxy_ && listener->tokenProxyDht > 0) {
+            if (listener->tokenProxyDht != 0) {
                 dht_via_proxy_->cancelListen(h, listener->tokenProxyDht);
             }
 #else
@@ -874,14 +883,15 @@ DhtRunner::activeDht() const
 #if OPENDHT_PROXY_CLIENT
 void
 DhtRunner::enableProxy(bool proxify) {
-    if (proxify) {
-        // If no proxy url in the config, use 127.0.0.1:8000
-        std::string serverHost = config_.proxy_server.empty() ? "127.0.0.1:8000" : config_.proxy_server;
-        // Init the proxy client
+    if (!dht_via_proxy_) {
         auto dht_via_proxy = std::unique_ptr<DhtInterface>(
-            new DhtProxyClient(serverHost)
+            new DhtProxyClient(config_.proxy_server)
         );
         dht_via_proxy_ = std::unique_ptr<SecureDht>(new SecureDht(std::move(dht_via_proxy), config_.dht_config));
+    }
+    if (proxify) {
+        // Init the proxy client
+        dht_via_proxy_->start(config_.proxy_server);
         // add current listeners
         for (auto& listener: listeners_) {
             auto tokenProxy = dht_via_proxy_->listen(listener->hash, listener->gcb, std::move(listener->f), std::move(listener->w));
@@ -894,7 +904,7 @@ DhtRunner::enableProxy(bool proxify) {
         loop_(); // Restart the classic DHT.
         // We doesn't need to maintain the connection with the proxy.
         // Delete it
-        dht_via_proxy_.reset(nullptr);
+        dht_via_proxy_->shutdown({});
         // update all proxyToken for all proxyListener
         auto it = listeners_.begin();
         for (; it != listeners_.end(); ++it) {
@@ -911,4 +921,15 @@ DhtRunner::enableProxy(bool proxify) {
     }
 }
 #endif // OPENDHT_PROXY_CLIENT
+
+#if OPENDHT_PROXY_SERVER
+void
+DhtRunner::forwardAllMessages(bool forward)
+{
+#if OPENDHT_PROXY_CLIENT
+    dht_via_proxy_->forwardAllMessages(forward);
+#endif // OPENDHT_PROXY_CLIENT
+    dht_->forwardAllMessages(forward);
+}
+#endif // OPENDHT_PROXY_SERVER
 }
diff --git a/src/securedht.cpp b/src/securedht.cpp
index c6fd389e..9dd8b75b 100644
--- a/src/securedht.cpp
+++ b/src/securedht.cpp
@@ -228,8 +228,11 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter)
         for (const auto& v : values) {
             // Decrypt encrypted values
             if (v->isEncrypted()) {
-                if (not key_)
+                if (not key_) {
+                    if (force_forward_)
+                        tmpvals.push_back(v);
                     continue;
+                }
                 try {
                     Value decrypted_val (decrypt(*v));
                     if (decrypted_val.recipient == getId()) {
-- 
GitLab