diff --git a/CMakeLists.txt b/CMakeLists.txt index 0fbc9ee7012a656de7e2240c1f1f3e553c2850be..a764ad566702442d2f911bbc048b73b16bf5d042 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,7 +151,6 @@ list (APPEND opendht_SOURCES src/dhtrunner.cpp src/log.cpp src/peer_discovery.cpp - src/network_utils.h src/network_utils.cpp src/thread_pool.cpp ) @@ -179,6 +178,7 @@ list (APPEND opendht_HEADERS include/opendht/log_enable.h include/opendht/peer_discovery.h include/opendht/thread_pool.h + include/opendht/network_utils.h include/opendht.h ) diff --git a/include/opendht/dht.h b/include/opendht/dht.h index 266ff0a22cd693e289375e66d069b9186b47451d..4519a1de2e266d1b410f0ebfba358502982baa27 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -68,7 +68,7 @@ public: * Initialise the Dht with two open sockets (for IPv4 and IP6) * and an ID for the node. */ - Dht(const int& s, const int& s6, const Config& config, const Logger& l = {}); + Dht(std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, const Logger& l = {}); virtual ~Dht(); /** @@ -85,6 +85,8 @@ public: return std::max(getStatus(AF_INET), getStatus(AF_INET6)); } + net::DatagramSocket* getSocket() const override { return network_engine.getSocket(); }; + /** * Performs final operations before quitting. */ diff --git a/include/opendht/dht_interface.h b/include/opendht/dht_interface.h index d3dac19d640885fd4db4ae6c4067fbf187991a90..0cfdf3e16628df59dc0a94191582ca2e43328f3b 100644 --- a/include/opendht/dht_interface.h +++ b/include/opendht/dht_interface.h @@ -23,6 +23,10 @@ namespace dht { +namespace net { + class DatagramSocket; +} + class OPENDHT_PUBLIC DhtInterface { public: DhtInterface() = default; @@ -40,6 +44,8 @@ public: virtual NodeStatus getStatus(sa_family_t af) const = 0; virtual NodeStatus getStatus() const = 0; + virtual net::DatagramSocket* getSocket() const { return {}; }; + /** * Get the ID of the DHT node. */ diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index ec7a07bb6bb352c69957a5eec9f2929cd3432afb..97d578230581ab7c02b6616727e63f87be44be11 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -20,12 +20,13 @@ #pragma once +#include "def.h" #include "infohash.h" #include "value.h" #include "callbacks.h" #include "sockaddr.h" #include "log_enable.h" -#include "def.h" +#include "network_utils.h" #include <thread> #include <mutex> @@ -304,9 +305,7 @@ public: * Returns the currently bound address. * @param f: address family of the bound address to retreive. */ - const SockAddr& getBound(sa_family_t f = AF_INET) const { - return (f == AF_INET) ? bound4 : bound6; - } + SockAddr getBound(sa_family_t f = AF_INET) const; /** * Returns the currently bound port, in host byte order. @@ -412,13 +411,7 @@ public: */ time_point loop() { std::lock_guard<std::mutex> lck(dht_mtx); - time_point wakeup = time_point::min(); - try { - wakeup = loop_(); - } catch (const dht::SocketException& e) { - startNetwork(bound4, bound6); - } - return wakeup; + return loop_(); } /** @@ -470,8 +463,6 @@ private: */ void tryBootstrapContinuously(); - void stopNetwork(); - void startNetwork(const SockAddr sin4, const SockAddr sin6); time_point loop_(); NodeStatus getStatus() const { @@ -509,16 +500,8 @@ private: mutable std::mutex dht_mtx {}; std::thread dht_thread {}; std::condition_variable cv {}; - - std::thread rcv_thread {}; std::mutex sock_mtx {}; - - struct ReceivedPacket { - Blob data; - SockAddr from; - time_point received; - }; - std::queue<ReceivedPacket> rcv {}; + std::queue<std::unique_ptr<net::ReceivedPacket>> rcv {}; /** true if currently actively boostraping */ std::atomic_bool bootstraping {false}; @@ -526,7 +509,6 @@ private: std::vector<std::pair<std::string,std::string>> bootstrap_nodes_all {}; std::vector<std::pair<std::string,std::string>> bootstrap_nodes {}; std::thread bootstrap_thread {}; - /** protects bootstrap_nodes, bootstrap_thread */ std::mutex bootstrap_mtx {}; std::condition_variable bootstrap_cv {}; @@ -535,17 +517,11 @@ private: std::mutex storage_mtx {}; std::atomic_bool running {false}; - std::atomic_bool running_network {false}; NodeStatus status4 {NodeStatus::Disconnected}, status6 {NodeStatus::Disconnected}; StatusCallback statusCb {nullptr}; - int stop_writefd {-1}; - int s4 {-1}, s6 {-1}; - SockAddr bound4 {}; - SockAddr bound6 {}; - /** Push notification token */ std::string pushToken_; diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 9e623825d07e871207b0e952f27f33a267f695b1..f3da45f4b4ee8457fb485653a7b9c139ad344dbd 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -28,6 +28,7 @@ #include "rng.h" #include "rate_limiter.h" #include "log_enable.h" +#include "network_utils.h" #include <vector> #include <string> @@ -205,8 +206,8 @@ public: using RequestCb = std::function<void(const Request&, RequestAnswer&&)>; using RequestExpiredCb = std::function<void(const Request&, bool)>; - NetworkEngine(Logger& log, Scheduler& scheduler, const int& s = -1, const int& s6 = -1); - NetworkEngine(InfoHash& myid, NetId net, const int& s, const int& s6, Logger& log, Scheduler& scheduler, + NetworkEngine(Logger& log, Scheduler& scheduler, std::unique_ptr<DatagramSocket>&& sock); + NetworkEngine(InfoHash& myid, NetId net, std::unique_ptr<DatagramSocket>&& sock, Logger& log, Scheduler& scheduler, decltype(NetworkEngine::onError)&& onError, decltype(NetworkEngine::onNewNode)&& onNewNode, decltype(NetworkEngine::onReportedAddr)&& onReportedAddr, @@ -217,7 +218,9 @@ public: decltype(NetworkEngine::onAnnounce)&& onAnnounce, decltype(NetworkEngine::onRefresh)&& onRefresh); - virtual ~NetworkEngine(); + ~NetworkEngine(); + + net::DatagramSocket* getSocket() const { return dht_socket.get(); }; void clear(); @@ -242,7 +245,7 @@ public: void tellListenerExpired(Sp<Node> n, Tid socket_id, const InfoHash& hash, const Blob& ntoken, const std::vector<Value::Id>& values); bool isRunning(sa_family_t af) const; - inline want_t want () const { return dht_socket >= 0 && dht_socket6 >= 0 ? (WANT4 | WANT6) : -1; } + inline want_t want () const { return dht_socket->hasIPv4() and dht_socket->hasIPv6() ? (WANT4 | WANT6) : -1; } void connectivityChanged(sa_family_t); @@ -469,7 +472,7 @@ private: // basic wrapper for socket sendto function - int send(const char *buf, size_t len, int flags, const SockAddr& addr); + int send(const SockAddr& addr, const char *buf, size_t len, bool confirmed = false); void sendValueParts(const TransId& tid, const std::vector<Blob>& svals, const SockAddr& addr); std::vector<Blob> packValueHeader(msgpack::sbuffer&, const std::vector<Sp<Value>>&); @@ -511,8 +514,7 @@ private: /* DHT info */ const InfoHash& myid; const NetId network {0}; - const int& dht_socket; - const int& dht_socket6; + const std::unique_ptr<DatagramSocket> dht_socket; const Logger& DHT_LOG; NodeCache cache {}; diff --git a/include/opendht/network_utils.h b/include/opendht/network_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..277fbb6617b13afbf40b763921c71a69e3290c63 --- /dev/null +++ b/include/opendht/network_utils.h @@ -0,0 +1,113 @@ +/* + * Copyright (C) 2019 Savoir-faire Linux Inc. + * Author(s) : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ +#pragma once + +#include "def.h" + +#include "sockaddr.h" +#include "utils.h" + +#ifdef _WIN32 +#include <ws2tcpip.h> +#include <winsock2.h> +#else +#include <sys/socket.h> +#include <netinet/in.h> +#include <unistd.h> +#endif + +#include <functional> +#include <thread> +#include <atomic> +#include <iostream> + +namespace dht { +namespace net { + +int bindSocket(const SockAddr& addr, SockAddr& bound); + +bool setNonblocking(int fd, bool nonblocking = true); + +#ifdef _WIN32 +void udpPipe(int fds[2]); +#endif +struct ReceivedPacket { + Blob data; + SockAddr from; + time_point received; +}; + +class OPENDHT_PUBLIC DatagramSocket { +public: + using OnReceive = std::function<void(std::unique_ptr<ReceivedPacket>&& packet)>; + virtual ~DatagramSocket() {}; + + virtual int sendTo(const SockAddr& dest, const uint8_t* data, size_t size, bool replied) = 0; + + inline void setOnReceive(OnReceive&& cb) { + rx_callback = std::move(cb); + } + + virtual const SockAddr& getBound(sa_family_t family = AF_UNSPEC) const = 0; + virtual bool hasIPv4() const = 0; + virtual bool hasIPv6() const = 0; + + in_port_t getPort(sa_family_t family = AF_UNSPEC) const { + return getBound(family).getPort(); + } + + virtual void stop() = 0; +protected: + + inline void onReceived(std::unique_ptr<ReceivedPacket>&& packet) { + if (rx_callback) + rx_callback(std::move(packet)); + } +private: + OnReceive rx_callback; +}; + +class OPENDHT_PUBLIC UdpSocket : public DatagramSocket { +public: + UdpSocket(in_port_t port); + UdpSocket(const SockAddr& bind4, const SockAddr& bind6); + ~UdpSocket(); + + int sendTo(const SockAddr& dest, const uint8_t* data, size_t size, bool replied) override; + + const SockAddr& getBound(sa_family_t family = AF_UNSPEC) const override { + return (family == AF_INET6) ? bound6 : bound4; + } + + bool hasIPv4() const override { return s4 != -1; } + bool hasIPv6() const override { return s6 != -1; } + + void stop() override; +private: + int s4 {-1}; + int s6 {-1}; + int stopfd {-1}; + SockAddr bound4, bound6; + std::thread rcv_thread {}; + std::atomic_bool running {false}; + + void openSockets(const SockAddr& bind4, const SockAddr& bind6); +}; + +} +} diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index e09af77eb7a3a5ca8f70387861a27c00e9ff6f8e..1274825a388dda1d6716dad37b9a6e530c3c301c 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -207,6 +207,9 @@ public: NodeStatus getStatus() const override { return dht_->getStatus(); } + net::DatagramSocket* getSocket() const override { + return dht_->getSocket(); + }; bool isRunning(sa_family_t af = 0) const override { return dht_->isRunning(af); } diff --git a/include/opendht/sockaddr.h b/include/opendht/sockaddr.h index c622573aa58f698aafa0dcf40c7867ab5cce7d1b..143a39815e72f036bf750d6f4b980e6029f1f934 100644 --- a/include/opendht/sockaddr.h +++ b/include/opendht/sockaddr.h @@ -119,7 +119,7 @@ public: /** * Returns the address family or AF_UNSPEC if the address is not set. */ - sa_family_t getFamily() const { return len > sizeof(sa_family_t) ? addr->sa_family : AF_UNSPEC; } + sa_family_t getFamily() const { return len ? addr->sa_family : AF_UNSPEC; } /** * Resize the managed structure to the appropriate size (if needed), diff --git a/src/dht.cpp b/src/dht.cpp index 42f9e6426ca0f5ffdf67717fee853947cae32fd1..80a6640817230bae47a550c737172a92ffb2c0ee 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -1689,11 +1689,11 @@ Dht::~Dht() s.second->clear(); } -Dht::Dht() : store(), network_engine(DHT_LOG, scheduler) {} +Dht::Dht() : store(), network_engine(DHT_LOG, scheduler, {}) {} -Dht::Dht(const int& s, const int& s6, const Config& config, const Logger& l) - : DhtInterface(l), myid(config.node_id ? config.node_id : InfoHash::getRandom()), store(), store_quota(), - network_engine(myid, config.network, s, s6, DHT_LOG, scheduler, +Dht::Dht(std::unique_ptr<net::DatagramSocket>&& sock, const Config& config, const Logger& l) + : myid(config.node_id ? config.node_id : InfoHash::getRandom()), store(), store_quota(), + network_engine(myid, config.network, std::move(sock), DHT_LOG, scheduler, std::bind(&Dht::onError, this, _1, _2), std::bind(&Dht::onNewNode, this, _1, _2), std::bind(&Dht::onReportedAddr, this, _1, _2), @@ -1708,13 +1708,14 @@ Dht::Dht(const int& s, const int& s6, const Config& config, const Logger& l) maintain_storage(config.maintain_storage) { scheduler.syncTime(); - if (s < 0 && s6 < 0) - return; - if (s >= 0) { + auto s = network_engine.getSocket(); + if (not s or (not s->hasIPv4() and not s->hasIPv6())) + throw DhtException("Opened socket required"); + if (s->hasIPv4()) { buckets4 = {Bucket {AF_INET}}; buckets4.is_client = config.is_bootstrap; } - if (s6 >= 0) { + if (s->hasIPv6()) { buckets6 = {Bucket {AF_INET6}}; buckets6.is_client = config.is_bootstrap; } diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp index 03dabe29fd0a483096a4511af67dc9936127a4e7..5fbd45e4664a503f5289de352568af7fe6425171 100644 --- a/src/dhtrunner.cpp +++ b/src/dhtrunner.cpp @@ -102,9 +102,21 @@ DhtRunner::run(const SockAddr& local4, const SockAddr& local6, const DhtRunner:: { if (running) return; - startNetwork(local4, local6); - auto dht = std::unique_ptr<DhtInterface>(new Dht(s4, s6, SecureDht::getConfig(config.dht_config), context.logger ? *context.logger : Logger{})); + auto sock = std::unique_ptr<net::UdpSocket>(new net::UdpSocket(local4, local6)); + sock->setOnReceive([&] (std::unique_ptr<net::ReceivedPacket>&& pkt) { + { + std::lock_guard<std::mutex> lck(sock_mtx); + if (rcv.size() >= RX_QUEUE_MAX_SIZE) { + std::cerr << "Dropping packet: queue is full!" << std::endl; + rcv.pop(); + } + rcv.emplace(std::move(pkt)); + } + cv.notify_all(); + }); + + auto dht = std::unique_ptr<DhtInterface>(new Dht(std::move(sock), SecureDht::getConfig(config.dht_config))); dht_ = std::unique_ptr<SecureDht>(new SecureDht(std::move(dht), config.dht_config)); #ifdef OPENDHT_PROXY_CLIENT @@ -126,15 +138,10 @@ DhtRunner::run(const SockAddr& local4, const SockAddr& local6, const DhtRunner:: running = true; if (not config.threaded) return; - dht_thread = std::thread([this, local4, local6]() { + dht_thread = std::thread([this]() { while (running) { std::unique_lock<std::mutex> lk(dht_mtx); - time_point wakeup; - try { - wakeup = loop_(); - } catch (const dht::SocketException& e) { - startNetwork(local4, local6); - } + time_point wakeup = loop_(); auto hasJobToDo = [this]() { if (not running) @@ -213,7 +220,9 @@ DhtRunner::shutdown(ShutdownCallback cb) { void DhtRunner::join() { - stopNetwork(); + if (dht_) + if (auto sock = dht_->getSocket()) + sock->stop(); running = false; cv.notify_all(); bootstrap_cv.notify_all(); @@ -223,8 +232,6 @@ DhtRunner::join() dht_thread.join(); if (bootstrap_thread.joinable()) bootstrap_thread.join(); - if (rcv_thread.joinable()) - rcv_thread.join(); if (peerDiscovery_) { peerDiscovery_->join(); @@ -243,6 +250,15 @@ DhtRunner::join() } } +SockAddr +DhtRunner::getBound(sa_family_t af) const { + std::lock_guard<std::mutex> lck(dht_mtx); + if (dht_) + if (auto sock = dht_->getSocket()) + return sock->getBound(af); + return SockAddr{}; +} + void DhtRunner::dumpTables() const { @@ -475,7 +491,7 @@ DhtRunner::loop_() size_t dropped {0}; if (not received.empty()) { auto now = clock::now(); - while (not received.empty() and now - received.front().received > RX_QUEUE_MAX_DELAY) { + while (not received.empty() and now - received.front()->received > RX_QUEUE_MAX_DELAY) { received.pop(); dropped++; } @@ -485,10 +501,10 @@ DhtRunner::loop_() if (not received.empty()) { while (not received.empty()) { auto& pck = received.front(); - if (clock::now() - pck.received > RX_QUEUE_MAX_DELAY) + if (clock::now() - pck->received > RX_QUEUE_MAX_DELAY) dropped++; else - wakeup = dht->periodic(pck.data.data(), pck.data.size(), pck.from); + wakeup = dht->periodic(pck->data.data(), pck->data.size(), pck->from); received.pop(); } } else { @@ -497,7 +513,7 @@ DhtRunner::loop_() } if (dropped) - std::cerr << "Dropped %zu packets with high delay" << dropped << std::endl; + std::cerr << "Dropped " << dropped << " packets with high delay" << std::endl; NodeStatus nstatus4 = dht->getStatus(AF_INET); NodeStatus nstatus6 = dht->getStatus(AF_INET6); @@ -520,167 +536,6 @@ DhtRunner::loop_() return wakeup; } - -int bindSocket(const SockAddr& addr, SockAddr& bound) -{ - bool is_ipv6 = addr.getFamily() == AF_INET6; - int sock = socket(is_ipv6 ? PF_INET6 : PF_INET, SOCK_DGRAM, 0); - if (sock < 0) - throw DhtException(std::string("Can't open socket: ") + strerror(sock)); - int set = 1; -#ifdef SO_NOSIGPIPE - setsockopt(sock, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&set, sizeof(set)); -#endif - if (is_ipv6) - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&set, sizeof(set)); - net::set_nonblocking(sock); - int rc = bind(sock, addr.get(), addr.getLength()); - if(rc < 0) { - rc = errno; - close(sock); - throw DhtException("Can't bind socket on " + addr.toString() + " " + strerror(rc)); - } - sockaddr_storage ss; - socklen_t ss_len = sizeof(ss); - getsockname(sock, (sockaddr*)&ss, &ss_len); - bound = {ss, ss_len}; - return sock; -} - -void -DhtRunner::stopNetwork() -{ - running_network = false; - if (stop_writefd != -1) { - if (write(stop_writefd, "\0", 1) == -1) { - perror("write"); - } - } -} - -void -DhtRunner::startNetwork(const SockAddr sin4, const SockAddr sin6) -{ - stopNetwork(); - if (rcv_thread.joinable()) - rcv_thread.join(); - - int stopfds[2]; -#ifndef _WIN32 - auto status = pipe(stopfds); - if (status == -1) { - throw DhtException(std::string("Can't open pipe: ") + strerror(errno)); - } -#else - net::udpPipe(stopfds); -#endif - int stop_readfd = stopfds[0]; - stop_writefd = stopfds[1]; - - s4 = -1; - s6 = -1; - - bound4 = {}; - if (sin4) { - try { - s4 = bindSocket(sin4, bound4); - } catch (const DhtException& e) { - std::cerr << "Can't bind inet socket: " << e.what() << std::endl; - } - } - -#if 1 - bound6 = {}; - if (sin6) { - try { - s6 = bindSocket(sin6, bound6); - } catch (const DhtException& e) { - std::cerr << "Can't bind inet6 socket: " << e.what() << std::endl; - } - } -#endif - - if (s4 == -1 && s6 == -1) { - throw DhtException("Can't bind socket"); - } - - running_network = true; - rcv_thread = std::thread([this, stop_readfd]() { - try { - while (running_network) { - fd_set readfds; - - FD_ZERO(&readfds); - FD_SET(stop_readfd, &readfds); - if(s4 >= 0) - FD_SET(s4, &readfds); - if(s6 >= 0) - FD_SET(s6, &readfds); - - int selectFd = std::max({s4, s6, stop_readfd}) + 1; - int rc = select(selectFd, &readfds, nullptr, nullptr, nullptr); - if(rc < 0) { - if(errno != EINTR) { - perror("select"); - std::this_thread::sleep_for( std::chrono::seconds(1) ); - } - } - - if (not running_network) - break; - - if (rc > 0) { - std::array<uint8_t, 1024 * 64> buf; - sockaddr_storage from; - socklen_t from_len = sizeof(from); - - if (FD_ISSET(stop_readfd, &readfds)) { - if (recv(stop_readfd, (char*)buf.data(), buf.size(), 0) < 0) { - std::cerr << "Got stop packet error: " << strerror(errno) << std::endl; - break; - } - } - else if (s4 >= 0 && FD_ISSET(s4, &readfds)) - rc = recvfrom(s4, (char*)buf.data(), buf.size(), 0, (sockaddr*)&from, &from_len); - else if (s6 >= 0 && FD_ISSET(s6, &readfds)) - rc = recvfrom(s6, (char*)buf.data(), buf.size(), 0, (sockaddr*)&from, &from_len); - else - continue; - - if (rc > 0) { - { - std::lock_guard<std::mutex> lck(sock_mtx); - if (rcv.size() >= RX_QUEUE_MAX_SIZE) { - std::cerr << "Dropping packet: queue is full!" << std::endl; - rcv.pop(); - } - rcv.emplace(ReceivedPacket {Blob {buf.begin(), buf.begin()+rc}, SockAddr(from, from_len), clock::now()}); - } - cv.notify_all(); - } else if (rc == -1) { - std::cerr << "Error receiving packet: " << strerror(errno) << std::endl; - } - } - } - } catch (const std::exception& e) { - std::cerr << "Error in DHT networking thread: " << e.what() << std::endl; - } - if (s4 >= 0) - close(s4); - if (s6 >= 0) - close(s6); - s4 = -1; - s6 = -1; - bound4 = {}; - bound6 = {}; - if (stop_readfd != -1) - close(stop_readfd); - if (stop_writefd != -1) - close(stop_writefd); - stop_writefd = -1; - }); -} - void DhtRunner::get(InfoHash hash, GetCallback vcb, DoneCallback dcb, Value::Filter f, Where w) { diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 4f3e98bf9c137871fcb98b0357b1298a2e926562..269dafc163857f78dfe96115487eda356641e06d 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -99,11 +99,11 @@ RequestAnswer::RequestAnswer(ParsedMessage&& msg) nodes6(std::move(msg.nodes6)) {} -NetworkEngine::NetworkEngine(Logger& log, Scheduler& scheduler, const int& s, const int& s6) - : myid(zeroes), dht_socket(s), dht_socket6(s6), DHT_LOG(log), scheduler(scheduler) +NetworkEngine::NetworkEngine(Logger& log, Scheduler& scheduler, std::unique_ptr<DatagramSocket>&& sock) + : myid(zeroes), dht_socket(std::move(sock)), DHT_LOG(log), scheduler(scheduler) {} -NetworkEngine::NetworkEngine(InfoHash& myid, NetId net, const int& s, const int& s6, Logger& log, Scheduler& scheduler, +NetworkEngine::NetworkEngine(InfoHash& myid, NetId net, std::unique_ptr<DatagramSocket>&& sock, Logger& log, Scheduler& scheduler, decltype(NetworkEngine::onError)&& onError, decltype(NetworkEngine::onNewNode)&& onNewNode, decltype(NetworkEngine::onReportedAddr)&& onReportedAddr, @@ -122,7 +122,7 @@ NetworkEngine::NetworkEngine(InfoHash& myid, NetId net, const int& s, const int& onListen(std::move(onListen)), onAnnounce(std::move(onAnnounce)), onRefresh(std::move(onRefresh)), - myid(myid), network(net), dht_socket(s), dht_socket6(s6), DHT_LOG(log), scheduler(scheduler) + myid(myid), network(net), dht_socket(std::move(sock)), DHT_LOG(log), scheduler(scheduler) {} NetworkEngine::~NetworkEngine() { @@ -170,7 +170,7 @@ NetworkEngine::tellListenerRefreshed(Sp<Node> n, Tid socket_id, const InfoHash&, } // send response - send(buffer.data(), buffer.size(), 0, n->getAddr()); + send(n->getAddr(), buffer.data(), buffer.size()); } void @@ -200,7 +200,7 @@ NetworkEngine::tellListenerExpired(Sp<Node> n, Tid socket_id, const InfoHash&, c } // send response - send(buffer.data(), buffer.size(), 0, n->getAddr()); + send(n->getAddr(), buffer.data(), buffer.size()); } @@ -209,11 +209,11 @@ NetworkEngine::isRunning(sa_family_t af) const { switch (af) { case 0: - return dht_socket >= 0 || dht_socket6 >= 0; + return dht_socket->hasIPv4() or dht_socket->hasIPv6(); case AF_INET: - return dht_socket >= 0; + return dht_socket->hasIPv4(); case AF_INET6: - return dht_socket6 >= 0; + return dht_socket->hasIPv6(); default: return false; } @@ -254,9 +254,7 @@ NetworkEngine::requestStep(Sp<Request> sreq) req.on_expired(req, false); } - auto err = send((char*)req.msg.data(), req.msg.size(), - (node.getReplyTime() >= now - UDP_REPLY_TIME) ? 0 : MSG_CONFIRM, - node.getAddr()); + auto err = send(node.getAddr(), (char*)req.msg.data(), req.msg.size(), node.getReplyTime() < now - UDP_REPLY_TIME); if (err == ENETUNREACH || err == EHOSTUNREACH || err == EAFNOSUPPORT || @@ -610,33 +608,9 @@ insertAddr(msgpack::packer<msgpack::sbuffer>& pk, const SockAddr& addr) } int -NetworkEngine::send(const char *buf, size_t len, int flags, const SockAddr& addr) +NetworkEngine::send(const SockAddr& addr, const char *buf, size_t len, bool confirmed) { - if (not addr) - return EFAULT; - - int s; - if (addr.getFamily() == AF_INET) - s = dht_socket; - else if (addr.getFamily() == AF_INET6) - s = dht_socket6; - else - s = -1; - - if (s < 0) - return EAFNOSUPPORT; -#ifdef MSG_NOSIGNAL - flags |= MSG_NOSIGNAL; -#endif - if (sendto(s, buf, len, flags, addr.get(), addr.getLength()) == -1) { - int err = errno; - DHT_LOG.e("Can't send message to %s: %s", addr.toString().c_str(), strerror(err)); - if (err == EPIPE) { - throw SocketException(EPIPE); - } - return err; - } - return 0; + return dht_socket ? dht_socket->sendTo(addr, (const uint8_t*)buf, len, confirmed) : ENOTCONN; } Sp<Request> @@ -696,7 +670,7 @@ NetworkEngine::sendPong(const SockAddr& addr, Tid tid) { pk.pack(KEY_NETID); pk.pack(network); } - send(buffer.data(), buffer.size(), 0, addr); + send(addr, buffer.data(), buffer.size()); } Sp<Request> @@ -898,7 +872,7 @@ NetworkEngine::sendValueParts(const TransId& tid, const std::vector<Blob>& svals pk.pack(std::string("o")); pk.pack(start); pk.pack(std::string("d")); pk.pack_bin(end-start); pk.pack_bin_body((const char*)v.data()+start, end-start); - send(buffer.data(), buffer.size(), 0, addr); + send(addr, buffer.data(), buffer.size()); start = end; } while (start != v.size()); i++; @@ -957,7 +931,7 @@ NetworkEngine::sendNodesValues(const SockAddr& addr, Tid tid, const Blob& nodes, } // send response - send(buffer.data(), buffer.size(), 0, addr); + send(addr, buffer.data(), buffer.size()); // send parts if (not svals.empty()) @@ -1100,7 +1074,7 @@ NetworkEngine::sendListenConfirmation(const SockAddr& addr, Tid tid) { pk.pack(KEY_NETID); pk.pack(network); } - send(buffer.data(), buffer.size(), 0, addr); + send(addr, buffer.data(), buffer.size()); } Sp<Request> @@ -1234,7 +1208,7 @@ NetworkEngine::sendValueAnnounced(const SockAddr& addr, Tid tid, Value::Id vid) pk.pack(KEY_NETID); pk.pack(network); } - send(buffer.data(), buffer.size(), 0, addr); + send(addr, buffer.data(), buffer.size()); } void @@ -1266,7 +1240,7 @@ NetworkEngine::sendError(const SockAddr& addr, pk.pack(KEY_NETID); pk.pack(network); } - send(buffer.data(), buffer.size(), 0, addr); + send(addr, buffer.data(), buffer.size()); } void diff --git a/src/network_utils.cpp b/src/network_utils.cpp index 59ab926fb4baffbc43cf19eda2b26bd018811ba3..773225950d30bb3820bdfd8ee7a3e00dd89b42d2 100644 --- a/src/network_utils.cpp +++ b/src/network_utils.cpp @@ -20,7 +20,6 @@ #ifdef _WIN32 #include "utils.h" -#include "sockaddr.h" #include <io.h> #include <string> #include <cstring> @@ -30,11 +29,40 @@ #include <fcntl.h> #endif +#include <iostream> + namespace dht { namespace net { +int +bindSocket(const SockAddr& addr, SockAddr& bound) +{ + bool is_ipv6 = addr.getFamily() == AF_INET6; + int sock = socket(is_ipv6 ? PF_INET6 : PF_INET, SOCK_DGRAM, 0); + if (sock < 0) + throw DhtException(std::string("Can't open socket: ") + strerror(sock)); + int set = 1; +#ifdef SO_NOSIGPIPE + setsockopt(sock, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&set, sizeof(set)); +#endif + if (is_ipv6) + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&set, sizeof(set)); + net::setNonblocking(sock); + int rc = bind(sock, addr.get(), addr.getLength()); + if (rc < 0) { + rc = errno; + close(sock); + throw DhtException("Can't bind socket on " + addr.toString() + " " + strerror(rc)); + } + sockaddr_storage ss; + socklen_t ss_len = sizeof(ss); + getsockname(sock, (sockaddr*)&ss, &ss_len); + bound = {ss, ss_len}; + return sock; +} + bool -set_nonblocking(int fd, bool nonblocking) +setNonblocking(int fd, bool nonblocking) { #ifdef _WIN32 unsigned long mode = !!nonblocking; @@ -77,5 +105,192 @@ void udpPipe(int fds[2]) } #endif +UdpSocket::UdpSocket(in_port_t port) { + SockAddr bind4; + bind4.setFamily(AF_INET); + bind4.setPort(port); + SockAddr bind6; + bind6.setFamily(AF_INET6); + bind6.setPort(port); + openSockets(bind4, bind6); +} + +UdpSocket::UdpSocket(const SockAddr& bind4, const SockAddr& bind6) { + openSockets(bind4, bind6); +} + +UdpSocket::~UdpSocket() { + stop(); + if (rcv_thread.joinable()) + rcv_thread.join(); +} + +int +UdpSocket::sendTo(const SockAddr& dest, const uint8_t* data, size_t size, bool replied) { + if (not dest) + return EFAULT; + + int s; + switch (dest.getFamily()) { + case AF_INET: s = s4; break; + case AF_INET6: s = s6; break; + default: s = -1; break; + } + + if (s < 0) + return EAFNOSUPPORT; + + int flags = 0; +#ifdef MSG_CONFIRM + if (replied) + flags |= MSG_CONFIRM; +#endif +#ifdef MSG_NOSIGNAL + flags |= MSG_NOSIGNAL; +#endif + + if (sendto(s, data, size, flags, dest.get(), dest.getLength()) == -1) { + int err = errno; + std::cerr << "Can't send message to " << dest.toString() << ": " << strerror(err) << std::endl; + if (err == EPIPE) { + auto bind4 = std::move(bound4), bind6 = std::move(bound6); + openSockets(bind4, bind6); + return sendTo(dest, data, size, false); + } + return err; + } + return 0; +} + +void +UdpSocket::openSockets(const SockAddr& bind4, const SockAddr& bind6) +{ + stop(); + if (rcv_thread.joinable()) + rcv_thread.join(); + + int stopfds[2]; +#ifndef _WIN32 + auto status = pipe(stopfds); + if (status == -1) { + throw DhtException(std::string("Can't open pipe: ") + strerror(errno)); + } +#else + udpPipe(stopfds); +#endif + int stop_readfd = stopfds[0]; + stopfd = stopfds[1]; + + s4 = -1; + s6 = -1; + + bound4 = {}; + if (bind4) { + try { + s4 = bindSocket(bind4, bound4); + } catch (const DhtException& e) { + std::cerr << "Can't bind inet socket: " << e.what() << std::endl; + } + } + +#if 1 + bound6 = {}; + if (bind6) { + try { + s6 = bindSocket(bind6, bound6); + } catch (const DhtException& e) { + std::cerr << "Can't bind inet6 socket: " << e.what() << std::endl; + } + } +#endif + + if (s4 == -1 && s6 == -1) { + throw DhtException("Can't bind socket"); + } + + running = true; + rcv_thread = std::thread([this, stop_readfd]() { + try { + while (running) { + fd_set readfds; + + FD_ZERO(&readfds); + FD_SET(stop_readfd, &readfds); + if(s4 >= 0) + FD_SET(s4, &readfds); + if(s6 >= 0) + FD_SET(s6, &readfds); + + int selectFd = std::max({s4, s6, stop_readfd}) + 1; + int rc = select(selectFd, &readfds, nullptr, nullptr, nullptr); + if (rc < 0) { + if (errno != EINTR) { + perror("select"); + std::this_thread::sleep_for( std::chrono::seconds(1) ); + } + } + + if (not running) + break; + + if (rc > 0) { + std::array<uint8_t, 1024 * 64> buf; + sockaddr_storage from; + socklen_t from_len = sizeof(from); + + if (FD_ISSET(stop_readfd, &readfds)) { + if (recv(stop_readfd, (char*)buf.data(), buf.size(), 0) < 0) { + std::cerr << "Got stop packet error: " << strerror(errno) << std::endl; + break; + } + } + else if (s4 >= 0 && FD_ISSET(s4, &readfds)) + rc = recvfrom(s4, (char*)buf.data(), buf.size(), 0, (sockaddr*)&from, &from_len); + else if (s6 >= 0 && FD_ISSET(s6, &readfds)) + rc = recvfrom(s6, (char*)buf.data(), buf.size(), 0, (sockaddr*)&from, &from_len); + else + continue; + + if (rc > 0) { + auto pkt = std::unique_ptr<ReceivedPacket>(new ReceivedPacket); + pkt->data = {buf.begin(), buf.begin()+rc}; + pkt->from = {from, from_len}; + pkt->received = clock::now(); + onReceived(std::move(pkt)); + } else if (rc == -1) { + std::cerr << "Error receiving packet: " << strerror(errno) << std::endl; + } + } + } + } catch (const std::exception& e) { + std::cerr << "Error in DHT networking thread: " << e.what() << std::endl; + } + if (s4 >= 0) + close(s4); + if (s6 >= 0) + close(s6); + s4 = -1; + s6 = -1; + bound4 = {}; + bound6 = {}; + if (stop_readfd != -1) + close(stop_readfd); + if (stopfd != -1) + close(stopfd); + stopfd = -1; + }); +} + +void +UdpSocket::stop() +{ + if (running.exchange(false)) { + auto sfd = stopfd; + if (sfd != -1 && write(sfd, "\0", 1) == -1) { + std::cerr << "can't write to stop fd" << std::endl; + } + } +} + } } diff --git a/src/network_utils.h b/src/network_utils.h deleted file mode 100644 index bc63904cbc04d64ff5775b90d1f7ae830d70484d..0000000000000000000000000000000000000000 --- a/src/network_utils.h +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (C) 2019 Savoir-faire Linux Inc. - * Author(s) : Adrien Béraud <adrien.beraud@savoirfairelinux.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see <https://www.gnu.org/licenses/>. - */ -#pragma once - -#ifdef _WIN32 -#include <ws2tcpip.h> -#include <winsock2.h> -#else -#include <sys/socket.h> -#include <netinet/in.h> -#include <unistd.h> -#endif - -#ifndef IPV6_JOIN_GROUP -#define IPV6_JOIN_GROUP IPV6_ADD_MEMBERSHIP -#endif - -namespace dht { -namespace net { - -bool set_nonblocking(int fd, bool nonblocking = true); - -#ifdef _WIN32 -void udpPipe(int fds[2]); -#endif - -} -} diff --git a/src/peer_discovery.cpp b/src/peer_discovery.cpp index 821ba41513f4b787b3a0b6d2b5e28bf56879ffb9..6341c250f674526d04ce08c3bc4dfaaafa9a1ce8 100644 --- a/src/peer_discovery.cpp +++ b/src/peer_discovery.cpp @@ -35,6 +35,10 @@ typedef SSIZE_T ssize_t; #endif #include <fcntl.h> +#ifndef IPV6_JOIN_GROUP +#define IPV6_JOIN_GROUP IPV6_ADD_MEMBERSHIP +#endif + namespace dht { // Organization-local Scope multicast @@ -175,7 +179,7 @@ PeerDiscovery::DomainPeerDiscovery::initialize_socket(sa_family_t domain) if (sockfd < 0) { throw std::runtime_error(std::string("Socket Creation Error: ") + strerror(errno)); } - net::set_nonblocking(sockfd); + net::setNonblocking(sockfd); return sockfd; } @@ -502,7 +506,6 @@ PeerDiscovery::PeerDiscovery(in_port_t port) } catch(const std::exception& e) { std::cerr << "Can't start peer discovery (IPv6): " << e.what() << std::endl; } - } PeerDiscovery::~PeerDiscovery(){} diff --git a/src/utils.cpp b/src/utils.cpp index e267961e6370503163e2db4a798964b1330be285..9b7aabbbc2d4720b761d07e4bf7da2d993daf23f 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -97,9 +97,11 @@ SockAddr::setAddress(const char* address) case AF_INET6: addr = &getIPv6().sin6_addr; break; + default: + throw std::runtime_error("Unknown address family"); } - if (not addr or inet_pton(family, address, addr) <= 0) - throw std::runtime_error("Can't parse IP address"); + if (inet_pton(family, address, addr) <= 0) + throw std::runtime_error(std::string("Can't parse IP address: ") + strerror(errno)); } std::string