diff --git a/include/opendht/net.h b/include/opendht/net.h index 180c4cf900db66743cc6ace39a9f6d4566f6b88d..fd680e5166a95c61e2d58c694b7b8383295e9195 100644 --- a/include/opendht/net.h +++ b/include/opendht/net.h @@ -42,41 +42,26 @@ struct TransPrefix : public std::array<uint8_t, 2> { struct TransId final : public std::array<uint8_t, 4> { static const constexpr uint16_t INVALID {0}; - TransId() { std::fill_n(begin(), 4, 0); } + TransId() { std::fill(begin(), end(), 0); } TransId(const std::array<char, 4>& o) { std::copy(o.begin(), o.end(), begin()); } TransId(const TransPrefix prefix, uint16_t seqno = 0) { std::copy_n(prefix.begin(), prefix.size(), begin()); *reinterpret_cast<uint16_t*>(data()+prefix.size()) = seqno; } - - TransId(const char* q, size_t l) : array<uint8_t, 4>() { - if (l > 4) { - length = 0; - } else { - std::copy_n(q, l, begin()); - length = l; - } - } - - uint16_t getTid() const { - return *reinterpret_cast<const uint16_t*>(&(*this)[2]); + TransId(uint32_t id) { + *reinterpret_cast<uint32_t*>(data()) = htonl(id); } uint32_t toInt() const { - return *reinterpret_cast<const uint32_t*>(&(*this)[0]); + return ntohl(*reinterpret_cast<const uint32_t*>(&(*this)[0])); } - bool matches(const TransPrefix prefix, uint16_t* tid = nullptr) const { - if (std::equal(begin(), begin()+2, prefix.begin())) { - if (tid) - *tid = getTid(); - return true; - } else - return false; + bool matches(const TransPrefix prefix) const { + return std::equal(begin(), begin()+2, prefix.begin()); } - - unsigned length {4}; }; +TransId unpackTid(msgpack::object& o); + } /* namespace net */ } /* dht */ diff --git a/include/opendht/node.h b/include/opendht/node.h index 22e7e5622fbb1abeec92e902711f6c12f4f2e11a..1f6efda99623f00a328e51a3cfe1daa689e0aba1 100644 --- a/include/opendht/node.h +++ b/include/opendht/node.h @@ -34,9 +34,17 @@ namespace net { struct Request; struct Socket; struct RequestAnswer; -using SocketCb = std::function<void(const Sp<Node>&, RequestAnswer&&)>; } /* namespace net */ +using SocketCb = std::function<void(const Sp<Node>&, net::RequestAnswer&&)>; +using SocketId = uint32_t; +struct Socket { + Socket() {} + Socket(SocketCb&& on_receive) : + on_receive(std::move(on_receive)) {} + SocketCb on_receive {}; +}; + struct Node { const InfoHash id; @@ -108,16 +116,16 @@ struct Node { * * @return the socket. */ - Sp<net::Socket> openSocket(const net::TransId& id, net::SocketCb&& cb); + SocketId openSocket(SocketCb&& cb); - Sp<net::Socket> getSocket(const net::TransId& id) const; + Socket* getSocket(SocketId id); /** * Closes a socket so that no further data will be red on that socket. * * @param socket The socket to close. */ - void closeSocket(const Sp<net::Socket>& socket); + void closeSocket(SocketId id); /** * Resets the state of the node so it's not expired anymore. @@ -156,11 +164,11 @@ private: time_point reply_time {time_point::min()}; /* time of last correct reply received */ unsigned auth_errors {0}; bool expired_ {false}; - uint16_t transaction_id {1}; + SocketId transaction_id; using TransactionDist = std::uniform_int_distribution<decltype(transaction_id)>; std::map<net::TransId, Sp<net::Request>> requests_ {}; - std::map<net::TransId, Sp<net::Socket>> sockets_ {}; + std::map<SocketId, Socket> sockets_; }; } diff --git a/src/dht.cpp b/src/dht.cpp index 4a3485cf2eee8c7f1d70b98f15b756f56569e1cf..b79f4ecfb8cb9093d9a9f5abbf968ce399fbb0e1 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -540,40 +540,40 @@ Dht::searchStep(Sp<Search> sr) continue; for (const auto& l : sr->listeners) { const auto& query = l.second.query; - if (n.getListenTime(query) <= now) { - DHT_LOG.w(sr->id, n.node->id, "[search %s] [node %s] sending 'listen'", - sr->id.toString().c_str(), n.node->toString().c_str()); - - const auto& r = n.listenStatus.find(query); - auto prev_req = r != n.listenStatus.end() ? r->second : nullptr; - - std::weak_ptr<Search> ws = sr; - n.listenStatus[query] = network_engine.sendListen(n.node, sr->id, *query, n.token, prev_req, - [this,ws,query](const net::Request& req, net::RequestAnswer&& answer) mutable - { /* on done */ - if (auto sr = ws.lock()) { - onListenDone(req.node, answer, sr); - scheduler.edit(sr->nextSearchStep, scheduler.time()); - } - }, - [this,ws,query](const net::Request& req, bool over) mutable - { /* on expired */ - if (auto sr = ws.lock()) { - scheduler.edit(sr->nextSearchStep, scheduler.time()); - if (over) - if (auto sn = sr->getNode(req.node)) - sn->listenStatus.erase(query); - } - }, - [this,ws,query](const Sp<Node>& node, net::RequestAnswer&& answer) mutable - { /* on new values */ - if (auto sr = ws.lock()) { - onGetValuesDone(node, answer, sr, query); - scheduler.edit(sr->nextSearchStep, scheduler.time()); - } + if (n.getListenTime(query) > now) + continue; + DHT_LOG.w(sr->id, n.node->id, "[search %s] [node %s] sending 'listen'", + sr->id.toString().c_str(), n.node->toString().c_str()); + + const auto& r = n.listenStatus.find(query); + auto prev_req = r != n.listenStatus.end() ? r->second : nullptr; + + std::weak_ptr<Search> ws = sr; + n.listenStatus[query] = network_engine.sendListen(n.node, sr->id, *query, n.token, prev_req, + [this,ws,query](const net::Request& req, net::RequestAnswer&& answer) mutable + { /* on done */ + if (auto sr = ws.lock()) { + onListenDone(req.node, answer, sr); + scheduler.edit(sr->nextSearchStep, scheduler.time()); } - ); - } + }, + [this,ws,query](const net::Request& req, bool over) mutable + { /* on expired */ + if (auto sr = ws.lock()) { + scheduler.edit(sr->nextSearchStep, scheduler.time()); + if (over) + if (auto sn = sr->getNode(req.node)) + sn->listenStatus.erase(query); + } + }, + [this,ws,query](const Sp<Node>& node, net::RequestAnswer&& answer) mutable + { /* on new values */ + if (auto sr = ws.lock()) { + onGetValuesDone(node, answer, sr, query); + scheduler.edit(sr->nextSearchStep, scheduler.time()); + } + } + ); } if (not n.candidate and ++i == LISTEN_NODES) break; diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 4986ae824c4e5c930fec0e16f5ea59a9e92a30c9..04fefd2559004c832bdb88d5075fe384d0d45232 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -153,7 +153,7 @@ NetworkEngine::tellListener(Sp<Node> node, uint32_t socket_id, const InfoHash& h { auto nnodes = bufferNodes(node->getFamily(), hash, want, nodes, nodes6); try { - sendNodesValues(node->getAddr(), TransId((char*)&socket_id, 4), nnodes.first, nnodes.second, values, query, ntoken); + sendNodesValues(node->getAddr(), TransId(socket_id), nnodes.first, nnodes.second, values, query, ntoken); } catch (const std::overflow_error& e) { DHT_LOG.e("Can't send value: buffer not large enough !"); } @@ -401,7 +401,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro auto node = cache.getNode(msg->id, from, now, true, msg->is_client); if (msg->type == MessageType::Error or msg->type == MessageType::Reply) { - auto rsocket = node->getSocket(msg->tid); + auto rsocket = node->getSocket(msg->tid.toInt()); auto req = node->getRequest(msg->tid); /* either response for a request or data for an opened socket */ @@ -954,20 +954,21 @@ NetworkEngine::sendListen(Sp<Node> n, RequestExpiredCb&& on_expired, SocketCb&& socket_cb) { - Sp<Socket> socket; - auto tid = TransId { TransPrefix::LISTEN, previous ? previous->tid.getTid() : n->getNewTid() }; + uint32_t socket; + auto tid = TransId { TransPrefix::LISTEN, n->getNewTid() }; if (previous and previous->node == n) { - socket = previous->socket; + socket = previous->getSocket(); } else { if (previous) DHT_LOG.e(hash, "[node %s] trying refresh listen contract with wrong node", previous->node->toString().c_str()); - socket = n->openSocket(TransPrefix::GET_VALUES, std::move(socket_cb)); + socket = n->openSocket(std::move(socket_cb)); } if (not socket) { DHT_LOG.e(hash, "[node %s] unable to get a valid socket for listen. Aborting listen", n->toString().c_str()); return {}; } + TransId sid(socket); msgpack::sbuffer buffer; msgpack::packer<msgpack::sbuffer> pk(&buffer); @@ -978,8 +979,8 @@ NetworkEngine::sendListen(Sp<Node> n, pk.pack(std::string("id")); pk.pack(myid); pk.pack(std::string("h")); pk.pack(hash); pk.pack(std::string("token")); packToken(pk, token); - pk.pack(std::string("sid")); pk.pack_bin(socket->id.size()); - pk.pack_bin_body((const char*)socket->id.data(), socket->id.size()); + pk.pack(std::string("sid")); pk.pack_bin(sid.size()); + pk.pack_bin_body((const char*)sid.data(), sid.size()); if (has_query) { pk.pack(std::string("q")); pk.pack(query); } diff --git a/src/node.cpp b/src/node.cpp index 7974b36e0d333e7539f38c83af3f540ee3717ded..ba8f4cb9371b4fb731b590d9358d967144c74a4d 100644 --- a/src/node.cpp +++ b/src/node.cpp @@ -32,10 +32,10 @@ constexpr std::chrono::seconds Node::MAX_RESPONSE_TIME; Node::Node(const InfoHash& id, const SockAddr& addr, bool client) -: id(id), addr(addr), is_client(client) +: id(id), addr(addr), is_client(client), sockets_() { crypto::random_device rd; - transaction_id = TransactionDist{1}(rd); + transaction_id = std::uniform_int_distribution<SocketId>{1}(rd); } /* This is our definition of a known-good node. */ @@ -112,7 +112,7 @@ Node::cancelRequest(const Sp<net::Request>& req) { if (req) { req->cancel(); - closeSocket(req->getSocket()); + closeSocket(req->closeSocket()); requests_.erase(req->getTid()); } } @@ -128,31 +128,31 @@ Node::setExpired() sockets_.clear(); } - -Sp<net::Socket> -Node::openSocket(const net::TransId& tid, net::SocketCb&& cb) +SocketId +Node::openSocket(SocketCb&& cb) { - auto s = sockets_.emplace(tid, std::make_shared<net::Socket>(tid, cb)); - //if (not s.second) - // DHT_LOG.e(id, "[node %s] socket (tid: %d) already opened!", id.toString().c_str(), tid.toInt()); - //else - // DHT_LOG.w("Opened socket (tid: %d), %lu opened", s.first->second->id, sockets_.size()); - return s.first->second; -} + if (++transaction_id == 0) + transaction_id = 1; + auto sock = Socket(std::move(cb)); + auto s = sockets_.emplace(transaction_id, std::move(sock)); + if (not s.second) + s.first->second = std::move(sock); + return transaction_id; +} -Sp<net::Socket> -Node::getSocket(const net::TransId& tid) const +Socket* +Node::getSocket(SocketId id) { - auto it = sockets_.find(tid); - return it == sockets_.end() ? nullptr : it->second; + auto it = sockets_.find(id); + return it == sockets_.end() ? nullptr : &it->second; } void -Node::closeSocket(const Sp<net::Socket>& socket) +Node::closeSocket(SocketId id) { - if (socket) { - sockets_.erase(socket->id); + if (id) { + sockets_.erase(id); //DHT_LOG.w("Closing socket (tid: %d), %lu remaining", socket->id, sockets_.size()); } } diff --git a/src/parsed_message.h b/src/parsed_message.h index f2f691d6fc9d663c83db869ba2e6bbe893f18775..1164342edba36045a0462ec38f6ffdedc928ab79 100644 --- a/src/parsed_message.h +++ b/src/parsed_message.h @@ -19,6 +19,7 @@ #include "infohash.h" #include "sockaddr.h" +#include "net.h" #include <map> @@ -131,7 +132,7 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) auto v = findMapValue(msg, "p"); if (auto t = findMapValue(msg, "t")) - tid = {t->as<std::array<char, 4>>()}; + tid = unpackTid(*t); if (auto rv = findMapValue(msg, "v")) ua = rv->as<std::string>(); @@ -198,7 +199,7 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) } if (auto t = findMapValue(req, "sid")) - socket_id = {t->as<std::array<char, 4>>()}; + socket_id = unpackTid(*t); if (auto rid = findMapValue(req, "id")) id = {*rid}; diff --git a/src/request.h b/src/request.h index d75554f1ee05e9d99cba3e96dd196731b971d2af..0b222eba8d3dd09c04e64d8585e610c23e0c126d 100644 --- a/src/request.h +++ b/src/request.h @@ -28,23 +28,6 @@ namespace net { class NetworkEngine; struct ParsedMessage; -struct RequestAnswer; -using SocketCb = std::function<void(const Sp<Node>&, RequestAnswer&&)>; - -/** - * Open route to a node for continous incoming packets. - * A socket lets a remote node send us continuous packets treated using a - * given callback. This is intended to provide an easy management of - * specific updates nodes can send. For e.g, this is used in the case of the - * "listen" operation for treating updates a node has for a given storage. - */ -struct Socket { - Socket() {} - Socket(TransId id, SocketCb on_receive) : - id(id), on_receive(on_receive) {} - TransId id; - SocketCb on_receive {}; -}; /*! * @class Request @@ -81,12 +64,13 @@ struct Request { Blob&& msg, std::function<void(const Request&, ParsedMessage&&)> on_done, std::function<void(const Request&, bool)> on_expired, - Sp<Socket> socket = {}) : + uint32_t socket = 0) : node(node), on_done(on_done), on_expired(on_expired), tid(tid), msg(std::move(msg)), socket(socket) { } TransId getTid() const { return tid; } - const Sp<Socket>& getSocket() const { return socket; } + uint32_t getSocket() { return socket; } + uint32_t closeSocket() { auto ret = socket; socket = 0; return ret; } void setExpired() { if (pending()) { @@ -134,7 +118,7 @@ private: const TransId tid; /* the request id. */ Blob msg {}; /* the serialized message. */ - Sp<Socket> socket; /* the socket used for further reponses. */ + uint32_t socket; /* the socket used for further reponses. */ }; } /* namespace net */ diff --git a/src/utils.cpp b/src/utils.cpp index 87826e1fca2a0002726e2960bc77fed6f46db122..cb2fd5b39cdda0d19cf678861edd3d16ce67c9ef 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -19,6 +19,7 @@ #include "utils.h" #include "sockaddr.h" #include "default_types.h" +#include "net.h" /* An IPv4 equivalent to IN6_IS_ADDR_UNSPECIFIED */ #ifndef IN_IS_ADDR_UNSPECIFIED @@ -160,4 +161,18 @@ findMapValue(msgpack::object& map, const std::string& key) { return nullptr; } +namespace net { + +TransId +unpackTid(msgpack::object& o) { + switch (o.type) { + case msgpack::type::POSITIVE_INTEGER: + return o.as<uint32_t>(); + default: + return o.as<std::array<char, 4>>(); + } +} + +} + }