From aee08022fcee8788075aa9426d24c132e4ec0e56 Mon Sep 17 00:00:00 2001 From: Adrien Beraud <adrien.beraud@savoirfairelinux.com> Date: Mon, 27 Mar 2023 12:12:15 -0400 Subject: [PATCH] dht: use asio for networking, scheduling --- CMakeLists.txt | 8 +- c/opendht.cpp | 50 +++- include/opendht/callbacks.h | 13 + include/opendht/default_types.h | 22 +- include/opendht/dht.h | 44 ++- include/opendht/dht_interface.h | 14 +- include/opendht/dht_proxy_client.h | 30 +- include/opendht/dht_proxy_server.h | 2 +- include/opendht/dhtrunner.h | 47 +--- include/opendht/infohash.h | 15 +- include/opendht/network_engine.h | 41 +-- include/opendht/network_utils.h | 138 +--------- include/opendht/node.h | 10 +- include/opendht/node_export.h | 10 +- include/opendht/peer_discovery.h | 4 +- include/opendht/securedht.h | 14 +- include/opendht/sockaddr.h | 311 --------------------- include/opendht/udp_socket.h | 106 +++++++ include/opendht/utils.h | 39 +++ include/opendht/value.h | 1 - src/Makefile.am | 1 - src/default_types.cpp | 9 +- src/dht.cpp | 425 +++++++++++++++++------------ src/dht_proxy_client.cpp | 188 ++++++------- src/dht_proxy_server.cpp | 21 +- src/dhtrunner.cpp | 325 ++++++++++------------ src/network_engine.cpp | 200 ++++++++------ src/network_utils.cpp | 362 ++++-------------------- src/node.cpp | 5 +- src/node_cache.cpp | 2 +- src/parsed_message.h | 31 ++- src/peer_discovery.cpp | 3 +- src/request.h | 2 + src/search.h | 110 +++++--- src/storage.h | 3 +- src/udp_socket.cpp | 150 ++++++++++ src/utils.cpp | 8 +- tests/dhtproxytester.cpp | 2 + tests/dhtrunnertester.cpp | 4 +- tests/httptester.cpp | 1 + tools/tools_common.h | 6 +- 41 files changed, 1273 insertions(+), 1504 deletions(-) delete mode 100644 include/opendht/sockaddr.h create mode 100644 include/opendht/udp_socket.h create mode 100644 src/udp_socket.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b33cdb43..921771d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,9 +84,7 @@ if (NOT MSVC) ) endif() - if (OPENDHT_HTTP OR OPENDHT_PEER_DISCOVERY) - find_path(ASIO_INCLUDE_DIR asio.hpp REQUIRED) - endif () + find_path(ASIO_INCLUDE_DIR asio.hpp REQUIRED) if (OPENDHT_HTTP) find_package(Restinio REQUIRED) @@ -210,12 +208,12 @@ list (APPEND opendht_SOURCES src/log.cpp src/network_utils.cpp src/thread_pool.cpp + src/udp_socket.cpp ) list (APPEND opendht_HEADERS include/opendht/def.h include/opendht/utils.h - include/opendht/sockaddr.h include/opendht/rng.h include/opendht/crypto.h include/opendht/infohash.h @@ -234,7 +232,7 @@ list (APPEND opendht_HEADERS include/opendht/log.h include/opendht/log_enable.h include/opendht/thread_pool.h - include/opendht/network_utils.h + include/opendht/udp_socket.h include/opendht.h ) diff --git a/c/opendht.cpp b/c/opendht.cpp index d1fe0811..8fd933aa 100644 --- a/c/opendht.cpp +++ b/c/opendht.cpp @@ -31,6 +31,7 @@ extern "C" { #endif #include <errno.h> +#include <cstdlib> const char* dht_version() { @@ -358,14 +359,53 @@ int dht_runner_run_config(dht_runner* r, in_port_t port, const dht_runner_config return 0; } +asio::ip::udp::endpoint sockaddr_to_endpoint(struct sockaddr* addr, socklen_t addr_len) { + if (addr->sa_family == AF_INET) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(addr); + asio::ip::address_v4 address(ntohl(addr_in->sin_addr.s_addr)); + unsigned short port = ntohs(addr_in->sin_port); + return asio::ip::udp::endpoint(address, port); + } else if (addr->sa_family == AF_INET6) { + auto addr_in6 = reinterpret_cast<struct sockaddr_in6*>(addr); + asio::ip::address_v6::bytes_type bytes; + std::memcpy(bytes.data(), &addr_in6->sin6_addr, 16); + asio::ip::address_v6 address(bytes, addr_in6->sin6_scope_id); + unsigned short port = ntohs(addr_in6->sin6_port); + return asio::ip::udp::endpoint(address, port); + } else { + throw std::runtime_error("Invalid sockaddr family"); + } +} + +struct sockaddr* endpoint_to_sockaddr(const asio::ip::udp::endpoint& endpoint, socklen_t* addr_len = nullptr) { + const struct sockaddr* addr_src = endpoint.data(); + auto len = endpoint.size(); + struct sockaddr* addr_copy = reinterpret_cast<struct sockaddr*>(std::malloc(len)); + std::memcpy(addr_copy, addr_src, len); + if (addr_len) + *addr_len = len; + return addr_copy; +} + +struct sockaddr** endpoints_to_sockaddrs(const std::vector<asio::ip::udp::endpoint>& endpoints) { + size_t num_endpoints = endpoints.size(); + struct sockaddr** sockaddrs = reinterpret_cast<struct sockaddr**>(std::malloc((num_endpoints + 1) * sizeof(struct sockaddr*))); + for (size_t i = 0; i < num_endpoints; ++i) { + sockaddrs[i] = endpoint_to_sockaddr(endpoints[i]); + } + sockaddrs[num_endpoints] = nullptr; // Null-terminate the array + return sockaddrs; +} + + void dht_runner_ping(dht_runner* r, struct sockaddr* addr, socklen_t addr_len, dht_done_cb done_cb, void* cb_user_data) { auto runner = reinterpret_cast<dht::DhtRunner*>(r); if (done_cb) { - runner->bootstrap(dht::SockAddr(addr, addr_len), [done_cb, cb_user_data](bool ok){ + runner->bootstrap(sockaddr_to_endpoint(addr, addr_len), [done_cb, cb_user_data](bool ok){ done_cb(ok, cb_user_data); }); } else { - runner->bootstrap(dht::SockAddr(addr, addr_len)); + runner->bootstrap(sockaddr_to_endpoint(addr, addr_len)); } } @@ -505,11 +545,7 @@ struct sockaddr** dht_runner_get_public_address(const dht_runner* r) { auto addrs = const_cast<dht::DhtRunner*>(runner)->getPublicAddress(); if (addrs.empty()) return nullptr; - auto ret = (struct sockaddr**)malloc(sizeof(struct sockaddr*) * (addrs.size() + 1)); - for (size_t i=0; i<addrs.size(); i++) - ret[i] = addrs[i].release(); - ret[addrs.size()] = nullptr; - return ret; + return endpoints_to_sockaddrs(addrs); } #ifdef __cplusplus diff --git a/include/opendht/callbacks.h b/include/opendht/callbacks.h index 88eeff42..e83aef6a 100644 --- a/include/opendht/callbacks.h +++ b/include/opendht/callbacks.h @@ -44,6 +44,18 @@ enum class NodeStatus { Connecting, // 1+ nodes Connected // 1+ good nodes }; +struct OPENDHT_PUBLIC DhtNodeStatus { + NodeStatus ipv4; + NodeStatus ipv6; + inline NodeStatus get() const { + if (ipv4 == NodeStatus::Connected or ipv6 == NodeStatus::Connected) + return NodeStatus::Connected; + if (ipv4 == NodeStatus::Connecting or ipv6 == NodeStatus::Connecting) + return NodeStatus::Connecting; + return NodeStatus::Disconnected; + } +}; +using StatusCallback = std::function<void(DhtNodeStatus)>; inline constexpr const char* statusToStr(NodeStatus status) { @@ -79,6 +91,7 @@ struct OPENDHT_PUBLIC NodeStats { struct OPENDHT_PUBLIC NodeInfo { InfoHash id; InfoHash node_id; + PkId node_long_id; NodeStats ipv4 {}; NodeStats ipv6 {}; size_t ongoing_ops {0}; diff --git a/include/opendht/default_types.h b/include/opendht/default_types.h index 0c5ad1f0..e2ad052b 100644 --- a/include/opendht/default_types.h +++ b/include/opendht/default_types.h @@ -19,7 +19,8 @@ #pragma once #include "value.h" -#include "sockaddr.h" + +#include <asio/ip/udp.hpp> namespace dht { enum class ImStatus : uint8_t { @@ -215,9 +216,10 @@ private: public: static const ValueType TYPE; - IpServiceAnnouncement(sa_family_t family = AF_UNSPEC, in_port_t p = 0) { - addr.setFamily(family); - addr.setPort(p); + IpServiceAnnouncement(sa_family_t family = AF_UNSPEC, in_port_t p = 0) + : addr() { + addr.address(family == AF_INET ? asio::ip::address(asio::ip::address_v4::any()) : asio::ip::address(asio::ip::address_v6::any())); + addr.port(p); } IpServiceAnnouncement(const SockAddr& sa) : addr(sa) {} @@ -229,23 +231,25 @@ public: template <typename Packer> void msgpack_pack(Packer& pk) const { - pk.pack_bin(addr.getLength()); - pk.pack_bin_body((const char*)addr.get(), addr.getLength()); + pk.pack_bin(addr.size()); + pk.pack_bin_body((const char*)addr.data(), addr.size()); } virtual void msgpack_unpack(const msgpack::object& o) { + //{(sockaddr*)o.via.bin.ptr, (socklen_t)o.via.bin.size}; + //auto p = if (o.type == msgpack::type::BIN) - addr = {(sockaddr*)o.via.bin.ptr, (socklen_t)o.via.bin.size}; + addr = *reinterpret_cast<const asio::ip::udp::endpoint*>((sockaddr*)o.via.bin.ptr); else throw msgpack::type_error(); } in_port_t getPort() const { - return addr.getPort(); + return addr.port(); } void setPort(in_port_t p) { - addr.setPort(p); + addr.port(p); } const SockAddr& getPeerAddr() const { diff --git a/include/opendht/dht.h b/include/opendht/dht.h index 33595718..4b26013d 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -29,6 +29,8 @@ #include "callbacks.h" #include "dht_interface.h" +#include <asio/strand.hpp> + #include <string> #include <array> #include <vector> @@ -44,6 +46,7 @@ namespace dht { namespace net { struct Request; +class UdpSocket; } /* namespace net */ struct Storage; @@ -65,10 +68,7 @@ public: * Initialise the Dht with two open sockets (for IPv4 and IP6) * and an ID for the node. */ - Dht(std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, const Sp<Logger>& l = {}); - - Dht(std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, const Logger& l = {}) - : Dht(std::move(sock), config, std::make_shared<Logger>(l)) {} + Dht(std::shared_ptr<net::strand> strand, std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, const Sp<Logger>& l = {}); virtual ~Dht(); @@ -89,8 +89,14 @@ public: NodeStatus getStatus() const override { return std::max(getStatus(AF_INET), getStatus(AF_INET6)); } + void addOnConnectedCallback(std::function<void()> cb) { + onConnectCallbacks_.emplace(std::move(cb)); + } + void addOnStateChangeCallback(StatusCallback cb) { + onStateChangeCallbacks_.emplace_back(std::move(cb)); + } - net::DatagramSocket* getSocket() const override { return network_engine.getSocket(); }; + const net::DatagramSocket* getSocket() const override { return network_engine.getSocket(); }; /** * Performs final operations before quitting. @@ -133,10 +139,10 @@ public: void pingNode(SockAddr, DoneCallbackSimple&& cb={}) override; - time_point periodic(const uint8_t *buf, size_t buflen, SockAddr, const time_point& now) override; + /*time_point periodic(const uint8_t *buf, size_t buflen, SockAddr, const time_point& now) override; time_point periodic(const uint8_t *buf, size_t buflen, const sockaddr* from, socklen_t fromlen, const time_point& now) override { return periodic(buf, buflen, SockAddr(from, fromlen), now); - } + }*/ /** * Get a value by searching on all available protocols (IPv4, IPv6), @@ -379,6 +385,7 @@ private: uint64_t secret {}; uint64_t oldsecret {}; + asio::steady_timer rotateRecretsJob; // registred types TypeStore types; @@ -399,10 +406,10 @@ private: std::vector<std::pair<std::string,std::string>> bootstrap_nodes {}; std::chrono::steady_clock::duration bootstrap_period {BOOTSTRAP_PERIOD}; - Sp<Scheduler::Job> bootstrapJob {}; + asio::steady_timer bootstrapJob; std::map<InfoHash, Storage> store; - std::map<SockAddr, StorageBucket, SockAddr::ipCmp> store_quota; + std::map<SockAddr, StorageBucket, ipCmp> store_quota; size_t total_values {0}; size_t total_store_size {0}; size_t max_store_keys {MAX_HASHES}; @@ -418,15 +425,18 @@ private: // timing - Scheduler scheduler; - Sp<Scheduler::Job> nextNodesConfirmation {}; - Sp<Scheduler::Job> nextStorageMaintenance {}; + asio::steady_timer nextNodesConfirmation; + asio::steady_timer nextStorageMaintenance; + asio::steady_timer expirationJob; + asio::steady_timer statusCheckJob; net::NetworkEngine network_engine; using ReportedAddr = std::pair<unsigned, SockAddr>; std::vector<ReportedAddr> reported_addr; std::string persistPath; + std::vector<StatusCallback> onStateChangeCallbacks_ {}; + std::queue<std::function<void()>> onConnectCallbacks_ {}; // are we a bootstrap node ? // note: Any running node can be used as a bootstrap node. @@ -449,7 +459,7 @@ private: // Storage void storageAddListener(const InfoHash& id, const Sp<Node>& node, size_t tid, Query&& = {}, int version = 0); - bool storageStore(const InfoHash& id, const Sp<Value>& value, time_point created, const SockAddr& sa = {}, bool permanent = false); + bool storageStore(const InfoHash& id, const Sp<Value>& value, time_point created, const SockAddr* sa = nullptr, bool permanent = false); bool storageRefresh(const InfoHash& id, Value::Id vid); void expireStore(); void expireStorage(InfoHash h); @@ -493,7 +503,8 @@ private: void onNewNode(const Sp<Node>& node, int confirm); const Sp<Node> findNode(const InfoHash& id, sa_family_t af) const; bool trySearchInsert(const Sp<Node>& node); - + void scheduleStatusCheck(); + // Searches inline SearchMap& searches(sa_family_t af) { return dht(af).searches; } inline const SearchMap& searches(sa_family_t af) const { return dht(af).searches; } @@ -503,6 +514,7 @@ private: * infohash (id), using the specified IP version (IPv4 or IPv6). */ Sp<Search> search(const InfoHash& id, sa_family_t af, GetCallback = {}, QueryCallback = {}, DoneCallback = {}, Value::Filter = {}, const Sp<Query>& q = {}); + void scheduleNodeConfirmation(const time_point& step); void announce(const InfoHash& id, sa_family_t af, Sp<Value> value, DoneCallback callback, time_point created=time_point::max(), bool permanent = false); size_t listenTo(const InfoHash& id, sa_family_t af, ValueCallback cb, Value::Filter f = {}, const Sp<Query>& q = {}); @@ -520,9 +532,13 @@ private: void confirmNodes(); void expire(); + void onStateChanged(); void onConnected(); void onDisconnected(); + inline const time_point& time() const { return network_engine.time(); } + inline asio::io_context& context() const { return network_engine.context(); } + /** * Generic function to execute when a 'get' request has completed. * diff --git a/include/opendht/dht_interface.h b/include/opendht/dht_interface.h index 157a2c08..a99f88e4 100644 --- a/include/opendht/dht_interface.h +++ b/include/opendht/dht_interface.h @@ -21,13 +21,14 @@ #include "infohash.h" #include "log_enable.h" #include "node_export.h" +#include "callbacks.h" #include <queue> namespace dht { namespace net { - class DatagramSocket; +class DatagramSocket; } class OPENDHT_PUBLIC DhtInterface { @@ -49,11 +50,10 @@ public: virtual NodeStatus getStatus(sa_family_t af) const = 0; virtual NodeStatus getStatus() const = 0; - void addOnConnectedCallback(std::function<void()> cb) { - onConnectCallbacks_.emplace(std::move(cb)); - } + virtual void addOnConnectedCallback(std::function<void()> cb) = 0; + virtual void addOnStateChangeCallback(StatusCallback cb) = 0; - virtual net::DatagramSocket* getSocket() const { return {}; }; + virtual const net::DatagramSocket* getSocket() const { return {}; }; /** * Get the ID of the DHT node. @@ -92,9 +92,6 @@ public: virtual void pingNode(SockAddr, DoneCallbackSimple&& cb={}) = 0; - virtual time_point periodic(const uint8_t *buf, size_t buflen, SockAddr, const time_point& now) = 0; - virtual time_point periodic(const uint8_t *buf, size_t buflen, const sockaddr* from, socklen_t fromlen, const time_point& now) = 0; - /** * Get a value by searching on all available protocols (IPv4, IPv6), * and call the provided get callback when values are found at key. @@ -277,7 +274,6 @@ public: protected: std::shared_ptr<Logger> logger_ {}; - std::queue<std::function<void()>> onConnectCallbacks_ {}; }; } // namespace dht diff --git a/include/opendht/dht_proxy_client.h b/include/opendht/dht_proxy_client.h index fdb4120f..871ff8d2 100644 --- a/include/opendht/dht_proxy_client.h +++ b/include/opendht/dht_proxy_client.h @@ -54,7 +54,7 @@ public: explicit DhtProxyClient( std::shared_ptr<crypto::Certificate> serverCA, crypto::Identity clientIdentity, - std::function<void()> loopSignal, const std::string& serverHost, + std::shared_ptr<asio::io_context::strand> strand, const std::string& serverHost, const std::string& pushClientId = "", std::shared_ptr<Logger> logger = {}); void setHeaderFields(http::Request& request); @@ -97,6 +97,13 @@ public: NodeStatus getStatus() const override { return std::max(getStatus(AF_INET), getStatus(AF_INET6)); } + void addOnConnectedCallback(std::function<void()> cb) { + onConnectCallbacks_.emplace(std::move(cb)); + } + void addOnStateChangeCallback(StatusCallback cb) { + onStateChangeCallbacks_.emplace_back(std::move(cb)); + } + void onStateChanged(); /** * Performs final operations before quitting. @@ -213,11 +220,6 @@ public: */ void pushNotificationReceived(const std::map<std::string, std::string>& notification) override; - time_point periodic(const uint8_t*, size_t, SockAddr, const time_point& now) override; - time_point periodic(const uint8_t* buf, size_t buflen, const sockaddr* from, socklen_t fromlen, const time_point& now) override { - return periodic(buf, buflen, SockAddr(from, fromlen), now); - } - /** * Similar to Dht::get, but sends a Query to filter data remotely. * @param key the key for which to query data for. @@ -285,7 +287,6 @@ public: } void connectivityChanged() override { getProxyInfos(); - loopSignal_(); } private: @@ -357,6 +358,8 @@ private: SockAddr publicAddressV4_; SockAddr publicAddressV6_; std::atomic_bool launchConnectedCbs_ {false}; + std::vector<StatusCallback> onStateChangeCallbacks_ {}; + std::queue<std::function<void()>> onConnectCallbacks_ {}; InfoHash myid {}; @@ -367,7 +370,10 @@ private: * ASIO I/O Context for sockets in httpClient_ * Note: Each context is used in one thread only */ - asio::io_context httpContext_; + //asio::io_context httpContext_; + //std::shared_ptr<asio::io_context> httpContext_; + std::shared_ptr<asio::io_context::strand> strand_; + asio::io_context& context() const { return strand_->context(); } mutable std::mutex resolverLock_; std::shared_ptr<http::Resolver> resolver_; @@ -387,12 +393,6 @@ private: size_t listenerToken_ {0}; std::map<InfoHash, ProxySearch> searches_; - /** - * Callbacks should be executed in the main thread. - */ - std::mutex lockCallbacks_; - std::vector<std::function<void()>> callbacks_; - Sp<InfoState> infoState_; /** @@ -438,8 +438,6 @@ private: #endif #endif - const std::function<void()> loopSignal_; - #ifdef OPENDHT_PUSH_NOTIFICATIONS std::string fillBody(bool resubscribe); void getPushRequest(Json::Value&) const; diff --git a/include/opendht/dht_proxy_server.h b/include/opendht/dht_proxy_server.h index 9a97a2ed..88054151 100644 --- a/include/opendht/dht_proxy_server.h +++ b/include/opendht/dht_proxy_server.h @@ -25,7 +25,6 @@ #include "infohash.h" #include "proxy.h" #include "scheduler.h" -#include "sockaddr.h" #include "value.h" #include "http.h" @@ -89,6 +88,7 @@ public: const ProxyServerConfig& config = {}, const std::shared_ptr<dht::Logger>& logger = {}); + void stop(); virtual ~DhtProxyServer(); DhtProxyServer(const DhtProxyServer& other) = delete; diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index 4d15a009..9d4761b1 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -24,10 +24,10 @@ #include "infohash.h" #include "value.h" #include "callbacks.h" -#include "sockaddr.h" #include "log_enable.h" -#include "network_utils.h" #include "node_export.h" +#include "network_utils.h" +#include "udp_socket.h" #include <thread> #include <mutex> @@ -52,10 +52,7 @@ struct SecureDhtConfig; * thread that will update the DHT when appropriate. */ class OPENDHT_PUBLIC DhtRunner { - public: - using StatusCallback = std::function<void(NodeStatus, NodeStatus)>; - struct Config { SecureDhtConfig dht_config {}; bool threaded {true}; @@ -74,6 +71,7 @@ public: struct Context { std::shared_ptr<Logger> logger {}; std::unique_ptr<net::DatagramSocket> sock; + std::shared_ptr<asio::io_context> ioContext {}; std::shared_ptr<PeerDiscovery> peerDiscovery {}; StatusCallback statusChangedCallback {}; CertificateStoreQuery certificateStore {}; @@ -407,20 +405,6 @@ public: void run(const Config& config, Context&& context); - void setOnStatusChanged(StatusCallback&& cb) { - statusCb = std::move(cb); - } - - /** - * In non-threaded mode, the user should call this method - * regularly and everytime a new packet is received. - * @return the next op - */ - time_point loop() { - std::lock_guard<std::mutex> lck(dht_mtx); - return loop_(); - } - /** * Gracefuly disconnect from network. */ @@ -475,17 +459,20 @@ private: Stopping }; - time_point loop_(); - - NodeStatus getStatus() const { - return std::max(status4, status6); - } - bool checkShutdown(); void opEnded(); DoneCallback bindOpDoneCallback(DoneCallback&& cb); DoneCallbackSimple bindOpDoneCallback(DoneCallbackSimple&& cb); + inline void post(std::function<void()>&& op, bool prio = false) { + ioContext_->post(asio::bind_executor(*strand_, std::move(op))); + } + + void postOp(std::function<void(SecureDht&)>&& op, bool prio = false); + + std::shared_ptr<asio::io_context> ioContext_; + std::shared_ptr<asio::io_context::strand> strand_; + /** DHT instance */ std::unique_ptr<SecureDht> dht_; @@ -504,11 +491,7 @@ private: mutable std::mutex dht_mtx {}; std::thread dht_thread {}; std::condition_variable cv {}; - std::mutex sock_mtx {}; - net::PacketList rcv {}; - decltype(rcv) rcv_free {}; - - std::queue<std::function<void(SecureDht&)>> pending_ops_prio {}; + /*std::queue<std::function<void(SecureDht&)>> pending_ops_prio {};*/ std::queue<std::function<void(SecureDht&)>> pending_ops {}; std::mutex storage_mtx {}; @@ -516,10 +499,6 @@ private: std::atomic_size_t ongoing_ops {0}; std::vector<ShutdownCallback> shutdownCallbacks_; - NodeStatus status4 {NodeStatus::Disconnected}, - status6 {NodeStatus::Disconnected}; - StatusCallback statusCb {nullptr}; - /** PeerDiscovery Parameters */ std::shared_ptr<PeerDiscovery> peerDiscovery_; diff --git a/include/opendht/infohash.h b/include/opendht/infohash.h index 25753614..870af8f7 100644 --- a/include/opendht/infohash.h +++ b/include/opendht/infohash.h @@ -22,19 +22,7 @@ #include "rng.h" #include <msgpack.hpp> - -#ifndef _WIN32 -#include <netinet/in.h> -#include <netdb.h> -#ifdef __ANDROID__ -typedef uint16_t in_port_t; -#endif -#else -#include <iso646.h> -#include <ws2tcpip.h> -typedef uint16_t sa_family_t; -typedef uint16_t in_port_t; -#endif +#include <asio/ip/udp.hpp> #include <iostream> #include <iomanip> @@ -50,6 +38,7 @@ typedef uint16_t in_port_t; namespace dht { + using byte = uint8_t; namespace crypto { diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 94dbdb91..f97ba820 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -19,16 +19,19 @@ #pragma once +#include "network_utils.h" #include "node_cache.h" #include "value.h" #include "infohash.h" #include "node.h" -#include "scheduler.h" #include "utils.h" #include "rng.h" #include "rate_limiter.h" #include "log_enable.h" -#include "network_utils.h" +#include "udp_socket.h" + +#include <asio/strand.hpp> +#include <asio/ip/udp.hpp> #include <vector> #include <string> @@ -44,10 +47,6 @@ struct Request; struct Socket; struct TransId; -#ifndef MSG_CONFIRM -#define MSG_CONFIRM 0 -#endif - struct NetworkConfig { NetId network {0}; ssize_t max_req_per_sec {0}; @@ -217,10 +216,10 @@ public: NetworkEngine( InfoHash& myid, NetworkConfig config, + const Sp<strand>& strand, std::unique_ptr<DatagramSocket>&& sock, const Sp<Logger>& log, std::mt19937_64& rd, - Scheduler& scheduler, decltype(NetworkEngine::onError)&& onError, decltype(NetworkEngine::onNewNode)&& onNewNode, decltype(NetworkEngine::onReportedAddr)&& onReportedAddr, @@ -233,7 +232,9 @@ public: ~NetworkEngine(); - net::DatagramSocket* getSocket() const { return dht_socket.get(); }; + asio::io_context& context() const { return strand_->context(); } + + const net::DatagramSocket* getSocket() const { return dht_socket.get(); }; void clear(); @@ -441,8 +442,8 @@ public: */ void processMessage(const uint8_t *buf, size_t buflen, SockAddr addr); - Sp<Node> insertNode(const InfoHash& id, const SockAddr& addr) { - auto n = cache.getNode(id, addr, scheduler.time(), 0); + Sp<Node> insertNode(const InfoHash& id, const SockAddr& addr) { + auto n = cache.getNode(id, addr, time(), 0); onNewNode(n, 0); return n; } @@ -474,6 +475,12 @@ public: size_t getPartialCount() const { return partial_messages.size(); } + /** + * Accessors for the common time reference used for synchronizing + * operations. + */ + inline const time_point& time() const { return now; } + inline const time_point& syncTime() { return (now = clock::now()); } private: @@ -499,7 +506,7 @@ private: static constexpr size_t MTU {1280}; static constexpr size_t MAX_PACKET_VALUE_SIZE {600}; - static constexpr size_t MAX_MESSAGE_VALUE_SIZE {56 * 1024}; + static constexpr size_t MAX_MESSAGE_VALUE_SIZE {16 * MTU}; static const std::string my_v; @@ -530,7 +537,7 @@ private: // basic wrapper for socket sendto function - int send(const SockAddr& addr, const char *buf, size_t len, bool confirmed = false); + asio::error_code send(const SockAddr& addr, const char *buf, size_t len, bool confirmed = false); void sendValueParts(Tid tid, const std::vector<Blob>& svals, const SockAddr& addr); std::vector<Blob> packValueHeader(msgpack::sbuffer&, std::vector<Sp<Value>>::const_iterator, std::vector<Sp<Value>>::const_iterator) const; @@ -573,9 +580,10 @@ private: void deserializeNodes(ParsedMessage& msg, const SockAddr& from); /* DHT info */ + Sp<strand> strand_; const InfoHash& myid; const NetworkConfig config {}; - const std::unique_ptr<DatagramSocket> dht_socket; + std::unique_ptr<DatagramSocket> dht_socket; Sp<Logger> logger_; std::mt19937_64& rd; @@ -583,7 +591,7 @@ private: // global limiting should be triggered by at least 8 different IPs using IpLimiter = RateLimiter; - using IpLimiterMap = std::map<SockAddr, IpLimiter, SockAddr::ipCmp>; + using IpLimiterMap = std::map<SockAddr, IpLimiter, ipCmp>; IpLimiterMap address_rate_limiter; RateLimiter rate_limiter; ssize_t limiter_maintenance {0}; @@ -594,10 +602,9 @@ private: MessageStats in_stats {}, out_stats {}; std::set<SockAddr> blacklist {}; - - Scheduler& scheduler; - bool logIncoming_ {false}; + + time_point now {clock::now()}; }; } /* namespace net */ diff --git a/include/opendht/network_utils.h b/include/opendht/network_utils.h index 238b876f..1b2783db 100644 --- a/include/opendht/network_utils.h +++ b/include/opendht/network_utils.h @@ -19,143 +19,21 @@ #include "def.h" -#include "sockaddr.h" -#include "utils.h" -#include "log_enable.h" - -#ifdef _WIN32 -#include <ws2tcpip.h> -#include <winsock2.h> -#else -#include <sys/socket.h> -#include <netinet/in.h> -#include <unistd.h> -#endif - +#include <asio.hpp> +#include <chrono> +#include <cstdint> #include <functional> -#include <thread> -#include <atomic> -#include <mutex> #include <list> +#include <memory> +#include <vector> namespace dht { namespace net { -static const constexpr in_port_t DHT_DEFAULT_PORT = 4222; -static const constexpr size_t RX_QUEUE_MAX_SIZE = 1024 * 64; -static const constexpr std::chrono::milliseconds RX_QUEUE_MAX_DELAY(650); - -int bindSocket(const SockAddr& addr, SockAddr& bound); - -bool setNonblocking(int fd, bool nonblocking = true); - -#ifdef _WIN32 -void udpPipe(int fds[2]); -#endif -struct ReceivedPacket { - Blob data; - SockAddr from; - time_point received; -}; -using PacketList = std::list<ReceivedPacket>; - -class OPENDHT_PUBLIC DatagramSocket { -public: - /** A function that takes a list of new received packets and - * optionally returns consumed packets for recycling. - **/ - using OnReceive = std::function<PacketList(PacketList&& packets)>; - virtual ~DatagramSocket() {}; - - virtual int sendTo(const SockAddr& dest, const uint8_t* data, size_t size, bool replied) = 0; - - inline void setOnReceive(OnReceive&& cb) { - std::lock_guard<std::mutex> lk(lock); - rx_callback = std::move(cb); - } - - virtual bool hasIPv4() const = 0; - virtual bool hasIPv6() const = 0; - - SockAddr getBound(sa_family_t family = AF_UNSPEC) const { - std::lock_guard<std::mutex> lk(lock); - return getBoundRef(family); - } - in_port_t getPort(sa_family_t family = AF_UNSPEC) const { - std::lock_guard<std::mutex> lk(lock); - return getBoundRef(family).getPort(); - } +using asio::ip::udp; +//using time_point = std::chrono::high_resolution_clock::time_point; - virtual const SockAddr& getBoundRef(sa_family_t family = AF_UNSPEC) const = 0; - - /** Virtual resolver mothod allows to implement custom resolver */ - virtual std::vector<SockAddr> resolve(const std::string& host, const std::string& service = {}) { - return SockAddr::resolve(host, service); - } - - virtual void stop() = 0; -protected: - - PacketList getNewPacket() { - PacketList pkts; - if (toRecycle_.empty()) { - pkts.emplace_back(); - } else { - auto begIt = toRecycle_.begin(); - auto begItNext = std::next(begIt); - pkts.splice(pkts.end(), toRecycle_, begIt, begItNext); - } - return pkts; - } - - inline void onReceived(PacketList&& packets) { - std::lock_guard<std::mutex> lk(lock); - if (rx_callback) { - auto r = rx_callback(std::move(packets)); - if (not r.empty() and toRecycle_.size() < RX_QUEUE_MAX_SIZE) - toRecycle_.splice(toRecycle_.end(), std::move(r)); - } - } -protected: - mutable std::mutex lock; -private: - OnReceive rx_callback; - PacketList toRecycle_; -}; - -class OPENDHT_PUBLIC UdpSocket : public DatagramSocket { -public: - UdpSocket(in_port_t port, const std::shared_ptr<Logger>& l = {}); - UdpSocket(const SockAddr& bind4, const SockAddr& bind6, const std::shared_ptr<Logger>& l = {}); - ~UdpSocket(); - - int sendTo(const SockAddr& dest, const uint8_t* data, size_t size, bool replied) override; - - const SockAddr& getBoundRef(sa_family_t family = AF_UNSPEC) const override { - return (family == AF_INET6) ? bound6 : bound4; - } - - bool hasIPv4() const override { - std::lock_guard<std::mutex> lk(lock); - return s4 != -1; - } - bool hasIPv6() const override { - std::lock_guard<std::mutex> lk(lock); - return s6 != -1; - } - - void stop() override; -private: - std::shared_ptr<Logger> logger; - int s4 {-1}; - int s6 {-1}; - int stopfd {-1}; - SockAddr bound4, bound6; - std::thread rcv_thread {}; - std::atomic_bool running {false}; - - void openSockets(const SockAddr& bind4, const SockAddr& bind6); -}; +static const constexpr in_port_t DHT_DEFAULT_PORT = 4222; } } diff --git a/include/opendht/node.h b/include/opendht/node.h index 907892c2..59027660 100644 --- a/include/opendht/node.h +++ b/include/opendht/node.h @@ -21,7 +21,6 @@ #include "infohash.h" // includes socket structures #include "utils.h" -#include "sockaddr.h" #include "node_export.h" #include <list> @@ -50,16 +49,13 @@ struct Node { Node(const InfoHash& id, const SockAddr& addr, std::mt19937_64& rd, bool client=false); Node(const InfoHash& id, SockAddr&& addr, std::mt19937_64& rd, bool client=false); - Node(const InfoHash& id, const sockaddr* sa, socklen_t salen, std::mt19937_64& rd) - : Node(id, SockAddr(sa, salen), rd) {} + //Node(const InfoHash& id, const sockaddr* sa, socklen_t salen, std::mt19937_64& rd) + // : Node(id, SockAddr(sa, salen), rd) {} const InfoHash& getId() const { return id; } const SockAddr& getAddr() const { return addr; } - std::string getAddrStr() const { - return addr.toString(); - } bool isClient() const { return is_client; } bool isIncoming() { return time > reply_time; } @@ -96,7 +92,7 @@ struct Node { ne.addr = addr; return ne; } - sa_family_t getFamily() const { return addr.getFamily(); } + sa_family_t getFamily() const { return addr.protocol().family(); } void update(const SockAddr&); diff --git a/include/opendht/node_export.h b/include/opendht/node_export.h index bb451acc..2402d61f 100644 --- a/include/opendht/node_export.h +++ b/include/opendht/node_export.h @@ -18,7 +18,9 @@ #include "def.h" #include "infohash.h" -#include "sockaddr.h" +#include <asio/ip/udp.hpp> + +#include <string_view> #include <string_view> @@ -27,7 +29,7 @@ using namespace std::literals; struct OPENDHT_PUBLIC NodeExport { InfoHash id; - SockAddr addr; + asio::ip::udp::endpoint addr; template <typename Packer> void msgpack_pack(Packer& pk) const @@ -36,8 +38,8 @@ struct OPENDHT_PUBLIC NodeExport { pk.pack("id"sv); pk.pack(id); pk.pack("addr"sv); - pk.pack_bin(addr.getLength()); - pk.pack_bin_body((const char*)addr.get(), (size_t)addr.getLength()); + pk.pack_bin(addr.size()); + pk.pack_bin_body((const char*)addr.data(), addr.size()); } void msgpack_unpack(msgpack::object o); diff --git a/include/opendht/peer_discovery.h b/include/opendht/peer_discovery.h index 610e444f..f5a626ba 100644 --- a/include/opendht/peer_discovery.h +++ b/include/opendht/peer_discovery.h @@ -20,8 +20,8 @@ #pragma once #include "def.h" -#include "sockaddr.h" #include "infohash.h" +#include "utils.h" #include "log_enable.h" #include <thread> @@ -36,7 +36,7 @@ class OPENDHT_PUBLIC PeerDiscovery { public: static constexpr in_port_t DEFAULT_PORT = 8888; - using ServiceDiscoveredCallback = std::function<void(msgpack::object&&, SockAddr&&)>; + using ServiceDiscoveredCallback = std::function<void(msgpack::object&&, asio::ip::udp::endpoint)>; PeerDiscovery(in_port_t port = DEFAULT_PORT, std::shared_ptr<asio::io_context> ioContext = {}, std::shared_ptr<Logger> logger = {}); ~PeerDiscovery(); diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index 19e3d248..c2d2b791 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -204,12 +204,6 @@ public: std::vector<SockAddr> getPublicAddress(sa_family_t family = 0) override { return dht_->getPublicAddress(family); } - time_point periodic(const uint8_t *buf, size_t buflen, SockAddr sa, const time_point& now) override { - return dht_->periodic(buf, buflen, std::move(sa), now); - } - time_point periodic(const uint8_t *buf, size_t buflen, const sockaddr* from, socklen_t fromlen, const time_point& now) override { - return dht_->periodic(buf, buflen, from, fromlen, now); - } NodeStatus updateStatus(sa_family_t af) override { return dht_->updateStatus(af); } @@ -219,7 +213,13 @@ public: NodeStatus getStatus() const override { return dht_->getStatus(); } - net::DatagramSocket* getSocket() const override { + void addOnConnectedCallback(std::function<void()> cb) override { + dht_->addOnConnectedCallback(std::move(cb)); + } + void addOnStateChangeCallback(StatusCallback cb) override { + dht_->addOnStateChangeCallback(std::move(cb)); + } + const net::DatagramSocket* getSocket() const override { return dht_->getSocket(); }; bool isRunning(sa_family_t af = 0) const override { diff --git a/include/opendht/sockaddr.h b/include/opendht/sockaddr.h deleted file mode 100644 index 717515cd..00000000 --- a/include/opendht/sockaddr.h +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Copyright (C) 2014-2022 Savoir-faire Linux Inc. - * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see <https://www.gnu.org/licenses/>. - */ - -#pragma once - -#include "def.h" - -#ifndef _WIN32 -#include <sys/socket.h> -#include <netinet/in.h> -#include <arpa/inet.h> -#ifdef __ANDROID__ -typedef uint16_t in_port_t; -#endif -#else -#include <iso646.h> -#include <stdint.h> -#include <winsock2.h> -#include <ws2def.h> -#include <ws2tcpip.h> -typedef uint16_t sa_family_t; -typedef uint16_t in_port_t; -#endif - -#include <string> -#include <memory> -#include <vector> -#include <stdexcept> -#include <stdlib.h> - -#include <cstring> -#include <cstddef> - -namespace dht { - -OPENDHT_PUBLIC std::string print_addr(const sockaddr* sa, socklen_t slen); -OPENDHT_PUBLIC std::string print_addr(const sockaddr_storage& ss, socklen_t sslen); - -/** - * A Socket Address (sockaddr*), with abstraction for IPv4, IPv6 address families. - */ -class OPENDHT_PUBLIC SockAddr { -public: - SockAddr() {} - SockAddr(const SockAddr& o) { - set(o.get(), o.getLength()); - } - SockAddr(SockAddr&& o) noexcept : addr(std::move(o.addr)), len(o.len) { - o.len = 0; - } - - /** - * Build from existing address. - */ - SockAddr(const sockaddr* sa, socklen_t length) { - if (length > static_cast<socklen_t>(sizeof(sockaddr_storage))) - throw std::runtime_error("Socket address length is too large"); - set(sa, length); - } - SockAddr(const sockaddr* sa) { - socklen_t len = 0; - if (sa) { - if (sa->sa_family == AF_INET) - len = sizeof(sockaddr_in); - else if(sa->sa_family == AF_INET6) - len = sizeof(sockaddr_in6); - else - throw std::runtime_error("Unknown address family"); - } - set(sa, len); - } - - /** - * Build from an existing sockaddr_storage structure. - */ - SockAddr(const sockaddr_storage& ss, socklen_t len) : SockAddr((const sockaddr*)&ss, len) {} - - static std::vector<SockAddr> resolve(const std::string& host, const std::string& service = {}); - - bool operator<(const SockAddr& o) const { - if (len != o.len) - return len < o.len; - return std::memcmp((const uint8_t*)get(), (const uint8_t*)o.get(), len) < 0; - } - - bool equals(const SockAddr& o) const { - return len == o.len - && std::memcmp((const uint8_t*)get(), (const uint8_t*)o.get(), len) == 0; - } - SockAddr& operator=(const SockAddr& o) { - set(o.get(), o.getLength()); - return *this; - } - SockAddr& operator=(SockAddr&& o) { - len = o.len; - o.len = 0; - addr = std::move(o.addr); - return *this; - } - - std::string toString() const { - return print_addr(get(), getLength()); - } - - /** - * Returns the address family or AF_UNSPEC if the address is not set. - */ - sa_family_t getFamily() const { return len ? addr->sa_family : AF_UNSPEC; } - - /** - * Resize the managed structure to the appropriate size (if needed), - * in which case the sockaddr structure is cleared to zero, - * and set the address family field (sa_family). - */ - void setFamily(sa_family_t af) { - socklen_t new_length; - switch(af) { - case AF_INET: - new_length = sizeof(sockaddr_in); - break; - case AF_INET6: - new_length = sizeof(sockaddr_in6); - break; - default: - new_length = 0; - } - if (new_length != len) { - len = new_length; - if (len) addr.reset((sockaddr*)::calloc(len, 1)); - else addr.reset(); - } - if (len) - addr->sa_family = af; - } - - /** - * Set Network Interface to any - */ - void setAny() { - auto family = getFamily(); - switch(family) { - case AF_INET: - getIPv4().sin_addr.s_addr = htonl(INADDR_ANY); - break; - case AF_INET6: - getIPv6().sin6_addr = in6addr_any; - break; - } - } - - /** - * Retreive the port (in host byte order) or 0 if the address is not - * of a supported family. - */ - in_port_t getPort() const { - switch(getFamily()) { - case AF_INET: - return ntohs(getIPv4().sin_port); - case AF_INET6: - return ntohs(getIPv6().sin6_port); - default: - return 0; - } - } - /** - * Set the port. The address must be of a supported family. - * @param p The port in host byte order. - */ - void setPort(in_port_t p) { - switch(getFamily()) { - case AF_INET: - getIPv4().sin_port = htons(p); - break; - case AF_INET6: - getIPv6().sin6_port = htons(p); - break; - } - } - - /** - * Set the address part of the socket address from a numeric IP address (string representation). - * Family must be already set. Throws in case of parse failue. - */ - void setAddress(const char* address); - - /** - * Returns the accessible byte length at the pointer returned by #get(). - * If zero, #get() returns null. - */ - socklen_t getLength() const { return len; } - - /** - * An address is defined to be true if its length is not zero. - */ - explicit operator bool() const noexcept { - return len; - } - - /** - * Returns the address to the managed sockaddr structure. - * The accessible length is returned by #getLength(). - */ - const sockaddr* get() const { return addr.get(); } - - /** - * Returns the address to the managed sockaddr structure. - * The accessible length is returned by #getLength(). - */ - sockaddr* get() { return addr.get(); } - - inline const sockaddr_in& getIPv4() const { - return *reinterpret_cast<const sockaddr_in*>(get()); - } - inline const sockaddr_in6& getIPv6() const { - return *reinterpret_cast<const sockaddr_in6*>(get()); - } - inline sockaddr_in& getIPv4() { - return *reinterpret_cast<sockaddr_in*>(get()); - } - inline sockaddr_in6& getIPv6() { - return *reinterpret_cast<sockaddr_in6*>(get()); - } - - /** - * Releases the ownership of the managed object, if any. - * The caller is responsible for deleting the object with free(). - */ - inline sockaddr* release() { - len = 0; - return addr.release(); - } - - /** - * Return true if address is a loopback IP address. - */ - bool isLoopback() const; - - /** - * Return true if address is not a public IP address. - */ - bool isPrivate() const; - - bool isUnspecified() const; - - bool isMappedIPv4() const; - SockAddr getMappedIPv4(); - SockAddr getMappedIPv6(); - - /** - * A comparator to classify IP addresses, only considering the - * first 64 bits in IPv6. - */ - struct ipCmp { - bool operator()(const SockAddr& a, const SockAddr& b) const { - if (a.len != b.len) - return a.len < b.len; - socklen_t start, len; - switch(a.getFamily()) { - case AF_INET: - start = offsetof(sockaddr_in, sin_addr); - len = sizeof(in_addr); - break; - case AF_INET6: - start = offsetof(sockaddr_in6, sin6_addr); - // don't consider more than 64 bits (IPv6) - len = 8; - break; - default: - start = 0; - len = a.len; - break; - } - return std::memcmp((uint8_t*)a.get()+start, - (uint8_t*)b.get()+start, len) < 0; - } - }; -private: - struct free_delete { void operator()(void* p) { ::free(p); } }; - std::unique_ptr<sockaddr, free_delete> addr {}; - socklen_t len {0}; - - void set(const sockaddr* sa, socklen_t length) { - if (len != length) { - len = length; - if (len) addr.reset((sockaddr*)::malloc(len)); - else addr.reset(); - } - if (len) - std::memcpy((uint8_t*)get(), (const uint8_t*)sa, len); - } - -}; - -OPENDHT_PUBLIC bool operator==(const SockAddr& a, const SockAddr& b); - -} diff --git a/include/opendht/udp_socket.h b/include/opendht/udp_socket.h new file mode 100644 index 00000000..7286108d --- /dev/null +++ b/include/opendht/udp_socket.h @@ -0,0 +1,106 @@ +#pragma once + +#include <asio.hpp> +#include <chrono> +#include <cstdint> +#include <functional> +#include <memory> +#include <vector> + +namespace dht { +namespace net { + +using asio::ip::udp; +using strand = asio::io_context::strand; + +struct ReceivedPacket { + std::vector<uint8_t> data; + udp::endpoint from; + //time_point received; +}; + + +class DatagramSocket { +public: + /** A function that takes a list of new received packets and + * optionally returns consumed packets for recycling. + **/ + //using OnReceive = std::function<PacketList(PacketList&& packets)>; + using ReceiveCallback = std::function<void(const ReceivedPacket&)>; + + virtual ~DatagramSocket() {}; + + virtual asio::error_code sendTo(const uint8_t* data, size_t size, const udp::endpoint& dest) = 0; + + virtual void setOnReceive(const ReceiveCallback& cb) = 0;/* { + //std::lock_guard<std::mutex> lk(lock); + //rx_callback = std::move(cb); + }*/ + + virtual bool hasIPv4() const = 0; + virtual bool hasIPv6() const = 0; + + virtual udp::endpoint getBound(sa_family_t family = AF_UNSPEC) const = 0; + in_port_t getPort(sa_family_t family = AF_UNSPEC) const { + //std::lock_guard<std::mutex> lk(lock); + return getBound(family).port(); + } + + //virtual const udp::endpoint& getBoundRef(sa_family_t family = AF_UNSPEC) const = 0; + + virtual void stop() = 0; +/*protected: + + PacketList getNewPacket() { + PacketList pkts; + if (toRecycle_.empty()) { + pkts.emplace_back(); + } else { + auto begIt = toRecycle_.begin(); + auto begItNext = std::next(begIt); + pkts.splice(pkts.end(), toRecycle_, begIt, begItNext); + } + return pkts; + } + + inline void onReceived(PacketList&& packets) { + std::lock_guard<std::mutex> lk(lock); + if (rx_callback) { + auto r = rx_callback(std::move(packets)); + if (not r.empty() and toRecycle_.size() < RX_QUEUE_MAX_SIZE) + toRecycle_.splice(toRecycle_.end(), std::move(r)); + } + } +protected: + mutable std::mutex lock;*/ +private: + //ReceiveCallback rx_callback; + //PacketList toRecycle_; +}; + + +class UdpSocket : public DatagramSocket { +public: + UdpSocket(std::shared_ptr<strand> strand, const udp::endpoint& ipv4_endpoint, const udp::endpoint& ipv6_endpoint); + + void setOnReceive(const ReceiveCallback& callback); + void start_receive(); + + void stop(); + + void sendToAsync(std::vector<uint8_t> data, const udp::endpoint& to); + asio::error_code sendTo(const uint8_t* buf, size_t len, const udp::endpoint& to); + + bool hasIPv4() const; + bool hasIPv6() const; + + udp::endpoint getBound(sa_family_t af) const; +private: + class SocketHandler; + //std::shared_ptr<strand> strand_; + std::shared_ptr<SocketHandler> ipv4_handler_; + std::shared_ptr<SocketHandler> ipv6_handler_; +}; + +} // namespace net +} // namespace dht \ No newline at end of file diff --git a/include/opendht/utils.h b/include/opendht/utils.h index 691e936e..aec41214 100644 --- a/include/opendht/utils.h +++ b/include/opendht/utils.h @@ -21,6 +21,8 @@ #include "def.h" #include <msgpack.hpp> +#include <asio/ip/address.hpp> +#include <asio/ip/udp.hpp> #include <chrono> #include <random> @@ -62,6 +64,11 @@ void erase_if(std::map<Key, Item>& map, const Condition& condition) OPENDHT_PUBLIC std::pair<std::string, std::string> splitPort(const std::string& s); +inline +std::string print_addr(const asio::ip::udp::endpoint& endpoint) { + return asio::ip::detail::endpoint(endpoint.address(), endpoint.port()).to_string(); +} + class OPENDHT_PUBLIC DhtException : public std::runtime_error { public: DhtException(const std::string &str = "") : @@ -182,4 +189,36 @@ inline msgpack::object* findMapValue(const msgpack::object& map, std::string_vie return findMapValue(map, key.data(), key.size()); } +using SockAddr = asio::ip::udp::endpoint; + +/** + * A comparator to classify IP addresses, only considering the + * first 64 bits in IPv6. + */ +struct ipCmp { + bool operator()(const SockAddr& a, const SockAddr& b) const { + return a < b; + /*if (a.len != b.len) + return a.len < b.len; + socklen_t start, len; + switch(a.getFamily()) { + case AF_INET: + start = offsetof(sockaddr_in, sin_addr); + len = sizeof(in_addr); + break; + case AF_INET6: + start = offsetof(sockaddr_in6, sin6_addr); + // don't consider more than 64 bits (IPv6) + len = 8; + break; + default: + start = 0; + len = a.len; + break; + } + return std::memcmp((uint8_t*)a.get()+start, + (uint8_t*)b.get()+start, len) < 0;*/ + } +}; + } // namespace dht diff --git a/include/opendht/value.h b/include/opendht/value.h index ae8776f4..180f6cef 100644 --- a/include/opendht/value.h +++ b/include/opendht/value.h @@ -22,7 +22,6 @@ #include "infohash.h" #include "crypto.h" #include "utils.h" -#include "sockaddr.h" #include <msgpack.hpp> diff --git a/src/Makefile.am b/src/Makefile.am index 7bfc015a..88821630 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -45,7 +45,6 @@ nobase_include_HEADERS = \ ../include/opendht/scheduler.h \ ../include/opendht/rate_limiter.h \ ../include/opendht/utils.h \ - ../include/opendht/sockaddr.h \ ../include/opendht/infohash.h \ ../include/opendht/node.h \ ../include/opendht/value.h \ diff --git a/src/default_types.cpp b/src/default_types.cpp index f5c3ae2c..9efc3d68 100644 --- a/src/default_types.cpp +++ b/src/default_types.cpp @@ -54,14 +54,7 @@ DhtMessage::ServiceFilter(const std::string& s) std::ostream& operator<< (std::ostream& s, const IpServiceAnnouncement& v) { - if (v.addr) { - s << "Peer: "; - s << "port " << v.getPort(); - char hbuf[NI_MAXHOST]; - if (getnameinfo(v.addr.get(), v.addr.getLength(), hbuf, sizeof(hbuf), nullptr, 0, NI_NUMERICHOST) == 0) { - s << " addr " << std::string(hbuf, strlen(hbuf)); - } - } + s << "Peer: " << v.addr; return s; } diff --git a/src/dht.cpp b/src/dht.cpp index 9eb1d271..beca46c6 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -36,12 +36,6 @@ namespace dht { using namespace std::placeholders; -constexpr std::chrono::minutes Dht::MAX_STORAGE_MAINTENANCE_EXPIRE_TIME; -constexpr std::chrono::minutes Dht::SEARCH_EXPIRE_TIME; -constexpr std::chrono::seconds Dht::BOOTSTRAP_PERIOD; -constexpr duration Dht::LISTEN_EXPIRE_TIME; -constexpr duration Dht::LISTEN_EXPIRE_TIME_PUBLIC; -constexpr duration Dht::REANNOUNCE_MARGIN; static constexpr size_t MAX_REQUESTS_PER_SEC {8 * 1024}; static constexpr duration BOOTSTRAP_PERIOD_MAX {std::chrono::hours(24)}; @@ -50,9 +44,13 @@ Dht::updateStatus(sa_family_t af) { auto& d = dht(af); auto old = d.status; - d.status = d.getStatus(scheduler.time()); + d.status = d.getStatus(time()); if (d.status != old) { + if (logger_) + logger_->d("status for %s changed %s to %s", + af == AF_INET ? "v4" : "v6" , statusToStr(old), statusToStr(d.status)); auto& other = dht(af == AF_INET ? AF_INET6 : AF_INET); + onStateChanged(); if (other.status == NodeStatus::Disconnected && d.status == NodeStatus::Disconnected) { onDisconnected(); } else if (other.status == NodeStatus::Connected || d.status == NodeStatus::Connected) { @@ -108,7 +106,7 @@ Dht::shutdown(ShutdownCallback cb, bool stop) } // Last store maintenance - scheduler.syncTime(); + network_engine.syncTime(); auto remaining = std::make_shared<int>(0); auto str_donecb = [=](bool, const std::vector<Sp<Node>>&) { --*remaining; @@ -160,15 +158,38 @@ Dht::getPublicAddress(sa_family_t family) std::vector<SockAddr> ret; ret.reserve(!family ? reported_addr.size() : reported_addr.size()/2); for (const auto& addr : reported_addr) - if (!family || family == addr.second.getFamily()) + if (!family || family == addr.second.protocol().family()) ret.emplace_back(addr.second); return ret; } +void +Dht::scheduleNodeConfirmation(const time_point& step) +{ + nextNodesConfirmation.expires_at(step); + nextNodesConfirmation.async_wait([this](const asio::error_code &ec){ + if (ec != asio::error::operation_aborted) + confirmNodes(); + }); +} + +void +Dht::scheduleStatusCheck() +{ + statusCheckJob.expires_at(time()); + statusCheckJob.async_wait([this](const asio::error_code &ec){ + if (ec != asio::error::operation_aborted) { + network_engine.syncTime(); + updateStatus(AF_INET); + updateStatus(AF_INET6); + } + }); +} + bool Dht::trySearchInsert(const Sp<Node>& node) { - const auto& now = scheduler.time(); + const auto& now = time(); if (not node) return false; auto& srs = searches(node->getFamily()); @@ -177,21 +198,21 @@ Dht::trySearchInsert(const Sp<Node>& node) // insert forward for (auto it = closest; it != srs.end(); it++) { - auto& s = *it->second; - if (s.insertNode(node, now)) { + auto& s = it->second; + if (s->insertNode(node, now)) { inserted = true; - scheduler.edit(s.nextSearchStep, now); - } else if (not s.expired and not s.done) + s->scheduleStep(now); + } else if (not s->expired and not s->done) break; } // insert backward for (auto it = closest; it != srs.begin();) { --it; - auto& s = *it->second; - if (s.insertNode(node, now)) { + auto& s = it->second; + if (s->insertNode(node, now)) { inserted = true; - scheduler.edit(s.nextSearchStep, now); - } else if (not s.expired and not s.done) + s->scheduleStep(now); + } else if (not s->expired and not s->done) break; } return inserted; @@ -215,14 +236,15 @@ Dht::reportedAddr(const SockAddr& addr) void Dht::onNewNode(const Sp<Node>& node, int confirm) { - const auto& now = scheduler.time(); + const auto& now = time(); auto& b = buckets(node->getFamily()); auto wasEmpty = confirm < 2 && b.grow_time < now - std::chrono::minutes(5); if (b.onNewNode(node, confirm, now, myid, network_engine) or confirm) { trySearchInsert(node); if (wasEmpty) { - scheduler.edit(nextNodesConfirmation, now + std::chrono::seconds(1)); + scheduleNodeConfirmation(now + std::chrono::seconds(1)); } + scheduleStatusCheck(); } } @@ -249,7 +271,7 @@ Dht::expireBuckets(RoutingTable& list) void Dht::expireSearches() { - auto t = scheduler.time() - SEARCH_EXPIRE_TIME; + auto t = time() - SEARCH_EXPIRE_TIME; auto expired = [&](std::pair<const InfoHash, Sp<Search>>& srp) { auto& sr = *srp.second; auto b = sr.callbacks.empty() && sr.announce.empty() && sr.listeners.empty() && sr.step_time < t; @@ -270,7 +292,7 @@ Dht::searchNodeGetDone(const net::Request& req, std::weak_ptr<Search> ws, Sp<Query> query) { - const auto& now = scheduler.time(); + const auto& now = time(); if (auto sr = ws.lock()) { sr->insertNode(req.node, now, answer.ntoken); if (auto srn = sr->getNode(req.node)) { @@ -284,11 +306,15 @@ Dht::searchNodeGetDone(const net::Request& req, srn->getStatus[q] = std::move(dummy_req); } } - auto syncTime = srn->getSyncTime(scheduler.time()); - if (srn->syncJob) - scheduler.edit(srn->syncJob, syncTime); - else - srn->syncJob = scheduler.add(syncTime, std::bind(&Dht::searchStep, this, ws)); + auto syncTime = srn->getSyncTime(time()); + if (!srn->syncJob) + srn->syncJob = std::make_unique<asio::steady_timer>(network_engine.context()); + srn->syncJob->expires_at(syncTime); + srn->syncJob->async_wait([this, ws](const asio::error_code &ec){ + if (ec == asio::error::operation_aborted) + return; + searchStep(ws); + }); } onGetValuesDone(req.node, answer, sr, query); } @@ -306,7 +332,7 @@ Dht::searchNodeGetExpired(const net::Request& status, if (over) srn->getStatus.erase(query); } - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); } } @@ -373,7 +399,7 @@ Dht::searchSendGetValues(Sp<Search> sr, SearchNode* pn, bool update) if (sr->done or sr->currentlySolicitedNodeCount() >= MAX_REQUESTED_SEARCH_NODES) return nullptr; - const auto& now = scheduler.time(); + const auto& now = time(); std::weak_ptr<Search> ws = sr; auto cb = sr->callbacks.begin(); @@ -437,7 +463,9 @@ Dht::searchSendGetValues(Sp<Search> sr, SearchNode* pn, bool update) return nullptr; } -void Dht::searchSendAnnounceValue(const Sp<Search>& sr) { +void +Dht::searchSendAnnounceValue(const Sp<Search>& sr) +{ if (sr->announce.empty()) return; unsigned i = 0; @@ -447,7 +475,7 @@ void Dht::searchSendAnnounceValue(const Sp<Search>& sr) { { /* when put done */ if (auto sr = ws.lock()) { onAnnounceDone(req.node, answer, sr); - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); } }; @@ -455,7 +483,7 @@ void Dht::searchSendAnnounceValue(const Sp<Search>& sr) { { /* when put expired */ if (over) if (auto sr = ws.lock()) - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); }; auto onSelectDone = @@ -463,14 +491,14 @@ void Dht::searchSendAnnounceValue(const Sp<Search>& sr) { { /* on probing done */ auto sr = ws.lock(); if (not sr) return; - const auto& now = scheduler.time(); - sr->insertNode(req.node, scheduler.time(), answer.ntoken); + const auto& now = time(); + sr->insertNode(req.node, time(), answer.ntoken); auto sn = sr->getNode(req.node); if (not sn) return; if (not sn->isSynced(now)) { /* Search is now unsynced. Let's call searchStep to sync again. */ - scheduler.edit(sr->nextSearchStep, now); + sr->scheduleStep(now); return; } for (auto& a : sr->announce) { @@ -491,7 +519,11 @@ void Dht::searchSendAnnounceValue(const Sp<Search>& sr) { auto next_refresh_time = now + getType(a.value->type).expiration; auto& acked = sn->acked[a.value->id]; - scheduler.cancel(acked.refresh); + if (acked.refresh) { + acked.refresh->cancel(); + acked.refresh.reset(); + } + //scheduler.cancel(acked.refresh); /* only put the value if the node doesn't already have it */ if (not hasValue or seq_no < a.value->seq) { if (logger_) @@ -523,7 +555,7 @@ void Dht::searchSendAnnounceValue(const Sp<Search>& sr) { network_engine.sendAnnounceValue(sn->node, sr->id, v, created, sn->token, onDone, onExpired), next_refresh_time }; - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); return true; } } @@ -541,17 +573,24 @@ void Dht::searchSendAnnounceValue(const Sp<Search>& sr) { acked = {std::move(ack_req), next_refresh_time}; /* step to clear announces */ - scheduler.edit(sr->nextSearchStep, now); + sr->scheduleStep(now); } if (a.permanent) { - acked.refresh = scheduler.add(next_refresh_time - REANNOUNCE_MARGIN, std::bind(&Dht::searchStep, this, ws)); + if (!acked.refresh) + acked.refresh = std::make_unique<asio::steady_timer>(network_engine.context()); + acked.refresh->expires_at(next_refresh_time - REANNOUNCE_MARGIN); + acked.refresh->async_wait([this, ws](const asio::error_code &ec){ + if (ec == asio::error::operation_aborted) + return; + searchStep(ws); + }); } } }; static const auto PROBE_QUERY = std::make_shared<Query>(Select {}.field(Value::Field::Id).field(Value::Field::SeqNum)); - const auto& now = scheduler.time(); + const auto& now = time(); for (auto& np : sr->nodes) { auto& n = *np; if (not n.isSynced(now)) @@ -606,7 +645,7 @@ Dht::searchSynchedNodeListen(const Sp<Search>& sr, SearchNode& n) const auto& query = l.second.query; auto r = n.listenStatus.find(query); - if (n.getListenTime(r, listenExp) > scheduler.time()) + if (n.getListenTime(r, listenExp) > time()) continue; // if (logger_) // logger_->d(sr->id, n.node->id, "[search %s] [node %s] sending 'listen'", @@ -621,28 +660,36 @@ Dht::searchSynchedNodeListen(const Sp<Search>& sr, SearchNode& n) n.node->openSocket([this,ws,query](const Sp<Node>& node, net::RequestAnswer&& answer) mutable { /* on new values */ if (auto sr = ws.lock()) { - scheduler.edit(sr->nextSearchStep, scheduler.time()); - sr->insertNode(node, scheduler.time(), answer.ntoken); + sr->scheduleStep(time()); + sr->insertNode(node, time(), answer.ntoken); if (auto sn = sr->getNode(node)) { - sn->onValues(query, std::move(answer), types, scheduler); + sn->onValues(query, std::move(answer), types, time()); } } }))).first; - r->second.cacheExpirationJob = scheduler.add(time_point::max(), [this,ws,query,node=n.node]{ - if (auto sr = ws.lock()) { - if (auto sn = sr->getNode(node)) { - sn->expireValues(query, scheduler); - } - } - }); + if (!r->second.cacheExpirationJob) { + r->second.cacheExpirationJob = std::make_unique<asio::steady_timer>(network_engine.context()); + r->second.onCacheExpired = [this,ws,query,node=n.node]{ + if (auto sr = ws.lock()) + if (auto sn = sr->getNode(node)) { + sn->expireValues(query, time()); + } + }; + } } auto new_req = network_engine.sendListen(n.node, sr->id, *query, n.token, r->second.socketId, [this,ws,query](const net::Request& req, net::RequestAnswer&& answer) mutable { /* on done */ if (auto sr = ws.lock()) { - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); if (auto sn = sr->getNode(req.node)) { - auto job = scheduler.add(sn->getListenTime(query, getListenExpiration()), std::bind(&Dht::searchStep, this, ws)); + auto job = std::make_unique<asio::steady_timer>(network_engine.context()); + job->expires_at(sn->getListenTime(query, getListenExpiration())); + job->async_wait([this, ws](const asio::error_code &ec){ + if (ec == asio::error::operation_aborted) + return; + searchStep(ws); + }); sn->onListenSynced(query, true, std::move(job)); } onListenDone(req.node, answer, sr); @@ -651,7 +698,7 @@ Dht::searchSynchedNodeListen(const Sp<Search>& sr, SearchNode& n) [this,ws,query](const net::Request& req, bool over) mutable { /* on request expired */ if (auto sr = ws.lock()) { - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); if (over) if (auto sn = sr->getNode(req.node)) sn->listenStatus.erase(query); @@ -673,8 +720,10 @@ Dht::searchStep(std::weak_ptr<Search> ws) { auto sr = ws.lock(); if (not sr or sr->expired or sr->done) return; - - const auto& now = scheduler.time(); + if (logger_) + logger_->d(sr->id, "[search %s IPv%c] step (%d requests)", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', sr->currentlySolicitedNodeCount()); + const auto& now = network_engine.syncTime(); /*if (auto req_count = sr->currentlySolicitedNodeCount()) if (logger_) logger_->d(sr->id, "[search %s IPv%c] step (%d requests)", @@ -753,7 +802,7 @@ Dht::searchStep(std::weak_ptr<Search> ws) } unsigned Dht::refill(Dht::Search& sr) { - const auto& now = scheduler.time(); + const auto& now = time(); sr.refill_time = now; /* we search for up to SEARCH_NODES good nodes. */ auto cached_nodes = network_engine.getCachedNodes(sr.id, sr.af, SEARCH_NODES); @@ -824,14 +873,15 @@ Dht::search(const InfoHash& id, sa_family_t af, GetCallback gcb, QueryCallback q sr->expired = false; sr->nodes.clear(); sr->nodes.reserve(SEARCH_NODES+1); - sr->nextSearchStep = scheduler.add(time_point::max(), std::bind(&Dht::searchStep, this, std::weak_ptr<Search>(sr))); + sr->nextSearchStep = std::make_unique<asio::steady_timer>(context()); + sr->onSearchStep = std::bind(&Dht::searchStep, this, std::weak_ptr<Search>(sr)); if (logger_) logger_->w(id, "[search %s IPv%c] new search", id.toString().c_str(), (af == AF_INET) ? '4' : '6'); if (search_id == 0) search_id++; } - - sr->get(f, q, qcb, gcb, dcb, scheduler); + + sr->get(f, q, qcb, gcb, dcb, time()); refill(*sr); return sr; @@ -849,7 +899,7 @@ Dht::announce(const InfoHash& id, auto srp = srs.find(id); if (auto sr = srp == srs.end() ? search(id, af) : srp->second) { sr->put(value, callback, created, permanent); - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); } else if (callback) { callback(false, {}); } @@ -860,7 +910,7 @@ Dht::listenTo(const InfoHash& id, sa_family_t af, ValueCallback cb, Value::Filte { if (!isRunning(af)) return 0; - // logger__ERR("[search %s IPv%c] search_time is now in %lfs", sr->id.toString().c_str(), (sr->af == AF_INET) ? '4' : '6', print_dt(tm-clock::now())); + // logger__ERR("[search %s IPv%c] search_time is now in %lfs", sr->id.toString().c_str(), (sr->af == AF_INET) ? '4' : '6', print_dt(tm-time())); //logger__WARN("listenTo %s", id.toString().c_str()); auto& srs = searches(af); @@ -870,7 +920,7 @@ Dht::listenTo(const InfoHash& id, sa_family_t af, ValueCallback cb, Value::Filte throw DhtException("Can't create search"); if (logger_) logger_->w(id, "[search %s IPv%c] listen", id.to_c_str(), (af == AF_INET) ? '4' : '6'); - return sr->listen(cb, std::move(f), q, scheduler); + return sr->listen(cb, std::move(f), q, time()); } size_t @@ -881,7 +931,7 @@ Dht::listen(const InfoHash& id, ValueCallback cb, Value::Filter f, Where where) logger_->w(id, "Listen called with invalid key"); return 0; } - scheduler.syncTime(); + network_engine.syncTime(); auto token = ++listener_token; auto gcb = OpValueCache::cacheCallback(std::move(cb), [this, id, token]{ @@ -892,7 +942,7 @@ Dht::listen(const InfoHash& id, ValueCallback cb, Value::Filter f, Where where) auto filter = Value::Filter::chain(std::move(f), query->where.getFilter()); auto st = store.find(id); if (st == store.end() && store.size() < max_store_keys) - st = store.emplace(id, scheduler.time() + MAX_STORAGE_MAINTENANCE_EXPIRE_TIME).first; + st = store.emplace(id, time() + MAX_STORAGE_MAINTENANCE_EXPIRE_TIME).first; size_t tokenlocal = 0; if (st != store.end()) { @@ -915,7 +965,7 @@ Dht::listen(const InfoHash& id, ValueCallback cb, Value::Filter f, Where where) bool Dht::cancelListen(const InfoHash& id, size_t token) { - scheduler.syncTime(); + network_engine.syncTime(); auto it = listeners.find(token); if (it == listeners.end()) { @@ -934,7 +984,7 @@ Dht::cancelListen(const InfoHash& id, size_t token) if (token) { auto srp = srs.find(id); if (srp != srs.end()) - srp->second->cancelListen(token, scheduler); + srp->second->cancelListen(token, time()); } }; searches_cancel_listen(dht4.searches, std::get<1>(it->second)); @@ -972,10 +1022,10 @@ Dht::put(const InfoHash& id, Sp<Value> val, DoneCallback callback, time_point cr } if (val->id == Value::INVALID_ID) val->id = std::uniform_int_distribution<Value::Id>{1}(rd); - scheduler.syncTime(); - const auto& now = scheduler.time(); + network_engine.syncTime(); + const auto& now = time(); created = std::min(now, created); - storageStore(id, val, created, {}, permanent); + storageStore(id, val, created, nullptr, permanent); if (logger_) logger_->d(id, "put: adding %s -> %s", id.toString().c_str(), val->toString().c_str()); @@ -1042,7 +1092,7 @@ Dht::get(const InfoHash& id, GetCallback getcb, DoneCallback donecb, Value::Filt donecb(false, {}); return; } - scheduler.syncTime(); + network_engine.syncTime(); auto op = std::make_shared<GetStatus<std::map<Value::Id, Sp<Value>>>>(); auto gcb = [getcb, donecb, op](const std::vector<Sp<Value>>& vals) { @@ -1089,7 +1139,7 @@ void Dht::query(const InfoHash& id, QueryCallback cb, DoneCallback done_cb, Quer done_cb(false, {}); return; } - scheduler.syncTime(); + network_engine.syncTime(); auto op = std::make_shared<GetStatus<std::vector<Sp<FieldValueIndex>>>>(); auto f = q.where.getFilter(); auto qcb = [cb, done_cb, op](const std::vector<Sp<FieldValueIndex>>& fields){ @@ -1256,9 +1306,9 @@ Dht::storageChanged(const InfoHash& id, Storage& st, const Sp<Value>& v, bool ne } bool -Dht::storageStore(const InfoHash& id, const Sp<Value>& value, time_point created, const SockAddr& sa, bool permanent) +Dht::storageStore(const InfoHash& id, const Sp<Value>& value, time_point created, const SockAddr* sa, bool permanent) { - const auto& now = scheduler.time(); + const auto& now = time(); created = std::min(created, now); auto expiration = permanent ? time_point::max() : created + getType(value->type).expiration; if (expiration < now) @@ -1270,21 +1320,28 @@ Dht::storageStore(const InfoHash& id, const Sp<Value>& value, time_point created return false; auto st_i = store.emplace(id, now); st = st_i.first; - if (maintain_storage and st_i.second) - scheduler.add(st->second.maintenance_time, std::bind(&Dht::dataPersistence, this, id)); + //if (maintain_storage and st_i.second) + // scheduler.add(st->second.maintenance_time, std::bind(&Dht::dataPersistence, this, id)); } StorageBucket* store_bucket {nullptr}; if (sa) - store_bucket = &store_quota[sa]; + store_bucket = &store_quota[*sa]; auto store = st->second.store(id, value, created, expiration, store_bucket); if (auto vs = store.first) { total_store_size += store.second.size_diff; total_values += store.second.values_diff; - scheduler.cancel(vs->expiration_job); + if (vs->expiration_job) + vs->expiration_job->cancel(); if (not permanent) { - vs->expiration_job = scheduler.add(expiration, std::bind(&Dht::expireStorage, this, id)); + if (not vs->expiration_job) + vs->expiration_job = std::make_unique<asio::steady_timer>(context()); + vs->expiration_job->expires_at(expiration); + vs->expiration_job->async_wait([this, id](const asio::error_code &ec){ + if (ec != asio::error::operation_aborted) + expireStorage(id); + }); } if (total_store_size > max_store_size) { auto value = vs->data; @@ -1303,7 +1360,7 @@ Dht::storageStore(const InfoHash& id, const Sp<Value>& value, time_point created void Dht::storageAddListener(const InfoHash& id, const Sp<Node>& node, size_t socket_id, Query&& query, int version) { - const auto& now = scheduler.time(); + const auto& now = time(); auto st = store.find(id); if (st == store.end()) { if (store.size() >= max_store_keys) @@ -1330,7 +1387,7 @@ Dht::expireStore(decltype(store)::iterator i) { const auto& id = i->first; auto& st = i->second; - auto stats = st.expire(id, scheduler.time()); + auto stats = st.expire(id, time()); if (not stats.second.empty()) { storageRemoved(id, st, stats.second, -stats.first); } @@ -1339,6 +1396,7 @@ Dht::expireStore(decltype(store)::iterator i) void Dht::expireStorage(InfoHash h) { + network_engine.syncTime(); auto i = store.find(h); if (i != store.end()) expireStore(i); @@ -1413,8 +1471,8 @@ Dht::expireStore() auto exp_value = largest->second.getOldest(); auto storage = store.find(exp_value.first); if (storage != store.end()) { - if (logger_) - logger_->w("Storage quota full: discarding value from %s at %s %016" PRIx64, largest->first.toString().c_str(), exp_value.first.to_c_str(), exp_value.second); + //if (logger_) + // logger_->w("Storage quota full: discarding value from %s at %s %016" PRIx64, largest->first.toString().c_str(), exp_value.first.to_c_str(), exp_value.second); if (auto value = storage->second.remove(exp_value.first, exp_value.second)) { storageRemoved(storage->first, storage->second, {value}, value->size()); @@ -1437,12 +1495,12 @@ Dht::expireStore() void Dht::connectivityChanged(sa_family_t af) { - const auto& now = scheduler.time(); - scheduler.edit(nextNodesConfirmation, now); + const auto& now = time(); + scheduleNodeConfirmation(now); buckets(af).connectivityChanged(now); network_engine.connectivityChanged(af); reported_addr.erase(std::remove_if(reported_addr.begin(), reported_addr.end(), [&](const ReportedAddr& addr){ - return addr.second.getFamily() == af; + return addr.second.protocol().family() == af; }), reported_addr.end()); startBootstrap(); // will only happen if disconnected } @@ -1453,28 +1511,36 @@ Dht::rotateSecrets() oldsecret = secret; secret = std::uniform_int_distribution<uint64_t>{}(rd); uniform_duration_distribution<> time_dist(std::chrono::minutes(15), std::chrono::minutes(45)); - auto rotate_secrets_time = scheduler.time() + time_dist(rd); - scheduler.add(rotate_secrets_time, std::bind(&Dht::rotateSecrets, this)); + auto rotate_secrets_time = time() + time_dist(rd); + //scheduler.add(rotate_secrets_time, std::bind(&Dht::rotateSecrets, this)); + rotateRecretsJob.expires_at(rotate_secrets_time); + rotateRecretsJob.async_wait([this](const asio::error_code &ec){ + if (ec != asio::error::operation_aborted) { + network_engine.syncTime(); + rotateSecrets(); + } + }); } Blob Dht::makeToken(const SockAddr& addr, bool old) const { + in_port_t port = addr.port(); + const void *ip; size_t iplen; - in_port_t port; + asio::ip::address_v4 v4; + asio::ip::address_v6 v6; - auto family = addr.getFamily(); - if (family == AF_INET) { - const auto& sin = addr.getIPv4(); - ip = &sin.sin_addr; + const auto& a = addr.address(); + if (a.is_v4()) { + v4 = a.to_v4(); + ip = &v4; iplen = 4; - port = sin.sin_port; - } else if (family == AF_INET6) { - const auto& sin6 = addr.getIPv6(); - ip = &sin6.sin6_addr; + } else if (a.is_v6()) { + v6 = a.to_v6(); + ip = &v6; iplen = 16; - port = sin6.sin6_port; } else { return {}; } @@ -1491,7 +1557,7 @@ Dht::makeToken(const SockAddr& addr, bool old) const bool Dht::tokenMatch(const Blob& token, const SockAddr& addr) const { - if (not addr or token.size() != TOKEN_SIZE) + if (token.size() != TOKEN_SIZE) return false; if (token == makeToken(addr, false)) return true; @@ -1503,7 +1569,7 @@ Dht::tokenMatch(const Blob& token, const SockAddr& addr) const NodeStats Dht::getNodesStats(sa_family_t af) const { - NodeStats stats = dht(af).getNodesStats(scheduler.time(), myid); + NodeStats stats = dht(af).getNodesStats(time(), myid); stats.node_cache_size = network_engine.getNodeCacheSize(af); return stats; } @@ -1532,7 +1598,7 @@ Dht::Kad::getNodesStats(time_point now, const InfoHash& myid) const void Dht::dumpBucket(const Bucket& b, std::ostream& out) const { - const auto& now = scheduler.time(); + const auto& now = time(); using namespace std::chrono; out << b.first << " count: " << b.nodes.size() << " updated: " << print_time_relative(now, b.time); if (b.cached) @@ -1557,7 +1623,7 @@ Dht::dumpBucket(const Bucket& b, std::ostream& out) const void Dht::dumpSearch(const Search& sr, std::ostream& out) const { - const auto& now = scheduler.time(); + const auto& now = time(); const auto& listen_expire = getListenExpiration(); using namespace std::chrono; out << std::endl << "Search IPv" << (sr.af == AF_INET6 ? '6' : '4') << ' ' << sr.id << " gets: " << sr.callbacks.size(); @@ -1635,7 +1701,7 @@ Dht::dumpSearch(const Search& sr, std::ostream& out) const out << "] "; } } - out << n.node->getAddrStr() << std::endl; + out << n.node->getAddr() << std::endl; } } @@ -1678,7 +1744,7 @@ Dht::getStorageLog() const if (ip.second.size()) q_map.emplace(ip.second.size(), &ip.first); for (auto ip = q_map.rbegin(); ip != q_map.rend(); ++ip) - out << "IP " << ip->second->toString() << " uses " << ip->first << " bytes" << std::endl; + out << "IP " << *ip->second << " uses " << ip->first << " bytes" << std::endl; out << std::endl; out << "Total " << store.size() << " storages, " << total_values << " values ("; if (total_store_size < 1024) @@ -1789,15 +1855,21 @@ fromDhtConfig(const Config& config) return netConf; } -Dht::Dht(std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, const Sp<Logger>& l) +Dht::Dht(std::shared_ptr<net::strand> strand, std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, const Sp<Logger>& l) : DhtInterface(l), myid(config.node_id ? config.node_id : InfoHash::getRandom(rd)), + rotateRecretsJob(strand->context()), + bootstrapJob(strand->context()), store(), store_quota(), max_store_keys(config.max_store_keys ? (int)config.max_store_keys : MAX_HASHES), max_store_size(config.max_store_size ? (int)config.max_store_size : DEFAULT_STORAGE_LIMIT), max_searches(config.max_searches ? (int)config.max_searches : MAX_SEARCHES), - network_engine(myid, fromDhtConfig(config), std::move(sock), logger_, rd, scheduler, + nextNodesConfirmation(strand->context()), + nextStorageMaintenance(strand->context()), + expirationJob(strand->context()), + statusCheckJob(strand->context()), + network_engine(myid, fromDhtConfig(config), strand, std::move(sock), logger_, rd, std::bind(&Dht::onError, this, _1, _2), std::bind(&Dht::onNewNode, this, _1, _2), std::bind(&Dht::onReportedAddr, this, _1, _2), @@ -1812,7 +1884,7 @@ Dht::Dht(std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, cons maintain_storage(config.maintain_storage), public_stable(config.public_stable) { - scheduler.syncTime(); + network_engine.syncTime(); auto s = network_engine.getSocket(); if (not s or (not s->hasIPv4() and not s->hasIPv6())) throw DhtException("Opened socket required"); @@ -1828,7 +1900,7 @@ Dht::Dht(std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, cons search_id = std::uniform_int_distribution<decltype(search_id)>{}(rd); uniform_duration_distribution<> time_dis {std::chrono::seconds(3), std::chrono::seconds(5)}; - nextNodesConfirmation = scheduler.add(scheduler.time() + time_dis(rd), std::bind(&Dht::confirmNodes, this)); + scheduleNodeConfirmation(time() + time_dis(rd)); // Fill old secret secret = std::uniform_int_distribution<uint64_t>{}(rd); @@ -1890,7 +1962,7 @@ Dht::bucketMaintenance(RoutingTable& list) bool sent {false}; for (auto b = list.begin(); b != list.end(); ++b) { - if (b->time < scheduler.time() - std::chrono::minutes(10) || b->nodes.empty()) { + if (b->time < time() - std::chrono::minutes(10) || b->nodes.empty()) { /* This bucket hasn't seen any positive confirmation for a long time. Pick a random id in this bucket's range, and send a request to a random node. */ @@ -1927,14 +1999,13 @@ Dht::bucketMaintenance(RoutingTable& list) if (logger_) logger_->d(id, n->id, "[node %s] sending find %s for bucket maintenance", n->toString().c_str(), id.toString().c_str()); - //auto start = scheduler.time(); + //auto start = time(); network_engine.sendFindNode(n, id, want, nullptr, [this,n](const net::Request&, bool over) { if (over) { - const auto& end = scheduler.time(); - // using namespace std::chrono; + const auto& end = time(); // if (logger_) // logger_->d(n->id, "[node %s] bucket maintenance op expired after %s", n->toString().c_str(), print_duration(end-start).c_str()); - scheduler.edit(nextNodesConfirmation, end + Node::MAX_RESPONSE_TIME); + scheduleNodeConfirmation(end + Node::MAX_RESPONSE_TIME); } }); sent = true; @@ -1947,7 +2018,7 @@ Dht::bucketMaintenance(RoutingTable& list) void Dht::dataPersistence(InfoHash id) { - const auto& now = scheduler.time(); + const auto& now = time(); auto str = store.find(id); if (str != store.end() and now > str->second.maintenance_time) { if (logger_) @@ -1955,14 +2026,18 @@ Dht::dataPersistence(InfoHash id) id.toString().c_str(), str->second.valueCount(), str->second.totalSize()); maintainStorage(*str); str->second.maintenance_time = now + MAX_STORAGE_MAINTENANCE_EXPIRE_TIME; - scheduler.add(str->second.maintenance_time, std::bind(&Dht::dataPersistence, this, id)); + str->second.maintenance_job->expires_at(str->second.maintenance_time); + str->second.maintenance_job->async_wait([this,id](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) + dataPersistence(id); + }); } } size_t Dht::maintainStorage(decltype(store)::value_type& storage, bool force, const DoneCallback& donecb) { - const auto& now = scheduler.time(); + const auto& now = time(); size_t announce_per_af = 0; auto maintain = [&](sa_family_t af){ @@ -1996,34 +2071,30 @@ Dht::maintainStorage(decltype(store)::value_type& storage, bool force, const Don return announce_per_af; } -time_point -Dht::periodic(const uint8_t *buf, size_t buflen, SockAddr from, const time_point& now) -{ - scheduler.syncTime(now); - if (buflen) { - try { - network_engine.processMessage(buf, buflen, std::move(from)); - } catch (const std::exception& e) { - if (logger_) - logger_->w("Can't process message: %s", e.what()); - } - } - return scheduler.run(); -} - void Dht::expire() { uniform_duration_distribution<> time_dis(std::chrono::minutes(2), std::chrono::minutes(6)); - auto expire_stuff_time = scheduler.time() + duration(time_dis(rd)); + auto expire_stuff_time = time() + duration(time_dis(rd)); expireBuckets(dht4.buckets); expireBuckets(dht6.buckets); expireStore(); expireSearches(); - scheduler.add(expire_stuff_time, std::bind(&Dht::expire, this)); + expirationJob.expires_at(expire_stuff_time); + expirationJob.async_wait([this](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) + expire(); + }); } +void +Dht::onStateChanged(){ + for (auto& cb : onStateChangeCallbacks_) + cb(DhtNodeStatus{dht4.status, dht6.status}); +} + + void Dht::onConnected() { @@ -2038,7 +2109,7 @@ Dht::onConnected() void Dht::onDisconnected() { - if (not bootstrapJob) + if (bootstrapJob.expires_from_now() <= std::chrono::seconds(0)) bootstrap(); } @@ -2049,35 +2120,42 @@ Dht::bootstrap() return; if (logger_) logger_->d(myid, "Bootstraping"); + asio::ip::udp::resolver resolver(network_engine.context()); for (const auto& boootstrap : bootstrap_nodes) { try { - auto ips = network_engine.getSocket()->resolve(boootstrap.first, boootstrap.second); - for (auto& ip : ips) { - if (ip.getPort() == 0) - ip.setPort(net::DHT_DEFAULT_PORT); - pingNode(ip); + for (const auto& ip : resolver.resolve(boootstrap.first, boootstrap.second)) { + auto endpoint = ip.endpoint(); + if (endpoint.port() == 0) + endpoint.port(net::DHT_DEFAULT_PORT); + pingNode(std::move(endpoint)); } } catch (const std::exception& e) { if (logger_) logger_->e(myid, "Can't resolve %s:%s: %s", boootstrap.first.c_str(), boootstrap.second.c_str(), e.what()); } } - scheduler.cancel(bootstrapJob); - bootstrapJob = scheduler.add(scheduler.time() + bootstrap_period, std::bind(&Dht::bootstrap, this)); - bootstrap_period = std::min(bootstrap_period * 2, BOOTSTRAP_PERIOD_MAX); + bootstrapJob.expires_at(time() + bootstrap_period); + bootstrapJob.async_wait([this](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) + bootstrap(); + }); } void Dht::startBootstrap() { stopBootstrap(); - bootstrapJob = scheduler.add(scheduler.time(), std::bind(&Dht::bootstrap, this)); + bootstrapJob.expires_at(time()); + bootstrapJob.async_wait([this](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) + bootstrap(); + }); } void Dht::stopBootstrap() { - scheduler.cancel(bootstrapJob); + bootstrapJob.cancel(); bootstrap_period = BOOTSTRAP_PERIOD; } @@ -2086,7 +2164,7 @@ Dht::confirmNodes() { using namespace std::chrono; bool soon = false; - const auto& now = scheduler.time(); + const auto& now = network_engine.syncTime(); if (dht4.searches.empty() and dht4.status == NodeStatus::Connected) { if (logger_) @@ -2118,7 +2196,7 @@ Dht::confirmNodes() : uniform_duration_distribution<> {seconds(60), seconds(180)}; auto confirm_nodes_time = now + time_dis(rd); - scheduler.edit(nextNodesConfirmation, confirm_nodes_time); + scheduleNodeConfirmation(confirm_nodes_time); } std::vector<ValuesExport> @@ -2148,7 +2226,7 @@ Dht::exportValues() const void Dht::importValues(const std::vector<ValuesExport>& import) { - const auto& now = scheduler.time(); + const auto& now = time(); for (const auto& value : import) { if (value.second.empty()) @@ -2189,7 +2267,7 @@ Dht::importValues(const std::vector<ValuesExport>& import) std::vector<NodeExport> Dht::exportNodes() const { - const auto& now = scheduler.time(); + const auto& now = time(); std::vector<NodeExport> nodes; const auto b4 = dht4.buckets.findBucket(myid); if (b4 != dht4.buckets.end()) { @@ -2221,19 +2299,18 @@ Dht::exportNodes() const void Dht::insertNode(const InfoHash& id, const SockAddr& addr) { - if (addr.getFamily() != AF_INET && addr.getFamily() != AF_INET6) - return; - scheduler.syncTime(); + network_engine.syncTime(); network_engine.insertNode(id, addr); } void Dht::pingNode(SockAddr sa, DoneCallbackSimple&& cb) { - scheduler.syncTime(); if (logger_) - logger_->d("Sending ping to %s", sa.toString().c_str()); - auto& count = dht(sa.getFamily()).pending_pings; + logger_->d("Sending ping to %s", print_addr(sa).c_str()); + network_engine.syncTime(); + auto af = sa.protocol().family(); + auto& count = dht(af).pending_pings; count++; network_engine.sendPing(std::move(sa), [&count,cb](const net::Request&, net::RequestAnswer&&) { count--; @@ -2246,6 +2323,7 @@ Dht::pingNode(SockAddr sa, DoneCallbackSimple&& cb) cb(false); } }); + updateStatus(af); } void @@ -2262,7 +2340,7 @@ Dht::onError(Sp<net::Request> req, net::DhtProtocolException e) { n->token.clear(); n->last_get_reply = time_point::min(); searchSendGetValues(sr); - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); break; } } @@ -2276,8 +2354,7 @@ Dht::onError(Sp<net::Request> req, net::DhtProtocolException e) { void Dht::onReportedAddr(const InfoHash& /*id*/, const SockAddr& addr) { - if (addr) - reportedAddr(addr); + reportedAddr(addr); } net::RequestAnswer @@ -2289,7 +2366,7 @@ Dht::onPing(Sp<Node>) net::RequestAnswer Dht::onFindNode(Sp<Node> node, const InfoHash& target, want_t want) { - const auto& now = scheduler.time(); + const auto& now = time(); net::RequestAnswer answer; answer.ntoken = makeToken(node->getAddr(), false); if (want & WANT4) @@ -2310,7 +2387,7 @@ Dht::onGetValues(Sp<Node> node, const InfoHash& hash, want_t, const Query& query net::DhtProtocolException::GET_NO_INFOHASH }; } - const auto& now = scheduler.time(); + const auto& now = time(); net::RequestAnswer answer {}; auto st = store.find(hash); answer.ntoken = makeToken(node->getAddr(), false); @@ -2399,7 +2476,7 @@ void Dht::onGetValuesDone(const Sp<Node>& node, searchSendGetValues(sr); // Force to recompute the next step time - scheduler.edit(sr->nextSearchStep, scheduler.time()); + sr->scheduleStep(time()); } } @@ -2432,9 +2509,9 @@ Dht::onListenDone(const Sp<Node>& /* node */, net::RequestAnswer& /* answer */, // sr->id.toString().c_str(), node->toString().c_str(), answer.values.size()); if (not sr->done) { - const auto& now = scheduler.time(); - searchSendGetValues(sr); - scheduler.edit(sr->nextSearchStep, now); + const auto& now = time(); + searchSendGetValues(sr); + sr->scheduleStep(now); } } @@ -2462,7 +2539,7 @@ Dht::onAnnounce(Sp<Node> n, { // We store a value only if we think we're part of the // SEARCH_NODES nodes around the target id. - auto closest_nodes = buckets(node.getFamily()).findClosestNodes(hash, scheduler.time(), SEARCH_NODES); + auto closest_nodes = buckets(node.getFamily()).findClosestNodes(hash, time(), SEARCH_NODES); if (closest_nodes.size() >= TARGET_NODES and hash.xorCmp(closest_nodes.back()->id, myid) < 0) { if (logger_) logger_->w(hash, node.id, "[node %s] announce too far from the target. Dropping value.", node.toString().c_str()); @@ -2470,7 +2547,7 @@ Dht::onAnnounce(Sp<Node> n, } } - auto created = std::min(creation_date, scheduler.time()); + auto created = std::min(creation_date, time()); for (const auto& v : values) { if (v->id == Value::INVALID_ID) { if (logger_) @@ -2493,7 +2570,7 @@ Dht::onAnnounce(Sp<Node> n, if (logger_) logger_->d(hash, node.id, "[store %s] editing %s", hash.toString().c_str(), vc->toString().c_str()); - storageStore(hash, vc, created, node.getAddr()); + storageStore(hash, vc, created, &node.getAddr()); } else { if (logger_) logger_->d(hash, node.id, "[store %s] rejecting edition of %s because of storage policy", @@ -2501,12 +2578,12 @@ Dht::onAnnounce(Sp<Node> n, } } } else { - // Allow the value to be edited by the storage policy + // Allow the value to be stored by the storage policy const auto& type = getType(vc->type); if (type.storePolicy(hash, vc, node.id, node.getAddr())) { // if (logger_) // logger_->d(hash, node.id, "[store %s] storing %s", hash.toString().c_str(), std::to_string(vc->id).c_str()); - storageStore(hash, vc, created, node.getAddr()); + storageStore(hash, vc, created, &node.getAddr()); } else { if (logger_) logger_->d(hash, node.id, "[store %s] rejecting storage of %s", @@ -2542,7 +2619,7 @@ Dht::onRefresh(Sp<Node> node, const InfoHash& hash, const Blob& token, const Val bool Dht::storageRefresh(const InfoHash& id, Value::Id vid) { - const auto& now = scheduler.time(); + const auto& now = time(); auto s = store.find(id); if (s != store.end()) { // Values like for a permanent put can be refreshed. So, inform remote listeners that the value @@ -2566,9 +2643,15 @@ Dht::storageRefresh(const InfoHash& id, Value::Id vid) auto expiration = s->second.refresh(id, now, vid, types); if (expiration.first) { - scheduler.cancel(expiration.first->expiration_job); + if (expiration.first->expiration_job) + expiration.first->expiration_job->cancel(); if (expiration.second != time_point::max()) { - expiration.first->expiration_job = scheduler.add(expiration.second, std::bind(&Dht::expireStorage, this, id)); + //expiration.first->expiration_job = scheduler.add(expiration.second, std::bind(&Dht::expireStorage, this, id)); + expiration.first->expiration_job = std::make_unique<asio::steady_timer>(context(), expiration.second); + expiration.first->expiration_job->async_wait([this, id](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) + expireStorage(id); + }); } } return true; diff --git a/src/dht_proxy_client.cpp b/src/dht_proxy_client.cpp index c6685299..23c1c039 100644 --- a/src/dht_proxy_client.cpp +++ b/src/dht_proxy_client.cpp @@ -24,6 +24,8 @@ #include "utils.h" #include <http_parser.h> + +#include <asio/ip/address.hpp> #include <deque> @@ -108,18 +110,17 @@ DhtProxyClient::DhtProxyClient() {} DhtProxyClient::DhtProxyClient( std::shared_ptr<dht::crypto::Certificate> serverCA, dht::crypto::Identity clientIdentity, - std::function<void()> signal, const std::string& serverHost, + std::shared_ptr<asio::io_context::strand> strand, const std::string& serverHost, const std::string& pushClientId, std::shared_ptr<dht::Logger> logger) : DhtInterface(logger) , proxyUrl_(serverHost) , clientIdentity_(clientIdentity), serverCertificate_(serverCA) , pushClientId_(pushClientId), pushSessionId_(getRandomSessionId()) - , loopSignal_(signal) + , strand_(strand) , jsonReader_(Json::CharReaderBuilder{}.newCharReader()) { - localAddrv4_.setFamily(AF_INET); - localAddrv6_.setFamily(AF_INET6); - + /*localAddrv4_.setFamily(AF_INET); + localAddrv6_.setFamily(AF_INET6);*/ jsonBuilder_["commentStyle"] = "None"; jsonBuilder_["indentation"] = ""; if (logger_) { @@ -131,13 +132,13 @@ DhtProxyClient::DhtProxyClient( clientIdentity_.second->toString(false/*chain*/).c_str()); } // run http client - httpClientThread_ = std::thread([this](){ + /*httpClientThread_ = std::thread([this](){ try { if (logger_) logger_->d("[proxy:client] starting io_context"); - // Ensures the httpContext_ won't run out of work - auto work = asio::make_work_guard(httpContext_); - httpContext_.run(); + // Ensures the context() won't run out of work + auto work = asio::make_work_guard(context()); + context().run(); if (logger_) logger_->d("[proxy:client] http client io_context stopped"); } @@ -145,7 +146,7 @@ DhtProxyClient::DhtProxyClient( if (logger_) logger_->e("[proxy:client] run error: %s", ex.what()); } - }); + });*/ if (!proxyUrl_.empty()) startProxy(); } @@ -159,12 +160,10 @@ DhtProxyClient::startProxy() if (logger_) logger_->d("[proxy:client] start proxy with %s", proxyUrl_.c_str()); - nextProxyConfirmationTimer_ = std::make_unique<asio::steady_timer>(httpContext_, std::chrono::steady_clock::now()); + nextProxyConfirmationTimer_ = std::make_unique<asio::steady_timer>(context(), std::chrono::steady_clock::now()); nextProxyConfirmationTimer_->async_wait(std::bind(&DhtProxyClient::handleProxyConfirm, this, std::placeholders::_1)); - listenerRestartTimer_ = std::make_unique<asio::steady_timer>(httpContext_); - - loopSignal_(); + listenerRestartTimer_ = std::make_unique<asio::steady_timer>(context()); } void @@ -203,8 +202,8 @@ DhtProxyClient::stop() for (auto& request : requests_) request.second->cancel(); } - if (not httpContext_.stopped()) - httpContext_.stop(); + if (not context().stopped()) + context().stop(); if (httpClientThread_.joinable()) httpClientThread_.join(); requests_.clear(); @@ -286,21 +285,6 @@ DhtProxyClient::isRunning(sa_family_t af) const } } -time_point -DhtProxyClient::periodic(const uint8_t*, size_t, SockAddr, const time_point& /*now*/) -{ - // Exec all currently stored callbacks - decltype(callbacks_) callbacks; - { - std::lock_guard<std::mutex> lock(lockCallbacks_); - callbacks = std::move(callbacks_); - } - for (auto& callback : callbacks) - callback(); - callbacks.clear(); - return time_point::max(); -} - void DhtProxyClient::setHeaderFields(http::Request& request){ request.set_header_field(restinio::http_field_t::accept, "*/*"); @@ -351,15 +335,11 @@ DhtProxyClient::get(const InfoHash& key, GetCallback cb, DoneCallback donecb, Va values.emplace_back(std::move(value)); } if (not values.empty() and cb) { - { - std::lock_guard<std::mutex> lock(lockCallbacks_); - callbacks_.emplace_back([opstate, cb, values = std::move(values)](){ - if (not opstate->stop.load() and not cb(values)){ - opstate->stop.store(true); - } - }); - } - loopSignal_(); + context().post([opstate, cb, values = std::move(values)](){ + if (not opstate->stop.load() and not cb(values)){ + opstate->stop.store(true); + } + }); } } catch(const std::exception& e) { if (logger_) @@ -376,14 +356,10 @@ DhtProxyClient::get(const InfoHash& key, GetCallback cb, DoneCallback donecb, Va opFailed(); } if (donecb) { - { - std::lock_guard<std::mutex> lock(lockCallbacks_); - callbacks_.emplace_back([donecb, opstate](){ - donecb(opstate->ok, {}); - opstate->stop.store(true); - }); - } - loopSignal_(); + context().post([donecb, opstate](){ + donecb(opstate->ok, {}); + opstate->stop.store(true); + }); } if (not isDestroying_) { std::lock_guard<std::mutex> l(requestLock_); @@ -419,7 +395,7 @@ DhtProxyClient::put(const InfoHash& key, Sp<Value> val, DoneCallback cb, time_po auto& search = searches_[key]; if (val->id) { auto id = val->id; - auto refreshPutTimer = std::make_unique<asio::steady_timer>(httpContext_, proxy::OP_TIMEOUT - proxy::OP_MARGIN); + auto refreshPutTimer = std::make_unique<asio::steady_timer>(context(), proxy::OP_TIMEOUT - proxy::OP_MARGIN); refreshPutTimer->async_wait(std::bind(&DhtProxyClient::handleRefreshPut, this, std::placeholders::_1, key, id)); search.puts.erase(id); search.puts.emplace(std::piecewise_construct, @@ -433,12 +409,10 @@ DhtProxyClient::put(const InfoHash& key, Sp<Value> val, DoneCallback cb, time_po if (ok) *ok = result; if (cb) { - std::lock_guard<std::mutex> lock(lockCallbacks_); - callbacks_.emplace_back([cb, result](){ + context().post([cb, result](){ cb(result, {}); }); } - loopSignal_(); }, created, permanent); } @@ -475,10 +449,10 @@ DhtProxyClient::buildRequest(const std::string& target) auto resolver = resolver_; l.unlock(); if (not resolver) - resolver = std::make_shared<http::Resolver>(httpContext_, proxyUrl_, logger_); + resolver = std::make_shared<http::Resolver>(context(), proxyUrl_, logger_); auto request = target.empty() - ? std::make_shared<http::Request>(httpContext_, resolver) - : std::make_shared<http::Request>(httpContext_, resolver, target); + ? std::make_shared<http::Request>(context(), resolver) + : std::make_shared<http::Request>(context(), resolver, target); if (serverCertificate_) request->set_certificate_authority(serverCertificate_); if (clientIdentity_.first and clientIdentity_.second) @@ -529,7 +503,7 @@ DhtProxyClient::doPut(const InfoHash& key, Sp<Value> val, DoneCallbackSimple cb, auto it = search.pendingPuts.find(val); if (it != search.pendingPuts.end()) { auto sok = std::make_shared<std::atomic_bool>(ok); - auto refreshPutTimer = std::make_unique<asio::steady_timer>(httpContext_, proxy::OP_TIMEOUT - proxy::OP_MARGIN); + auto refreshPutTimer = std::make_unique<asio::steady_timer>(context(), proxy::OP_TIMEOUT - proxy::OP_MARGIN); refreshPutTimer->async_wait(std::bind(&DhtProxyClient::handleRefreshPut, this, std::placeholders::_1, key, id)); search.puts.emplace(std::piecewise_construct, std::forward_as_tuple(id), @@ -636,7 +610,7 @@ DhtProxyClient::getProxyInfos() if (logger_) logger_->d("[proxy:client] [status] sending request"); - auto resolver = std::make_shared<http::Resolver>(httpContext_, proxyUrl_, logger_); + auto resolver = std::make_shared<http::Resolver>(context(), proxyUrl_, logger_); queryProxyInfo(infoState, resolver, AF_INET); queryProxyInfo(infoState, resolver, AF_INET6); std::lock_guard<std::mutex> l(resolverLock_); @@ -649,7 +623,7 @@ DhtProxyClient::queryProxyInfo(const Sp<InfoState>& infoState, const Sp<http::Re if (logger_) logger_->d("[proxy:client] [status] query ipv%i info", family == AF_INET ? 4 : 6); try { - auto request = std::make_shared<http::Request>(httpContext_, resolver, family); + auto request = std::make_shared<http::Request>(context(), resolver, family); if (serverCertificate_) request->set_certificate_authority(serverCertificate_); auto reqid = request->id(); @@ -715,7 +689,7 @@ DhtProxyClient::onProxyInfos(const Json::Value& proxyInfos, const sa_family_t fa if (logger_) logger_->e("[proxy:client] [info] request failed for %s", family == AF_INET ? "ipv4" : "ipv6"); status = NodeStatus::Disconnected; - if (pubAddress) { + if (pubAddress != SockAddr{}) { pubAddress = {}; ipChanged = true; } @@ -728,15 +702,20 @@ DhtProxyClient::onProxyInfos(const Json::Value& proxyInfos, const sa_family_t fa stats4_ = NodeStats(proxyInfos["ipv4"]); stats6_ = NodeStats(proxyInfos["ipv6"]); auto publicIp = parsePublicAddress(proxyInfos["public_ip"]); - ipChanged = pubAddress && pubAddress.toString() != publicIp.toString(); + ipChanged = pubAddress != SockAddr{} && pubAddress != publicIp; pubAddress = publicIp; if (proxyInfos.isMember("local_ip")) { std::string localIp = proxyInfos["local_ip"].asString(); - if (localAddress.toString() != localIp) { - localAddress.setAddress(localIp.c_str()); - ipChanged = (bool)localAddress; + auto localAddr = asio::ip::make_address(localIp); + if (localAddr != localAddress.address()) { + localAddress = SockAddr(localAddr, localAddress.port()); + ipChanged = localAddress != SockAddr{}; } + /*if (localAddress.toString() != localIp) { + localAddress.setAddress(localIp.c_str()); + ipChanged = localAddress != SockAddr{}; + }*/ } if (!ipChanged && stats4_.good_nodes + stats6_.good_nodes) @@ -752,14 +731,16 @@ DhtProxyClient::onProxyInfos(const Json::Value& proxyInfos, const sa_family_t fa } } auto newStatus = std::max(statusIpv4_, statusIpv6_); + if (newStatus != oldStatus) { + onStateChanged(); + } if (newStatus == NodeStatus::Connected) { if (oldStatus == NodeStatus::Disconnected || oldStatus == NodeStatus::Connecting || launchConnectedCbs_) { launchConnectedCbs_ = false; listenerRestartTimer_->expires_at(std::chrono::steady_clock::now()); listenerRestartTimer_->async_wait(std::bind(&DhtProxyClient::restartListeners, this, std::placeholders::_1)); if (not onConnectCallbacks_.empty()) { - std::lock_guard<std::mutex> lock(lockCallbacks_); - callbacks_.emplace_back([cbs = std::move(onConnectCallbacks_)]() mutable { + context().post([cbs = std::move(onConnectCallbacks_)]() mutable { while (not cbs.empty()) { cbs.front()(); cbs.pop(); @@ -777,18 +758,25 @@ DhtProxyClient::onProxyInfos(const Json::Value& proxyInfos, const sa_family_t fa nextProxyConfirmationTimer_->expires_at(next); nextProxyConfirmationTimer_->async_wait(std::bind(&DhtProxyClient::handleProxyConfirm, this, std::placeholders::_1)); } - l.unlock(); - loopSignal_(); } +void +DhtProxyClient::onStateChanged(){ + for (auto& cb : onStateChangeCallbacks_) + cb(DhtNodeStatus{statusIpv4_, statusIpv6_}); +} + + SockAddr DhtProxyClient::parsePublicAddress(const Json::Value& val) { auto public_ip = val.asString(); auto hostAndService = splitPort(public_ip); - auto sa = SockAddr::resolve(hostAndService.first); - if (sa.empty()) return {}; - return sa.front().getMappedIPv4(); + auto sa = asio::ip::make_address(hostAndService.first); + if (sa == asio::ip::address{}) return {}; + if (sa.is_v6() and sa.to_v6().is_v4_mapped()) + return SockAddr(asio::ip::make_address_v4(asio::ip::v4_mapped, sa.to_v6()), 0); + return SockAddr(sa, 0); } std::vector<SockAddr> @@ -796,8 +784,8 @@ DhtProxyClient::getPublicAddress(sa_family_t family) { std::lock_guard<std::mutex> l(lockCurrentProxyInfos_); std::vector<SockAddr> result; - if (publicAddressV6_ && family != AF_INET) result.emplace_back(publicAddressV6_); - if (publicAddressV4_ && family != AF_INET6) result.emplace_back(publicAddressV4_); + if (publicAddressV6_ != SockAddr{} && family != AF_INET) result.emplace_back(publicAddressV6_); + if (publicAddressV4_ != SockAddr{} && family != AF_INET6) result.emplace_back(publicAddressV4_); return result; } @@ -857,7 +845,7 @@ DhtProxyClient::listen(const InfoHash& key, ValueCallback cb, Value::Filter filt * (if the proxy crash for any reason) */ if (!l->second.refreshSubscriberTimer) - l->second.refreshSubscriberTimer = std::make_unique<asio::steady_timer>(httpContext_); + l->second.refreshSubscriberTimer = std::make_unique<asio::steady_timer>(context()); l->second.refreshSubscriberTimer->expires_at(std::chrono::steady_clock::now() + proxy::OP_TIMEOUT - proxy::OP_MARGIN); l->second.refreshSubscriberTimer->async_wait(std::bind(&DhtProxyClient::handleResubscribe, this, @@ -924,7 +912,7 @@ DhtProxyClient::cancelListen(const InfoHash& key, size_t gtoken) // define real cancel listen only once if (not it->second.opExpirationTimer) - it->second.opExpirationTimer = std::make_unique<asio::steady_timer>(httpContext_, ops.getExpiration()); + it->second.opExpirationTimer = std::make_unique<asio::steady_timer>(context(), ops.getExpiration()); else it->second.opExpirationTimer->expires_at(ops.getExpiration()); it->second.opExpirationTimer->async_wait(std::bind(&DhtProxyClient::handleExpireListener, this, std::placeholders::_1, key)); @@ -1056,14 +1044,10 @@ DhtProxyClient::sendListen(const restinio::http_request_header_t& header, auto value = std::make_shared<Value>(json); if (cb){ auto expired = json.get("expired", Json::Value(false)).asBool(); - { - std::lock_guard<std::mutex> lock(lockCallbacks_); - callbacks_.emplace_back([cb, value, opstate, expired]() { - if (not opstate->stop.load() and not cb({value}, expired, system_clock::time_point::min())) - opstate->stop.store(true); - }); - } - loopSignal_(); + context().post([cb, value, opstate, expired]() { + if (not opstate->stop.load() and not cb({value}, expired, system_clock::time_point::min())) + opstate->stop.store(true); + }); } } } catch(const std::exception& e) { @@ -1111,7 +1095,6 @@ DhtProxyClient::opFailed() statusIpv6_ = NodeStatus::Disconnected; } getConnectivityStatus(); - loopSignal_(); } void @@ -1147,7 +1130,7 @@ DhtProxyClient::restartListeners(const asio::error_code &ec) *ok = result; }, time_point::max(), true); if (!put.second.refreshPutTimer) { - put.second.refreshPutTimer = std::make_unique<asio::steady_timer>(httpContext_); + put.second.refreshPutTimer = std::make_unique<asio::steady_timer>(context()); } put.second.refreshPutTimer->expires_at(std::chrono::steady_clock::now() + proxy::OP_TIMEOUT - proxy::OP_MARGIN); put.second.refreshPutTimer->async_wait(std::bind(&DhtProxyClient::handleRefreshPut, this, @@ -1205,7 +1188,6 @@ DhtProxyClient::pushNotificationReceived(const std::map<std::string, std::string statusIpv4_ = NodeStatus::Connected; statusIpv6_ = NodeStatus::Connected; } - auto launchLoop = false; try { auto sessionId = notification.find("s"); if (sessionId != notification.end() and sessionId->second != pushSessionId_) { @@ -1224,7 +1206,7 @@ DhtProxyClient::pushNotificationReceived(const std::map<std::string, std::string auto vid = std::stoull(vidIt->second); auto& put = search.puts.at(vid); if (!put.refreshPutTimer) - put.refreshPutTimer = std::make_unique<asio::steady_timer>(httpContext_, std::chrono::steady_clock::now()); + put.refreshPutTimer = std::make_unique<asio::steady_timer>(context(), std::chrono::steady_clock::now()); else put.refreshPutTimer->expires_at(std::chrono::steady_clock::now()); put.refreshPutTimer->async_wait(std::bind(&DhtProxyClient::handleRefreshPut, this, std::placeholders::_1, key, vid)); @@ -1266,23 +1248,19 @@ DhtProxyClient::pushNotificationReceived(const std::map<std::string, std::string getline(ss, substr, ','); ids.emplace_back(std::stoull(substr)); } - { - std::lock_guard<std::mutex> lockCb(lockCallbacks_); - callbacks_.emplace_back([this, key, token, opstate, ids, sendTime]() { - if (opstate->stop) - return; - std::lock_guard<std::mutex> lock(searchLock_); - auto s = searches_.find(key); - if (s == searches_.end()) - return; - auto l = s->second.listeners.find(token); - if (l == s->second.listeners.end()) - return; - if (not opstate->stop and not l->second.cache.onValuesExpired(ids, sendTime)) - opstate->stop = true; - }); - } - launchLoop = true; + context().post([this, key, token, opstate, ids, sendTime]() { + if (opstate->stop) + return; + std::lock_guard<std::mutex> lock(searchLock_); + auto s = searches_.find(key); + if (s == searches_.end()) + return; + auto l = s->second.listeners.find(token); + if (l == s->second.listeners.end()) + return; + if (not opstate->stop and not l->second.cache.onValuesExpired(ids, sendTime)) + opstate->stop = true; + }); } } } @@ -1290,8 +1268,6 @@ DhtProxyClient::pushNotificationReceived(const std::map<std::string, std::string if (logger_) logger_->e("[proxy:client] [push] receive error: %s", e.what()); } - if (launchLoop) - loopSignal_(); #else (void) notification; #endif @@ -1314,7 +1290,7 @@ DhtProxyClient::resubscribe(const InfoHash& key, const size_t token, Listener& l header.method(restinio::http_method_subscribe()); header.request_target("/" + key.toString()); if (!listener.refreshSubscriberTimer){ - listener.refreshSubscriberTimer = std::make_unique<asio::steady_timer>(httpContext_); + listener.refreshSubscriberTimer = std::make_unique<asio::steady_timer>(context()); } listener.refreshSubscriberTimer->expires_at(std::chrono::steady_clock::now() + proxy::OP_TIMEOUT - proxy::OP_MARGIN); diff --git a/src/dht_proxy_server.cpp b/src/dht_proxy_server.cpp index 7ac9a48a..c48a1fa6 100644 --- a/src/dht_proxy_server.cpp +++ b/src/dht_proxy_server.cpp @@ -279,7 +279,7 @@ DhtProxyServer::DhtProxyServer(const std::shared_ptr<DhtRunner>& dht, ioContext_, std::forward<restinio::run_on_this_thread_settings_t<RestRouterTraitsTls>>(std::move(settings)) ); - // run http server + // run https server serverThread_ = std::thread([this]{ httpsServer_->open_async([]{/*ok*/}, [](std::exception_ptr ex){ std::rethrow_exception(ex); @@ -466,8 +466,17 @@ DhtProxyServer::io_context() const return *ioContext_; } -DhtProxyServer::~DhtProxyServer() +void +DhtProxyServer::stop() { + if (logger_) + logger_->d("[proxy:server] closing http server"); + ioContext_->stop(); + if (serverThread_.joinable()) + serverThread_.join(); + if (logger_) + logger_->d("[proxy:server] http server closed"); + if (not persistPath_.empty()) { if (logger_) logger_->d("Saving proxy state to %.*s", (int)persistPath_.size(), persistPath_.c_str()); @@ -495,13 +504,13 @@ DhtProxyServer::~DhtProxyServer() pushListeners_.clear(); #endif } - if (logger_) - logger_->d("[proxy:server] closing http server"); +} + +DhtProxyServer::~DhtProxyServer() +{ ioContext_->stop(); if (serverThread_.joinable()) serverThread_.join(); - if (logger_) - logger_->d("[proxy:server] http server closed"); } template< typename ServerSettings > diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp index 449e0380..fc7c7972 100644 --- a/src/dhtrunner.cpp +++ b/src/dhtrunner.cpp @@ -24,7 +24,6 @@ #include "dhtrunner.h" #include "securedht.h" -#include "network_utils.h" #ifdef OPENDHT_PEER_DISCOVERY #include "peer_discovery.h" #endif @@ -65,24 +64,22 @@ DhtRunner::~DhtRunner() void DhtRunner::run(in_port_t port, Config& config, Context&& context) { - config.bind4.setFamily(AF_INET); - config.bind4.setPort(port); - config.bind6.setFamily(AF_INET6); - config.bind6.setPort(port); + config.bind4 = {asio::ip::address_v4::any(), port}; + config.bind6 = {asio::ip::address_v6::any(), port}; run(config, std::move(context)); } void DhtRunner::run(const char* ip4, const char* ip6, const char* service, Config& config, Context&& context) { - auto res4 = SockAddr::resolve(ip4, service); - auto res6 = SockAddr::resolve(ip6, service); - if (res4.empty()) - res4.emplace_back(); - if (res6.empty()) - res6.emplace_back(); - config.bind4 = std::move(res4.front()); - config.bind6 = std::move(res6.front()); + auto s = asio::ip::udp::resolver::query(service); + asio::ip::udp::resolver resolver(*context.ioContext); + auto res4 = resolver.resolve(ip4, service); + auto res6 = resolver.resolve(ip6, service); + if (not res4.empty()) + config.bind4 = std::move(*res4.begin()); + if (not res6.empty()) + config.bind6 = std::move(*res6.begin()); run(config, std::move(context)); } @@ -100,31 +97,25 @@ DhtRunner::run(const Config& config, Context&& context) try { auto local4 = config.bind4; auto local6 = config.bind6; - if (not local4 and not local6) { - if (context.logger) - context.logger->w("[runner %p] No address to bind specified in the configuration, using default addresses", this); - local4.setFamily(AF_INET); - local6.setFamily(AF_INET6); - } auto state_path = config.dht_config.node_config.persist_path; if (not state_path.empty()) state_path += "_port.txt"; - if (not state_path.empty() && (local4.getPort() == 0 || local6.getPort() == 0)) { + if (not state_path.empty() && (local4.port() == 0 || local6.port() == 0)) { std::ifstream inConfig(state_path); if (inConfig.is_open()) { in_port_t port; if (inConfig >> port) { - if (local4.getPort() == 0) { + if (local4.port() == 0) { if (context.logger) context.logger->d("[runner %p] Using IPv4 port %hu from saved configuration", this, port); - local4.setPort(port); + local4.port(port); } } if (inConfig >> port) { - if (local6.getPort() == 0) { + if (local6.port() == 0) { if (context.logger) context.logger->d("[runner %p] Using IPv6 port %hu from saved configuration", this, port); - local6.setPort(port); + local6.port(port); } } } @@ -140,34 +131,24 @@ DhtRunner::run(const Config& config, Context&& context) identityAnnouncedCb_ = context.identityAnnouncedCb; #endif + if (not context.ioContext) + context.ioContext.reset(new asio::io_context); + + ioContext_ = context.ioContext; + strand_ = std::make_shared<asio::io_context::strand>(asio::io_context::strand(*ioContext_)); + if (config.proxy_server.empty()) { if (not context.sock) { - context.sock.reset(new net::UdpSocket(local4, local6, context.logger)); + //if (context.logger) + // context.logger->d("[runner %p] Creating new socket with local addresses %s and %s", this, local4.toString().c_str(), local6.toString().c_str()); + context.sock.reset(new net::UdpSocket(strand_, local4, local6/*, context.logger*/)); } - context.sock->setOnReceive([&] (net::PacketList&& pkts) { - net::PacketList ret; - { - std::lock_guard<std::mutex> lck(sock_mtx); - rcv.splice(rcv.end(), std::move(pkts)); - size_t dropped = 0; - while (rcv.size() > net::RX_QUEUE_MAX_SIZE) { - rcv.pop_front(); - dropped++; - } - if (dropped and logger_) { - logger_->w("[runner %p] dropped %zu packets: queue is full!", this, dropped); - } - ret = std::move(rcv_free); - } - cv.notify_all(); - return ret; - }); if (not state_path.empty()) { std::ofstream outConfig(state_path); - outConfig << context.sock->getBoundRef(AF_INET).getPort() << std::endl; - outConfig << context.sock->getBoundRef(AF_INET6).getPort() << std::endl; + outConfig << context.sock->getPort(AF_INET) << std::endl; + outConfig << context.sock->getPort(AF_INET6) << std::endl; } - auto dht = std::make_unique<Dht>(std::move(context.sock), SecureDht::getConfig(config.dht_config), context.logger); + auto dht = std::make_unique<Dht>(strand_, std::move(context.sock), SecureDht::getConfig(config.dht_config), context.logger); dht_ = std::make_unique<SecureDht>(std::move(dht), config.dht_config, std::move(context.identityAnnouncedCb), context.logger); } else { enableProxy(true); @@ -180,8 +161,20 @@ DhtRunner::run(const Config& config, Context&& context) throw; } + dht_->addOnStateChangeCallback([this](const DhtNodeStatus& status){ + if (status.get() != NodeStatus::Connecting) { + auto pending = std::move(pending_ops); + while (not pending.empty()) { + pending.front()(*dht_); + pending.pop(); + } + } + }); + if (context.statusChangedCallback) { - statusCb = std::move(context.statusChangedCallback); + if (logger_) + logger_->d("[dhtrunner] starting io_context"); + dht_->addOnStateChangeCallback(std::move(context.statusChangedCallback)); } if (context.certificateStore) { dht_->setLocalCertificateStore(std::move(context.certificateStore)); @@ -190,32 +183,18 @@ DhtRunner::run(const Config& config, Context&& context) if (not config.threaded) return; dht_thread = std::thread([this]() { - while (running != State::Idle) { - std::unique_lock<std::mutex> lk(dht_mtx); - time_point wakeup = loop_(); - - auto hasJobToDo = [this]() { - if (running == State::Idle) - return true; - { - std::lock_guard<std::mutex> lck(sock_mtx); - if (not rcv.empty()) - return true; - } - { - std::lock_guard<std::mutex> lck(storage_mtx); - if (not pending_ops_prio.empty()) - return true; - auto s = getStatus(); - if (not pending_ops.empty() and (s == NodeStatus::Connected or s == NodeStatus::Disconnected)) - return true; - } - return false; - }; - if (wakeup == time_point::max()) - cv.wait(lk, hasJobToDo); - else - cv.wait_until(lk, wakeup, hasJobToDo); + try { + if (logger_) + logger_->d("[dhtrunner] starting io_context"); + // Ensures the context won't run out of work + auto work = asio::make_work_guard(*ioContext_); + ioContext_->run(); + if (logger_) + logger_->d("[dhtrunner] io_context stopped"); + } + catch(const std::exception& ex){ + if (logger_) + logger_->e("[dhtrunner] run error: %s", ex.what()); } }); @@ -234,7 +213,7 @@ DhtRunner::run(const Config& config, Context&& context) auto netId = config.dht_config.node_config.network; if (config.peer_discovery) { peerDiscovery_->startDiscovery<NodeInsertionPack>(PEER_DISCOVERY_DHT_SERVICE, [this, netId](NodeInsertionPack&& v, SockAddr&& addr){ - addr.setPort(v.port); + addr.port(v.port); if (v.nodeId != dht_->getNodeId() && netId == v.net){ bootstrap(v.nodeId, addr); } @@ -247,14 +226,16 @@ DhtRunner::run(const Config& config, Context&& context) adc.nodeId = dht_->getNodeId(); if (auto socket = dht_->getSocket()) { // IPv4 - if (const auto& bound4 = socket->getBoundRef(AF_INET)) { - adc.port = bound4.getPort(); + const auto& bound4 = socket->getBound(AF_INET); + if (bound4 != asio::ip::udp::endpoint{}) { + adc.port = bound4.port(); msgpack::pack(sbuf_node, adc); peerDiscovery_->startPublish(AF_INET, PEER_DISCOVERY_DHT_SERVICE, sbuf_node); } // IPv6 - if (const auto& bound6 = socket->getBoundRef(AF_INET6)) { - adc.port = bound6.getPort(); + const auto& bound6 = socket->getBound(AF_INET6); + if (bound6 != asio::ip::udp::endpoint{}) { + adc.port = bound6.port(); sbuf_node.clear(); msgpack::pack(sbuf_node, adc); peerDiscovery_->startPublish(AF_INET6, PEER_DISCOVERY_DHT_SERVICE, sbuf_node); @@ -267,15 +248,16 @@ DhtRunner::run(const Config& config, Context&& context) void DhtRunner::shutdown(ShutdownCallback cb, bool stop) { - std::unique_lock<std::mutex> lck(storage_mtx); + //std::unique_lock<std::mutex> lck(storage_mtx); auto expected = State::Running; if (not running.compare_exchange_strong(expected, State::Stopping)) { if (expected == State::Stopping and ongoing_ops) { if (cb) shutdownCallbacks_.emplace_back(std::move(cb)); + //asio::post(*ioContext_, [cb] { cb(); }); + // ioContext_->post([cb] { cb(); }); } else if (cb) { - lck.unlock(); cb(); } return; @@ -284,14 +266,28 @@ DhtRunner::shutdown(ShutdownCallback cb, bool stop) { logger_->d("[runner %p] state changed to Stopping, %zu ongoing ops", this, ongoing_ops.load()); ongoing_ops++; shutdownCallbacks_.emplace_back(std::move(cb)); - pending_ops.emplace([=](SecureDht&) mutable { + post([=]() mutable { auto onShutdown = [this]{ opEnded(); }; if (dht_) dht_->shutdown(onShutdown, stop); else opEnded(); }); - cv.notify_all(); + //cv.notify_all(); +} + +void +DhtRunner::postOp(std::function<void(SecureDht&)>&& op, bool prio) { + ongoing_ops++; + /*if (prio) + ioContext_->post(asio::bind_executor(*strand_, std::move(op))); + else*/ + ioContext_->post(asio::bind_executor(*strand_, [this, prio, op=std::move(op)]{ + if (dht_ && (prio || dht_->getStatus() != NodeStatus::Connecting)) + op(*dht_); + else + pending_ops.emplace(std::move(op)); + })); } void @@ -320,7 +316,6 @@ bool DhtRunner::checkShutdown() { decltype(shutdownCallbacks_) cbs; { - std::lock_guard<std::mutex> lck(storage_mtx); if (running != State::Stopping or ongoing_ops) return false; cbs = std::move(shutdownCallbacks_); @@ -342,31 +337,29 @@ DhtRunner::join() if (peerDiscovery_) peerDiscovery_->stop(); #endif - if (dht_) + /*if (dht_) if (auto sock = dht_->getSocket()) - sock->stop(); + sock->stop();*/ if (logger_) logger_->d("[runner %p] state changed to Idle", this); } - if (dht_thread.joinable()) + if (dht_thread.joinable()) { + ioContext_->stop(); dht_thread.join(); + } { std::lock_guard<std::mutex> lck(storage_mtx); if (ongoing_ops and logger_) { logger_->w("[runner %p] stopping with %zu remaining ops", this, ongoing_ops.load()); } - pending_ops = decltype(pending_ops)(); - pending_ops_prio = decltype(pending_ops_prio)(); ongoing_ops = 0; shutdownCallbacks_.clear(); } { std::lock_guard<std::mutex> lck(dht_mtx); resetDht(); - status4 = NodeStatus::Disconnected; - status6 = NodeStatus::Disconnected; } } @@ -384,7 +377,7 @@ DhtRunner::getBoundPort(sa_family_t af) const { std::lock_guard<std::mutex> lck(dht_mtx); if (dht_) if (auto sock = dht_->getSocket()) - return sock->getPort(af); + return sock->getBound(af).port(); return 0; } @@ -512,8 +505,8 @@ DhtRunner::getNodeInfo() const { info.ipv4 = dht_->getNodesStats(AF_INET); info.ipv6 = dht_->getNodesStats(AF_INET6); if (auto sock = dht_->getSocket()) { - info.bound4 = sock->getBoundRef(AF_INET).getPort(); - info.bound6 = sock->getBoundRef(AF_INET6).getPort(); + info.bound4 = sock->getBound(AF_INET).port(); + info.bound6 = sock->getBound(AF_INET6).port(); } } info.ongoing_ops = ongoing_ops; @@ -525,17 +518,17 @@ DhtRunner::getNodeInfo(std::function<void(std::shared_ptr<NodeInfo>)> cb) { std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; - pending_ops_prio.emplace([cb = std::move(cb), this](SecureDht& dht){ + post([cb = std::move(cb), this](){ auto sinfo = std::make_shared<NodeInfo>(); auto& info = *sinfo; - info.id = dht.getId(); - info.node_id = dht.getNodeId(); - info.ipv4 = dht.getNodesStats(AF_INET); - info.ipv6 = dht.getNodesStats(AF_INET6); - std::tie(info.storage_size, info.storage_values) = dht.getStoreSize(); - if (auto sock = dht.getSocket()) { - info.bound4 = sock->getBoundRef(AF_INET).getPort(); - info.bound6 = sock->getBoundRef(AF_INET6).getPort(); + info.id = dht_->getId(); + info.node_id = dht_->getNodeId(); + info.ipv4 = dht_->getNodesStats(AF_INET); + info.ipv6 = dht_->getNodesStats(AF_INET6); + std::tie(info.storage_size, info.storage_values) = dht_->getStoreSize(); + if (auto sock = dht_->getSocket()) { + info.bound4 = sock->getBound(AF_INET).port(); + info.bound6 = sock->getBound(AF_INET6).port(); } info.ongoing_ops = ongoing_ops; cb(std::move(sinfo)); @@ -594,7 +587,12 @@ DhtRunner::getPublicAddressStr(sa_family_t af) const { auto addrs = getPublicAddress(af); std::vector<std::string> ret(addrs.size()); - std::transform(addrs.begin(), addrs.end(), ret.begin(), [](const SockAddr& a) { return a.toString(); }); + std::transform(addrs.begin(), addrs.end(), ret.begin(), [](const SockAddr& a) { + return asio::ip::detail::endpoint(a.address(), a.port()).to_string(); + /*std::ostringstream ss; + ss << a; + return ss.str();*/ + }); return ret; } @@ -603,8 +601,8 @@ DhtRunner::getPublicAddress(std::function<void(std::vector<SockAddr>&&)> cb, sa_ { std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; - pending_ops_prio.emplace([cb = std::move(cb), this, af](SecureDht& dht){ - cb(dht.getPublicAddress(af)); + post([cb = std::move(cb), this, af](){ + cb(dht_->getPublicAddress(af)); opEnded(); }); cv.notify_all(); @@ -622,7 +620,7 @@ DhtRunner::setLocalCertificateStore(CertificateStoreQuery&& query_method) { if (dht_) dht_->setLocalCertificateStore(std::forward<CertificateStoreQuery>(query_method)); } - +/* time_point DhtRunner::loop_() { @@ -698,22 +696,24 @@ DhtRunner::loop_() } return wakeup; -} +}*/ void DhtRunner::get(InfoHash hash, GetCallback vcb, DoneCallback dcb, Value::Filter f, Where w) { - std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { - lck.unlock(); if (dcb) dcb(false, {}); return; } - ongoing_ops++; - pending_ops.emplace([=](SecureDht& dht) mutable { - dht.get(hash, std::move(vcb), bindOpDoneCallback(std::move(dcb)), std::move(f), std::move(w)); + postOp([ + hash, + vcb = std::move(vcb), + dcb=bindOpDoneCallback(std::move(dcb)), + f = std::move(f), + w = std::move(w) + ](SecureDht& dht) mutable { + dht.get(hash, std::move(vcb), std::move(dcb), std::move(f), std::move(w)); }); - cv.notify_all(); } void @@ -729,11 +729,9 @@ DhtRunner::query(const InfoHash& hash, QueryCallback cb, DoneCallback done_cb, Q if (done_cb) done_cb(false, {}); return; } - ongoing_ops++; - pending_ops.emplace([=](SecureDht& dht) mutable { + postOp([=](SecureDht& dht) mutable { dht.query(hash, std::move(cb), bindOpDoneCallback(std::move(done_cb)), std::move(q)); }); - cv.notify_all(); } std::future<size_t> @@ -746,8 +744,8 @@ DhtRunner::listen(InfoHash hash, ValueCallback vcb, Value::Filter f, Where w) ret_token->set_value(0); return ret_token->get_future(); } - pending_ops.emplace([=](SecureDht& dht) mutable { - ret_token->set_value(dht.listen(hash, std::move(vcb), std::move(f), std::move(w))); + post([=]() mutable { + ret_token->set_value(dht_->listen(hash, std::move(vcb), std::move(f), std::move(w))); }); cv.notify_all(); return ret_token->get_future(); @@ -765,26 +763,21 @@ DhtRunner::cancelListen(InfoHash h, size_t token) std::lock_guard<std::mutex> lck(storage_mtx); if (running != State::Running) return; - ongoing_ops++; - pending_ops.emplace([=](SecureDht& dht) { + postOp([=](SecureDht& dht) { dht.cancelListen(h, token); opEnded(); }); - cv.notify_all(); } void DhtRunner::cancelListen(InfoHash h, std::shared_future<size_t> ftoken) { - std::lock_guard<std::mutex> lck(storage_mtx); if (running != State::Running) return; - ongoing_ops++; - pending_ops.emplace([this, h, ftoken = std::move(ftoken)](SecureDht& dht) { + postOp([this, h, ftoken = std::move(ftoken)](SecureDht& dht) { dht.cancelListen(h, ftoken.get()); opEnded(); }); - cv.notify_all(); } void @@ -796,14 +789,12 @@ DhtRunner::put(InfoHash hash, Value&& value, DoneCallback cb, time_point created if (cb) cb(false, {}); return; } - ongoing_ops++; - pending_ops.emplace([=, + postOp([=, cb = std::move(cb), sv = std::make_shared<Value>(std::move(value)) ] (SecureDht& dht) mutable { dht.put(hash, sv, bindOpDoneCallback(std::move(cb)), created, permanent); }); - cv.notify_all(); } void @@ -815,11 +806,9 @@ DhtRunner::put(InfoHash hash, std::shared_ptr<Value> value, DoneCallback cb, tim if (cb) cb(false, {}); return; } - ongoing_ops++; - pending_ops.emplace([=, value = std::move(value), cb = std::move(cb)](SecureDht& dht) mutable { + postOp([=, value = std::move(value), cb = std::move(cb)](SecureDht& dht) mutable { dht.put(hash, value, bindOpDoneCallback(std::move(cb)), created, permanent); }); - cv.notify_all(); } void @@ -832,20 +821,18 @@ void DhtRunner::cancelPut(const InfoHash& h, Value::Id id) { std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { - dht.cancelPut(h, id); + post([=]() { + dht_->cancelPut(h, id); }); - cv.notify_all(); } void DhtRunner::cancelPut(const InfoHash& h, const std::shared_ptr<Value>& value) { std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { - dht.cancelPut(h, value->id); + post([=]() { + dht_->cancelPut(h, value->id); }); - cv.notify_all(); } void @@ -857,14 +844,12 @@ DhtRunner::putSigned(InfoHash hash, std::shared_ptr<Value> value, DoneCallback c if (cb) cb(false, {}); return; } - ongoing_ops++; - pending_ops.emplace([=, + postOp([=, cb = std::move(cb), value = std::move(value) ](SecureDht& dht) mutable { dht.putSigned(hash, value, bindOpDoneCallback(std::move(cb)), permanent); }); - cv.notify_all(); } void @@ -888,14 +873,12 @@ DhtRunner::putEncrypted(InfoHash hash, InfoHash to, std::shared_ptr<Value> value if (cb) cb(false, {}); return; } - ongoing_ops++; - pending_ops.emplace([=, + postOp([=, cb = std::move(cb), value = std::move(value) ] (SecureDht& dht) mutable { dht.putEncrypted(hash, to, value, bindOpDoneCallback(std::move(cb)), permanent); }); - cv.notify_all(); } void @@ -920,7 +903,7 @@ DhtRunner::putEncrypted(InfoHash hash, const std::shared_ptr<crypto::PublicKey>& return; } ongoing_ops++; - pending_ops.emplace([=, + postOp([=, cb = std::move(cb), value = std::move(value) ] (SecureDht& dht) mutable { @@ -939,8 +922,8 @@ void DhtRunner::bootstrap(const std::string& host, const std::string& service) { std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([host, service] (SecureDht& dht) mutable { - dht.addBootstrap(host, service); + post([this, host, service] () mutable { + dht_->addBootstrap(host, service); }); cv.notify_all(); } @@ -949,8 +932,8 @@ void DhtRunner::bootstrap(const std::string& hostService) { std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([host_service = splitPort(hostService)] (SecureDht& dht) mutable { - dht.addBootstrap(host_service.first, host_service.second); + post([this, host_service = splitPort(hostService)] () mutable { + dht_->addBootstrap(host_service.first, host_service.second); }); cv.notify_all(); } @@ -959,8 +942,8 @@ void DhtRunner::clearBootstrap() { std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([] (SecureDht& dht) mutable { - dht.clearBootstrap(); + post([this] () mutable { + dht_->clearBootstrap(); }); cv.notify_all(); } @@ -974,15 +957,15 @@ DhtRunner::bootstrap(std::vector<SockAddr> nodes, DoneCallbackSimple cb) } std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; - pending_ops_prio.emplace([ + post([this, cb = bindOpDoneCallback(std::move(cb)), nodes = std::move(nodes) - ] (SecureDht& dht) mutable { + ] () mutable { auto rem = cb ? std::make_shared<std::pair<size_t, bool>>(nodes.size(), false) : nullptr; for (auto& node : nodes) { - if (node.getPort() == 0) - node.setPort(net::DHT_DEFAULT_PORT); - dht.pingNode(std::move(node), [rem,cb](bool ok) { + if (node.port() == 0) + node.port(net::DHT_DEFAULT_PORT); + dht_->pingNode(std::move(node), [rem,cb](bool ok) { auto& r = *rem; r.first--; r.second |= ok; @@ -1005,8 +988,10 @@ DhtRunner::bootstrap(SockAddr addr, DoneCallbackSimple cb) return; } ongoing_ops++; - pending_ops_prio.emplace([addr = std::move(addr), cb = bindOpDoneCallback(std::move(cb))](SecureDht& dht) mutable { - dht.pingNode(std::move(addr), std::move(cb)); + if (addr.address().is_unspecified()) + addr.address(addr.address().is_v4() ? asio::ip::address(asio::ip::address_v4::loopback()) : asio::ip::address(asio::ip::address_v6::loopback())); + post([this, addr = std::move(addr), cb = bindOpDoneCallback(std::move(cb))]() mutable { + dht_->pingNode(std::move(addr), std::move(cb)); }); cv.notify_all(); } @@ -1017,8 +1002,8 @@ DhtRunner::bootstrap(const InfoHash& id, const SockAddr& address) std::lock_guard<std::mutex> lck(storage_mtx); if (running != State::Running) return; - pending_ops_prio.emplace([id, address](SecureDht& dht) mutable { - dht.insertNode(id, address); + post([this, id, address]() mutable { + dht_->insertNode(id, address); }); cv.notify_all(); } @@ -1029,9 +1014,9 @@ DhtRunner::bootstrap(std::vector<NodeExport> nodes) std::lock_guard<std::mutex> lck(storage_mtx); if (running != State::Running) return; - pending_ops_prio.emplace([nodes = std::move(nodes)](SecureDht& dht) { + post([this, nodes = std::move(nodes)]() { for (auto& node : nodes) - dht.insertNode(node); + dht_->insertNode(node); }); cv.notify_all(); } @@ -1040,8 +1025,8 @@ void DhtRunner::connectivityChanged() { std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([=](SecureDht& dht) { - dht.connectivityChanged(); + post([=]() { + dht_->connectivityChanged(); #ifdef OPENDHT_PEER_DISCOVERY if (peerDiscovery_) peerDiscovery_->connectivityChanged(); @@ -1058,14 +1043,12 @@ DhtRunner::findCertificate(InfoHash hash, std::function<void(const Sp<crypto::Ce cb({}); return; } - ongoing_ops++; - pending_ops.emplace([this, hash, cb = std::move(cb)] (SecureDht& dht) { + postOp([this, hash, cb = std::move(cb)] (SecureDht& dht) { dht.findCertificate(hash, [this, cb = std::move(cb)](const Sp<crypto::Certificate>& crt){ cb(crt); opEnded(); }); }); - cv.notify_all(); } void @@ -1104,13 +1087,7 @@ DhtRunner::enableProxy(bool proxify) auto dht_via_proxy = std::make_unique<DhtProxyClient>( config_.server_ca, config_.client_identity, - [this]{ - if (config_.threaded) { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([=](SecureDht&) mutable {}); - cv.notify_all(); - } - }, + strand_, config_.proxy_server, config_.push_node_id, logger_); if (not config_.push_token.empty()) dht_via_proxy->setPushNotificationToken(config_.push_token); @@ -1182,7 +1159,7 @@ DhtRunner::pushNotificationReceived(const std::map<std::string, std::string>& da { #if defined(OPENDHT_PROXY_CLIENT) && defined(OPENDHT_PUSH_NOTIFICATIONS) std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([=](SecureDht&) { + post([=]() { if (dht_) dht_->pushNotificationReceived(data); }); diff --git a/src/network_engine.cpp b/src/network_engine.cpp index a92915f8..43d3d3c7 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -53,6 +53,8 @@ struct NetworkEngine::PartialMessage { time_point start; time_point last_part; std::unique_ptr<ParsedMessage> msg; + std::unique_ptr<asio::steady_timer> timeout; + std::unique_ptr<asio::steady_timer> final_timeout; }; std::vector<Blob> @@ -83,10 +85,10 @@ RequestAnswer::RequestAnswer(ParsedMessage&& msg) {} NetworkEngine::NetworkEngine(InfoHash& myid, NetworkConfig c, + const Sp<strand>& strand, std::unique_ptr<DatagramSocket>&& sock, const Sp<Logger>& log, std::mt19937_64& rand, - Scheduler& scheduler, decltype(NetworkEngine::onError)&& onError, decltype(NetworkEngine::onNewNode)&& onNewNode, decltype(NetworkEngine::onReportedAddr)&& onReportedAddr, @@ -96,6 +98,7 @@ NetworkEngine::NetworkEngine(InfoHash& myid, NetworkConfig c, decltype(NetworkEngine::onListen)&& onListen, decltype(NetworkEngine::onAnnounce)&& onAnnounce, decltype(NetworkEngine::onRefresh)&& onRefresh) : + strand_(strand), onError(std::move(onError)), onNewNode(std::move(onNewNode)), onReportedAddr(std::move(onReportedAddr)), @@ -107,9 +110,12 @@ NetworkEngine::NetworkEngine(InfoHash& myid, NetworkConfig c, onRefresh(std::move(onRefresh)), myid(myid), config(c), dht_socket(std::move(sock)), logger_(log), rd(rand), cache(rd), - rate_limiter(config.max_req_per_sec), - scheduler(scheduler) -{} + rate_limiter(config.max_req_per_sec) +{ + dht_socket->setOnReceive([this](const ReceivedPacket& pkt) { + processMessage(pkt.data.data(), pkt.data.size(), pkt.from); + }); +} NetworkEngine::~NetworkEngine() { clear(); @@ -124,7 +130,7 @@ NetworkEngine::tellListener(const Sp<Node>& node, Tid socket_id, const InfoHash& auto nnodes = bufferNodes(node->getFamily(), hash, want, nodes, nodes6); try { if (version >= 1) { - sendUpdateValues(node, hash, std::move(values), scheduler.time(), ntoken, socket_id); + sendUpdateValues(node, hash, std::move(values), time(), ntoken, socket_id); } else { sendNodesValues(node->getAddr(), socket_id, nnodes.first, nnodes.second, std::move(values), query, ntoken); } @@ -275,7 +281,7 @@ NetworkEngine::requestStep(Sp<Request> sreq) if (not req.pending()) return; - auto now = scheduler.time(); + const auto& now = time(); auto& node = *req.node; if (req.isExpired(now)) { // if (logger_) @@ -289,18 +295,19 @@ NetworkEngine::requestStep(Sp<Request> sreq) } auto err = send(node.getAddr(), (char*)req.msg.data(), req.msg.size(), node.getReplyTime() < now - UDP_REPLY_TIME); - if (err == ENETUNREACH || - err == EHOSTUNREACH || - err == EAFNOSUPPORT || - err == EPIPE || - err == EPERM) + + if (err == asio::error::network_unreachable || + err == asio::error::host_unreachable || + err == asio::error::address_family_not_supported || + err == asio::error::broken_pipe || + err == asio::error::no_permission) { node.setExpired(); if (not node.id) requests.erase(req.tid); } else { req.last_try = now; - if (err != EAGAIN) { + if (err != asio::error::try_again) { ++req.attempt_count; req.attempt_duration += req.attempt_duration + uniform_duration_distribution<>(0ms, ((duration)Node::MAX_RESPONSE_TIME)/4)(rd); @@ -309,9 +316,15 @@ NetworkEngine::requestStep(Sp<Request> sreq) } } std::weak_ptr<Request> wreq = sreq; - scheduler.add(req.last_try + req.attempt_duration, [this,wreq] { - if (auto req = wreq.lock()) - requestStep(req); + if (!req.expiration_timer) + req.expiration_timer = std::make_unique<asio::steady_timer>(context()); + req.expiration_timer->expires_at(req.last_try + req.attempt_duration); + req.expiration_timer->async_wait([this,wreq](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) { + syncTime(); + if (auto req = wreq.lock()) + requestStep(req); + } }); } } @@ -326,7 +339,7 @@ NetworkEngine::sendRequest(const Sp<Request>& request) auto& node = *request->node; if (not node.id) requests.emplace(request->tid, request); - request->start = scheduler.time(); + request->start = time(); node.requested(request); requestStep(request); } @@ -336,7 +349,7 @@ NetworkEngine::sendRequest(const Sp<Request>& request) bool NetworkEngine::rateLimit(const SockAddr& addr) { - const auto& now = scheduler.time(); + const auto& now = time(); // occasional IP limiter maintenance (a few times every second at max rate) if (limiter_maintenance++ == config.max_peer_req_per_sec) { @@ -360,24 +373,20 @@ NetworkEngine::rateLimit(const SockAddr& addr) bool NetworkEngine::isMartian(const SockAddr& addr) { - if (addr.getPort() == 0) + if (addr.port() == 0) return true; - switch(addr.getFamily()) { + switch(addr.protocol().family()) { case AF_INET: { - const auto& sin = addr.getIPv4(); - const uint8_t* address = (const uint8_t*)&sin.sin_addr; - return (address[0] == 0) || - ((address[0] & 0xE0) == 0xE0); + const auto& sin = addr.address().to_v4().to_bytes(); + return (sin[0] == 0) || + ((sin[0] & 0xE0) == 0xE0); } case AF_INET6: { - if (addr.getLength() < sizeof(sockaddr_in6)) - return true; - const auto& sin6 = addr.getIPv6(); - const uint8_t* address = (const uint8_t*)&sin6.sin6_addr; - return address[0] == 0xFF || - (address[0] == 0xFE && (address[1] & 0xC0) == 0x80) || - memcmp(address, InfoHash::zero().data(), 16) == 0 || - memcmp(address, v4prefix, 12) == 0; + const auto& sin6 = addr.address().to_v6().to_bytes(); + return sin6[0] == 0xFF || + (sin6[0] == 0xFE && (sin6[1] & 0xC0) == 0x80) || + memcmp(sin6.data(), InfoHash::zero().data(), 16) == 0 || + memcmp(sin6.data(), v4prefix, 12) == 0; } default: return true; @@ -402,16 +411,17 @@ NetworkEngine::isNodeBlacklisted(const SockAddr& addr) const void NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, SockAddr f) { - auto from = f.getMappedIPv4(); + syncTime(); + const auto& from = f;/*.getMappedIPv4(); if (isMartian(from)) { if (logger_) logger_->w("Received packet from martian node %s", from.toString().c_str()); return; - } + }*/ if (isNodeBlacklisted(from)) { - if (logger_) - logger_->w("Received packet from blacklisted node %s", from.toString().c_str()); + //if (logger_) + // logger_->w("Received packet from blacklisted node %s", from.toString().c_str()); return; } @@ -433,7 +443,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, SockAddr f) return; } - const auto& now = scheduler.time(); + const auto& now = time(); // partial value data if (msg->type == MessageType::ValueData) { @@ -445,7 +455,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, SockAddr f) rateLimit(from); return; } - if (!pmsg_it->second.from.equals(from)) { + if (pmsg_it->second.from != from) { if (logger_) logger_->d("Received partial message data from unexpected IP address"); rateLimit(from); @@ -463,8 +473,13 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, SockAddr f) } catch (...) { return; } - } else - scheduler.add(now + RX_TIMEOUT, std::bind(&NetworkEngine::maintainRxBuffer, this, msg->tid)); + } else { + pmsg_it->second.timeout->expires_at(now + RX_TIMEOUT); + pmsg_it->second.timeout->async_wait([this, t=msg->tid](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) + maintainRxBuffer(t); + }); + } } return; } @@ -499,8 +514,18 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, SockAddr f) pmsg.msg = std::move(msg); pmsg.start = now; pmsg.last_part = now; - scheduler.add(now + RX_MAX_PACKET_TIME, std::bind(&NetworkEngine::maintainRxBuffer, this, k)); - scheduler.add(now + RX_TIMEOUT, std::bind(&NetworkEngine::maintainRxBuffer, this, k)); + //scheduler.add(now + RX_MAX_PACKET_TIME, std::bind(&NetworkEngine::maintainRxBuffer, this, k)); + //scheduler.add(now + RX_TIMEOUT, std::bind(&NetworkEngine::maintainRxBuffer, this, k)); + pmsg.timeout = std::make_unique<asio::steady_timer>(context(), now + RX_TIMEOUT); + pmsg.final_timeout = std::make_unique<asio::steady_timer>(context(), now + RX_MAX_PACKET_TIME); + pmsg.timeout->async_wait([this,k](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) + maintainRxBuffer(k); + }); + pmsg.final_timeout->async_wait([this,k](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) + maintainRxBuffer(k); + }); } else if (logger_) logger_->e("Partial message with given TID %u already exists", k); } @@ -509,7 +534,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, SockAddr f) void NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& from) { - const auto& now = scheduler.time(); + const auto& now = time(); auto node = cache.getNode(msg->id, from, now, true, msg->is_client); if (msg->type == MessageType::ValueUpdate) { @@ -566,10 +591,10 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro if (not req->setError(DhtProtocolException {msg->error_code})) onError(req, DhtProtocolException {msg->error_code}); } else { - if (logIncoming_) + /*if (logIncoming_) if (logger_) logger_->w(msg->id, "[node %s %s] received unknown error message %u", - msg->id.toString().c_str(), from.toString().c_str(), msg->error_code); + msg->id.toString().c_str(), from.toString().c_str(), msg->error_code);*/ } break; } @@ -581,7 +606,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro or r.getType() == MessageType::Refresh) { r.node->authSuccess(); } - r.reply_time = scheduler.time(); + r.reply_time = time(); deserializeNodes(*msg, from); r.setDone(std::move(*msg)); @@ -613,7 +638,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro // logger_->d(msg->target, node->id, "[node %s] got 'find' request for %s (%d)", node->toString().c_str(), msg->target.toString().c_str(), msg->want); ++in_stats.find; RequestAnswer answer = onFindNode(node, msg->target, msg->want); - auto nnodes = bufferNodes(from.getFamily(), msg->target, msg->want, answer.nodes4, answer.nodes6); + auto nnodes = bufferNodes(from.protocol().family(), msg->target, msg->want, answer.nodes4, answer.nodes6); sendNodesValues(from, msg->tid, nnodes.first, nnodes.second, {}, {}, answer.ntoken); break; } @@ -622,7 +647,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro // logger_->d(msg->info_hash, node->id, "[node %s] got 'get' request for %s", node->toString().c_str(), msg->info_hash.toString().c_str()); ++in_stats.get; RequestAnswer answer = onGetValues(node, msg->info_hash, msg->want, msg->query); - auto nnodes = bufferNodes(from.getFamily(), msg->info_hash, msg->want, answer.nodes4, answer.nodes6); + auto nnodes = bufferNodes(from.protocol().family(), msg->info_hash, msg->want, answer.nodes4, answer.nodes6); sendNodesValues(from, msg->tid, nnodes.first, nnodes.second, answer.values, msg->query, answer.ntoken); break; } @@ -652,7 +677,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro logger_->d(msg->info_hash, node->id, "[node %s] got 'listen' request for %s", node->toString().c_str(), msg->info_hash.toString().c_str()); ++in_stats.listen; RequestAnswer answer = onListen(node, msg->info_hash, msg->token, msg->socket_id, std::move(msg->query), msg->version); - auto nnodes = bufferNodes(from.getFamily(), msg->info_hash, msg->want, answer.nodes4, answer.nodes6); + auto nnodes = bufferNodes(from.protocol().family(), msg->info_hash, msg->want, answer.nodes4, answer.nodes6); sendListenConfirmation(from, msg->tid); break; } @@ -682,19 +707,23 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro void insertAddr(msgpack::packer<msgpack::sbuffer>& pk, const SockAddr& addr) { - size_t addr_len = std::min<size_t>(addr.getLength(), - (addr.getFamily() == AF_INET) ? sizeof(in_addr) : sizeof(in6_addr)); - void* addr_ptr = (addr.getFamily() == AF_INET) ? (void*)&addr.getIPv4().sin_addr - : (void*)&addr.getIPv6().sin6_addr; pk.pack("sa"); - pk.pack_bin(addr_len); - pk.pack_bin_body((char*)addr_ptr, addr_len); + const auto& a = addr.address(); + if (a.is_v4()) { + auto bytes = a.to_v4().to_bytes(); + pk.pack_bin(bytes.size()); + pk.pack_bin_body((char*)bytes.data(), bytes.size()); + } else { + auto bytes = a.to_v6().to_bytes(); + pk.pack_bin(bytes.size()); + pk.pack_bin_body((char*)bytes.data(), bytes.size()); + } } -int +asio::error_code NetworkEngine::send(const SockAddr& addr, const char *buf, size_t len, bool confirmed) { - return dht_socket ? dht_socket->sendTo(addr, (const uint8_t*)buf, len, confirmed) : ENOTCONN; + return dht_socket ? dht_socket->sendTo((const uint8_t*)buf, len, addr) : asio::error::not_connected; } Sp<Request> @@ -854,18 +883,15 @@ NetworkEngine::sendGetValues(const Sp<Node>& n, const InfoHash& info_hash, const SockAddr deserializeIPv4(const uint8_t* ni) { SockAddr addr; - addr.setFamily(AF_INET); - auto& sin = addr.getIPv4(); - std::memcpy(&sin.sin_addr, ni, 4); - std::memcpy(&sin.sin_port, ni + 4, 2); + addr.address(asio::ip::address_v4({ni[0], ni[1], ni[2], ni[3]})); + addr.port(ntohs(*reinterpret_cast<const uint16_t*>(ni + 4))); return addr; } SockAddr deserializeIPv6(const uint8_t* ni) { SockAddr addr; - addr.setFamily(AF_INET6); - auto& sin6 = addr.getIPv6(); - std::memcpy(&sin6.sin6_addr, ni, 16); - std::memcpy(&sin6.sin6_port, ni + 16, 2); + addr.address(asio::ip::address_v6({ni[0], ni[1], ni[2], ni[3], ni[4], ni[5], ni[6], ni[7], + ni[8], ni[9], ni[10], ni[11], ni[12], ni[13], ni[14], ni[15]})); + addr.port(ntohs(*reinterpret_cast<const uint16_t*>(ni + 16))); return addr; } @@ -875,21 +901,21 @@ NetworkEngine::deserializeNodes(ParsedMessage& msg, const SockAddr& from) { throw DhtProtocolException {DhtProtocolException::WRONG_NODE_INFO_BUF_LEN}; } // deserialize nodes - const auto& now = scheduler.time(); + const auto& now = time(); for (unsigned i = 0, n = msg.nodes4_raw.size() / NODE4_INFO_BUF_LEN; i < n; i++) { const uint8_t* ni = msg.nodes4_raw.data() + i * NODE4_INFO_BUF_LEN; const auto& ni_id = *reinterpret_cast<const InfoHash*>(ni); if (ni_id == myid) continue; SockAddr addr = deserializeIPv4(ni + ni_id.size()); - if (addr.isLoopback() and from.getFamily() == AF_INET) { - auto port = addr.getPort(); + if (addr.address().is_loopback() and from.protocol().family() == AF_INET) { + auto port = addr.port(); addr = from; - addr.setPort(port); + addr.port(port); } if (isMartian(addr) || isNodeBlacklisted(addr)) continue; - msg.nodes4.emplace_back(cache.getNode(ni_id, addr, now, false)); + msg.nodes4.emplace_back(cache.getNode(ni_id, std::move(addr), now, false)); onNewNode(msg.nodes4.back(), 0); } for (unsigned i = 0, n = msg.nodes6_raw.size() / NODE6_INFO_BUF_LEN; i < n; i++) { @@ -898,14 +924,14 @@ NetworkEngine::deserializeNodes(ParsedMessage& msg, const SockAddr& from) { if (ni_id == myid) continue; SockAddr addr = deserializeIPv6(ni + ni_id.size()); - if (addr.isLoopback() and from.getFamily() == AF_INET6) { - auto port = addr.getPort(); + if (addr.address().is_loopback() and from.protocol().family() == AF_INET6) { + auto port = addr.port(); addr = from; - addr.setPort(port); + addr.port(port); } if (isMartian(addr) || isNodeBlacklisted(addr)) continue; - msg.nodes6.emplace_back(cache.getNode(ni_id, addr, now, false)); + msg.nodes6.emplace_back(cache.getNode(ni_id, std::move(addr), now, false)); onNewNode(msg.nodes6.back(), 0); } } @@ -1037,21 +1063,25 @@ NetworkEngine::bufferNodes(sa_family_t af, const InfoHash& id, std::vector<Sp<No bnodes.resize(NODE4_INFO_BUF_LEN * nnode); for (size_t i=0; i<nnode; i++) { const Node& n = *nodes[i]; - const auto& sin = n.getAddr().getIPv4(); + const auto& sin = n.getAddr(); + auto port = htons(sin.port()); + auto addr = sin.address().to_v4().to_bytes(); auto dest = bnodes.data() + NODE4_INFO_BUF_LEN * i; memcpy(dest, n.id.data(), HASH_LEN); - memcpy(dest + HASH_LEN, &sin.sin_addr, sizeof(in_addr)); - memcpy(dest + HASH_LEN + sizeof(in_addr), &sin.sin_port, sizeof(in_port_t)); + memcpy(dest + HASH_LEN, addr.data(), sizeof(in_addr)); + memcpy(dest + HASH_LEN + sizeof(in_addr), &port, sizeof(in_port_t)); } } else if (af == AF_INET6) { bnodes.resize(NODE6_INFO_BUF_LEN * nnode); for (size_t i=0; i<nnode; i++) { const Node& n = *nodes[i]; - const auto& sin6 = n.getAddr().getIPv6(); + const auto& sin6 = n.getAddr(); + auto port = htons(sin6.port()); + auto addr = sin6.address().to_v6().to_bytes(); auto dest = bnodes.data() + NODE6_INFO_BUF_LEN * i; memcpy(dest, n.id.data(), HASH_LEN); - memcpy(dest + HASH_LEN, &sin6.sin6_addr, sizeof(in6_addr)); - memcpy(dest + HASH_LEN + sizeof(in6_addr), &sin6.sin6_port, sizeof(in_port_t)); + memcpy(dest + HASH_LEN, addr.data(), sizeof(in6_addr)); + memcpy(dest + HASH_LEN + sizeof(in6_addr), &port, sizeof(in_port_t)); } } return bnodes; @@ -1158,7 +1188,7 @@ NetworkEngine::sendAnnounceValue(const Sp<Node>& n, msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_map(5+(config.network?1:0)); - bool add_created = created < scheduler.time(); + bool add_created = created < time(); pk.pack(KEY_A); pk.pack_map(add_created ? 5 : 4); pk.pack(KEY_REQ_ID); pk.pack(myid); pk.pack(KEY_REQ_H); pk.pack(infohash); @@ -1244,13 +1274,13 @@ NetworkEngine::sendUpdateValues(const Sp<Node>& n, msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_map(5+(config.network?1:0)); - pk.pack(KEY_A); pk.pack_map((created < scheduler.time() ? 7 : 6)); + pk.pack(KEY_A); pk.pack_map((created < time() ? 7 : 6)); pk.pack(KEY_REQ_ID); pk.pack(myid); pk.pack(KEY_VERSION); pk.pack(1); pk.pack(KEY_REQ_H); pk.pack(infohash); pk.pack(KEY_REQ_SID); pk.pack(sid); auto v = packValueHeader(buffer, begin, end); - if (created < scheduler.time()) { + if (created < time()) { pk.pack(KEY_REQ_CREATION); pk.pack(to_time_t(created)); } @@ -1385,11 +1415,11 @@ NetworkEngine::maintainRxBuffer(Tid tid) { auto msg = partial_messages.find(tid); if (msg != partial_messages.end()) { - const auto& now = scheduler.time(); + const auto& now = time(); if (msg->second.start + RX_MAX_PACKET_TIME < now || msg->second.last_part + RX_TIMEOUT < now) { - if (logger_) - logger_->w("Dropping expired partial message from %s", msg->second.from.toString().c_str()); + //if (logger_) + // logger_->w("Dropping expired partial message from %s", msg->second.from.toString().c_str()); partial_messages.erase(msg); } } diff --git a/src/network_utils.cpp b/src/network_utils.cpp index 8f523b8c..a7bb8ac8 100644 --- a/src/network_utils.cpp +++ b/src/network_utils.cpp @@ -34,330 +34,84 @@ namespace dht { namespace net { +/* -int -bindSocket(const SockAddr& addr, SockAddr& bound) -{ - bool is_ipv6 = addr.getFamily() == AF_INET6; - int sock = socket(is_ipv6 ? PF_INET6 : PF_INET, SOCK_DGRAM, 0); - if (sock < 0) - throw DhtException(std::string("Can't open socket: ") + strerror(sock)); - int set = 1; -#ifdef SO_NOSIGPIPE - setsockopt(sock, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&set, sizeof(set)); -#endif - if (is_ipv6) - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&set, sizeof(set)); - net::setNonblocking(sock); - int rc = bind(sock, addr.get(), addr.getLength()); - if (rc < 0) { - rc = errno; - close(sock); - throw DhtException("Can't bind socket on " + addr.toString() + " " + strerror(rc)); - } - sockaddr_storage ss; - socklen_t ss_len = sizeof(ss); - getsockname(sock, (sockaddr*)&ss, &ss_len); - bound = {ss, ss_len}; - return sock; -} - -bool -setNonblocking(int fd, bool nonblocking) -{ -#ifdef _WIN32 - unsigned long mode = !!nonblocking; - int rc = ioctlsocket(fd, FIONBIO, &mode); - return rc == 0; -#else - int rc = fcntl(fd, F_GETFL, 0); - if (rc < 0) - return false; - rc = fcntl(fd, F_SETFL, nonblocking ? (rc | O_NONBLOCK) : (rc & ~O_NONBLOCK)); - return rc >= 0; -#endif -} - -#ifdef _WIN32 -void udpPipe(int fds[2]) +UdpSocket::UdpSocket(in_port_t port, const std::shared_ptr<Logger>& l) + : logger(l) { - int lst = socket(AF_INET, SOCK_DGRAM, 0); - if (lst < 0) - throw DhtException(std::string("Can't open socket: ") + strerror(WSAGetLastError())); - sockaddr_in inaddr; - sockaddr addr; - memset(&inaddr, 0, sizeof(inaddr)); - memset(&addr, 0, sizeof(addr)); - inaddr.sin_family = AF_INET; - inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - inaddr.sin_port = 0; - int yes = 1; - setsockopt(lst, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(yes)); - int rc = bind(lst, (sockaddr*)&inaddr, sizeof(inaddr)); - if (rc < 0) { - close(lst); - throw DhtException("Can't bind socket on " + print_addr((sockaddr*)&inaddr, sizeof(inaddr)) + " " + strerror(rc)); + asio::io_service io_service; + asio::ip::udp::endpoint endpoint(asio::ip::udp::v4(), port); + s4.open(endpoint.protocol()); + s4.set_option(asio::ip::udp::socket::reuse_address(true)); + s4.bind(endpoint); + s4.non_blocking(true); + + try { + asio::ip::udp::endpoint endpoint(asio::ip::udp::v6(), port); + s6.open(endpoint.protocol()); + s6.set_option(asio::ip::udp::socket::reuse_address(true)); + s6.bind(endpoint); + s6.non_blocking(true); + } catch (...) { } - socklen_t len = sizeof(addr); - getsockname(lst, &addr, &len); - fds[0] = lst; - fds[1] = socket(AF_INET, SOCK_DGRAM, 0); - connect(fds[1], &addr, len); -} -#endif - -UdpSocket::UdpSocket(in_port_t port, const std::shared_ptr<Logger>& l) : logger(l) { - SockAddr bind4; - bind4.setFamily(AF_INET); - bind4.setPort(port); - SockAddr bind6; - bind6.setFamily(AF_INET6); - bind6.setPort(port); - std::lock_guard<std::mutex> lk(lock); - openSockets(bind4, bind6); } -UdpSocket::UdpSocket(const SockAddr& bind4, const SockAddr& bind6, const std::shared_ptr<Logger>& l) : logger(l) +UdpSocket::UdpSocket(const SockAddr& bind4, const SockAddr& bind6, const std::shared_ptr<Logger>& l) + : logger(l) { - std::lock_guard<std::mutex> lk(lock); - openSockets(bind4, bind6); -} - -UdpSocket::~UdpSocket() { - stop(); - if (rcv_thread.joinable()) - rcv_thread.join(); -} - -int -UdpSocket::sendTo(const SockAddr& dest, const uint8_t* data, size_t size, bool replied) { - if (not dest) - return EFAULT; - - int s; - switch (dest.getFamily()) { - case AF_INET: s = s4; break; - case AF_INET6: s = s6; break; - default: s = -1; break; + if (bind4.isSet()) { + asio::io_service io_service; + asio::ip::udp::endpoint endpoint(asio::ip::udp::v4(), bind4.getPort()); + s4.open(endpoint.protocol()); + s4.set_option(asio::ip::udp::socket::reuse_address(true)); + s4.bind(endpoint); + s4.non_blocking(true); } - if (s < 0) - return EAFNOSUPPORT; - - int flags = 0; -#ifdef MSG_CONFIRM - if (replied) - flags |= MSG_CONFIRM; -#else - (void) replied; -#endif -#ifdef MSG_NOSIGNAL - flags |= MSG_NOSIGNAL; -#endif - - if (sendto(s, (const char*)data, size, flags, dest.get(), dest.getLength()) == -1) { - int err = errno; - if (logger) - logger->d("Can't send message to %s: %s", dest.toString().c_str(), strerror(err)); - if (err == EPIPE || err == ENOTCONN || err == ECONNRESET) { - std::lock_guard<std::mutex> lk(lock); - auto bind4 = std::move(bound4), bind6 = std::move(bound6); - openSockets(bind4, bind6); - return sendTo(dest, data, size, false); + if (bind6.isSet()) { + try { + asio::io_service io_service; + asio::ip::udp::endpoint endpoint(asio::ip::udp::v6(), bind6.getPort()); + s6.open(endpoint.protocol()); + s6.set_option(asio::ip::udp::socket::reuse_address(true)); + s6.bind(endpoint); + s6.non_blocking(true); + } catch (...) { } - return err; } - return 0; } -void -UdpSocket::openSockets(const SockAddr& bind4, const SockAddr& bind6) +UdpSocket::~UdpSocket() { stop(); - if (rcv_thread.joinable()) - rcv_thread.join(); - - int stopfds[2]; -#ifndef _WIN32 - auto status = pipe(stopfds); - if (status == -1) { - throw DhtException(std::string("Can't open pipe: ") + strerror(errno)); - } -#else - udpPipe(stopfds); -#endif - int stop_readfd = stopfds[0]; - - stopfd = stopfds[1]; - s4 = -1; - s6 = -1; - - bound4 = {}; - if (bind4) { - try { - s4 = bindSocket(bind4, bound4); - } catch (const DhtException& e) { - if (logger) - logger->e("Can't bind inet socket: %s", e.what()); - } - } - -#if 1 - bound6 = {}; - if (bind6) { - if (bind6.getPort() == 0) { - // Attempt to use the same port as IPv4 with IPv6 - if (auto p4 = bound4.getPort()) { - auto b6 = bind6; - b6.setPort(p4); - try { - s6 = bindSocket(b6, bound6); - } catch (const DhtException& e) { - if (logger) - logger->e("Can't bind inet6 socket: %s", e.what()); - } - } - } - if (s6 == -1) { - try { - s6 = bindSocket(bind6, bound6); - } catch (const DhtException& e) { - if (logger) - logger->e("Can't bind inet6 socket: %s", e.what()); - } - } - } -#endif +} - if (s4 == -1 && s6 == -1) { - throw DhtException("Can't bind socket"); +int UdpSocket::sendTo(const SockAddr& dest, const uint8_t* data, size_t size, bool replied) +{ + asio::ip::udp::socket* sock = nullptr; + if (dest.isV4()) + sock = &s4; + else if (dest.isV6()) + sock = &s6; + else + return -1; + + try { + asio::ip::udp::endpoint endpoint(dest.toAsio()); + return sock->send_to(asio::buffer(data, size), endpoint); + } catch (...) { + return -1; } - - running = true; - rcv_thread = std::thread([this, stop_readfd, ls4=s4, ls6=s6]() mutable { - int selectFd = std::max({ls4, ls6, stop_readfd}) + 1; - try { - while (running) { - fd_set readfds; - - FD_ZERO(&readfds); - FD_SET(stop_readfd, &readfds); - if(ls4 >= 0) - FD_SET(ls4, &readfds); - if(ls6 >= 0) - FD_SET(ls6, &readfds); - - int rc = select(selectFd, &readfds, nullptr, nullptr, nullptr); - if (rc < 0) { - if (errno != EINTR) { - if (logger) - logger->e("Select error: %s", strerror(errno)); - std::this_thread::sleep_for(std::chrono::seconds(1)); - } - } - - if (not running) - break; - - if (rc > 0) { - std::array<uint8_t, 1024 * 64> buf; - sockaddr_storage from; - socklen_t from_len = sizeof(from); - - if (FD_ISSET(stop_readfd, &readfds)) { - if (recv(stop_readfd, (char*)buf.data(), buf.size(), 0) < 0) { - if (logger) - logger->e("Got stop packet error: %s", strerror(errno)); - break; - } - } - else if (ls4 >= 0 && FD_ISSET(ls4, &readfds)) - rc = recvfrom(ls4, (char*)buf.data(), buf.size(), 0, (sockaddr*)&from, &from_len); - else if (ls6 >= 0 && FD_ISSET(ls6, &readfds)) - rc = recvfrom(ls6, (char*)buf.data(), buf.size(), 0, (sockaddr*)&from, &from_len); - else - continue; - - if (rc > 0) { - auto pkts = getNewPacket(); - auto& pkt = pkts.front(); - pkt.data.insert(pkt.data.end(), buf.begin(), buf.begin()+rc); - pkt.from = {from, from_len}; - pkt.received = clock::now(); - onReceived(std::move(pkts)); - } else if (rc == -1) { - if (logger) - logger->e("Error receiving packet: %s", strerror(errno)); - int err = errno; - if (err == EPIPE || err == ENOTCONN || err == ECONNRESET) { - if (not running) break; - std::unique_lock<std::mutex> lk(lock, std::try_to_lock); - if (lk.owns_lock()) { - if (not running) break; - if (ls4 >= 0) { - close(ls4); - try { - ls4 = bindSocket(bound4, bound4); - } catch (const DhtException& e) { - if (logger) - logger->e("Can't bind inet socket: %s", e.what()); - } - } - if (ls6 >= 0) { - close(ls6); - try { - ls6 = bindSocket(bound6, bound6); - } catch (const DhtException& e) { - if (logger) - logger->e("Can't bind inet6 socket: %s", e.what()); - } - } - if (ls4 < 0 && ls6 < 0) - break; - s4 = ls4; - s6 = ls6; - selectFd = std::max({ls4, ls6, stop_readfd}) + 1; - } else { - break; - } - } - } - } - } - } catch (const std::exception& e) { - if (logger) - logger->e("Error in UdpSocket rx thread: %s", e.what()); - } - if (ls4 >= 0) - close(ls4); - if (ls6 >= 0) - close(ls6); - if (stop_readfd != -1) - close(stop_readfd); - if (stopfd != -1) - close(stopfd); - std::unique_lock<std::mutex> lk(lock, std::try_to_lock); - if (lk.owns_lock()) { - s4 = -1; - s6 = -1; - bound4 = {}; - bound6 = {}; - stopfd = -1; - } - }); } -void -UdpSocket::stop() +void UdpSocket::stop() { - if (running.exchange(false)) { - auto sfd = stopfd; - if (sfd != -1 && write(sfd, "\0", 1) == -1) { - if (logger) - logger->e("Can't write to stop fd"); - } - } + s4.cancel(); + s4.close(); + s6.cancel(); + s6.close(); } +*/ } } diff --git a/src/node.cpp b/src/node.cpp index 1e7df287..c7f9cda9 100644 --- a/src/node.cpp +++ b/src/node.cpp @@ -167,7 +167,7 @@ Node::toString() const std::ostream& operator<< (std::ostream& s, const Node& h) { - s << h.id << " " << h.addr.toString(); + s << h.id << " " << h.addr; return s; } @@ -188,7 +188,8 @@ NodeExport::msgpack_unpack(msgpack::object o) if (maddr.via.bin.size > sizeof(sockaddr_storage)) throw msgpack::type_error(); id.msgpack_unpack(o.via.map.ptr[0].val); - addr = {(const sockaddr*)maddr.via.bin.ptr, (socklen_t)maddr.via.bin.size}; + //addr = {(const sockaddr*)maddr.via.bin.ptr, (socklen_t)maddr.via.bin.size}; + // TODO } std::ostream& operator<< (std::ostream& s, const NodeExport& h) diff --git a/src/node_cache.cpp b/src/node_cache.cpp index 34366c4e..7ef8f8e6 100644 --- a/src/node_cache.cpp +++ b/src/node_cache.cpp @@ -38,7 +38,7 @@ Sp<Node> NodeCache::getNode(const InfoHash& id, const SockAddr& addr, time_point now, bool confirm, bool client) { if (not id) return std::make_shared<Node>(id, addr, rd, client); - return cache(addr.getFamily()).getNode(id, addr, now, confirm, client, rd); + return cache(addr.protocol().family()).getNode(id, addr, now, confirm, client, rd); } std::vector<Sp<Node>> diff --git a/src/parsed_message.h b/src/parsed_message.h index a1441028..208f6fcb 100644 --- a/src/parsed_message.h +++ b/src/parsed_message.h @@ -18,7 +18,6 @@ #pragma once #include "infohash.h" -#include "sockaddr.h" #include "net.h" #include <map> @@ -313,15 +312,29 @@ ParsedMessage::msgpack_unpack(const msgpack::object& msg) throw msgpack::type_error(); auto l = parsedReq.sa->via.bin.size; if (l == sizeof(in_addr)) { - addr.setFamily(AF_INET); - auto& a = addr.getIPv4(); - a.sin_port = 0; - std::copy_n(parsedReq.sa->via.bin.ptr, l, (char*)&a.sin_addr); + addr = {asio::ip::address_v4({ + (uint8_t)parsedReq.sa->via.bin.ptr[0], + (uint8_t)parsedReq.sa->via.bin.ptr[1], + (uint8_t)parsedReq.sa->via.bin.ptr[2], + (uint8_t)parsedReq.sa->via.bin.ptr[3]}), 0}; } else if (l == sizeof(in6_addr)) { - addr.setFamily(AF_INET6); - auto& a = addr.getIPv6(); - a.sin6_port = 0; - std::copy_n(parsedReq.sa->via.bin.ptr, l, (char*)&a.sin6_addr); + addr = {asio::ip::address_v6({ + (uint8_t)parsedReq.sa->via.bin.ptr[0], + (uint8_t)parsedReq.sa->via.bin.ptr[1], + (uint8_t)parsedReq.sa->via.bin.ptr[2], + (uint8_t)parsedReq.sa->via.bin.ptr[3], + (uint8_t)parsedReq.sa->via.bin.ptr[4], + (uint8_t)parsedReq.sa->via.bin.ptr[5], + (uint8_t)parsedReq.sa->via.bin.ptr[6], + (uint8_t)parsedReq.sa->via.bin.ptr[7], + (uint8_t)parsedReq.sa->via.bin.ptr[8], + (uint8_t)parsedReq.sa->via.bin.ptr[9], + (uint8_t)parsedReq.sa->via.bin.ptr[10], + (uint8_t)parsedReq.sa->via.bin.ptr[11], + (uint8_t)parsedReq.sa->via.bin.ptr[12], + (uint8_t)parsedReq.sa->via.bin.ptr[13], + (uint8_t)parsedReq.sa->via.bin.ptr[14], + (uint8_t)parsedReq.sa->via.bin.ptr[15]}), 0}; } } else addr = {}; diff --git a/src/peer_discovery.cpp b/src/peer_discovery.cpp index aa031765..d9da2509 100644 --- a/src/peer_discovery.cpp +++ b/src/peer_discovery.cpp @@ -19,7 +19,6 @@ */ #include "peer_discovery.h" -#include "network_utils.h" #include "utils.h" #include <asio.hpp> @@ -153,7 +152,7 @@ PeerDiscovery::DomainPeerDiscovery::loopListener() return; } if (cb) - cb(std::move(o.val), SockAddr{ receiveFrom_.data(), (socklen_t)receiveFrom_.size() }); + cb(std::move(o.val), receiveFrom_); } } else { throw msgpack::type_error{}; diff --git a/src/request.h b/src/request.h index 318d120a..4d1ad44b 100644 --- a/src/request.h +++ b/src/request.h @@ -21,6 +21,7 @@ #include "net.h" #include "value.h" +#include <asio/steady_timer.hpp> namespace dht { struct Node; @@ -140,6 +141,7 @@ private: duration attempt_duration {((duration)Node::MAX_RESPONSE_TIME)/2}; time_point start {time_point::min()}; /* time when the request is created. */ time_point last_try {time_point::min()}; /* time of the last attempt to process the request. */ + std::unique_ptr<asio::steady_timer> expiration_timer; std::function<void(const Request&, ParsedMessage&&)> on_done {}; std::function<bool(const Request&, DhtProtocolException&&)> on_error {}; diff --git a/src/search.h b/src/search.h index 717bf8ac..48e4d2e1 100644 --- a/src/search.h +++ b/src/search.h @@ -23,6 +23,10 @@ #include "listener.h" #include "value_cache.h" #include "op_cache.h" +#include "node.h" +#include "dht.h" + +#include <asio/steady_timer.hpp> namespace dht { @@ -52,7 +56,7 @@ struct Dht::SearchNode { struct AnnounceStatus { Sp<net::Request> req {}; - Sp<Scheduler::Job> refresh {}; + std::unique_ptr<asio::steady_timer> refresh {}; time_point refresh_time; AnnounceStatus(){}; AnnounceStatus(Sp<net::Request> r, time_point t): req(std::move(r)), refresh_time(t) @@ -72,8 +76,9 @@ struct Dht::SearchNode { struct CachedListenStatus { ValueCache cache; - Sp<Scheduler::Job> refresh {}; - Sp<Scheduler::Job> cacheExpirationJob {}; + std::unique_ptr<asio::steady_timer> refresh {}; + std::unique_ptr<asio::steady_timer> cacheExpirationJob {}; + std::function<void()> onCacheExpired {}; Sp<net::Request> req {}; Tid socketId {0}; CachedListenStatus(ValueStateCallback&& cb, SyncCallback scb, Tid sid) @@ -85,6 +90,14 @@ struct Dht::SearchNode { req->node->closeSocket(socketId); } } + + void scheduleCacheExpired(const time_point& t) { + cacheExpirationJob->expires_at(t); + cacheExpirationJob->async_wait([cb = onCacheExpired](const asio::error_code &ec){ + if (ec != asio::error::operation_aborted) + cb(); + }); + } }; using NodeListenerStatus = std::map<Sp<Query>, CachedListenStatus>; @@ -101,7 +114,7 @@ struct Dht::SearchNode { Blob token {}; /* last token the node sent to us after a get request */ time_point last_get_reply {time_point::min()}; /* last time received valid token */ - Sp<Scheduler::Job> syncJob {}; + std::unique_ptr<asio::steady_timer> syncJob {}; bool candidate {false}; /* A search node is candidate if the search is/was synced and this node is a new candidate for inclusion. */ @@ -238,18 +251,17 @@ struct Dht::SearchNode { getStatus.clear(); } - void onValues(const Sp<Query>& q, net::RequestAnswer&& answer, const TypeStore& types, Scheduler& scheduler) + void onValues(const Sp<Query>& q, net::RequestAnswer&& answer, const TypeStore& types, const time_point& now) { auto l = listenStatus.find(q); if (l != listenStatus.end()) { - auto next = l->second.cache.onValues(answer.values, + l->second.scheduleCacheExpired(l->second.cache.onValues(answer.values, answer.refreshed_values, - answer.expired_values, types, scheduler.time()); - scheduler.edit(l->second.cacheExpirationJob, next); + answer.expired_values, types, now)); } } - void onListenSynced(const Sp<Query>& q, bool synced = true, Sp<Scheduler::Job> refreshJob = {}) { + void onListenSynced(const Sp<Query>& q, bool synced = true, std::unique_ptr<asio::steady_timer> refreshJob = {}) { auto l = listenStatus.find(q); if (l != listenStatus.end()) { if (l->second.refresh) @@ -259,11 +271,10 @@ struct Dht::SearchNode { } } - void expireValues(const Sp<Query>& q, Scheduler& scheduler) { + void expireValues(const Sp<Query>& q, const time_point& now) { auto l = listenStatus.find(q); if (l != listenStatus.end()) { - auto next = l->second.cache.expireValues(scheduler.time()); - scheduler.edit(l->second.cacheExpirationJob, next); + l->second.scheduleCacheExpired(l->second.cache.expireValues(now)); } } @@ -409,7 +420,17 @@ struct Dht::Search { uint16_t tid; time_point refill_time {time_point::min()}; time_point step_time {time_point::min()}; /* the time of the last search step */ - Sp<Scheduler::Job> nextSearchStep {}; + std::unique_ptr<asio::steady_timer> nextSearchStep {}; + std::function<void()> onSearchStep {}; + void scheduleStep(const time_point& now) { + if (nextSearchStep) { + nextSearchStep->expires_at(now); + nextSearchStep->async_wait([cb=onSearchStep](const asio::error_code &ec){ + if (ec != asio::error::operation_aborted) + cb(); + }); + } + } bool expired {false}; /* no node, or all nodes expired */ bool done {false}; /* search is over, cached for later */ @@ -432,7 +453,9 @@ struct Dht::Search { /* Cache */ SearchCache cache; - Sp<Scheduler::Job> opExpirationJob; + //std::unique_ptr<asio::steady_timer> opExpirationJob; + std::unique_ptr<asio::steady_timer> opExpirationJob {}; + ~Search() { if (opExpirationJob) @@ -537,46 +560,57 @@ struct Dht::Search { bool isAnnounced(Value::Id id) const; bool isListening(time_point now, duration exp) const; - void get(const Value::Filter& f, const Sp<Query>& q, const QueryCallback& qcb, const GetCallback& gcb, const DoneCallback& dcb, Scheduler& scheduler) { + void get(const Value::Filter& f, const Sp<Query>& q, const QueryCallback& qcb, const GetCallback& gcb, const DoneCallback& dcb, const time_point& now) { if (gcb or qcb) { if (not cache.get(f, q, gcb, dcb)) { - const auto& now = scheduler.time(); callbacks.emplace(now, Get { now, f, q, qcb, gcb, dcb }); - scheduler.edit(nextSearchStep, now); + scheduleStep(now); } } } - size_t listen(const ValueCallback& cb, Value::Filter&& f, const Sp<Query>& q, Scheduler& scheduler) { + size_t listen(const ValueCallback& cb, Value::Filter&& f, const Sp<Query>& q, const time_point& now) { //DHT_LOG.e(id, "[search %s IPv%c] listen", id.toString().c_str(), (af == AF_INET) ? '4' : '6'); return cache.listen(cb, q, std::move(f), [&](const Sp<Query>& q, ValueCallback vcb, SyncCallback scb){ done = false; auto token = ++listener_token; listeners.emplace(token, SearchListener{q, vcb, scb}); - scheduler.edit(nextSearchStep, scheduler.time()); + scheduleStep(now); return token; }); } - void cancelListen(size_t token, Scheduler& scheduler) { - cache.cancelListen(token, scheduler.time()); - if (not opExpirationJob) - opExpirationJob = scheduler.add(time_point::max(), [this,&scheduler]{ - auto nextExpire = cache.expire(scheduler.time(), [&](size_t t){ - const auto& ll = listeners.find(t); - if (ll != listeners.cend()) { - auto query = ll->second.query; - listeners.erase(ll); - if (listeners.empty()) { - for (auto& sn : nodes) sn->cancelListen(); - } else if (query) { - for (auto& sn : nodes) sn->cancelListen(query); - } - } - }); - scheduler.edit(opExpirationJob, nextExpire); - }); - scheduler.edit(opExpirationJob, cache.getExpiration()); + time_point expireListeners(const time_point& now) { + return cache.expire(now, [&](size_t t){ + const auto& ll = listeners.find(t); + if (ll != listeners.cend()) { + auto query = ll->second.query; + listeners.erase(ll); + if (listeners.empty()) { + for (auto& sn : nodes) sn->cancelListen(); + } else if (query) { + for (auto& sn : nodes) sn->cancelListen(query); + } + } + }); + } + + void scheduleListenerExpiration() { + opExpirationJob->async_wait([this](const asio::error_code &ec){ + if (ec != asio::error::operation_aborted) { + opExpirationJob->expires_at(expireListeners(clock::now())); + scheduleListenerExpiration(); + } + }); + } + + void cancelListen(size_t token, const time_point& now) { + if (cache.cancelListen(token, now)) { + if (opExpirationJob) { + opExpirationJob->expires_at(cache.getExpiration()); + scheduleListenerExpiration(); + } + } } std::vector<Sp<Value>> getPut() const { diff --git a/src/storage.h b/src/storage.h index 2533214e..8ee0afce 100644 --- a/src/storage.h +++ b/src/storage.h @@ -72,7 +72,7 @@ struct ValueStorage { Sp<Value> data {}; time_point created {}; time_point expiration {}; - Sp<Scheduler::Job> expiration_job {}; + std::unique_ptr<asio::steady_timer> expiration_job {}; StorageBucket* store_bucket {nullptr}; ValueStorage() {} @@ -83,6 +83,7 @@ struct ValueStorage { struct Storage { time_point maintenance_time {}; + std::unique_ptr<asio::steady_timer> maintenance_job {}; std::map<Sp<Node>, std::map<size_t, Listener>> listeners; std::map<size_t, LocalListener> local_listeners {}; size_t listener_token {1}; diff --git a/src/udp_socket.cpp b/src/udp_socket.cpp new file mode 100644 index 00000000..64486f55 --- /dev/null +++ b/src/udp_socket.cpp @@ -0,0 +1,150 @@ +#include "udp_socket.h" + +namespace dht { +namespace net { + + +class UdpSocket::SocketHandler : public std::enable_shared_from_this<UdpSocket::SocketHandler> { +public: + SocketHandler(const std::shared_ptr<strand>& strand, const udp::endpoint& endpoint) + : strand_(strand), socket_(strand->context()) + { + asio::error_code ec; + socket_.open(endpoint.protocol(), ec); + if (ec) + throw std::runtime_error("Failed to open socket: " + ec.message()); + + socket_.set_option(asio::socket_base::reuse_address(true), ec); + if (ec) + throw std::runtime_error("Failed to set socket option: " + ec.message()); + + socket_.bind(endpoint, ec); + if (ec) + throw std::runtime_error("Failed to bind socket: " + ec.message()); + } + void set_receive_callback(const ReceiveCallback& callback) { + receive_callback_ = callback; + } + + void start_receive() { + receive_next(); + } + + void stop() { + socket_.close(); + } + + void send_to_async(std::vector<uint8_t> data, const udp::endpoint& to) { + auto ctx = std::make_shared<std::vector<uint8_t>>(std::move(data)); + socket_.async_send_to(asio::buffer(*ctx), to, [ctx](const asio::error_code&, std::size_t) {}); + } + + asio::error_code send_to(const uint8_t* buf, size_t len, const udp::endpoint& to) { + asio::error_code ec; + socket_.send_to(asio::buffer(buf, len), to, 0, ec); + return ec; + } + + udp::endpoint getBound() const { + return socket_.local_endpoint(); + } + +private: + void receive_next() { + socket_.async_receive_from(asio::buffer(receive_buffer_), + remote_endpoint_, + asio::bind_executor(*strand_, [self = shared_from_this()](const asio::error_code& error, std::size_t bytes) { + self->handle_receive(error, bytes); + self->receive_next(); + })); + } + + void handle_receive(const asio::error_code& error, std::size_t bytes) { + if (!error) { + if (receive_callback_) { + receive_callback_(ReceivedPacket{std::vector<uint8_t>(receive_buffer_.begin(), receive_buffer_.begin() + bytes), + remote_endpoint_/*, std::chrono::high_resolution_clock::now()*/}); + } + } else { + // Handle error + } + } + + std::shared_ptr<strand> strand_; + udp::socket socket_; + udp::endpoint remote_endpoint_; + std::array<uint8_t, 65536> receive_buffer_; + ReceiveCallback receive_callback_; +}; + +UdpSocket::UdpSocket(std::shared_ptr<strand> strand, const udp::endpoint& ipv4_endpoint, const udp::endpoint& ipv6_endpoint) + //: strand_(strand) +{ + try { + ipv4_handler_ = std::make_shared<SocketHandler>(strand, ipv4_endpoint); + } catch (const std::exception&) { + + } + + // Try to use the same port for IPv6 if not specified + udp::endpoint ipv6_endpoint_stack; + if (ipv6_endpoint.port() == 0) { + ipv6_endpoint_stack = ipv6_endpoint; + ipv6_endpoint_stack.port(ipv4_handler_->getBound().port()); + } + const auto& ipv6_endpoint_to_use = ipv6_endpoint.port() == 0 ? ipv6_endpoint_stack : ipv6_endpoint; + + try { + ipv6_handler_ = std::make_shared<SocketHandler>(strand, ipv6_endpoint_to_use); + } catch (const std::exception&) { + + } + if (!ipv4_handler_ && !ipv6_handler_) + throw std::runtime_error("Failed to bind sockets"); +} + +void UdpSocket::setOnReceive(const ReceiveCallback& callback) { + ipv4_handler_->set_receive_callback(callback); + ipv6_handler_->set_receive_callback(callback); + start_receive(); +} + +void UdpSocket::start_receive() { + ipv4_handler_->start_receive(); + ipv6_handler_->start_receive(); +} + +void UdpSocket::stop() { + ipv4_handler_->stop(); + ipv6_handler_->stop(); +} + +void UdpSocket::sendToAsync(std::vector<uint8_t> data, const udp::endpoint& to) { + auto handler = (to.address().is_v4()) ? ipv4_handler_ : ipv6_handler_; + handler->send_to_async(std::move(data), to); +} + +asio::error_code UdpSocket::sendTo(const uint8_t* buf, size_t len, const udp::endpoint& to) { + auto handler = (to.address().is_v4()) ? ipv4_handler_ : ipv6_handler_; + return handler->send_to(buf, len, to); +} + +bool UdpSocket::hasIPv4() const { + return ipv4_handler_ != nullptr; +} + +bool UdpSocket::hasIPv6() const { + return ipv6_handler_ != nullptr; +} + +udp::endpoint UdpSocket::getBound(sa_family_t af) const { + if (af == AF_INET && ipv4_handler_) { + return ipv4_handler_->getBound(); + } else if (af == AF_INET6 && ipv6_handler_) { + return ipv6_handler_->getBound(); + } + throw std::runtime_error("Invalid address family"); +} + +} // namespace net +} // namespace dht diff --git a/src/utils.cpp b/src/utils.cpp index 9bb945bd..3f81562f 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -21,7 +21,6 @@ #endif #include "utils.h" -#include "sockaddr.h" #include "default_types.h" /* An IPv4 equivalent to IN6_IS_ADDR_UNSPECIFIED */ @@ -41,8 +40,6 @@ const char* version() { const HexMap hex_map = {}; -static constexpr std::array<uint8_t, 12> MAPPED_IPV4_PREFIX {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}}; - std::pair<std::string, std::string> splitPort(const std::string& s) { if (s.empty()) @@ -63,6 +60,7 @@ splitPort(const std::string& s) { return {s.substr(0,found), s.substr(found+1)}; } +/* std::vector<SockAddr> SockAddr::resolve(const std::string& host, const std::string& service) { @@ -102,6 +100,7 @@ SockAddr::setAddress(const char* address) if (inet_pton(family, address, addr) <= 0) throw std::runtime_error(std::string("Can't parse IP address: ") + strerror(errno)); } +*/ std::string print_addr(const sockaddr* sa, socklen_t slen) @@ -127,6 +126,7 @@ print_addr(const sockaddr_storage& ss, socklen_t sslen) return print_addr((const sockaddr*)&ss, sslen); } +/* bool SockAddr::isUnspecified() const { @@ -232,7 +232,7 @@ SockAddr::getMappedIPv6() bool operator==(const SockAddr& a, const SockAddr& b) { return a.equals(b); -} +}*/ time_point from_time_t(std::time_t t) { auto dt = system_clock::from_time_t(t) - system_clock::now(); diff --git a/tests/dhtproxytester.cpp b/tests/dhtproxytester.cpp index d2b541e9..6e70adbb 100644 --- a/tests/dhtproxytester.cpp +++ b/tests/dhtproxytester.cpp @@ -62,6 +62,8 @@ DhtProxyTester::tearDown() { nodePeer.join(); nodeClient.join(); + serverProxy->stop(); + bool done = false; std::condition_variable cv; std::mutex cv_m; diff --git a/tests/dhtrunnertester.cpp b/tests/dhtrunnertester.cpp index f56caad4..90ed63bd 100644 --- a/tests/dhtrunnertester.cpp +++ b/tests/dhtrunnertester.cpp @@ -64,9 +64,9 @@ DhtRunnerTester::tearDown() { void DhtRunnerTester::testConstructors() { CPPUNIT_ASSERT(node1.getBoundPort()); - CPPUNIT_ASSERT_EQUAL(node1.getBoundPort(), node1.getBound().getPort()); + CPPUNIT_ASSERT_EQUAL(node1.getBoundPort(), node1.getBound().port()); CPPUNIT_ASSERT(node2.getBoundPort()); - CPPUNIT_ASSERT_EQUAL(node2.getBoundPort(), node2.getBound().getPort()); + CPPUNIT_ASSERT_EQUAL(node2.getBoundPort(), node2.getBound().port()); dht::DhtRunner::Config config {}; dht::DhtRunner::Context context {}; diff --git a/tests/httptester.cpp b/tests/httptester.cpp index 342b2b95..92320b70 100644 --- a/tests/httptester.cpp +++ b/tests/httptester.cpp @@ -47,6 +47,7 @@ HttpTester::setUp() { void HttpTester::tearDown() { + serverProxy->stop(); serverProxy.reset(); nodePeer->join(); } diff --git a/tools/tools_common.h b/tools/tools_common.h index cfeade02..d599d77f 100644 --- a/tools/tools_common.h +++ b/tools/tools_common.h @@ -24,7 +24,7 @@ #include <opendht.h> #include <opendht/log.h> #include <opendht/crypto.h> -#include <opendht/network_utils.h> +//#include <opendht/network_utils.h> #ifdef OPENDHT_INDEXATION #include <opendht/indexation/pht.h> #endif @@ -187,8 +187,8 @@ getDhtConfig(dht_params& params) context.logger = dht::log::getStdLogger(); } if (context.logger) { - context.statusChangedCallback = [logger = *context.logger](dht::NodeStatus status4, dht::NodeStatus status6) { - logger.WARN("Connectivity changed: IPv4: %s, IPv6: %s", dht::statusToStr(status4), dht::statusToStr(status6)); + context.statusChangedCallback = [logger = context.logger](dht::DhtNodeStatus status) { + logger->WARN("Connectivity changed: IPv4: %s, IPv6: %s", dht::statusToStr(status.ipv4), dht::statusToStr(status.ipv6)); }; } return {std::move(config), std::move(context)}; -- GitLab