From 08c7356481735377df067c23a1ed6e7b64c70d6f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Fri, 12 Jun 2020 02:26:05 -0400
Subject: [PATCH] dht: move socket to SearchNode

---
 include/opendht/network_engine.h |  5 ++--
 src/dht.cpp                      | 44 +++++++++++---------------------
 src/network_engine.cpp           | 25 +++---------------
 src/node.cpp                     |  1 -
 src/request.h                    | 14 +++-------
 src/search.h                     | 18 +++++++++----
 6 files changed, 38 insertions(+), 69 deletions(-)

diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h
index 13bfae4a..3523125f 100644
--- a/include/opendht/network_engine.h
+++ b/include/opendht/network_engine.h
@@ -362,10 +362,9 @@ public:
                            const InfoHash& hash,
                            const Query& query,
                            const Blob& token,
-                           Sp<Request> previous,
+                           Tid socketId,
                            RequestCb&& on_done,
-                           RequestExpiredCb&& on_expired,
-                           SocketCb&& socket_cb);
+                           RequestExpiredCb&& on_expired);
     /**
      * Send a "announce" request to a given node.
      *
diff --git a/src/dht.cpp b/src/dht.cpp
index 31dd27b2..b2419aa3 100644
--- a/src/dht.cpp
+++ b/src/dht.cpp
@@ -591,33 +591,30 @@ Dht::searchSynchedNodeListen(const Sp<Search>& sr, SearchNode& n)
     std::weak_ptr<Search> ws = sr;
     for (const auto& l : sr->listeners) {
         const auto& query = l.second.query;
-        auto list_token = l.first;
-        if (n.getListenTime(query, listenExp) > scheduler.time())
+        
+        auto r = n.listenStatus.find(query);
+        if (n.getListenTime(r, listenExp) > scheduler.time())
             continue;
         // if (logger_)
         //     logger_->d(sr->id, n.node->id, "[search %s] [node %s] sending 'listen'",
         //        sr->id.toString().c_str(), n.node->toString().c_str());
 
-        auto r = n.listenStatus.find(query);
         if (r == n.listenStatus.end()) {
             r = n.listenStatus.emplace(std::piecewise_construct,
                 std::forward_as_tuple(query),
                 std::forward_as_tuple(
-                [ws,list_token](const std::vector<Sp<Value>>& values, bool expired){
-                    if (auto sr = ws.lock()) {
-                        auto l = sr->listeners.find(list_token);
-                        if (l != sr->listeners.end()) {
-                            l->second.get_cb(values, expired);
-                        }
-                    }
-                }, [ws,list_token] (ListenSyncStatus status) {
-                    if (auto sr = ws.lock()) {
-                        auto l = sr->listeners.find(list_token);
-                        if (l != sr->listeners.end()) {
-                            l->second.sync_cb(status);
+                    l.second.get_cb,
+                    l.second.sync_cb,
+                    n.node->openSocket([this,ws,query](const Sp<Node>& node, net::RequestAnswer&& answer) mutable {
+                        /* on new values */
+                        if (auto sr = ws.lock()) {
+                            scheduler.edit(sr->nextSearchStep, scheduler.time());
+                            sr->insertNode(node, scheduler.time(), answer.ntoken);
+                            if (auto sn = sr->getNode(node)) {
+                                sn->onValues(query, std::move(answer), types, scheduler);
+                            }
                         }
-                    }
-                })).first;
+                    }))).first;
             r->second.cacheExpirationJob = scheduler.add(time_point::max(), [this,ws,query,node=n.node]{
                 if (auto sr = ws.lock()) {
                     if (auto sn = sr->getNode(node)) {
@@ -626,8 +623,7 @@ Dht::searchSynchedNodeListen(const Sp<Search>& sr, SearchNode& n)
                 }
             });
         }
-        auto prev_req = r != n.listenStatus.end() ? r->second.req : nullptr;
-        auto new_req = network_engine.sendListen(n.node, sr->id, *query, n.token, prev_req,
+        auto new_req = network_engine.sendListen(n.node, sr->id, *query, n.token, r->second.socketId,
             [this,ws,query](const net::Request& req, net::RequestAnswer&& answer) mutable
             { /* on done */
                 if (auto sr = ws.lock()) {
@@ -647,16 +643,6 @@ Dht::searchSynchedNodeListen(const Sp<Search>& sr, SearchNode& n)
                         if (auto sn = sr->getNode(req.node))
                             sn->listenStatus.erase(query);
                 }
-            },
-            [this,ws,query](const Sp<Node>& node, net::RequestAnswer&& answer) mutable
-            { /* on new values */
-                if (auto sr = ws.lock()) {
-                    scheduler.edit(sr->nextSearchStep, scheduler.time());
-                    sr->insertNode(node, scheduler.time(), answer.ntoken);
-                    if (auto sn = sr->getNode(node)) {
-                        sn->onValues(query, std::move(answer), types, scheduler);
-                    }
-                }
             }
         );
         // Here the request may have failed and the CachedListenStatus removed
diff --git a/src/network_engine.cpp b/src/network_engine.cpp
index e91e7d1e..10991629 100644
--- a/src/network_engine.cpp
+++ b/src/network_engine.cpp
@@ -1073,27 +1073,11 @@ NetworkEngine::sendListen(Sp<Node> n,
         const InfoHash& hash,
         const Query& query,
         const Blob& token,
-        Sp<Request> previous,
+        Tid socketId,
         RequestCb&& on_done,
-        RequestExpiredCb&& on_expired,
-        SocketCb&& socket_cb)
+        RequestExpiredCb&& on_expired)
 {
-    Tid socket;
     Tid tid (n->getNewTid());
-    if (previous and previous->node == n) {
-        socket = previous->getSocket();
-    } else {
-        if (previous)
-            if (logger_)
-                logger_->e(hash, "[node %s] trying refresh listen contract with wrong node", previous->node->toString().c_str());
-        socket = n->openSocket(std::move(socket_cb));
-    }
-
-    if (not socket) {
-        if (logger_)
-            logger_->e(hash, "[node %s] unable to get a valid socket for listen. Aborting listen", n->toString().c_str());
-        return {};
-    }
     msgpack::sbuffer buffer;
     msgpack::packer<msgpack::sbuffer> pk(&buffer);
     pk.pack_map(5+(config.network?1:0));
@@ -1104,7 +1088,7 @@ NetworkEngine::sendListen(Sp<Node> n,
       pk.pack(KEY_VERSION);   pk.pack(1);
       pk.pack(KEY_REQ_H);     pk.pack(hash);
       pk.pack(KEY_REQ_TOKEN); packToken(pk, token);
-      pk.pack(KEY_REQ_SID);   pk.pack(socket);
+      pk.pack(KEY_REQ_SID);   pk.pack(socketId);
       if (has_query) {
           pk.pack(KEY_REQ_QUERY); pk.pack(query);
       }
@@ -1126,8 +1110,7 @@ NetworkEngine::sendListen(Sp<Node> n,
         [=](const Request& req_status, bool done) { /* on expired */
             if (on_expired)
                 on_expired(req_status, done);
-        },
-        socket
+        }
     );
     sendRequest(req);
     ++out_stats.listen;
diff --git a/src/node.cpp b/src/node.cpp
index 41829e6c..d207cb68 100644
--- a/src/node.cpp
+++ b/src/node.cpp
@@ -116,7 +116,6 @@ Node::cancelRequest(const Sp<net::Request>& req)
 {
     if (req) {
         req->cancel();
-        closeSocket(req->closeSocket());
         requests_.erase(req->getTid());
     }
 }
diff --git a/src/request.h b/src/request.h
index 92aa8dbe..5c468c04 100644
--- a/src/request.h
+++ b/src/request.h
@@ -73,24 +73,19 @@ struct Request {
             Sp<Node> node,
             Blob&& msg,
             std::function<void(const Request&, ParsedMessage&&)> on_done,
-            std::function<void(const Request&, bool)> on_expired,
-            Tid socket = 0) :
-        node(node), tid(tid), type(type), on_done(on_done), on_expired(on_expired), msg(std::move(msg)), socket(socket) { }
+            std::function<void(const Request&, bool)> on_expired) :
+        node(node), tid(tid), type(type), on_done(on_done), on_expired(on_expired), msg(std::move(msg)) { }
     Request(MessageType type, Tid tid,
             Sp<Node> node,
             Blob&& msg,
             std::function<void(const Request&, ParsedMessage&&)> on_done,
             std::function<bool(const Request&, DhtProtocolException&&)> on_error,
-            std::function<void(const Request&, bool)> on_expired,
-            Tid socket = 0) :
-        node(node), tid(tid), type(type), on_done(on_done), on_error(on_error), on_expired(on_expired), msg(std::move(msg)), socket(socket) { }
+            std::function<void(const Request&, bool)> on_expired) :
+        node(node), tid(tid), type(type), on_done(on_done), on_error(on_error), on_expired(on_expired), msg(std::move(msg)) { }
 
     Tid getTid() const { return tid; }
     MessageType getType() const { return type; }
 
-    Tid getSocket() const { return socket; }
-    Tid closeSocket() { auto ret = socket; socket = 0; return ret; }
-
     void setExpired() {
         if (pending()) {
             state_ = Request::State::EXPIRED;
@@ -152,7 +147,6 @@ private:
 
     Blob msg {};                      /* the serialized message. */
     std::vector<Blob> parts;
-    Tid socket;   /* the socket used for further reponses. */
 };
 
 } /* namespace net  */
diff --git a/src/search.h b/src/search.h
index b02f2a7b..f2a5a583 100644
--- a/src/search.h
+++ b/src/search.h
@@ -65,10 +65,16 @@ struct Dht::SearchNode {
         ValueCache cache;
         Sp<Scheduler::Job> cacheExpirationJob {};
         Sp<net::Request> req {};
-        CachedListenStatus(ValueStateCallback&& cb, SyncCallback&& scb)
-         : cache(std::forward<ValueStateCallback>(cb), std::forward<SyncCallback>(scb)) {}
+        Tid socketId {0};
+        CachedListenStatus(ValueStateCallback&& cb, SyncCallback scb, Tid sid)
+         : cache(std::forward<ValueStateCallback>(cb), std::forward<SyncCallback>(scb)), socketId(sid) {}
         CachedListenStatus(CachedListenStatus&&) = delete;
         CachedListenStatus(const CachedListenStatus&) = delete;
+        ~CachedListenStatus() {
+            if (socketId and req and req->node) {
+                req->node->closeSocket(socketId);
+            }
+        }
     };
     using NodeListenerStatus = std::map<Sp<Query>, CachedListenStatus>;
 
@@ -344,13 +350,15 @@ struct Dht::SearchNode {
      * Assuming the node is synced, should the "listen" request with Query q be
      * sent to this node now ?
      */
-    time_point getListenTime(const Sp<Query>& q, duration listen_expire) const {
-        auto listen_status = listenStatus.find(q);
-        if (listen_status == listenStatus.end() or not listen_status->second.req)
+    time_point getListenTime(const decltype(listenStatus)::const_iterator listen_status, duration listen_expire) const {
+        if (listen_status == listenStatus.cend() or not listen_status->second.req)
             return time_point::min();
         return listen_status->second.req->pending() ? time_point::max() :
             listen_status->second.req->reply_time + listen_expire - REANNOUNCE_MARGIN;
     }
+    time_point getListenTime(const Sp<Query>& q, duration listen_expire) const {
+        return getListenTime(listenStatus.find(q), listen_expire);
+    }
 
     /**
      * Is this node expired or candidate
-- 
GitLab