From cf6d9fd8c29bb262d9224477957ccca000749a1c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Mon, 25 Nov 2019 15:43:23 -0500
Subject: [PATCH] proxy client: add lock for requests

---
 include/opendht/dht_proxy_client.h |  2 ++
 src/dht_proxy_client.cpp           | 33 ++++++++++++++++++++++--------
 2 files changed, 27 insertions(+), 8 deletions(-)

diff --git a/include/opendht/dht_proxy_client.h b/include/opendht/dht_proxy_client.h
index 80198a7f..8f26a0cd 100644
--- a/include/opendht/dht_proxy_client.h
+++ b/include/opendht/dht_proxy_client.h
@@ -349,6 +349,8 @@ private:
      */
     asio::io_context httpContext_;
     std::shared_ptr<http::Resolver> resolver_;
+
+    mutable std::mutex requestLock_;
     std::map<unsigned int /*id*/, std::shared_ptr<http::Request>> requests_;
     /*
      * Thread for executing the http io_context.run() blocking call
diff --git a/src/dht_proxy_client.cpp b/src/dht_proxy_client.cpp
index 275923dd..8159fb3e 100644
--- a/src/dht_proxy_client.cpp
+++ b/src/dht_proxy_client.cpp
@@ -334,11 +334,15 @@ DhtProxyClient::get(const InfoHash& key, GetCallback cb, DoneCallback donecb, Va
                     }
                     loopSignal_();
                 }
+                std::lock_guard<std::mutex> l(requestLock_);
                 requests_.erase(reqid);
             }
         });
+        {
+            std::lock_guard<std::mutex> l(requestLock_);
+            requests_[reqid] = request;
+        }
         request->send();
-        requests_[reqid] = request;
     }
     catch (const std::exception &e){
         if (logger_)
@@ -470,11 +474,15 @@ DhtProxyClient::doPut(const InfoHash& key, Sp<Value> val, DoneCallbackSimple cb,
                 }
                 if (cb)
                     cb(ok);
+                std::lock_guard<std::mutex> l(requestLock_);
                 requests_.erase(reqid);
             }
         });
+        {
+            std::lock_guard<std::mutex> l(requestLock_);
+            requests_[reqid] = request;
+        }
         request->send();
-        requests_[reqid] = request;
     }
     catch (const std::exception &e){
         if (logger_)
@@ -590,14 +598,17 @@ DhtProxyClient::queryProxyInfo(std::shared_ptr<InfoState> infoState, sa_family_t
                     if (not infoState->cancel)
                         onProxyInfos(proxyInfos, family);
                 }
+                std::lock_guard<std::mutex> l(requestLock_);
                 requests_.erase(reqid);
             }
         });
 
         if (infoState->cancel.load())
             return;
-
-        requests_[reqid] = request;
+        {
+            std::lock_guard<std::mutex> l(requestLock_);
+            requests_[reqid] = request;
+        }
         request->send();
     }
     catch (const std::exception &e){
@@ -686,10 +697,10 @@ DhtProxyClient::listen(const InfoHash& key, ValueCallback cb, Value::Filter filt
     if (logger_)
         logger_->d("[proxy:client] [listen] [search %s]", key.to_c_str());
 
+    std::lock_guard<std::mutex> lock(searchLock_);
     auto& search = searches_[key];
     auto query = std::make_shared<Query>(Select{}, std::move(where));
     return search.ops.listen(cb, query, filter, [this, key](Sp<Query>, ValueCallback cb, SyncCallback) -> size_t {
-        std::lock_guard<std::mutex> lock(searchLock_);
         // Find search
         auto search = searches_.find(key);
         if (search == searches_.end()) {
@@ -863,13 +874,15 @@ DhtProxyClient::handleExpireListener(const asio::error_code &ec, const InfoHash&
                         requests_.erase(reqid);
                     }
                 });
+                {
+                    std::lock_guard<std::mutex> l(requestLock_);
+                    requests_[reqid] = request;
+                }
                 request->send();
-                requests_[reqid] = request;
             }
             catch (const std::exception &e){
                 if (logger_)
                      logger_->e("[proxy:client] [unsubscribe %s] failed: %s", key.to_c_str(), e.what());
-                requests_.erase(reqid);
             }
         } else {
             // stop the request
@@ -962,11 +975,15 @@ DhtProxyClient::sendListen(const restinio::http_request_header_t header,
                     if (response.status_code == 0)
                         opFailed();
                 }
+                std::lock_guard<std::mutex> l(requestLock_);
                 requests_.erase(reqid);
             }
         });
+        {
+            std::lock_guard<std::mutex> l(requestLock_);
+            requests_[reqid] = request;
+        }
         request->send();
-        requests_[reqid] = request;
     }
     catch (const std::exception &e){
         if (logger_)
-- 
GitLab