From 151d31411a6d236d956d8a1954af6054b4318da7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com> Date: Thu, 20 Nov 2014 01:42:34 -0500 Subject: [PATCH] Add initial code --- Makefile.am | 12 + autogen.sh | 5 + configure.ac | 20 + dhtcpp.pc.in | 9 + include/dht.h | 37 + include/dhtcpp/crypto.h | 139 ++ include/dhtcpp/dht.h | 594 ++++++++ include/dhtcpp/dhtrunner.h | 165 +++ include/dhtcpp/infohash.h | 159 ++ include/dhtcpp/securedht.h | 158 ++ include/dhtcpp/serialize.h | 316 ++++ include/dhtcpp/value.h | 328 +++++ m4/ax_cxx_compile_stdcxx_11.m4 | 135 ++ src/Makefile.am | 22 + src/crypto.cpp | 404 +++++ src/dht.cpp | 2532 ++++++++++++++++++++++++++++++++ src/dhtrunner.cpp | 329 +++++ src/infohash.cpp | 82 ++ src/securedht.cpp | 292 ++++ src/value.cpp | 221 +++ 20 files changed, 5959 insertions(+) create mode 100644 Makefile.am create mode 100755 autogen.sh create mode 100644 configure.ac create mode 100644 dhtcpp.pc.in create mode 100644 include/dht.h create mode 100644 include/dhtcpp/crypto.h create mode 100644 include/dhtcpp/dht.h create mode 100644 include/dhtcpp/dhtrunner.h create mode 100644 include/dhtcpp/infohash.h create mode 100644 include/dhtcpp/securedht.h create mode 100644 include/dhtcpp/serialize.h create mode 100644 include/dhtcpp/value.h create mode 100644 m4/ax_cxx_compile_stdcxx_11.m4 create mode 100644 src/Makefile.am create mode 100644 src/crypto.cpp create mode 100644 src/dht.cpp create mode 100644 src/dhtrunner.cpp create mode 100644 src/infohash.cpp create mode 100644 src/securedht.cpp create mode 100644 src/value.cpp diff --git a/Makefile.am b/Makefile.am new file mode 100644 index 00000000..cce6cbcf --- /dev/null +++ b/Makefile.am @@ -0,0 +1,12 @@ +SUBDIRS = src +ACLOCAL_AMFLAGS = -I m4 + +DOC_FILES = \ + README.md \ + LICENSE + +EXTRA_DIST = \ + $(DOC_FILES) + +pkgconfigdir = $(libdir)/pkgconfig +pkgconfig_DATA = dhtcpp.pc diff --git a/autogen.sh b/autogen.sh new file mode 100755 index 00000000..34658210 --- /dev/null +++ b/autogen.sh @@ -0,0 +1,5 @@ +test -f AUTHORS || touch AUTHORS +test -f ChangeLog || touch ChangeLog +test -f NEWS || touch NEWS +test -f README || cp -f README.md README +autoreconf --install --verbose -Wall \ No newline at end of file diff --git a/configure.ac b/configure.ac new file mode 100644 index 00000000..44ea0ba9 --- /dev/null +++ b/configure.ac @@ -0,0 +1,20 @@ +AC_INIT(dhtcpp, 0.1) +AC_CONFIG_AUX_DIR(ac) +AM_INIT_AUTOMAKE +AC_CONFIG_HEADERS([config.h]) +AC_CONFIG_MACRO_DIR([m4]) + +AC_PROG_CC +AC_PROG_CXX +AM_PROG_AR + +LT_INIT() +LT_LANG(C++) + +AX_CXX_COMPILE_STDCXX_11([noext],[mandatory]) + +PKG_PROG_PKG_CONFIG() +PKG_CHECK_MODULES([GNUTLS], [gnutls >= 3.1]) + +AC_CONFIG_FILES([Makefile src/Makefile dhtcpp.pc]) +AC_OUTPUT \ No newline at end of file diff --git a/dhtcpp.pc.in b/dhtcpp.pc.in new file mode 100644 index 00000000..a4381383 --- /dev/null +++ b/dhtcpp.pc.in @@ -0,0 +1,9 @@ +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ +Name: DhtCpp +Description: Lightweight C++11 Distributed Hash Table library +Version: @VERSION@ +Libs: -L${libdir} -ldhtcpp +Cflags: -I${includedir} \ No newline at end of file diff --git a/include/dht.h b/include/dht.h new file mode 100644 index 00000000..d1ed824b --- /dev/null +++ b/include/dht.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#pragma once + +#include "dhtcpp/dht.h" +#include "dhtcpp/value.h" +#include "dhtcpp/infohash.h" +#include "dhtcpp/securedht.h" +#include "dhtcpp/dhtrunner.h" diff --git a/include/dhtcpp/crypto.h b/include/dhtcpp/crypto.h new file mode 100644 index 00000000..5b0606f2 --- /dev/null +++ b/include/dhtcpp/crypto.h @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#pragma once + +#include "serialize.h" + +extern "C" { +#include <gnutls/gnutls.h> +#include <gnutls/abstract.h> +#include <gnutls/x509.h> +} + +#include <vector> +#include <memory> + +namespace dht { +namespace crypto { + +struct PrivateKey; +struct Certificate; + +typedef std::pair<std::shared_ptr<PrivateKey>, std::shared_ptr<Certificate>> Identity; + +/** + * Generate an RSA key pair (2048 bits) and a certificate. + * If a certificate authority (ca) is provided, it will be used to + * sign the certificate, otherwise the certificate will be self-signed. + */ +Identity generateIdentity(const std::string& name = "dhtnode", Identity ca = {}); + +struct PublicKey : public Serializable +{ + PublicKey() {} + PublicKey(gnutls_pubkey_t k) : pk(k) {} + PublicKey(const Blob& pk); + PublicKey(PublicKey&& o) noexcept : pk(o.pk) { o.pk = nullptr; }; + + ~PublicKey(); + operator bool() const { return pk; } + + PublicKey& operator=(PublicKey&& o) noexcept; + + InfoHash getId() const; + bool checkSignature(const Blob& data, const Blob& signature) const; + Blob encrypt(const Blob&) const; + + void pack(Blob& b) const override; + + void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) override; + + gnutls_pubkey_t pk {}; +private: + PublicKey(const PublicKey&) = delete; + PublicKey& operator=(const PublicKey&) = delete; +}; + +struct PrivateKey +{ + PrivateKey() {} + //PrivateKey(gnutls_privkey_t k) : key(k) {} + PrivateKey(gnutls_x509_privkey_t k); + PrivateKey(PrivateKey&& o) noexcept : key(o.key), x509_key(o.x509_key) + { o.key = nullptr; o.x509_key = nullptr; }; + PrivateKey& operator=(PrivateKey&& o) noexcept; + + PrivateKey(const Blob& import); + ~PrivateKey(); + operator bool() const { return key; } + PublicKey getPublicKey() const; + Blob serialize() const; + Blob sign(const Blob&) const; + Blob decrypt(const Blob& cypher) const; + + /** + * Generate a new RSA key pair + */ + static PrivateKey generate(); + +private: + PrivateKey(const PrivateKey&) = delete; + PrivateKey& operator=(const PrivateKey&) = delete; + gnutls_privkey_t key {}; + gnutls_x509_privkey_t x509_key {}; + + friend dht::crypto::Identity dht::crypto::generateIdentity(const std::string&, dht::crypto::Identity); +}; + +struct Certificate : public Serializable { + Certificate() {} + Certificate(gnutls_x509_crt_t crt) : cert(crt) {} + Certificate(const Blob& crt); + Certificate(Certificate&& o) noexcept : cert(o.cert) { o.cert = nullptr; }; + Certificate& operator=(Certificate&& o) noexcept; + + ~Certificate(); + operator bool() const { return cert; } + PublicKey getPublicKey() const; + void pack(Blob& b) const override; + void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) override; + +private: + Certificate(const Certificate&) = delete; + Certificate& operator=(const Certificate&) = delete; + gnutls_x509_crt_t cert {}; + + friend dht::crypto::Identity dht::crypto::generateIdentity(const std::string&, dht::crypto::Identity); +}; + + +} +} diff --git a/include/dhtcpp/dht.h b/include/dhtcpp/dht.h new file mode 100644 index 00000000..ecc3d2eb --- /dev/null +++ b/include/dhtcpp/dht.h @@ -0,0 +1,594 @@ +/* +Copyright (c) 2009-2014 Juliusz Chroboczek +Copyright (c) 2014 Savoir-Faire Linux Inc. + +Authors : Adrien Béraud <adrien.beraud@savoirfairelinux.com>, + Juliusz Chroboczek <jch@pps.univ–paris–diderot.fr> + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#include "infohash.h" +#include "value.h" + +#include <sys/socket.h> +#include <netinet/in.h> +#include <netdb.h> + +#include <string> +#include <array> +#include <vector> +#include <map> +#include <list> +#include <functional> +#include <algorithm> +#include <memory> + +namespace dht { + +/** + * Main Dht class. + * Provides a Distributed Hash Table node. + * + * Must be given open UDP sockets and ::periodic must be + * called regularly. + */ +class Dht { +public: + + enum class Status { + Disconnected, // 0 nodes + Connecting, // 1+ nodes + Connected // 4+ good nodes + }; + + typedef std::function<bool(const std::vector<std::shared_ptr<Value>>& values)> GetCallback; + typedef std::function<void(bool success)> DoneCallback; + + struct NodeExport { + InfoHash id; + sockaddr_storage ss; + socklen_t sslen; + }; + + Dht() {} + + /** + * Initialise the Dht with two open sockets (for IPv4 and IP6) + * and an ID for the node. + */ + Dht(int s, int s6, const InfoHash& id); + virtual ~Dht(); + + /** + * Get the ID of the node, which was provided in the constructor. + */ + inline const InfoHash& getId() const { return myid; } + + /** + * Get the current status of the node for the given family. + */ + Status getStatus(sa_family_t af) const; + + /** + * Returns true if the node have access to an open socket + * for the provided family. + */ + bool isRunning(sa_family_t af) const; + + /** + * Enable or disable logging of DHT internal messages + */ + void setLoggers(LogMethod&& error = NOLOG, LogMethod&& warn = NOLOG, LogMethod&& debug = NOLOG); + + virtual void registerType(const ValueType& type) { + types[type.id] = type; + } + const ValueType& getType(ValueType::Id type_id) const { + const auto& t_it = types.find(type_id); + return (t_it == types.end()) ? ValueType::USER_DATA : t_it->second; + } + + /** + * Insert a node in the main routing table. + * The node is not pinged, so this should be + * used to bootstrap efficiently from previously known nodes. + */ + bool insertNode(const InfoHash& id, const sockaddr*, socklen_t); + bool insertNode(const NodeExport& n) { + return insertNode(n.id, reinterpret_cast<const sockaddr*>(&n.ss), n.sslen); + } + + int pingNode(const sockaddr*, socklen_t); + + void periodic(const uint8_t *buf, size_t buflen, const sockaddr *from, socklen_t fromlen, time_t *tosleep); + + /** + * Get a value by searching on all available protocols (IPv4, IPv6), + * and call the callback when some values are found. + * The operation will start as soon as the node is connected to the network. + * GetCallback will be called every time new values are found, until + * GetCallback returns false or the search completes. + * Then, DoneCallback is called. + */ + void get(const InfoHash& id, GetCallback cb, DoneCallback donecb=nullptr, Value::Filter = Value::AllFilter()); + + /** + * Get locally stored data for the given hash. + */ + std::vector<std::shared_ptr<Value>> getLocal(const InfoHash& id, Value::Filter f = Value::AllFilter()) const; + + /** + * Get locally stored data for the given hash and value id. + */ + std::shared_ptr<Value> getLocal(const InfoHash& id, const Value::Id& vid) const; + + /** + * Announce a value on all available protocols (IPv4, IPv6), and + * automatically re-announce when it's about to expire. + * The operation will start as soon as the node is connected to the network. + * The done callback will be called once, when the first announce succeeds, or fails. + * + * A "put" operation will never end by itself because the value will need to be + * reannounced on a regular basis. + * User can call #cancelPut(InfoHash, Value::Id) to cancel a put operation. + */ + void put(const InfoHash&, Value&&, DoneCallback cb=nullptr); + + /** + * Get data currently being put at the given hash. + */ + std::vector<std::shared_ptr<Value>> getPut(const InfoHash&); + + /** + * Get data currently being put at the given hash with the given id. + */ + std::shared_ptr<Value> getPut(const InfoHash&, const Value::Id&); + + /** + * Stop any put/announce operation at the given location, + * for values with the given id. + */ + bool cancelPut(const InfoHash&, const Value::Id&); + + /** + * Get the list of good nodes for local storage saving purposes + * The list is ordered to minimize the back-to-work delay. + */ + std::vector<NodeExport> exportNodes(); + + typedef std::pair<InfoHash, Blob> ValuesExport; + std::vector<ValuesExport> exportValues() const; + void importValues(const std::vector<ValuesExport>&); + + int getNodesStats(sa_family_t af, unsigned *good_return, unsigned *dubious_return, unsigned *cached_return, unsigned *incoming_return) const; + void dumpTables() const; + + /* This must be provided by the user. */ + static bool isBlacklisted(const sockaddr*, socklen_t) { return false; } + +protected: + LogMethod DHT_DEBUG = NOLOG; + LogMethod DHT_WARN = NOLOG; + LogMethod DHT_ERROR = NOLOG; + +private: + + /* When performing a search, we search for up to SEARCH_NODES closest nodes + to the destination, and use the additional ones to backtrack if any of + the target 8 turn out to be dead. */ + static const unsigned SEARCH_NODES {14}; + + /* The maximum number of values we store for a given hash. */ + static const unsigned MAX_VALUES {2048}; + + /* The maximum number of hashes we're willing to track. */ + static const unsigned MAX_HASHES {16384}; + + /* The maximum number of searches we keep data about. */ + static const unsigned MAX_SEARCHES {1024}; + + /* A search with no nodes will timeout after this time. */ + static const time_t SEARCH_TIMEOUT {60}; + + /* The time after which we can send get requests for + a search in case of no answers. */ + static const time_t SEARCH_GET_STEP {15}; + + /* The time after which we consider a search to be expirable. */ + static const time_t SEARCH_EXPIRE_TIME {62 * 60}; + + /* The maximum number of nodes that we snub. There is probably little + reason to increase this value. */ + static const unsigned BLACKLISTED_MAX {10}; + + static const long unsigned MAX_REQUESTS_PER_SEC; + + static const time_t TOKEN_EXPIRE_TIME {10 * 60}; + + static const unsigned TOKEN_SIZE {64}; + + struct Node { + InfoHash id {}; + sockaddr_storage ss; + socklen_t sslen {0}; + time_t time {0}; /* time of last message received */ + time_t reply_time {0}; /* time of last correct reply received */ + time_t pinged_time {0}; /* time of last request */ + unsigned pinged {0}; /* how many requests we sent since last reply */ + + Node() { + std::fill_n((uint8_t*)&ss, sizeof(ss), 0); + } + Node(const InfoHash& id, const sockaddr* sa, socklen_t salen, time_t t, time_t reply_time) + : id(id), sslen(salen), time(t), reply_time(reply_time) { + std::copy_n((const uint8_t*)sa, salen, (uint8_t*)&ss); + } + bool isGood(time_t now) const; + NodeExport exportNode() const { return NodeExport {id, ss, sslen}; } + }; + + struct Bucket { + Bucket() {} + Bucket(sa_family_t af, const InfoHash& f = {}, time_t t = 0) + : af(af), first(f), time(t) {} + sa_family_t af {0}; + InfoHash first {}; + time_t time {0}; /* time of last reply in this bucket */ + std::list<Node> nodes {}; + sockaddr_storage cached {}; /* the address of a likely candidate */ + socklen_t cachedlen {0}; + + /** Return a random node in a bucket. */ + Node* randomNode(); + }; + + class RoutingTable : public std::list<Bucket> { + public: + using std::list<Bucket>::list; + + InfoHash middle(const RoutingTable::const_iterator&) const; + + RoutingTable::iterator findBucket(const InfoHash& id); + RoutingTable::const_iterator findBucket(const InfoHash& id) const; + + /** + * Returns true if the id is in the bucket's range. + */ + inline bool contains(const RoutingTable::const_iterator& bucket, const InfoHash& id) const { + return InfoHash::cmp(bucket->first, id) <= 0 + && (std::next(bucket) == end() || InfoHash::cmp(id, std::next(bucket)->first) < 0); + } + + /** + * Returns a random id in the bucket's range. + */ + InfoHash randomId(const RoutingTable::const_iterator& bucket) const; + + /** + * Split a bucket in two equal parts. + */ + bool split(const RoutingTable::iterator& b); + }; + + struct SearchNode { + SearchNode() {} + SearchNode(const InfoHash& id) : id(id) {} + + struct AnnounceStatus { + time_t request_time; /* the time of the last unanswered announce request */ + time_t reply_time; /* the time of the last announce confirmation */ + }; + typedef std::map<Value::Id, AnnounceStatus> AnnounceStatusMap; + + /** + * Can we use this node to announce ? + */ + bool isSynced(time_t now) const { + return /*pinged < 3 && replied &&*/ reply_time > now - 15 * 60; + } + + time_t getAnnounceTime(AnnounceStatusMap::const_iterator ack, const ValueType& type) const { + if (ack == acked.end()) + return request_time + 5; + return std::max<time_t>({ack->second.reply_time + type.expiration - 3, ack->second.request_time + 5, request_time + 5}); + } + time_t getAnnounceTime(Value::Id vid, const ValueType& type) const { + return getAnnounceTime(acked.find(vid), type); + } + + InfoHash id {}; + sockaddr_storage ss {}; + socklen_t sslen {0}; + time_t request_time {0}; /* the time of the last unanswered request */ + time_t reply_time {0}; /* the time of the last reply with a token */ + unsigned pinged {0}; + Blob token {}; + + AnnounceStatusMap acked {}; /* announcement status for a given value id */ + + // Generic temporary flag. + // Must be reset to false after use by the algorithm. + bool pending {false}; + }; + + struct Announce { + std::shared_ptr<Value> value; + DoneCallback callback; + }; + + /** + * A search is a pointer to the nodes we think are responsible + * for storing values for a given hash. + * + * A Search has 3 states: + * - Idle (nothing to do) + * - Syncing (Some nodes not synced) + * - Announcing (Some announces not performed on all nodes) + */ + struct Search { + uint16_t tid; + sa_family_t af; + time_t step_time {0}; /* the time of the last search_step */ + InfoHash id {}; + std::vector<Announce> announce {}; + std::vector<std::pair<Value::Filter, GetCallback>> callbacks {}; + DoneCallback done_callback {nullptr}; + bool done {false}; + std::vector<SearchNode> nodes {SEARCH_NODES+1}; + + bool insertNode(const InfoHash& id, const sockaddr*, socklen_t, time_t now, bool confirmed=false, const Blob& token={}); + void insertBucket(const Bucket&, time_t now); + + /** + * Can we use this search to announce ? + */ + bool isSynced(time_t now) const; + + /** + * Are all values that are registred for announcement announced ? + */ + bool isAnnounced(const std::map<ValueType::Id, ValueType>& types, time_t now) const { + auto at = getAnnounceTime(types); + return at && at < now; + } + + /** + * ret = 0 : no announce required. + * ret > 0 : (re-)announce required at time ret. + */ + time_t getAnnounceTime(const std::map<ValueType::Id, ValueType>& types) const; + + time_t getNextStepTime(const std::map<ValueType::Id, ValueType>& types, time_t now) const; + }; + + struct ValueStorage { + std::shared_ptr<Value> data {}; + time_t time {0}; + + ValueStorage() {} + ValueStorage(const std::shared_ptr<Value>& v, time_t t) : data(v), time(t) {} + }; + + struct Storage { + InfoHash id; + std::vector<ValueStorage> values; + }; + + enum class MessageType { + Error = 0, + Reply, + Ping, + FindNode, + GetValues, + AnnounceValue + }; + + struct TransPrefix : public std::array<uint8_t, 2> { + TransPrefix(const std::string& str) : std::array<uint8_t, 2>({(uint8_t)str[0], (uint8_t)str[1]}) {} + static const TransPrefix PING; + static const TransPrefix FIND_NODE; + static const TransPrefix GET_VALUES; + static const TransPrefix ANNOUNCE_VALUES; + }; + + /* Transaction-ids are 4-bytes long, with the first two bytes identifying + * the kind of request, and the remaining two a sequence number in + * host order. + */ + struct TransId final : public std::array<uint8_t, 4> { + TransId() {} + TransId(const TransPrefix prefix, uint16_t seqno = 0) { + std::copy_n(prefix.begin(), prefix.size(), begin()); + *reinterpret_cast<uint16_t*>(data()+prefix.size()) = seqno; + } + + TransId(const char* q, size_t l) : array<uint8_t, 4>() { + if (l > 4) { + length = 0; + } else { + std::copy_n(q, l, begin()); + length = l; + } + } + + bool matches(const TransPrefix prefix, uint16_t *seqno_return = nullptr) const { + if (std::equal(begin(), begin()+1, prefix.begin())) { + if (seqno_return) + *seqno_return = *reinterpret_cast<const uint16_t*>(&(*this)[2]); + return true; + } else + return false; + } + + unsigned length {4}; + }; + + // prevent copy + Dht(const Dht&) = delete; + Dht& operator=(const Dht&) = delete; + + int dht_socket {-1}; + int dht_socket6 {-1}; + + time_t search_time {0}; + time_t confirm_nodes_time {0}; + time_t rotate_secrets_time {0}; + + InfoHash myid {}; + static const uint8_t my_v[9]; + std::array<uint8_t, 8> secret {}; + std::array<uint8_t, 8> oldsecret {}; + + std::map<ValueType::Id, ValueType> types; + + // the stuff + RoutingTable buckets {}; + RoutingTable buckets6 {}; + std::vector<Storage> store {}; + std::list<Search> searches {}; + uint16_t search_id {0}; + + sockaddr_storage blacklist[BLACKLISTED_MAX] {}; + unsigned next_blacklisted = 0; + + struct timeval now {}; + time_t mybucket_grow_time {0}, mybucket6_grow_time {0}; + time_t expire_stuff_time {0}; + time_t rate_limit_time {0}; + + long unsigned rate_limit_tokens {MAX_REQUESTS_PER_SEC}; + + // Networking & packet handling + int send(const void* buf, size_t len, int flags, const sockaddr*, socklen_t); + int sendPing(const sockaddr*, socklen_t, TransId tid); + int sendPong(const sockaddr*, socklen_t, TransId tid); + + int sendFindNode(const sockaddr*, socklen_t, TransId tid, + const InfoHash& target, int want, int confirm); + + int sendNodesValues(const sockaddr*, socklen_t, TransId tid, + const uint8_t *nodes, unsigned nodes_len, + const uint8_t *nodes6, unsigned nodes6_len, + Storage *st, const Blob& token); + + int sendClosestNodes(const sockaddr*, socklen_t, TransId tid, + const InfoHash& id, int want, const Blob& token={}, + Storage *st=nullptr); + + int sendGetValues(const sockaddr*, socklen_t, TransId tid, + const InfoHash& infohash, int want, int confirm); + + int sendAnnounceValue(const sockaddr*, socklen_t, TransId tid, + const InfoHash& infohas, const Value& data, + const Blob& token, int confirm); + + int sendValueAnnounced(const sockaddr*, socklen_t, TransId, Value::Id); + + int sendError(const sockaddr*, socklen_t, TransId tid, int code, const char *message); + + void processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, socklen_t fromlen); + MessageType parseMessage(const uint8_t *buf, size_t buflen, + TransId& tid, + InfoHash& id_return, InfoHash& info_hash_return, + InfoHash& target_return, in_port_t& port_return, + Blob& token, Value::Id& value_id, + uint8_t *nodes_return, unsigned *nodes_len, + uint8_t *nodes6_return, unsigned *nodes6_len, + std::vector<std::shared_ptr<Value>>& values_return, + int *want_return, uint16_t& error_code); + + void rotateSecrets(); + + Blob makeToken(const sockaddr *sa, bool old) const; + bool tokenMatch(const Blob& token, const sockaddr *sa) const; + + // Storage + Storage* findStorage(const InfoHash& id); + const Storage* findStorage(const InfoHash& id) const { + return const_cast<Dht*>(this)->findStorage(id); + } + + ValueStorage* storageStore(const InfoHash& id, const std::shared_ptr<Value>& value); + void expireStorage(); + + // Buckets + Bucket* findBucket(const InfoHash& id, sa_family_t af) { + RoutingTable::iterator b; + switch (af) { + case AF_INET: + b = buckets.findBucket(id); + return b == buckets.end() ? nullptr : &(*b); + case AF_INET6: + b = buckets6.findBucket(id); + return b == buckets6.end() ? nullptr : &(*b); + default: + return nullptr; + } + } + const Bucket* findBucket(const InfoHash& id, sa_family_t af) const { + return const_cast<Dht*>(this)->findBucket(id, af); + } + + void expireBuckets(RoutingTable&); + int sendCachedPing(Bucket& b); + bool bucketMaintenance(RoutingTable&); + static unsigned insertClosestNode(uint8_t *nodes, unsigned numnodes, const InfoHash& id, const Node& n); + unsigned bufferClosestNodes(uint8_t *nodes, unsigned numnodes, const InfoHash& id, const Bucket& b) const; + void dumpBucket(const Bucket& b, std::ostream& out) const; + + // Nodes + Node* newNode(const InfoHash& id, const sockaddr*, socklen_t, int confirm); + Node* findNode(const InfoHash& id, sa_family_t af); + const Node* findNode(const InfoHash& id, sa_family_t af) const; + + void pinged(Node& n, Bucket *b = nullptr); + + void blacklistNode(const InfoHash* id, const sockaddr*, socklen_t); + bool isNodeBlacklisted(const sockaddr*, socklen_t) const; + static bool isMartian(const sockaddr*, socklen_t); + + // Searches + + /** + * Low-level method that will perform a search on the DHT for the + * specified infohash (id), using the specified IP version (IPv4 or IPv6). + * The values can be filtered by an arbitrary provided filter. + */ + Search* search(const InfoHash& id, sa_family_t af, GetCallback = nullptr, DoneCallback = nullptr, Value::Filter = Value::AllFilter()); + void announce(const InfoHash& id, sa_family_t af, const std::shared_ptr<Value>& value, DoneCallback callback); + + std::list<Search>::iterator newSearch(); + void bootstrapSearch(Search& sr); + Search *findSearch(unsigned short tid, sa_family_t af); + void expireSearches(); + bool searchSendGetValues(Search& sr, SearchNode *n = nullptr); + void searchStep(Search& sr); + void dumpSearch(const Search& sr, std::ostream& out) const; + + bool rateLimit(); + bool neighbourhoodMaintenance(RoutingTable&); + + static void *dht_memmem(const void *haystack, size_t haystacklen, const void *needle, size_t needlelen); + +}; + +} diff --git a/include/dhtcpp/dhtrunner.h b/include/dhtcpp/dhtrunner.h new file mode 100644 index 00000000..09b883ed --- /dev/null +++ b/include/dhtcpp/dhtrunner.h @@ -0,0 +1,165 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#pragma once + +#include "securedht.h" + +#include <thread> +#include <random> +#include <mutex> +#include <atomic> +#include <condition_variable> +#include <exception> + +#include <unistd.h> // close(fd) + +namespace dht { + +/** + * Provides a thread-safe interface to run the (secure) DHT. + * The class will open sockets on the provided port and will + * either wait for (expectedly frequent) calls to loop() or start an internal + * thread that will update the DHT when appropriate. + */ +class DhtRunner { + +public: + typedef std::function<void(Dht::Status, Dht::Status)> StatusCallback; + + DhtRunner() {} + virtual ~DhtRunner() { + join(); + } + + void get(InfoHash hash, Dht::GetCallback vcb, Dht::DoneCallback dcb=nullptr, Value::Filter f = Value::AllFilter()); + void get(const std::string& key, Dht::GetCallback vcb, Dht::DoneCallback dcb=nullptr, Value::Filter f = Value::AllFilter()); + + void put(InfoHash hash, Value&& value, Dht::DoneCallback cb=nullptr); + void put(const std::string& key, Value&& value, Dht::DoneCallback cb=nullptr); + + void putSigned(InfoHash hash, Value&& value, Dht::DoneCallback cb=nullptr); + void putSigned(const std::string& key, Value&& value, Dht::DoneCallback cb=nullptr); + + void putEncrypted(InfoHash hash, InfoHash to, Value&& value, Dht::DoneCallback cb=nullptr); + void putEncrypted(const std::string& key, InfoHash to, Value&& value, Dht::DoneCallback cb=nullptr); + + void bootstrap(const std::vector<sockaddr_storage>& nodes); + void bootstrap(const std::vector<Dht::NodeExport>& nodes); + + void dumpTables() const + { + std::unique_lock<std::mutex> lck(dht_mtx); + dht->dumpTables(); + } + + InfoHash getId() const { + if (!dht) + return {}; + return dht->getId(); + } + + std::vector<Dht::NodeExport> exportNodes() const { + std::unique_lock<std::mutex> lck(dht_mtx); + if (!dht) + return {}; + return dht->exportNodes(); + } + + std::vector<Dht::ValuesExport> exportValues() const { + std::unique_lock<std::mutex> lck(dht_mtx); + if (!dht) + return {}; + return dht->exportValues(); + } + + void setLoggers(LogMethod&& error = NOLOG, LogMethod&& warn = NOLOG, LogMethod&& debug = NOLOG) { + std::unique_lock<std::mutex> lck(dht_mtx); + dht->setLoggers(std::forward<LogMethod>(error), std::forward<LogMethod>(warn), std::forward<LogMethod>(debug)); + } + + void registerType(const ValueType& type) { + std::unique_lock<std::mutex> lck(dht_mtx); + dht->registerType(type); + } + + void importValues(const std::vector<Dht::ValuesExport>& values) { + std::unique_lock<std::mutex> lck(dht_mtx); + dht->importValues(values); + } + + bool isRunning() const { + return running; + } + + /** + * If threaded is false, loop() must be called periodically. + */ + void run(in_port_t port, const crypto::Identity identity, bool threaded = false, StatusCallback cb = nullptr); + + void loop() { + std::unique_lock<std::mutex> lck(dht_mtx); + loop_(); + } + + void join(); + +private: + + void doRun(in_port_t port, const crypto::Identity identity); + void loop_(); + + std::unique_ptr<SecureDht> dht {}; + mutable std::mutex dht_mtx {}; + std::thread dht_thread {}; + std::condition_variable cv {}; + + std::thread rcv_thread {}; + std::mutex sock_mtx {}; + std::vector<std::pair<Blob, sockaddr_storage>> rcv {}; + std::atomic<time_t> tosleep {0}; + + // IPC temporary storage + std::vector<std::tuple<InfoHash, Dht::GetCallback, Dht::DoneCallback, Value::Filter>> dht_gets {}; + std::vector<std::tuple<InfoHash, Value, Dht::DoneCallback>> dht_puts {}; + std::vector<std::tuple<InfoHash, Value, Dht::DoneCallback>> dht_sputs {}; + std::vector<std::tuple<InfoHash, InfoHash, Value, Dht::DoneCallback>> dht_eputs {}; + std::vector<sockaddr_storage> bootstrap_ips {}; + std::vector<Dht::NodeExport> bootstrap_nodes {}; + std::mutex storage_mtx {}; + + std::atomic<bool> running {false}; + + Dht::Status status4 {Dht::Status::Disconnected}, + status6 {Dht::Status::Disconnected}; + StatusCallback statusCb {nullptr}; +}; + +} diff --git a/include/dhtcpp/infohash.h b/include/dhtcpp/infohash.h new file mode 100644 index 00000000..31b82128 --- /dev/null +++ b/include/dhtcpp/infohash.h @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#pragma once + +#include <iostream> +#include <iomanip> +#include <array> +#include <vector> + +#include <cstring> + +// bytes +#define HASH_LEN 20 + +namespace dht { + +class DhtException : public std::runtime_error { + public: + DhtException(const std::string &str = "") : + std::runtime_error("DhtException occured: " + str) {} +}; + + +/** + * Represents an InfoHash. + * An InfoHash is a byte array of HASH_LEN bytes. + * InfoHashes identify nodes and values in the Dht. + */ +class InfoHash final : public std::array<uint8_t, HASH_LEN> { +public: + constexpr InfoHash() : std::array<uint8_t, HASH_LEN>() {} + constexpr InfoHash(const std::array<uint8_t, HASH_LEN>& h) : std::array<uint8_t, HASH_LEN>(h) {} + InfoHash(const uint8_t* h, size_t h_len=HASH_LEN) : std::array<uint8_t, HASH_LEN>() { + memcpy(data(), h, std::min((size_t)HASH_LEN, h_len)); + } + + /** + * Constructor from an hexadecimal string (without "0x"). + * hex must be at least 2.HASH_LEN characters long. + * If too long, only the first 2.HASH_LEN characters are read. + */ + InfoHash(const std::string& hex); + + /** + * Find the lowest 1 bit in an id. + * Result will allways be lower than 8*HASH_LEN + */ + inline unsigned lowbit() const { + int i, j; + for(i = HASH_LEN-1; i >= 0; i--) + if((*this)[i] != 0) + break; + if(i < 0) + return -1; + for(j = 7; j >= 0; j--) + if(((*this)[i] & (0x80 >> j)) != 0) + break; + return 8 * i + j; + } + + /** + * Forget about the ``XOR-metric''. An id is just a path from the + * root of the tree, so bits are numbered from the start. + */ + static inline int cmp(const InfoHash& __restrict__ id1, const InfoHash& __restrict__ id2) { + return std::memcmp(id1.data(), id2.data(), HASH_LEN); + } + + /** Find how many bits two ids have in common. */ + static inline unsigned + commonBits(const InfoHash& id1, const InfoHash& id2) + { + unsigned i, j; + uint8_t x; + for(i = 0; i < HASH_LEN; i++) { + if(id1[i] != id2[i]) + break; + } + + if(i == HASH_LEN) + return 8*HASH_LEN; + + x = id1[i] ^ id2[i]; + + j = 0; + while((x & 0x80) == 0) { + x <<= 1; + j++; + } + + return 8 * i + j; + } + + /** Determine whether id1 or id2 is closer to this */ + int + xorCmp(const InfoHash& id1, const InfoHash& id2) const + { + unsigned i; + for(i = 0; i < HASH_LEN; i++) { + uint8_t xor1, xor2; + if(id1[i] == id2[i]) + continue; + xor1 = id1[i] ^ (*this)[i]; + xor2 = id2[i] ^ (*this)[i]; + if(xor1 < xor2) + return -1; + else + return 1; + } + return 0; + } + + static inline InfoHash get(const std::string& data) { + return get((const uint8_t*)data.data(), data.size()); + } + + static inline InfoHash get(const std::vector<uint8_t>& data) { + return get(data.data(), data.size()); + } + + /** + * Computes the hash from a given data buffer of size data_len. + */ + static InfoHash get(const uint8_t* data, size_t data_len); + + friend std::ostream& operator<< (std::ostream& s, const InfoHash& h); + + std::string toString() const; +}; + +} diff --git a/include/dhtcpp/securedht.h b/include/dhtcpp/securedht.h new file mode 100644 index 00000000..d44dee06 --- /dev/null +++ b/include/dhtcpp/securedht.h @@ -0,0 +1,158 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#pragma once + +#include "dht.h" +#include "crypto.h" + +#include <map> +#include <vector> +#include <memory> + +namespace dht { + +class SecureDht : private Dht { +public: + + typedef std::function<void(bool)> SignatureCheckCallback; + + SecureDht() {} + + /** + * s, s6: bound socket descriptors for IPv4 and IPv6, respectively. + * For the Dht to be initialised, at least one of them must be >= 0. + * id: the identity to use for the crypto layer and to compute + * our own hash on the Dht. + */ + SecureDht(int s, int s6, crypto::Identity id); + + virtual ~SecureDht(); + + using Dht::periodic; + using Dht::pingNode; + using Dht::insertNode; + using Dht::exportNodes; + using Dht::exportValues; + using Dht::importValues; + using Dht::getStatus; + using Dht::dumpTables; + using Dht::put; + using Dht::setLoggers; + + InfoHash getId() const { + return key_->getPublicKey().getId(); + } + + ValueType secureType(ValueType&& type); + + ValueType secureType(const ValueType& type) { + ValueType tmp_type = type; + return secureType(std::move(tmp_type)); + } + + virtual void registerType(const ValueType& type) { + Dht::registerType(secureType(type)); + } + virtual void registerType(ValueType&& type) { + Dht::registerType(secureType(std::forward<ValueType>(type))); + } + virtual void registerInsecureType(const ValueType& type) { + Dht::registerType(type); + } + + /** + * "Secure" get(), that will check the signature of signed data, and decrypt encrypted data. + * If the signature can't be checked, or if the data can't be decrypted, it is not returned. + * Public, non-signed & non-encrypted data is retransmitted as-is. + */ + void get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter = Value::AllFilter()); + + /** + * Will take ownership of the value, sign it using our private key and put it in the DHT. + */ + void putSigned(const InfoHash& hash, Value&& data, DoneCallback callback); + + /** + * Will sign the data using our private key, encrypt it using the recipient' public key, + * and put it in the DHT. + * The operation will be immediate if the recipient' public key is known (otherwise it will be retrived first). + */ + void putEncrypted(const InfoHash& hash, const InfoHash& to, const std::shared_ptr<Value>& val, DoneCallback callback); + void putEncrypted(const InfoHash& hash, const InfoHash& to, Value&& v, DoneCallback callback) { + putEncrypted(hash, to, std::make_shared<Value>(std::move(v)), callback); + } + + /** + * Take ownership of the value and sign it using our private key. + */ + void sign(Value& v) const; + + Value encrypt(Value& v, const crypto::PublicKey& to) const; + + Value decrypt(const Value& v); + + void findCertificate(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::Certificate>)> cb); + + const std::shared_ptr<crypto::Certificate> registerCertificate(const InfoHash& node, const Blob& publicKey); + const std::shared_ptr<crypto::Certificate> getCertificate(const InfoHash& node) const; + +private: + // prevent copy + SecureDht(const SecureDht&) = delete; + SecureDht& operator=(const SecureDht&) = delete; + + std::shared_ptr<crypto::PrivateKey> key_ {}; + std::shared_ptr<crypto::Certificate> certificate_ {}; + + std::map<InfoHash, std::shared_ptr<crypto::Certificate>> nodesCertificates_ {}; +}; + +const ValueType CERTIFICATE_TYPE = {8, "Certificate", 60 * 60 * 24 * 7, + // A certificate can only be stored at it's public key ID. + [](InfoHash id, std::shared_ptr<Value>& v, InfoHash, const sockaddr*, socklen_t) { + try { + crypto::Certificate crt(v->data); + // TODO check certificate signature + return crt.getPublicKey().getId() == id; + } catch (const std::exception& e) {} + return false; + }, + [](InfoHash id, const std::shared_ptr<Value>& o, std::shared_ptr<Value>& n, InfoHash, const sockaddr*, socklen_t) { + try { + crypto::Certificate crt_old(o->data); + crypto::Certificate crt_new(n->data); + return crt_old.getPublicKey().getId() == crt_new.getPublicKey().getId(); + } catch (const std::exception& e) {} + return false; + } +}; + +} diff --git a/include/dhtcpp/serialize.h b/include/dhtcpp/serialize.h new file mode 100644 index 00000000..7b0fc75c --- /dev/null +++ b/include/dhtcpp/serialize.h @@ -0,0 +1,316 @@ +/** + * Copyright (c) 2013, Simone Pellegrini All rights reserved. + * Copyright (c) 2014 Savoir-Faire Linux. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * - Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * - Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#pragma once + +#include <vector> +#include <string> +#include <tuple> +#include <numeric> +#include <limits> + +typedef std::vector<uint8_t> Blob; + +template <class T> +inline void serialize(const T&, Blob&); + +namespace detail { + + template<std::size_t> struct int_{}; + +} + +// get_size +template <class T> +size_t get_size(const T& obj); + +namespace detail { + + typedef uint16_t serialized_size_t; + + template <class T> + struct get_size_helper; + + template <class T> + struct get_size_helper<std::vector<T>> { + static size_t value(const std::vector<T>& obj) { + return std::accumulate(obj.begin(), obj.end(), sizeof(serialized_size_t), + [](const size_t& acc, const T& cur) { return acc+get_size(cur); }); + } + }; + + template <> + struct get_size_helper<std::string> { + static size_t value(const std::string& obj) { + return sizeof(serialized_size_t) + obj.length()*sizeof(uint8_t); + } + }; + + template <class tuple_type> + inline size_t get_tuple_size(const tuple_type& obj, int_<0>) { + constexpr size_t idx = std::tuple_size<tuple_type>::value-1; + return get_size(std::get<idx>(obj)); + } + + template <class tuple_type, size_t pos> + inline size_t get_tuple_size(const tuple_type& obj, int_<pos>) { + constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; + size_t acc = get_size(std::get<idx>(obj)); + + // recur + return acc+get_tuple_size(obj, int_<pos-1>()); + } + + template <class ...T> + struct get_size_helper<std::tuple<T...>> { + static size_t value(const std::tuple<T...>& obj) { + return get_tuple_size(obj, int_<sizeof...(T)-1>()); + } + }; + + template <class T> + struct get_size_helper { + static size_t value(const T&) { return sizeof(T); } + }; + +} + +template <class T> +inline size_t get_size(const T& obj) { + return detail::get_size_helper<T>::value(obj); +} + +namespace detail { + + template <class T> + class serialize_helper; + + template <class T> + void serializer(const T& obj, Blob::iterator&); + + template <class tuple_type> + inline void serialize_tuple(const tuple_type& obj, Blob::iterator& res, int_<0>) { + constexpr size_t idx = std::tuple_size<tuple_type>::value-1; + serializer(std::get<idx>(obj), res); + } + + template <class tuple_type, size_t pos> + inline void serialize_tuple(const tuple_type& obj, Blob::iterator& res, int_<pos>) { + constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; + serializer(std::get<idx>(obj), res); + + // recur + serialize_tuple(obj, res, int_<pos-1>()); + } + + template <class... T> + struct serialize_helper<std::tuple<T...>> { + static void apply(const std::tuple<T...>& obj, Blob::iterator& res) { + detail::serialize_tuple(obj, res, detail::int_<sizeof...(T)-1>()); + } + + }; + + template <> + struct serialize_helper<std::string> { + static void apply(const std::string& obj, Blob::iterator& res) { + // store the number of elements of this vector at the beginning + if (obj.length() > std::numeric_limits<serialized_size_t>::max()) + throw std::length_error("string is too long"); + serializer(static_cast<serialized_size_t>(obj.length()), res); + for(const auto& cur : obj) { serializer(cur, res); } + } + + }; + + template <class T> + struct serialize_helper<std::vector<T>> { + static void apply(const std::vector<T>& obj, Blob::iterator& res) { + // store the number of elements of this vector at the beginning + if (obj.size() > std::numeric_limits<serialized_size_t>::max()) + throw std::length_error("vector is too large"); + serializer(static_cast<serialized_size_t>(obj.size()), res); + for(const auto& cur : obj) { serializer(cur, res); } + } + + }; + + template <class T> + struct serialize_helper { + static void apply(const T& obj, Blob::iterator& res) { + const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&obj); + std::copy(ptr,ptr+sizeof(T),res); + res+=sizeof(T); + } + + }; + + template <class T> + inline void serializer(const T& obj, Blob::iterator& res) { + serialize_helper<T>::apply(obj,res); + } + +} // end detail namespace + +template <class T> +inline void serialize(const T& obj, Blob& res) { + + size_t offset = res.size(); + size_t size = get_size(obj); + res.resize(res.size() + size); + + Blob::iterator it = res.begin()+offset; + detail::serializer(obj,it); + if (res.begin() + offset + size != it) + throw std::logic_error("error serializing object"); +} + +namespace detail { + + template <class T> + struct deserialize_helper; + + template <class T> + struct deserialize_helper { + static T apply(Blob::const_iterator& begin, + Blob::const_iterator end) { + if (begin+sizeof(T)>end) + throw std::length_error("error deserializing object"); + T val; + std::copy(begin, begin+sizeof(T), reinterpret_cast<uint8_t*>(&val)); + begin+=sizeof(T); + return val; + } + }; + + template <class T> + struct deserialize_helper<std::vector<T>> { + static std::vector<T> apply(Blob::const_iterator& begin, + Blob::const_iterator end) + { + // retrieve the number of elements + serialized_size_t size = deserialize_helper<serialized_size_t>::apply(begin,end); + + std::vector<T> vect(size); + for(size_t i=0; i<size; ++i) { + vect[i] = std::move(deserialize_helper<T>::apply(begin,end)); + } + return vect; + } + }; + + template <> + struct deserialize_helper<std::string> { + static std::string apply(Blob::const_iterator& begin, + Blob::const_iterator end) + { + // retrieve the number of elements + serialized_size_t size = deserialize_helper<serialized_size_t>::apply(begin,end); + + if (size == 0u) return std::string(); + std::string str(size,'\0'); + for(size_t i=0; i<size; ++i) { + str.at(i) = deserialize_helper<uint8_t>::apply(begin,end); + } + return str; + } + }; + + template <class tuple_type> + inline void deserialize_tuple(tuple_type& obj, + Blob::const_iterator& begin, + Blob::const_iterator end, int_<0>) { + constexpr size_t idx = std::tuple_size<tuple_type>::value-1; + typedef typename std::tuple_element<idx,tuple_type>::type T; + + std::get<idx>(obj) = std::move(deserialize_helper<T>::apply(begin, end)); + } + + template <class tuple_type, size_t pos> + inline void deserialize_tuple(tuple_type& obj, + Blob::const_iterator& begin, + Blob::const_iterator end, int_<pos>) { + constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; + typedef typename std::tuple_element<idx,tuple_type>::type T; + std::get<idx>(obj) = std::move(deserialize_helper<T>::apply(begin, end)); + + // recur + deserialize_tuple(obj, begin, end, int_<pos-1>()); + } + + template <class... T> + struct deserialize_helper<std::tuple<T...>> { + static std::tuple<T...> apply(Blob::const_iterator& begin, + Blob::const_iterator end) + { + //return std::make_tuple(deserialize(begin,begin+sizeof(T),T())...); + std::tuple<T...> ret; + deserialize_tuple(ret, begin, end, int_<sizeof...(T)-1>()); + return ret; + } + + }; + +} + +template <class T> +inline T deserialize(Blob::const_iterator& begin, const Blob::const_iterator& end) { + return detail::deserialize_helper<T>::apply(begin, end); +} + +template <class T> +inline T deserialize(const Blob& res) { + Blob::const_iterator it = res.begin(); + return deserialize<T>(it, res.end()); +} + +namespace dht { + + struct Serializable { + /** + * Append serialized object to res. + */ + virtual void pack(Blob& res) const = 0; + Blob getPacked() const { + Blob ret; + pack(ret); + return ret; + } + + /** + * Read serialized object from {begin, end}. + */ + virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) = 0; + void unpackBlob(const Blob& data) { + auto cib = data.cbegin(), cie = data.cend(); + unpack(cib, cie); + } + + virtual ~Serializable() = default; +}; + +} diff --git a/include/dhtcpp/value.h b/include/dhtcpp/value.h new file mode 100644 index 00000000..e10ce056 --- /dev/null +++ b/include/dhtcpp/value.h @@ -0,0 +1,328 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#pragma once + +#include "infohash.h" +#include "crypto.h" +#include "serialize.h" + +#include <sys/socket.h> +#include <netdb.h> + +#include <string> +#include <sstream> +#include <bitset> +#include <vector> +#include <iostream> +#include <algorithm> +#include <functional> +#include <memory> + +#include <cstdarg> + +namespace dht { + +/** + * Wrapper for logging methods + */ +struct LogMethod { + LogMethod() = default; + + template<typename T> + LogMethod( T&& t) : func(std::forward<T>(t)) {} + + void operator()(char const* format, ...) const { + va_list args; + va_start(args, format); + func(format, args); + va_end(args); + } + + void logPrintable(const uint8_t *buf, size_t buflen) const { + std::string buf_clean(buflen, '\0'); + for (size_t i=0; i<buflen; i++) + buf_clean[i] = buf[i] >= 32 && buf[i] <= 126 ? buf[i] : '.'; + (*this)("%s", buf_clean.c_str()); + } +private: + std::function<void(char const*, va_list)> func; +}; + +/** + * Dummy function used to disable logging + */ +constexpr int NOLOG(char const*, va_list) { return 0; } + + +class Value; + +typedef std::function<bool(InfoHash, std::shared_ptr<Value>&, InfoHash, const sockaddr*, socklen_t)> StorePolicy; +typedef std::function<bool(InfoHash, const std::shared_ptr<Value>&, std::shared_ptr<Value>&, InfoHash, const sockaddr*, socklen_t)> EditPolicy; + +struct ValueType { + typedef uint16_t Id; + ValueType () {} + + ValueType (Id id, std::string name, time_t e = 60 * 60) + : id(id), name(name), expiration(e) {} + + ValueType (Id id, std::string name, time_t e, StorePolicy&& sp, EditPolicy&& ep) + : id(id), name(name), expiration(e), storePolicy(std::move(sp)), editPolicy(std::move(ep)) {} + + virtual ~ValueType() {} + + bool operator==(const ValueType& o) { + return id == o.id; + } + + // Generic value type + static const ValueType USER_DATA; + + static bool DEFAULT_STORE_POLICY(InfoHash, std::shared_ptr<Value>&, InfoHash, const sockaddr*, socklen_t) { + return true; + } + static bool DEFAULT_EDIT_POLICY(InfoHash, const std::shared_ptr<Value>&, std::shared_ptr<Value>&, InfoHash, const sockaddr*, socklen_t) { + return false; + } + + Id id {0}; + std::string name {}; + time_t expiration {60 * 60}; + StorePolicy storePolicy {DEFAULT_STORE_POLICY}; + EditPolicy editPolicy {DEFAULT_EDIT_POLICY}; +}; + +/** + * A "value" is data potentially stored on the Dht, with some metadata. + * + * It can be an IP:port announced for a service, a public key, or any kind of + * light user-defined data (recommended: less than 512 bytes). + * + * Values are stored at a given InfoHash in the Dht, but also have a + * unique ID to distinguish between values stored at the same location. + */ +struct Value : public Serializable +{ + typedef uint64_t Id; + static const Id INVALID_ID {0}; + + typedef std::function<bool(const Value&)> Filter; + + static const Filter AllFilter() { + return [](const Value&){return true;}; + } + + static Filter TypeFilter(const ValueType& t) { + const auto tid = t.id; + return [tid](const Value& v) { + return v.type == tid; + }; + } + + static Filter chainFilters(Filter& f1, Filter& f2) { + return [f1,f2](const Value& v){ + return f1(v) && f2(v); + }; + } + + /** + * Hold information about how the data is signed/encrypted. + * Class is final because bitset have no virtual destructor. + */ + class ValueFlags final : public std::bitset<3> { + public: + using std::bitset<3>::bitset; + ValueFlags() {} + ValueFlags(bool sign, bool encrypted, bool have_recipient = false) : bitset<3>((sign ? 1:0) | (encrypted ? 2:0) | (have_recipient ? 4:0)) {} + bool isSigned() const { + return (*this)[0]; + } + bool isEncrypted() const { + return (*this)[1]; + } + bool haveRecipient() const { + return (*this)[2]; + } + }; + + bool isEncrypted() const { + return flags.isEncrypted(); + } + bool isSigned() const { + return flags.isSigned(); + } + + Value() {} + + Value (Id id) : id(id) {} + + /** Generic constructor */ + Value(ValueType::Id t, const Blob& data, Id id = INVALID_ID) + : id(id), type(t), data(data) {} + Value(ValueType::Id t, const Serializable& d, Id id = INVALID_ID) + : id(id), type(t), data(d.getPacked()) {} + Value(const ValueType& t, const Serializable& d, Id id = INVALID_ID) + : id(id), type(t.id), data(d.getPacked()) {} + + /** Custom user data constructor */ + Value(const Blob& userdata) : data(userdata) {} + Value(Blob&& userdata) : data(std::move(userdata)) {} + + Value(Value&& o) noexcept + : id(o.id), flags(o.flags), owner(std::move(o.owner)), recipient(o.recipient), + type(o.type), data(std::move(o.data)), seq(o.seq), signature(std::move(o.signature)), cypher(std::move(o.cypher)) {} + + inline bool operator== (const Value& o) { + return id == o.id && + (flags.isEncrypted() ? cypher == o.cypher : + (owner == o.owner && type == o.type && data == o.data && signature == o.signature)); + } + + void setRecipient(const InfoHash& r) { + recipient = r; + flags[2] = true; + } + + void setCypher(Blob&& c) { + cypher = std::move(c); + flags = {true, true, true}; + } + + /** + * Pack part of the data to be signed + */ + void packToSign(Blob& res) const; + Blob getToSign() const; + + /** + * Pack part of the data to be encrypted + */ + void packToEncrypt(Blob& res) const; + Blob getToEncrypt() const; + + void pack(Blob& res) const; + + void unpackBody(Blob::const_iterator& begin, Blob::const_iterator& end); + virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end); + + /** print value for debugging */ + friend std::ostream& operator<< (std::ostream& s, const Value& v); + + std::string toString() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + Id id {INVALID_ID}; + + // data (part that is signed / encrypted) + + ValueFlags flags {}; + + /** + * Public key of the signer. + */ + crypto::PublicKey owner {}; + + /** + * Hash of the recipient (optional). + * Should only be present for encrypted values. + * Can optionally be present for signed values. + */ + InfoHash recipient {}; + + /** + * Type of data. + */ + ValueType::Id type {ValueType::USER_DATA.id}; + Blob data {}; + + /** + * Sequence number to avoid replay attacks + */ + uint16_t seq {0}; + + /** + * Optional signature. + */ + Blob signature {}; + + /** + * Hold encrypted version of the data. + */ + Blob cypher {}; +}; + + +/* "Peer" announcement + */ +struct ServiceAnnouncement : public Serializable +{ + ServiceAnnouncement(in_port_t p = 0) { + ss.ss_family = 0; + setPort(p); + } + + ServiceAnnouncement(const sockaddr* sa, socklen_t sa_len) { + if (sa) + std::copy_n((const uint8_t*)sa, sa_len, (uint8_t*)&ss); + } + + ServiceAnnouncement(const Blob& b) { + unpackBlob(b); + } + + virtual void pack(Blob& res) const; + virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end); + + in_port_t getPort() const { + return ntohs(reinterpret_cast<const sockaddr_in*>(&ss)->sin_port); + } + void setPort(in_port_t p) { + reinterpret_cast<sockaddr_in*>(&ss)->sin_port = htons(p); + } + + sockaddr_storage getPeerAddr() const { + return ss; + } + + static const ValueType TYPE; + static bool storePolicy(InfoHash, std::shared_ptr<Value>&, InfoHash, const sockaddr*, socklen_t); + + /** print value for debugging */ + friend std::ostream& operator<< (std::ostream&, const ServiceAnnouncement&); + +private: + sockaddr_storage ss; +}; + +} diff --git a/m4/ax_cxx_compile_stdcxx_11.m4 b/m4/ax_cxx_compile_stdcxx_11.m4 new file mode 100644 index 00000000..5c10a764 --- /dev/null +++ b/m4/ax_cxx_compile_stdcxx_11.m4 @@ -0,0 +1,135 @@ +# ============================================================================ +# http://www.gnu.org/software/autoconf-archive/ax_cxx_compile_stdcxx_11.html +# ============================================================================ +# +# SYNOPSIS +# +# AX_CXX_COMPILE_STDCXX_11([ext|noext],[mandatory|optional]) +# +# DESCRIPTION +# +# Check for baseline language coverage in the compiler for the C++11 +# standard; if necessary, add switches to CXXFLAGS to enable support. +# +# The first argument, if specified, indicates whether you insist on an +# extended mode (e.g. -std=gnu++11) or a strict conformance mode (e.g. +# -std=c++11). If neither is specified, you get whatever works, with +# preference for an extended mode. +# +# The second argument, if specified 'mandatory' or if left unspecified, +# indicates that baseline C++11 support is required and that the macro +# should error out if no mode with that support is found. If specified +# 'optional', then configuration proceeds regardless, after defining +# HAVE_CXX11 if and only if a supporting mode is found. +# +# LICENSE +# +# Copyright (c) 2008 Benjamin Kosnik <bkoz@redhat.com> +# Copyright (c) 2012 Zack Weinberg <zackw@panix.com> +# Copyright (c) 2013 Roy Stogner <roystgnr@ices.utexas.edu> +# Copyright (c) 2014 Alexey Sokolov <sokolov@google.com> +# +# Copying and distribution of this file, with or without modification, are +# permitted in any medium without royalty provided the copyright notice +# and this notice are preserved. This file is offered as-is, without any +# warranty. + +#serial 4 + +m4_define([_AX_CXX_COMPILE_STDCXX_11_testbody], [[ + template <typename T> + struct check + { + static_assert(sizeof(int) <= sizeof(T), "not big enough"); + }; + + typedef check<check<bool>> right_angle_brackets; + + int a; + decltype(a) b; + + typedef check<int> check_type; + check_type c; + check_type&& cr = static_cast<check_type&&>(c); + + auto d = a; + auto l = [](){}; +]]) + +AC_DEFUN([AX_CXX_COMPILE_STDCXX_11], [dnl + m4_if([$1], [], [], + [$1], [ext], [], + [$1], [noext], [], + [m4_fatal([invalid argument `$1' to AX_CXX_COMPILE_STDCXX_11])])dnl + m4_if([$2], [], [ax_cxx_compile_cxx11_required=true], + [$2], [mandatory], [ax_cxx_compile_cxx11_required=true], + [$2], [optional], [ax_cxx_compile_cxx11_required=false], + [m4_fatal([invalid second argument `$2' to AX_CXX_COMPILE_STDCXX_11])]) + AC_LANG_PUSH([C++])dnl + ac_success=no + AC_CACHE_CHECK(whether $CXX supports C++11 features by default, + ax_cv_cxx_compile_cxx11, + [AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], + [ax_cv_cxx_compile_cxx11=yes], + [ax_cv_cxx_compile_cxx11=no])]) + if test x$ax_cv_cxx_compile_cxx11 = xyes; then + ac_success=yes + fi + + m4_if([$1], [noext], [], [dnl + if test x$ac_success = xno; then + for switch in -std=gnu++11 -std=gnu++0x; do + cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx11_$switch]) + AC_CACHE_CHECK(whether $CXX supports C++11 features with $switch, + $cachevar, + [ac_save_CXXFLAGS="$CXXFLAGS" + CXXFLAGS="$CXXFLAGS $switch" + AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], + [eval $cachevar=yes], + [eval $cachevar=no]) + CXXFLAGS="$ac_save_CXXFLAGS"]) + if eval test x\$$cachevar = xyes; then + CXXFLAGS="$CXXFLAGS $switch" + ac_success=yes + break + fi + done + fi]) + + m4_if([$1], [ext], [], [dnl + if test x$ac_success = xno; then + for switch in -std=c++11 -std=c++0x; do + cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx11_$switch]) + AC_CACHE_CHECK(whether $CXX supports C++11 features with $switch, + $cachevar, + [ac_save_CXXFLAGS="$CXXFLAGS" + CXXFLAGS="$CXXFLAGS $switch" + AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], + [eval $cachevar=yes], + [eval $cachevar=no]) + CXXFLAGS="$ac_save_CXXFLAGS"]) + if eval test x\$$cachevar = xyes; then + CXXFLAGS="$CXXFLAGS $switch" + ac_success=yes + break + fi + done + fi]) + AC_LANG_POP([C++]) + if test x$ax_cxx_compile_cxx11_required = xtrue; then + if test x$ac_success = xno; then + AC_MSG_ERROR([*** A compiler with support for C++11 language features is required.]) + fi + else + if test x$ac_success = xno; then + HAVE_CXX11=0 + AC_MSG_NOTICE([No compiler with C++11 support was found]) + else + HAVE_CXX11=1 + AC_DEFINE(HAVE_CXX11,1, + [define if the compiler supports basic C++11 syntax]) + fi + + AC_SUBST(HAVE_CXX11) + fi +]) diff --git a/src/Makefile.am b/src/Makefile.am new file mode 100644 index 00000000..da89a5e1 --- /dev/null +++ b/src/Makefile.am @@ -0,0 +1,22 @@ +lib_LTLIBRARIES = libdhtcpp.la + +AM_CPPFLAGS = -I../include/dhtcpp +libdhtcpp_la_CXXFLAGS = @CXXFLAGS@ + +libdhtcpp_la_SOURCES = \ + dht.cpp \ + infohash.cpp \ + value.cpp \ + crypto.cpp \ + securedht.cpp \ + dhtrunner.cpp + +nobase_include_HEADERS = \ + ../include/dht.h \ + ../include/dhtcpp/dht.h \ + ../include/dhtcpp/infohash.h \ + ../include/dhtcpp/value.h \ + ../include/dhtcpp/crypto.h \ + ../include/dhtcpp/securedht.h \ + ../include/dhtcpp/dhtrunner.h \ + ../include/dhtcpp/serialize.h diff --git a/src/crypto.cpp b/src/crypto.cpp new file mode 100644 index 00000000..c358039f --- /dev/null +++ b/src/crypto.cpp @@ -0,0 +1,404 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#include "securedht.h" + +extern "C" { +#include <gnutls/gnutls.h> +#include <gnutls/abstract.h> +#include <gnutls/x509.h> +} + +#include <random> +#include <sstream> +#include <random> + +static gnutls_digest_algorithm_t get_dig_for_pub(gnutls_pubkey_t pubkey) +{ + gnutls_digest_algorithm_t dig; + int result = gnutls_pubkey_get_preferred_hash_algorithm(pubkey, &dig, nullptr); + if (result < 0) + return GNUTLS_DIG_UNKNOWN; + return dig; +} + +static gnutls_digest_algorithm_t get_dig(gnutls_x509_crt_t crt) +{ + gnutls_pubkey_t pubkey; + gnutls_pubkey_init(&pubkey); + + int result = gnutls_pubkey_import_x509(pubkey, crt, 0); + if (result < 0) { + gnutls_pubkey_deinit(pubkey); + return GNUTLS_DIG_UNKNOWN; + } + + gnutls_digest_algorithm_t dig = get_dig_for_pub(pubkey); + gnutls_pubkey_deinit(pubkey); + return dig; +} + +namespace dht { +namespace crypto { + +PrivateKey::PrivateKey(gnutls_x509_privkey_t k) : x509_key(k) +{ + gnutls_privkey_init(&key); + if (gnutls_privkey_import_x509(key, k, GNUTLS_PRIVKEY_IMPORT_COPY) != GNUTLS_E_SUCCESS) { + key = nullptr; + throw DhtException("Can't load private key !"); + } +} + +PrivateKey::PrivateKey(const Blob& import) +{ + const gnutls_datum_t dt {(unsigned char*)import.data(), static_cast<unsigned>(import.size())}; + gnutls_x509_privkey_init(&x509_key); + int err = gnutls_x509_privkey_import(x509_key, &dt, GNUTLS_X509_FMT_PEM); + if (err != GNUTLS_E_SUCCESS) { + err = gnutls_x509_privkey_import(x509_key, &dt, GNUTLS_X509_FMT_DER); + } + if (err != GNUTLS_E_SUCCESS) { + gnutls_x509_privkey_deinit(x509_key); + throw DhtException("Can't load private key !"); + } + gnutls_privkey_init(&key); + if (gnutls_privkey_import_x509(key, x509_key, GNUTLS_PRIVKEY_IMPORT_COPY) != GNUTLS_E_SUCCESS) { + throw DhtException("Can't load private key !"); + } +} + +PrivateKey::~PrivateKey() +{ + if (key) { + gnutls_privkey_deinit(key); + key = nullptr; + } + if (x509_key) { + gnutls_x509_privkey_deinit(x509_key); + x509_key = nullptr; + } +} + +PrivateKey& +PrivateKey::operator=(PrivateKey&& o) noexcept +{ + if (key) { + gnutls_privkey_deinit(key); + key = nullptr; + } + if (x509_key) { + gnutls_x509_privkey_deinit(x509_key); + x509_key = nullptr; + } + key = o.key; x509_key = o.x509_key; + o.key = nullptr; o.x509_key = nullptr; + return *this; +} + +Blob +PrivateKey::sign(const Blob& data) const +{ + if (!key) + throw DhtException("Can't sign data: no private key set !"); + gnutls_datum_t sig; + const gnutls_datum_t dat {(unsigned char*)data.data(), (unsigned)data.size()}; + if (gnutls_privkey_sign_data(key, GNUTLS_DIG_SHA512, 0, &dat, &sig) != GNUTLS_E_SUCCESS) + throw DhtException("Can't sign data !"); + Blob ret(sig.data, sig.data+sig.size); + gnutls_free(sig.data); + return ret; +} + +Blob +PrivateKey::decrypt(const Blob& cipher) const +{ + if (!key) + throw DhtException("Can't decrypt data without private key !"); + const gnutls_datum_t dat {(uint8_t*)cipher.data(), (unsigned)cipher.size()}; + gnutls_datum_t out; + if (gnutls_privkey_decrypt_data(key, 0, &dat, &out) != GNUTLS_E_SUCCESS) + throw DhtException("Can't decrypt data !"); + Blob ret {out.data, out.data+out.size}; + gnutls_free(out.data); + return ret; +} + +Blob +PrivateKey::serialize() const +{ + if (!x509_key) + return {}; + size_t buf_sz = 8192; + Blob buffer; + buffer.resize(buf_sz); + int err = gnutls_x509_privkey_export_pkcs8(x509_key, GNUTLS_X509_FMT_PEM, nullptr, GNUTLS_PKCS_PLAIN, buffer.data(), &buf_sz); + if (err != GNUTLS_E_SUCCESS) { + std::cerr << "Could not export certificate - " << gnutls_strerror(err) << std::endl; + return {}; + } + buffer.resize(buf_sz); + return buffer; +} + +PublicKey +PrivateKey::getPublicKey() const +{ + gnutls_pubkey_t pk; + gnutls_pubkey_init(&pk); + PublicKey pk_ret {pk}; + if (gnutls_pubkey_import_privkey(pk, key, GNUTLS_KEY_KEY_CERT_SIGN | GNUTLS_KEY_CRL_SIGN, 0) != GNUTLS_E_SUCCESS) + return {}; + return pk_ret; +} + +PublicKey::PublicKey(const Blob& dat) : pk(nullptr) +{ + unpackBlob(dat); +} + +PublicKey::~PublicKey() +{ + if (pk) { + gnutls_pubkey_deinit(pk); + pk = nullptr; + } +} + +PublicKey& +PublicKey::operator=(PublicKey&& o) noexcept +{ + if (pk) + gnutls_pubkey_deinit(pk); + pk = o.pk; + o.pk = nullptr; + return *this; +} + +void +PublicKey::pack(Blob& b) const +{ + std::vector<uint8_t> tmp(2048); + size_t sz = tmp.size(); + int err = gnutls_pubkey_export(pk, GNUTLS_X509_FMT_DER, tmp.data(), &sz); + if (err != GNUTLS_E_SUCCESS) + throw std::invalid_argument(std::string("Could not export public key: ") + gnutls_strerror(err)); + tmp.resize(sz); + serialize<Blob>(tmp, b); +} + +void +PublicKey::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) +{ + Blob tmp = deserialize<Blob>(begin, end); + if (pk) + gnutls_pubkey_deinit(pk); + gnutls_pubkey_init(&pk); + const gnutls_datum_t dat {(uint8_t*)tmp.data(), (unsigned)tmp.size()}; + int err = gnutls_pubkey_import(pk, &dat, GNUTLS_X509_FMT_DER); + if (err != GNUTLS_E_SUCCESS) + throw std::invalid_argument(std::string("Could not read public key: ") + gnutls_strerror(err)); +} + +bool +PublicKey::checkSignature(const Blob& data, const Blob& signature) const { + if (!pk) + return false; + const gnutls_datum_t sig {(uint8_t*)signature.data(), (unsigned)signature.size()}; + const gnutls_datum_t dat {(uint8_t*)data.data(), (unsigned)data.size()}; + int rc = gnutls_pubkey_verify_data2(pk, GNUTLS_SIGN_RSA_SHA512, 0, &dat, &sig); + return rc >= 0; +} + +Blob +PublicKey::encrypt(const Blob& data) const +{ + if (!pk) + throw DhtException("Can't read public key !"); + const gnutls_datum_t dat {(uint8_t*)data.data(), (unsigned)data.size()}; + gnutls_datum_t encrypted; + int err = gnutls_pubkey_encrypt_data(pk, 0, &dat, &encrypted); + if (err != GNUTLS_E_SUCCESS) + throw DhtException(std::string("Can't encrypt data: ") + gnutls_strerror(err)); + Blob ret {encrypted.data, encrypted.data+encrypted.size}; + gnutls_free(encrypted.data); + return ret; +} + +InfoHash +PublicKey::getId() const +{ + InfoHash id; + size_t sz = id.size(); + gnutls_pubkey_get_key_id(pk, 0, id.data(), &sz); + return id; +} + +Certificate::Certificate(const Blob& certData) : cert(nullptr) +{ + unpackBlob(certData); +} + +Certificate& +Certificate::operator=(Certificate&& o) noexcept +{ + if (cert) + gnutls_x509_crt_deinit(cert); + cert = o.cert; + o.cert = nullptr; + return *this; +} + +void +Certificate::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) +{ + if (cert) + gnutls_x509_crt_deinit(cert); + gnutls_x509_crt_init(&cert); + const gnutls_datum_t crt_dt {(uint8_t*)&(*begin), (unsigned)(end-begin)}; + int err = gnutls_x509_crt_import(cert, &crt_dt, GNUTLS_X509_FMT_PEM); + if (err != GNUTLS_E_SUCCESS) { + cert = nullptr; + throw std::invalid_argument(std::string("Could not read certificate - ") + gnutls_strerror(err)); + } +} + +void +Certificate::pack(Blob& b) const +{ + auto b_size = b.size(); + size_t buf_sz = 8192; + b.resize(b_size + buf_sz); + int err = gnutls_x509_crt_export(cert, GNUTLS_X509_FMT_PEM, b.data()+b_size, &buf_sz); + if (err != GNUTLS_E_SUCCESS) { + std::cerr << "Could not export certificate - " << gnutls_strerror(err) << std::endl; + b.resize(b_size); + } + b.resize(b_size + buf_sz); +} + +Certificate::~Certificate() +{ + if (cert) { + gnutls_x509_crt_deinit(cert); + cert = nullptr; + } +} + +PublicKey +Certificate::getPublicKey() const +{ + gnutls_pubkey_t pk; + gnutls_pubkey_init(&pk); + PublicKey pk_ret(pk); + if (gnutls_pubkey_import_x509(pk, cert, 0) != GNUTLS_E_SUCCESS) + return {}; + return pk_ret; +} + +PrivateKey +PrivateKey::generate() +{ + gnutls_x509_privkey_t key; + if (gnutls_x509_privkey_init(&key) != GNUTLS_E_SUCCESS) + throw std::runtime_error("Can't initialize private key."); + if (gnutls_x509_privkey_generate(key, GNUTLS_PK_RSA, 2048, 0) != GNUTLS_E_SUCCESS) { + gnutls_x509_privkey_deinit(key); + throw std::runtime_error("Can't initialize RSA key pair."); + } + return PrivateKey{key}; +} + +crypto::Identity +generateIdentity(const std::string& name, crypto::Identity ca) +{ + int rc = gnutls_global_init(); + if (rc != GNUTLS_E_SUCCESS) + return {}; + + auto shared_key = std::make_shared<PrivateKey>(PrivateKey::generate()); + + gnutls_x509_crt_t cert; + if (gnutls_x509_crt_init(&cert) != GNUTLS_E_SUCCESS) + return {}; + auto shared_crt = std::make_shared<Certificate>(cert); + + gnutls_x509_crt_set_activation_time(cert, time(NULL)); + gnutls_x509_crt_set_expiration_time(cert, time(NULL) + (700 * 24 * 60 * 60)); + if (gnutls_x509_crt_set_key(cert, shared_key->x509_key) != GNUTLS_E_SUCCESS) { + std::cerr << "Error when setting certificate key" << std::endl; + return {}; + } + if (gnutls_x509_crt_set_version(cert, 3) != GNUTLS_E_SUCCESS) { + std::cerr << "Error when setting certificate version" << std::endl; + return {}; + } + + // TODO: compute the subject key using the recommended RFC method + auto pk_id = shared_key->getPublicKey().getId(); + gnutls_x509_crt_set_subject_key_id(cert, &pk_id, sizeof(pk_id)); + + gnutls_x509_crt_set_dn_by_oid(cert, GNUTLS_OID_X520_COMMON_NAME, 0, name.data(), name.length()); + + const std::string& uid_str = shared_key->getPublicKey().getId().toString(); + gnutls_x509_crt_set_dn_by_oid(cert, GNUTLS_OID_LDAP_UID, 0, uid_str.data(), uid_str.length()); + + { + std::random_device rdev; + std::uniform_int_distribution<uint64_t> dist{}; + uint64_t cert_serial = dist(rdev); + gnutls_x509_crt_set_serial(cert, &cert_serial, sizeof(cert_serial)); + } + + if (ca.first && ca.second) { + gnutls_x509_crt_set_key_usage (cert, GNUTLS_KEY_DIGITAL_SIGNATURE | GNUTLS_KEY_DATA_ENCIPHERMENT); + //if (gnutls_x509_crt_sign2(cert, ca.second->cert, ca.first->x509_key, get_dig(cert), 0) != GNUTLS_E_SUCCESS) { + if (gnutls_x509_crt_privkey_sign(cert, ca.second->cert, ca.first->key, get_dig(cert), 0) != GNUTLS_E_SUCCESS) { + std::cerr << "Error when signing certificate" << std::endl; + return {}; + } + } else { + gnutls_x509_crt_set_ca_status(cert, 1); + gnutls_x509_crt_set_key_usage (cert, GNUTLS_KEY_DIGITAL_SIGNATURE | GNUTLS_KEY_KEY_CERT_SIGN); + //if (gnutls_x509_crt_sign2(cert, cert, key, get_dig(cert), 0) != GNUTLS_E_SUCCESS) { + if (gnutls_x509_crt_privkey_sign(cert, cert, shared_key->key, get_dig(cert), 0) != GNUTLS_E_SUCCESS) { + std::cerr << "Error when signing certificate" << std::endl; + return {}; + } + } + + gnutls_global_deinit(); + + return {shared_key, shared_crt}; +} + +} + +} diff --git a/src/dht.cpp b/src/dht.cpp new file mode 100644 index 00000000..09b62926 --- /dev/null +++ b/src/dht.cpp @@ -0,0 +1,2532 @@ +/* +Copyright (c) 2009-2014 Juliusz Chroboczek +Copyright (c) 2014 Savoir-Faire Linux Inc. + +Authors : Adrien Béraud <adrien.beraud@savoirfairelinux.com>, + Juliusz Chroboczek <jch@pps.univ–paris–diderot.fr> + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#include "dht.h" + +extern "C" { +#include <gnutls/gnutls.h> +} + +#include <sys/time.h> + +#ifndef _WIN32 +#include <arpa/inet.h> +#include <sys/types.h> + +#else +#include <w32api.h> +#define WINVER WindowsXP +#include <ws2tcpip.h> +#endif + +#include <algorithm> +#include <sstream> + +#include <unistd.h> +#include <fcntl.h> +#include <cstdarg> +#include <cstring> + +#ifndef MSG_CONFIRM +#define MSG_CONFIRM 0 +#endif + +#ifdef _WIN32 + +#define EAFNOSUPPORT WSAEAFNOSUPPORT +static bool +set_nonblocking(int fd, int nonblocking) +{ + unsigned long mode = !!nonblocking; + int rc = ioctlsocket(fd, FIONBIO, &mode); + if (rc != 0) + errno = WSAGetLastError(); + return rc == 0; +} + +extern const char *inet_ntop(int, const void *, char *, socklen_t); + +#else + +static bool +set_nonblocking(int fd, int nonblocking) +{ + int rc = fcntl(fd, F_GETFL, 0); + if (rc < 0) + return false; + rc = fcntl(fd, F_SETFL, nonblocking?(rc | O_NONBLOCK):(rc & ~O_NONBLOCK)); + if (rc < 0) + return false; + return true; +} + +#endif + +#define WANT4 1 +#define WANT6 2 + +static std::mt19937 rd {std::random_device{}()}; +static std::uniform_int_distribution<uint8_t> rand_byte; + +static const uint8_t v4prefix[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0, 0, 0, 0 +}; + +static std::string +to_hex(const uint8_t *buf, size_t buflen) +{ + std::stringstream s; + s << std::hex; + for (size_t i = 0; i < buflen; i++) + s << std::setfill('0') << std::setw(2) << (unsigned)buf[i]; + s << std::dec; + return s.str(); +} + +namespace dht { + +const Dht::TransPrefix Dht::TransPrefix::PING = {"pn"}; +const Dht::TransPrefix Dht::TransPrefix::FIND_NODE = {"fn"}; +const Dht::TransPrefix Dht::TransPrefix::GET_VALUES = {"gp"}; +const Dht::TransPrefix Dht::TransPrefix::ANNOUNCE_VALUES = {"ap"}; + +const uint8_t Dht::my_v[9] = "1:v4:RNG"; + +static constexpr InfoHash zeroes {}; +static constexpr InfoHash ones = {std::array<uint8_t, HASH_LEN>{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF +}}; + +const long unsigned Dht::MAX_REQUESTS_PER_SEC = 400; + +void +Dht::setLoggers(LogMethod&& error, LogMethod&& warn, LogMethod&& debug) +{ + DHT_DEBUG = std::move(debug); + DHT_WARN = std::move(warn); + DHT_ERROR = std::move(error); +} + +Dht::Status +Dht::getStatus(sa_family_t af) const +{ + unsigned good = 0, dubious = 0, cached = 0, incoming = 0; + int tot = getNodesStats(af, &good, &dubious, &cached, &incoming); + if (tot < 1) + return Status::Disconnected; + else if (good < 4) + return Status::Connecting; + return Status::Connected; +} + +bool +Dht::isRunning(sa_family_t af) const +{ + return (af == AF_INET && dht_socket >= 0) + || (af == AF_INET6 && dht_socket6 >= 0); +} + +bool +Dht::isMartian(const sockaddr *sa, socklen_t len) +{ + // Check that sa_family can be accessed safely + if (!sa || len < sizeof(sockaddr_in)) + return true; + + switch(sa->sa_family) { + case AF_INET: { + sockaddr_in *sin = (sockaddr_in*)sa; + const uint8_t *address = (const uint8_t*)&sin->sin_addr; + return sin->sin_port == 0 || + (address[0] == 0) || + (address[0] == 127) || + ((address[0] & 0xE0) == 0xE0); + } + case AF_INET6: { + if (len < sizeof(sockaddr_in6)) + return true; + sockaddr_in6 *sin6 = (sockaddr_in6*)sa; + const uint8_t *address = (const uint8_t*)&sin6->sin6_addr; + return sin6->sin6_port == 0 || + (address[0] == 0xFF) || + (address[0] == 0xFE && (address[1] & 0xC0) == 0x80) || + (memcmp(address, zeroes.data(), 15) == 0 && + (address[15] == 0 || address[15] == 1)) || + (memcmp(address, v4prefix, 12) == 0); + } + + default: + return true; + } +} + +Dht::Node* +Dht::Bucket::randomNode() +{ + if (nodes.empty()) + return nullptr; + std::uniform_int_distribution<unsigned> rand_node(0, nodes.size()-1); + unsigned nn = rand_node(rd); + for (auto& n : nodes) + if (not nn--) return &n; + return &nodes.back(); +} + +InfoHash +Dht::RoutingTable::randomId(const Dht::RoutingTable::const_iterator& it) const +{ + int bit1 = it->first.lowbit(); + int bit2 = std::next(it) != end() ? std::next(it)->first.lowbit() : -1; + int bit = std::max(bit1, bit2) + 1; + + if (bit >= 8*HASH_LEN) + return it->first; + + int b = bit/8; + InfoHash id_return; + std::copy_n(it->first.begin(), b, id_return.begin()); + id_return[b] = it->first[b] & (0xFF00 >> (bit % 8)); + id_return[b] |= rand_byte(rd) >> (bit % 8); + for (unsigned i = b + 1; i < HASH_LEN; i++) + id_return[i] = rand_byte(rd); + return id_return; +} + +InfoHash +Dht::RoutingTable::middle(const RoutingTable::const_iterator& it) const +{ + int bit1 = it->first.lowbit(); + int bit2 = std::next(it) != end() ? std::next(it)->first.lowbit() : -1; + int bit = std::max(bit1, bit2) + 1; + + if (bit >= 8*HASH_LEN) + throw std::out_of_range("End of table"); + + InfoHash id = it->first; + id[bit / 8] |= (0x80 >> (bit % 8)); + return id; +} + +Dht::RoutingTable::iterator +Dht::RoutingTable::findBucket(const InfoHash& id) +{ + if (empty()) + return end(); + auto b = begin(); + while (true) { + auto next = std::next(b); + if (next == end()) + return b; + if (InfoHash::cmp(id, next->first) < 0) + return b; + b = next; + } +} + +Dht::RoutingTable::const_iterator +Dht::RoutingTable::findBucket(const InfoHash& id) const +{ + /* Avoid code duplication for the const version */ + const_iterator it = const_cast<RoutingTable*>(this)->findBucket(id); + return it; +} + +/* Every bucket contains an unordered list of nodes. */ +Dht::Node * +Dht::findNode(const InfoHash& id, sa_family_t af) +{ + Bucket* b = findBucket(id, af); + if (!b) + return nullptr; + for (auto& n : b->nodes) + if (n.id == id) return &n; + return nullptr; +} + +const Dht::Node* +Dht::findNode(const InfoHash& id, sa_family_t af) const +{ + const Bucket* b = findBucket(id, af); + if (!b) + return nullptr; + for (const auto& n : b->nodes) + if (n.id == id) return &n; + return nullptr; +} + +/* This is our definition of a known-good node. */ +bool +Dht::Node::isGood(time_t now) const +{ + return + pinged <= 2 && + reply_time >= now - 7200 && + time >= now - 15 * 60; +} + +/* Every bucket caches the address of a likely node. Ping it. */ +int +Dht::sendCachedPing(Bucket& b) +{ + /* We set family to 0 when there's no cached node. */ + if (b.cached.ss_family == 0) + return 0; + + DHT_DEBUG("Sending ping to cached node."); + int rc = sendPing((sockaddr*)&b.cached, b.cachedlen, TransId{TransPrefix::PING}); + b.cached.ss_family = 0; + b.cachedlen = 0; + return rc; +} + +/* Called whenever we send a request to a node, increases the ping count + and, if that reaches 3, sends a ping to a new candidate. */ +void +Dht::pinged(Node& n, Bucket *b) +{ + n.pinged++; + n.pinged_time = now.tv_sec; + if (n.pinged >= 3) + sendCachedPing(b ? *b : *findBucket(n.id, n.ss.ss_family)); +} + +/* The internal blacklist is an LRU cache of nodes that have sent + incorrect messages. */ +void +Dht::blacklistNode(const InfoHash* id, const sockaddr *sa, socklen_t salen) +{ + DHT_WARN("Blacklisting broken node."); + + if (id) { + /* Make the node easy to discard. */ + Node *n = findNode(*id, sa->sa_family); + if (n) { + n->pinged = 3; + pinged(*n); + } + /* Discard it from any searches in progress. */ + for (auto& sr : searches) { + sr.nodes.erase(std::remove_if (sr.nodes.begin(), sr.nodes.end(), [id](const SearchNode& sn){ + return sn.id == *id; + })); + } + } + /* And make sure we don't hear from it again. */ + memcpy(&blacklist[next_blacklisted], sa, salen); + next_blacklisted = (next_blacklisted + 1) % BLACKLISTED_MAX; +} + +bool +Dht::isNodeBlacklisted(const sockaddr *sa, socklen_t salen) const +{ + if (salen > sizeof(sockaddr_storage)) + return true; + + if (isBlacklisted(sa, salen)) + return true; + + for (unsigned i = 0; i < BLACKLISTED_MAX; i++) { + if (memcmp(&blacklist[i], sa, salen) == 0) + return true; + } + + return false; +} + +/* Split a bucket into two equal parts. */ +bool +Dht::RoutingTable::split(const RoutingTable::iterator& b) +{ + InfoHash new_id; + try { + new_id = middle(b); + } catch (const std::out_of_range& e) { + return false; + } + + // Insert new bucket + insert(std::next(b), Bucket {b->af, new_id, b->time}); + + // Re-assign nodes + std::list<Node> nodes {}; + nodes.splice(nodes.begin(), b->nodes); + while (!nodes.empty()) { + auto n = nodes.begin(); + auto b = findBucket(n->id); + if (b == end()) + nodes.erase(n); + else + b->nodes.splice(b->nodes.begin(), nodes, n); + } + return true; +} + +/* We just learnt about a node, not necessarily a new one. Confirm is 1 if + the node sent a message, 2 if it sent us a reply. */ +Dht::Node* +Dht::newNode(const InfoHash& id, const sockaddr *sa, socklen_t salen, int confirm) +{ + if (isMartian(sa, salen) || isNodeBlacklisted(sa, salen)) + return nullptr; + + auto& list = sa->sa_family == AF_INET ? buckets : buckets6; + auto b = list.findBucket(id); + if (b == list.end() || id == myid) + return nullptr; + + bool mybucket = list.contains(b, myid); + + if (confirm == 2) + b->time = now.tv_sec; + + for (auto& n : b->nodes) { + if (n.id != id) continue; + if (confirm || n.time < now.tv_sec - 15 * 60) { + /* Known node. Update stuff. */ + memcpy((sockaddr*)&n.ss, sa, salen); + if (confirm) + n.time = now.tv_sec; + if (confirm >= 2) { + n.reply_time = now.tv_sec; + n.pinged = 0; + n.pinged_time = 0; + } + if (confirm) { + /* If this node existed in searches but was expired, give it another chance. */ + for (auto& s : searches) { + if (s.af != sa->sa_family) continue; + if (s.insertNode(id, sa, salen, now.tv_sec, true)) { + time_t tm = s.getNextStepTime(types, now.tv_sec); + if (tm != 0 && (search_time == 0 || search_time > tm)) + search_time = tm; + } + } + } + } + return &n; + } + + /* New node. */ + + /* Try adding the node to searches */ + for (auto& s : searches) { + if (s.af != sa->sa_family) continue; + if (s.insertNode(id, sa, salen, now.tv_sec)) { + time_t tm = s.getNextStepTime(types, now.tv_sec); + if (tm != 0 && (search_time == 0 || search_time > tm)) + search_time = tm; + } + } + + if (mybucket) { + if (sa->sa_family == AF_INET) + mybucket_grow_time = now.tv_sec; + else + mybucket6_grow_time = now.tv_sec; + } + + /* First, try to get rid of a known-bad node. */ + for (auto& n : b->nodes) { + if (n.pinged < 3 || n.pinged_time >= now.tv_sec - 15) + continue; + n.id = id; + memcpy((sockaddr*)&n.ss, sa, salen); + n.time = confirm ? now.tv_sec : 0; + n.reply_time = confirm >= 2 ? now.tv_sec : 0; + n.pinged_time = 0; + n.pinged = 0; + return &n; + } + + if (b->nodes.size() >= 8) { + /* Bucket full. Ping a dubious node */ + bool dubious = false; + for (auto& n : b->nodes) { + /* Pick the first dubious node that we haven't pinged in the + last 15 seconds. This gives nodes the time to reply, but + tends to concentrate on the same nodes, so that we get rid + of bad nodes fast. */ + if (!n.isGood(now.tv_sec)) { + dubious = true; + if (n.pinged_time < now.tv_sec - 15) { + DHT_DEBUG("Sending ping to dubious node."); + sendPing((sockaddr*)&n.ss, n.sslen, TransId {TransPrefix::PING}); + n.pinged++; + n.pinged_time = now.tv_sec; + break; + } + } + } + + if (mybucket && (!dubious || list.size() == 1)) { + DHT_DEBUG("Splitting."); + sendCachedPing(*b); + list.split(b); + dumpTables(); + return newNode(id, sa, salen, confirm); + } + + /* No space for this node. Cache it away for later. */ + if (confirm || b->cached.ss_family == 0) { + memcpy(&b->cached, sa, salen); + b->cachedlen = salen; + } + + return nullptr; + } + + /* Create a new node. */ + b->nodes.emplace_front(id, sa, salen, confirm ? now.tv_sec : 0, confirm >= 2 ? now.tv_sec : 0); + return &b->nodes.front(); +} + +/* Called periodically to purge known-bad nodes. Note that we're very + conservative here: broken nodes in the table don't do much harm, we'll + recover as soon as we find better ones. */ +void +Dht::expireBuckets(RoutingTable& list) +{ + for (auto& b : list) { + bool changed = false; + b.nodes.remove_if([&changed](const Node& n) { + if (n.pinged >= 4) { + changed = true; + return true; + } + return false; + }); + if (changed) + sendCachedPing(b); + } + std::uniform_int_distribution<time_t> time_dis(120, 360-1); + expire_stuff_time = now.tv_sec + time_dis(rd); +} + +/* While a search is in progress, we don't necessarily keep the nodes being + walked in the main bucket table. A search in progress is identified by + a unique transaction id, a short (and hence small enough to fit in the + transaction id of the protocol packets). */ + +Dht::Search * +Dht::findSearch(unsigned short tid, sa_family_t af) +{ + auto sr = std::find_if (searches.begin(), searches.end(), [tid,af](const Search& s){ + return s.tid == tid && s.af == af; + }); + return sr == searches.end() ? nullptr : &(*sr); +} + +/* A search contains a list of nodes, sorted by decreasing distance to the + target. We just got a new candidate, insert it at the right spot or + discard it. */ +bool +Dht::Search::insertNode(const InfoHash& nid, + const sockaddr *sa, socklen_t salen, + time_t now, bool confirmed, const Blob& token) +{ + if (sa->sa_family != af) { + //DHT_DEBUG("Attempted to insert node in the wrong family."); + return false; + } + + // Fast track for the case where the node is not relevant for this search + if (nodes.size() == SEARCH_NODES && id.xorCmp(nid, nodes.back().id) > 0) + return false; + + bool found = false; + auto n = std::find_if(nodes.begin(), nodes.end(), [=,&found](const SearchNode& sn) { + if (sn.id == nid) { + found = true; + return true; + } + return id.xorCmp(nid, sn.id) < 0; + }); + if (!found) { + if (n == nodes.end() && nodes.size() == SEARCH_NODES) + return false; + n = nodes.insert(n, SearchNode{ nid }); + if (nodes.size() > SEARCH_NODES) + nodes.pop_back(); + } + + memcpy(&n->ss, sa, salen); + n->sslen = salen; + + if (confirmed) { + /*if (n->pinged >= 3) + DHT_WARN("Resurrecting node !");*/ + n->pinged = 0; + } + if (not token.empty()) { + n->reply_time = now; + n->request_time = 0; + /* n->pinged = 0;*/ + /*if (token.size() > 64) + DHT_DEBUG("Eek! Overlong token."); + else*/ + if (token.size() <= 64) + n->token = token; + } + + return true; +} + +void +Dht::expireSearches() +{ + auto t = now.tv_sec - SEARCH_EXPIRE_TIME; + searches.remove_if([t](const Search& sr) { + return sr.announce.empty() && sr.step_time < t; + }); +} + +bool +Dht::searchSendGetValues(Search& sr, SearchNode *n) +{ + time_t t = now.tv_sec; + if (!n) { + auto ni = std::find_if(sr.nodes.begin(), sr.nodes.end(), [t](const SearchNode& sn) { + return sn.pinged < 3 && !sn.isSynced(t) && sn.request_time < t - 15; + }); + if (ni != sr.nodes.end()) + n = &*ni; + } + + if (!n || n->pinged >= 3 || n->isSynced(t) || n->request_time >= t - 15) + return false; + + { + char hbuf[NI_MAXHOST]; + char sbuf[NI_MAXSERV]; + getnameinfo((sockaddr*)&n->ss, n->sslen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), NI_NUMERICHOST | NI_NUMERICSERV); + DHT_WARN("Sending get_values to %s:%s for %s.", hbuf, sbuf, n->id.toString().c_str()); + } + sendGetValues((sockaddr*)&n->ss, n->sslen, TransId {TransPrefix::GET_VALUES, sr.tid}, sr.id, -1, n->reply_time >= t - 15); + n->pinged++; + n->request_time = t; + + /* If the node happens to be in our main routing table, mark it + as pinged. */ + Node *node = findNode(n->id, n->ss.ss_family); + if (node) pinged(*node); + return true; +} + +/* When a search is in progress, we periodically call search_step to send + further requests. */ +void +Dht::searchStep(Search& sr) +{ + if (sr.nodes.empty()) { + // No nodes... yet ? + // Nothing to do, wait for the timeout. + /* + if (sr.step_time == 0) + sr.step_time = now.tv_sec; + if (now.tv_sec - sr.step_time > SEARCH_TIMEOUT) { + DHT_WARN("Search timed out."); + if (sr.done_callback) + sr.done_callback(false); + if (sr.announce.empty()) + sr.done = true; + } + */ + return; + } + + /* Check if the first 8 live nodes have replied. */ + if (sr.isSynced(now.tv_sec)) { + DHT_DEBUG("searchStep (synced)."); + for (auto& a : sr.announce) { + if (!a.value) { + continue; + DHT_ERROR("Trying to announce a null value !"); + } + unsigned i = 0; + bool all_acked = true; + auto vid = a.value->id; + const auto& type = getType(a.value->type); + for (auto& n : sr.nodes) { + if (n.pinged >= 3) + continue; + // A proposed extension to the protocol consists in + // omitting the token when storage tables are full. While + // I don't think this makes a lot of sense -- just sending + // a positive reply is just as good --, let's deal with it. + // if (n.token.empty()) + // n.acked[vid] = now.tv_sec; + auto a_status = n.acked.find(vid); + auto at = n.getAnnounceTime(a_status, type); + if ( at <= now.tv_sec ) { + all_acked = false; + //storageStore(sr.id, a.value); + { + char hbuf[NI_MAXHOST]; + char sbuf[NI_MAXSERV]; + getnameinfo((sockaddr*)&n.ss, n.sslen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), NI_NUMERICHOST | NI_NUMERICSERV); + DHT_WARN("Sending announce_value to %s:%s (%s).", hbuf, sbuf, n.id.toString().c_str()); + } + sendAnnounceValue((sockaddr*)&n.ss, sizeof(sockaddr_storage), + TransId {TransPrefix::ANNOUNCE_VALUES, sr.tid}, sr.id, *a.value, + n.token, n.reply_time >= now.tv_sec - 15); + if (a_status == n.acked.end()) { + n.acked[vid] = { .request_time = now.tv_sec }; + } else { + a_status->second.request_time = now.tv_sec; + } + n.pending = true; + } + if (++i == 8) + break; + } + if (all_acked && a.callback) { + a.callback(true); + a.callback = nullptr; + } + } + for (auto& n : sr.nodes) { + if (n.pending) { + n.pending = false; + n.pinged++; + n.request_time = now.tv_sec; + if (auto node = findNode(n.id, n.ss.ss_family)) + pinged(*node); + } + } + DHT_DEBUG("Search done."); + if (sr.done_callback) { + sr.done_callback(true); + sr.done_callback = nullptr; + } + if (sr.announce.empty()) + sr.done = true; + } else { + DHT_DEBUG("searchStep."); + if (sr.step_time + SEARCH_GET_STEP >= now.tv_sec) + return; + if (sr.nodes.empty() && sr.announce.empty()) { + sr.done = true; + return; + } + + unsigned i = 0; + for (auto& sn : sr.nodes) { + i += searchSendGetValues(sr, &sn) ? 1 : 0; + if (i >= 3) + break; + } + } + sr.step_time = now.tv_sec; +} + + +std::list<Dht::Search>::iterator +Dht::newSearch() +{ + auto oldest = searches.begin(); + for (auto i = searches.begin(); i != searches.end(); ++i) { + if (i->done && (oldest->step_time > i->step_time)) + oldest = i; + } + + /* The oldest slot is expired. */ + if (oldest != searches.end() && oldest->announce.empty() && oldest->step_time < now.tv_sec - SEARCH_EXPIRE_TIME) + return oldest; + + /* Allocate a new slot. */ + if (searches.size() < MAX_SEARCHES) { + searches.push_front(Search {}); + return searches.begin(); + } + + /* Oh, well, never mind. Reuse the oldest slot. */ + return oldest; +} + +/* Insert the contents of a bucket into a search structure. */ +void +Dht::Search::insertBucket(const Bucket& b, time_t now) +{ + for (auto& n : b.nodes) + insertNode(n.id, (sockaddr*)&n.ss, n.sslen, now); +} + +bool +Dht::Search::isSynced(time_t now) const +{ + unsigned i = 0; + for (const auto& n : nodes) { + if (n.pinged >= 3) + continue; + if (!n.isSynced(now)) + return false; + if (++i == 8) + break; + } + return i > 0; +} + +time_t +Dht::Search::getAnnounceTime(const std::map<ValueType::Id, ValueType>& types) const +{ + if (nodes.empty()) + return 0; + time_t ret = 0; + for (const auto& a : announce) { + if (!a.value) continue; + auto type_it = types.find(a.value->type); + const ValueType& type = (type_it == types.end()) ? ValueType::USER_DATA : type_it->second; + unsigned i = 0; + for (const auto& n : nodes) { + if (n.pinged >= 3) + continue; + auto at = n.getAnnounceTime(a.value->id, type); + if (at != 0 && (ret == 0 || ret > at)) + ret = at; + if (++i == 8) + break; + } + } + return ret; +} + +time_t +Dht::Search::getNextStepTime(const std::map<ValueType::Id, ValueType>& types, time_t now) const +{ + if (done || nodes.empty()) + return 0; + if (!isSynced(now)) + return step_time + SEARCH_GET_STEP + 1; + return getAnnounceTime(types); +} + +void +Dht::bootstrapSearch(Dht::Search& sr) +{ + auto& list = (sr.af == AF_INET) ? buckets : buckets6; + if (list.empty() || (list.size() == 1 && list.front().nodes.empty())) + return; + DHT_DEBUG("bootstrapSearch."); + auto b = list.findBucket(sr.id); + if (b == list.end()) + return; + + time_t t = now.tv_sec; + sr.insertBucket(*b, t); + + if (sr.nodes.size() < SEARCH_NODES) { + if (std::next(b) != list.end()) + sr.insertBucket(*std::next(b), t); + if (b != list.begin()) + sr.insertBucket(*std::prev(b), t); + } + if (sr.nodes.size() < SEARCH_NODES) + sr.insertBucket(*list.findBucket(myid), t); +} + +/* Start a search. */ +Dht::Search* +Dht::search(const InfoHash& id, sa_family_t af, GetCallback callback, DoneCallback done_callback, Value::Filter filter) +{ + if (!isRunning(af)) { + DHT_ERROR("Unsupported protocol IPv%s bucket for %s", (af == AF_INET) ? "4" : "6", id.toString().c_str()); + if (done_callback) + done_callback(false); + return nullptr; + } + + auto sr = std::find_if (searches.begin(), searches.end(), [id,af](const Search& s) { + return s.id == id && s.af == af; + }); + + time_t t = now.tv_sec; + if (sr != searches.end()) { + sr->done = false; + // Discard any doubtful nodes. + sr->nodes.erase(std::remove_if (sr->nodes.begin(), sr->nodes.end(), [t](const SearchNode& n) { + return n.pinged >= 3 || n.reply_time < t - 7200; + }), sr->nodes.end()); + } else { + sr = newSearch(); + if (sr == searches.end()) { + errno = ENOSPC; + return nullptr; + } + sr->af = af; + sr->tid = search_id++; + sr->step_time = 0; + sr->id = id; + sr->done = false; + sr->nodes = {}; + DHT_DEBUG("New IPv%s search for %s", (af == AF_INET) ? "4" : "6", id.toString().c_str()); + } + + if (callback) + sr->callbacks.emplace_back(filter, callback); + sr->done_callback = done_callback; + + bootstrapSearch(*sr); + searchStep(*sr); + search_time = t; + return &(*sr); +} + +void +Dht::announce(const InfoHash& id, sa_family_t af, const std::shared_ptr<Value>& value, DoneCallback callback) +{ + if (!value) { + if (callback) + callback(false); + return; + } + auto sri = std::find_if (searches.begin(), searches.end(), [id,af](const Search& s) { + return s.id == id && s.af == af; + }); + Search* sr = (sri == searches.end()) ? search(id, af, nullptr, nullptr) : &(*sri); + if (!sr) { + if (callback) + callback(false); + return; + } + sr->done = false; + auto a_sr = std::find_if(sr->announce.begin(), sr->announce.end(), [&](const Announce& a){ + return a.value->id == value->id; + }); + if (a_sr == sr->announce.end()) + sr->announce.emplace_back(Announce {value, callback}); + else { + if (a_sr->value != value) { + a_sr->value = value; + for (auto& n : sr->nodes) + n.acked[value->id] = {0, 0}; + } + a_sr->callback = callback; + } +} + +void +Dht::put(const InfoHash& id, Value&& value, DoneCallback callback) +{ + if (value.id == Value::INVALID_ID) { + std::random_device rdev; + std::uniform_int_distribution<Value::Id> rand_id {}; + value.id = rand_id(rdev); + } + + auto val = std::make_shared<Value>(std::move(value)); + DHT_DEBUG("put: adding %s -> %s", id.toString().c_str(), val->toString().c_str()); + + auto ok = std::make_shared<bool>(false); + auto done = std::make_shared<bool>(false); + auto done4 = std::make_shared<bool>(false); + auto done6 = std::make_shared<bool>(false); + auto donecb = [=]() { + // Callback as soon as the value is announced on one of the available networks + if (callback && !*done && (*ok || (*done4 && *done6))) { + callback(*ok); + *done = true; + } + }; + announce(id, AF_INET, val, [=](bool ok4) { + DHT_DEBUG("search done IPv4 %d", ok4); + *done4 = true; + *ok |= ok4; + donecb(); + }); + announce(id, AF_INET6, val, [=](bool ok6) { + DHT_DEBUG("search done IPv6 %d", ok6); + *done6 = true; + *ok |= ok6; + donecb(); + }); +} + +void +Dht::get(const InfoHash& id, GetCallback getcb, DoneCallback donecb, Value::Filter filter) +{ + /* Try to answer this search locally. */ + if (getcb) { + auto locVals = getLocal(id, filter); + if (not locVals.empty()) { + DHT_DEBUG("Found local data (%d values).", locVals.size()); + getcb(locVals); + } + } + + auto done = std::make_shared<bool>(false); + auto done4 = std::make_shared<bool>(false); + auto done6 = std::make_shared<bool>(false); + auto vals = std::make_shared<std::vector<std::shared_ptr<Value>>>(); + auto done_l = [=]() { + if ((*done4 && *done6) || *done) { + *done = true; + donecb(true); + } + }; + auto cb = [=](const std::vector<std::shared_ptr<Value>>& values) { + if (*done) + return false; + std::vector<std::shared_ptr<Value>> newvals {}; + for (const auto& v : values) { + auto it = std::find_if(vals->cbegin(), vals->cend(), [&](const std::shared_ptr<Value>& sv) { + return sv == v || *sv == *v; + }); + if (it == vals->cend()) { + if (filter(*v)) + newvals.push_back(v); + } + } + if (!newvals.empty()) { + *done = !getcb(newvals); + vals->insert(vals->end(), newvals.begin(), newvals.end()); + } + done_l(); + return !*done; + }; + Dht::search(id, AF_INET, cb, [=](bool) { + *done4 = true; + done_l(); + }); + Dht::search(id, AF_INET6, cb, [=](bool) { + *done6 = true; + done_l(); + }); + +} + +std::vector<std::shared_ptr<Value>> +Dht::getLocal(const InfoHash& id, Value::Filter f) const +{ + auto s = findStorage(id); + if (!s) return {}; + std::vector<std::shared_ptr<Value>> vals; + vals.reserve(s->values.size()); + for (auto& v : s->values) + if (f(*v.data)) vals.push_back(v.data); + return vals; +} + +std::shared_ptr<Value> +Dht::getLocal(const InfoHash& id, const Value::Id& vid) const +{ + if (auto s = findStorage(id)) { + for (auto& v : s->values) + if (v.data->id == vid) return v.data; + } + return {}; +} + +std::vector<std::shared_ptr<Value>> +Dht::getPut(const InfoHash& id) +{ + std::vector<std::shared_ptr<Value>> ret; + for (const auto& search: searches) { + if (search.id != id) + continue; + ret.reserve(ret.size() + search.announce.size()); + for (const auto& a : search.announce) + ret.push_back(a.value); + } + return ret; +} + +std::shared_ptr<Value> +Dht::getPut(const InfoHash& id, const Value::Id& vid) +{ + for (const auto& search : searches) { + if (search.id != id) + continue; + for (const auto& a : search.announce) { + if (a.value->id == vid) + return a.value; + } + } + return nullptr; +} + +bool +Dht::cancelPut(const InfoHash& id, const Value::Id& vid) +{ + bool canceled {false}; + for (auto& search: searches) { + if (search.id != id) + continue; + for (auto it = search.announce.begin(); it != search.announce.end();) { + if (it->value->id == vid) { + canceled = true; + it = search.announce.erase(it); + } + else + ++it; + } + } +} + +/* A struct storage stores all the stored peer addresses for a given info + hash. */ + +Dht::Storage* +Dht::findStorage(const InfoHash& id) +{ + for (auto& st : store) + if (st.id == id) + return &st; + return nullptr; +} + +Dht::ValueStorage* +Dht::storageStore(const InfoHash& id, const std::shared_ptr<Value>& value) +{ + Storage *st = findStorage(id); + if (!st) { + if (store.size() >= MAX_HASHES) + return nullptr; + store.push_back(Storage {id, {}}); + st = &store.back(); + } + + auto it = std::find_if (st->values.begin(), st->values.end(), [&](const ValueStorage& vr) { + return vr.data == value || vr.data->id == value->id; + }); + if (it != st->values.end()) { + /* Already there, only need to refresh */ + it->time = now.tv_sec; + if (it->data != value) { + DHT_DEBUG("Updating %s -> %s", id.toString().c_str(), value->toString().c_str()); + it->data = value; + } + return &*it; + } else { + DHT_DEBUG("Storing %s -> %s", id.toString().c_str(), value->toString().c_str()); + if (st->values.size() >= MAX_VALUES) + return nullptr; + st->values.emplace_back(value, now.tv_sec); + return &st->values.back(); + } +} + +void +Dht::expireStorage() +{ + auto i = store.begin(); + while (i != store.end()) + { + i->values.erase( + std::partition(i->values.begin(), i->values.end(), + [&](const ValueStorage& v) + { + if (!v.data) return true; // should not happen + const auto& type = getType(v.data->type); + bool expired = v.time + type.expiration < now.tv_sec; + if (expired) + DHT_DEBUG("Discarding expired value %s", v.data->toString().c_str()); + return !expired; + }), + i->values.end()); + + if (i->values.size() == 0) { + DHT_DEBUG("Discarding expired value %s", i->id.toString().c_str()); + i = store.erase(i); + } + else + ++i; + } +} + +void +Dht::rotateSecrets() +{ + std::uniform_int_distribution<time_t> time_dist(15*60, 45*60); + rotate_secrets_time = now.tv_sec + time_dist(rd); + + oldsecret = secret; + { + std::random_device rdev; + std::generate_n(secret.begin(), secret.size(), std::bind(rand_byte, std::ref(rdev))); + } +} + +Blob +Dht::makeToken(const sockaddr *sa, bool old) const +{ + void *ip; + size_t iplen; + in_port_t port; + + if (sa->sa_family == AF_INET) { + sockaddr_in *sin = (sockaddr_in*)sa; + ip = &sin->sin_addr; + iplen = 4; + port = htons(sin->sin_port); + } else if (sa->sa_family == AF_INET6) { + sockaddr_in6 *sin6 = (sockaddr_in6*)sa; + ip = &sin6->sin6_addr; + iplen = 16; + port = htons(sin6->sin6_port); + } else { + return {}; + } + + const auto& c1 = old ? oldsecret : secret; + Blob data; + data.reserve(sizeof(secret)+2+iplen); + data.insert(data.end(), c1.begin(), c1.end()); + data.insert(data.end(), (uint8_t*)ip, (uint8_t*)ip+iplen); + data.insert(data.end(), (uint8_t*)&port, ((uint8_t*)&port)+2); + + size_t sz = TOKEN_SIZE; + Blob ret {}; + ret.resize(sz); + gnutls_datum_t gnudata = {data.data(), (unsigned int)data.size()}; + if (gnutls_fingerprint(GNUTLS_DIG_SHA512, &gnudata, ret.data(), &sz) != GNUTLS_E_SUCCESS) + throw DhtException("Can't compute SHA512"); + ret.resize(sz); + return ret; +} + +bool +Dht::tokenMatch(const Blob& token, const sockaddr *sa) const +{ + if (!sa || token.size() != TOKEN_SIZE) + return false; + if (token == makeToken(sa, false)) + return true; + if (token == makeToken(sa, true)) + return true; + return false; +} + +int +Dht::getNodesStats(sa_family_t af, unsigned *good_return, unsigned *dubious_return, unsigned *cached_return, unsigned *incoming_return) const +{ + unsigned good = 0, dubious = 0, cached = 0, incoming = 0; + auto& list = (af == AF_INET) ? buckets : buckets6; + + for (const auto& b : list) { + for (auto& n : b.nodes) { + if (n.isGood(now.tv_sec)) { + good++; + if (n.time > n.reply_time) + incoming++; + } else { + dubious++; + } + } + if (b.cached.ss_family > 0) + cached++; + } + if (good_return) + *good_return = good; + if (dubious_return) + *dubious_return = dubious; + if (cached_return) + *cached_return = cached; + if (incoming_return) + *incoming_return = incoming; + return good + dubious; +} + +void +Dht::dumpBucket(const Bucket& b, std::ostream& out) const +{ + out << b.first << " count " << b.nodes.size() << " age " << (int)(now.tv_sec - b.time); + if (b.cached.ss_family) + out << " (cached)"; + out << std::endl; + for (auto& n : b.nodes) { + std::string buf(INET6_ADDRSTRLEN, '\0'); + unsigned short port; + out << " Node " << n.id << " "; + if (n.ss.ss_family == AF_INET) { + sockaddr_in *sin = (sockaddr_in*)&n.ss; + inet_ntop(AF_INET, &sin->sin_addr, (char*)buf.data(), buf.size()); + port = ntohs(sin->sin_port); + } else if (n.ss.ss_family == AF_INET6) { + sockaddr_in6 *sin6 = (sockaddr_in6*)&n.ss; + inet_ntop(AF_INET6, &sin6->sin6_addr, (char*)buf.data(), buf.size()); + port = ntohs(sin6->sin6_port); + } else { + out << "unknown(" << (unsigned)n.ss.ss_family << ")"; + port = 0; + } + buf.resize(std::char_traits<char>::length(buf.c_str())); + + if (n.ss.ss_family == AF_INET6) + out << "[" << buf << "]:" << port; + else + out << buf << ":" << port; + if (n.time != n.reply_time) + out << " age " << (now.tv_sec - n.time) << ", " << (now.tv_sec - n.reply_time); + else + out << " age " << (now.tv_sec - n.time); + if (n.pinged) + out << " (" << n.pinged << ")"; + if (n.isGood(now.tv_sec)) + out << " (good)"; + out << std::endl; + } +} + +void +Dht::dumpSearch(const Search& sr, std::ostream& out) const +{ + out << std::endl << "Search (IPv" << (sr.af == AF_INET6 ? "6" : "4") << ") " << sr.id; + out << " age " << (now.tv_sec - sr.step_time) << " s"; + if (sr.done) + out << " [done]"; + bool synced = sr.isSynced(now.tv_sec); + out << (synced ? " [synced]" : " [not synced]"); + if (synced && not sr.announce.empty()) { + auto at = sr.getAnnounceTime(types); + if (at && at > now.tv_sec) + out << " [all announced]"; + else + out << " announce at " << at << ", in " << (at-now.tv_sec) << " s."; + } + out << std::endl; + + for (const auto& n : sr.announce) { + out << " Announcement: " << *n.value << std::endl; + } + + unsigned i = 0; + for (const auto& n : sr.nodes) { + out << " Node " << i++ << " id " << n.id << " bits " << InfoHash::commonBits(sr.id, n.id); + if (n.request_time) + out << " req: " << (now.tv_sec - n.request_time) << " s,"; + out << " age:" << (now.tv_sec - n.reply_time) << " s"; + if (n.pinged) + out << " pinged: " << n.pinged; + if (findNode(n.id, AF_INET)) + out << " [known]"; + if (n.reply_time) + out << " [replied]"; + out << (n.isSynced(now.tv_sec) ? " [synced]" : " [not synced]"); + out << std::endl; + } +} + +void +Dht::dumpTables() const +{ + std::stringstream out; + out << "My id " << myid << std::endl; + + out << "Buckets IPv4 :" << std::endl; + for (const auto& b : buckets) + dumpBucket(b, out); + out << "Buckets IPv6 :" << std::endl; + for (const auto& b : buckets6) + dumpBucket(b, out); + + for (const auto& sr : searches) + dumpSearch(sr, out); + out << std::endl; + + for (const auto& st : store) { + out << "Storage " << st.id << " " << st.values.size() << " values:" << std::endl; + for (const auto& v : st.values) + out << " " << *v.data << " (" << (now.tv_sec - v.time) << "s)" << std::endl; + } + + DHT_DEBUG("%s", out.str().c_str()); +} + + +Dht::Dht(int s, int s6, const InfoHash& id) + : dht_socket(s), dht_socket6(s6), myid(id) +{ + if (s < 0 && s6 < 0) + return; + + if (s >= 0) { + buckets = {Bucket {AF_INET}}; + if (!set_nonblocking(s, 1)) + throw DhtException("Can't set socket to non-blocking mode"); + } + + if (s6 >= 0) { + buckets6 = {Bucket {AF_INET6}}; + if (!set_nonblocking(s6, 1)) + throw DhtException("Can't set socket to non-blocking mode"); + } + + std::uniform_int_distribution<decltype(search_id)> searchid_dis {}; + search_id = searchid_dis(rd); + + gettimeofday(&now, nullptr); + + std::uniform_int_distribution<time_t> time_dis {0,3}; + mybucket_grow_time = now.tv_sec; + mybucket6_grow_time = now.tv_sec; + confirm_nodes_time = now.tv_sec + time_dis(rd); + rate_limit_time = now.tv_sec; + + // Fill old secret + { + std::random_device rdev; + std::generate_n(secret.begin(), secret.size(), std::bind(rand_byte, std::ref(rdev))); + } + rotateSecrets(); + + expireBuckets(buckets); + expireBuckets(buckets6); + + DHT_DEBUG("DHT initialised with node ID %s", myid.toString().c_str()); +} + + +Dht::~Dht() +{} + +/* Rate control for requests we receive. */ + +bool +Dht::rateLimit() +{ + if (rate_limit_tokens == 0) { + rate_limit_tokens = std::min(MAX_REQUESTS_PER_SEC, 100 * static_cast<long unsigned>(now.tv_sec - rate_limit_time)); + rate_limit_time = now.tv_sec; + } + + if (rate_limit_tokens == 0) + return false; + + rate_limit_tokens--; + return true; +} + +bool +Dht::neighbourhoodMaintenance(RoutingTable& list) +{ + DHT_DEBUG("neighbourhoodMaintenance"); + + auto b = list.findBucket(myid); + if (b == list.end()) + return false; + + InfoHash id = myid; + id[HASH_LEN-1] = rand_byte(rd); + + std::binomial_distribution<bool> rand_trial(1, 1./8.); + auto q = b; + if (std::next(q) != list.end() && (q->nodes.empty() || rand_trial(rd))) + q = std::next(q); + if (b != list.begin() && (q->nodes.empty() || rand_trial(rd))) { + auto r = std::prev(b); + if (!r->nodes.empty()) + q = r; + } + + /* Since our node-id is the same in both DHTs, it's probably + profitable to query both families. */ + int want = dht_socket >= 0 && dht_socket6 >= 0 ? (WANT4 | WANT6) : -1; + Node *n = q->randomNode(); + if (n) { + DHT_DEBUG("Sending find_node for%s neighborhood maintenance.", q->af == AF_INET6 ? " IPv6" : ""); + sendFindNode((sockaddr*)&n->ss, n->sslen, + TransId {TransPrefix::FIND_NODE}, id, want, + n->reply_time >= now.tv_sec - 15); + pinged(*n, &(*q)); + } + + return true; +} + +bool +Dht::bucketMaintenance(RoutingTable& list) +{ + std::binomial_distribution<bool> rand_trial(1, 1./8.); + std::binomial_distribution<bool> rand_trial_38(1, 1./38.); + + for (auto b = list.begin(); b != list.end(); ++b) { + if (b->time < now.tv_sec - 600 || b->nodes.empty()) { + /* This bucket hasn't seen any positive confirmation for a long + time. Pick a random id in this bucket's range, and send + a request to a random node. */ + InfoHash id = list.randomId(b); + auto q = b; + /* If the bucket is empty, we try to fill it from a neighbour. + We also sometimes do it gratuitiously to recover from + buckets full of broken nodes. */ + if (std::next(b) != list.end() && (q->nodes.empty() || rand_trial(rd))) + q = std::next(b); + if (b != list.begin() && (q->nodes.empty() || rand_trial(rd))) { + auto r = std::prev(b); + if (!r->nodes.empty()) + q = r; + } + + Node *n = q->randomNode(); + if (n) { + int want = -1; + + if (dht_socket >= 0 && dht_socket6 >= 0) { + auto otherbucket = findBucket(id, q->af == AF_INET ? AF_INET6 : AF_INET); + if (otherbucket && otherbucket->nodes.size() < 8) + /* The corresponding bucket in the other family + is emptyish -- querying both is useful. */ + want = WANT4 | WANT6; + else if (rand_trial_38(rd)) + /* Most of the time, this just adds overhead. + However, it might help stitch back one of + the DHTs after a network collapse, so query + both, but only very occasionally. */ + want = WANT4 | WANT6; + } + + DHT_DEBUG("Sending find_node for%s bucket maintenance.", q->af == AF_INET6 ? " IPv6" : ""); + sendFindNode((sockaddr*)&n->ss, n->sslen, + TransId {TransPrefix::FIND_NODE}, id, want, + n->reply_time >= now.tv_sec - 15); + pinged(*n, &(*q)); + /* In order to avoid sending queries back-to-back, + give up for now and reschedule us soon. */ + return true; + } + } + } + return false; +} + +void +Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, socklen_t fromlen) +{ + if (buflen == 0) + return; + + //DHT_DEBUG("processMessage %p %lu %p %lu", buf, buflen, from, fromlen); + + MessageType message; + InfoHash id, info_hash, target; + TransId tid; + Blob token {}; + uint8_t nodes[256], nodes6[1024]; + unsigned nodes_len = 256, nodes6_len = 1024; + in_port_t port; + Value::Id value_id; + uint16_t error_code; + + std::vector<std::shared_ptr<Value>> values; + + int want; + uint16_t ttid; + + if (isMartian(from, fromlen)) + return; + + if (isNodeBlacklisted(from, fromlen)) { + DHT_DEBUG("Received packet from blacklisted node."); + return; + } + + if (buf[buflen] != '\0') + throw DhtException("Unterminated message."); + + try { + message = parseMessage(buf, buflen, tid, id, info_hash, target, + port, token, value_id, + nodes, &nodes_len, nodes6, &nodes6_len, + values, &want, error_code); + if (message != MessageType::Error && id == zeroes) + throw DhtException("no or invalid InfoHash"); + } catch (const std::exception& e) { + DHT_DEBUG("Can't process message of size %lu: %s.", buflen, e.what()); + DHT_DEBUG.logPrintable(buf, buflen); + return; + } + + if (id == myid) { + DHT_DEBUG("Received message from self."); + return; + } + + if (message > MessageType::Reply) { + /* Rate limit requests. */ + if (!rateLimit()) { + DHT_DEBUG("Dropping request due to rate limiting."); + return; + } + } + + switch (message) { + case MessageType::Error: + if (tid.length != 4) return; + DHT_WARN("Received error message:"); + DHT_WARN.logPrintable(buf, buflen); + if (error_code == 401 && id != zeroes && tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid)) { + auto sr = findSearch(ttid, from->sa_family); + if (!sr) return; + DHT_WARN("Received wrong token error for known search %s", sr->id.toString().c_str()); + for (auto& n : sr->nodes) { + if (n.id != id) continue; + newNode(id, from, fromlen, 2); + n.request_time = 0; + n.reply_time = 0; + n.pinged = 0; + } + searchSendGetValues(*sr); + } + break; + case MessageType::Reply: + if (tid.length != 4) { + DHT_ERROR("Broken node truncates transaction ids (len: %d): ", tid.length); + DHT_ERROR.logPrintable(buf, buflen); + /* This is really annoying, as it means that we will + time-out all our searches that go through this node. + Kill it. */ + blacklistNode(&id, from, fromlen); + return; + } + if (tid.matches(TransPrefix::PING)) { + DHT_DEBUG("Pong!"); + newNode(id, from, fromlen, 2); + } else if (tid.matches(TransPrefix::FIND_NODE) or tid.matches(TransPrefix::GET_VALUES)) { + bool gp = false; + Search *sr = nullptr; + if (tid.matches(TransPrefix::GET_VALUES, &ttid)) { + gp = true; + sr = findSearch(ttid, from->sa_family); + } + DHT_DEBUG("Nodes found (%u+%u)%s!", nodes_len/26, nodes6_len/38, gp ? " for get_values" : ""); + if (nodes_len % 26 != 0 || nodes6_len % 38 != 0) { + DHT_WARN("Unexpected length for node info!"); + blacklistNode(&id, from, fromlen); + } else if (gp && sr == NULL) { + DHT_WARN("Unknown search!"); + newNode(id, from, fromlen, 1); + } else { + newNode(id, from, fromlen, 2); + for (unsigned i = 0; i < nodes_len / 26; i++) { + uint8_t *ni = nodes + i * 26; + const InfoHash& ni_id = *reinterpret_cast<InfoHash*>(ni); + if (ni_id == myid) + continue; + sockaddr_in sin { .sin_family = AF_INET }; + memcpy(&sin.sin_addr, ni + ni_id.size(), 4); + memcpy(&sin.sin_port, ni + ni_id.size() + 4, 2); + newNode(ni_id, (sockaddr*)&sin, sizeof(sin), 0); + if (sr && sr->af == AF_INET) { + sr->insertNode(ni_id, (sockaddr*)&sin, sizeof(sin), now.tv_sec); + } + } + for (unsigned i = 0; i < nodes6_len / 38; i++) { + uint8_t *ni = nodes6 + i * 38; + InfoHash* ni_id = reinterpret_cast<InfoHash*>(ni); + if (*ni_id == myid) + continue; + sockaddr_in6 sin6 {.sin6_family = AF_INET6}; + memcpy(&sin6.sin6_addr, ni + HASH_LEN, 16); + memcpy(&sin6.sin6_port, ni + HASH_LEN + 16, 2); + newNode(*ni_id, (sockaddr*)&sin6, sizeof(sin6), 0); + if (sr && sr->af == AF_INET6) { + sr->insertNode(*ni_id, (sockaddr*)&sin6, sizeof(sin6), now.tv_sec); + } + } + if (sr) { + /* Since we received a reply, the number of + requests in flight has decreased. Let's push + another request. */ + /*if (sr->isSynced(now.tv_sec)) { + DHT_DEBUG("Trying to accelerate search!"); + search_time = now.tv_sec; + //sr->step_time = 0; + } else {*/ + searchSendGetValues(*sr); + //} + } + } + if (sr) { + sr->insertNode(id, from, fromlen, now.tv_sec, true, token); + if (!values.empty()) { + DHT_DEBUG("Got %d values !", values.size()); + for (auto& cb : sr->callbacks) { + if (!cb.second) continue; + std::vector<std::shared_ptr<Value>> tmp; + std::copy_if(values.begin(), values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v){ + return cb.first(*v); + }); + if (cb.second and not tmp.empty()) + cb.second(tmp); + } + } + if (sr->isSynced(now.tv_sec)) { + search_time = now.tv_sec; + } + } + } else if (tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid)) { + DHT_DEBUG("Got reply to announce_values."); + Search *sr = findSearch(ttid, from->sa_family); + if (!sr || value_id == Value::INVALID_ID) { + DHT_DEBUG("Unknown search or announce!"); + newNode(id, from, fromlen, 1); + } else { + newNode(id, from, fromlen, 2); + for (auto& sn : sr->nodes) + if (sn.id == id) { + auto it = sn.acked.insert({value_id, {}}); + it.first->second.request_time = 0; + it.first->second.reply_time = now.tv_sec; + sn.request_time = 0; + //sn.reply_time = now.tv_sec; + //sn.acked[value_id] = now.tv_sec; + sn.pinged = 0; + + break; + } + /* See comment for gp above. */ + searchSendGetValues(*sr); + } + } else { + DHT_WARN("Unexpected reply: "); + DHT_WARN.logPrintable(buf, buflen); + } + break; + case MessageType::Ping: + DHT_DEBUG("Got ping (%d)!", tid.length); + newNode(id, from, fromlen, 1); + DHT_DEBUG("Sending pong."); + sendPong(from, fromlen, tid); + break; + case MessageType::FindNode: + DHT_DEBUG("Got \"find node\" request"); + newNode(id, from, fromlen, 1); + DHT_DEBUG("Sending closest nodes (%d).", want); + sendClosestNodes(from, fromlen, tid, target, want); + break; + case MessageType::GetValues: + DHT_DEBUG("Got \"get values\" request"); + newNode(id, from, fromlen, 1); + if (info_hash == zeroes) { + DHT_DEBUG("Eek! Got get_values with no info_hash."); + sendError(from, fromlen, tid, 203, "Get_values with no info_hash"); + break; + } else { + Storage* st = findStorage(info_hash); + Blob ntoken = makeToken(from, false); + if (st && st->values.size() > 0) { + DHT_DEBUG("Sending found%s values.", from->sa_family == AF_INET6 ? " IPv6" : ""); + sendClosestNodes(from, fromlen, tid, info_hash, want, ntoken, st); + } else { + DHT_DEBUG("Sending nodes for get_values."); + sendClosestNodes(from, fromlen, tid, info_hash, want, ntoken); + } + } + break; + case MessageType::AnnounceValue: + DHT_DEBUG("Got \"announce value\" request!"); + newNode(id, from, fromlen, 1); + if (info_hash == zeroes) { + DHT_DEBUG("Announce_value with no info_hash."); + sendError(from, fromlen, tid, 203, "Announce_value with no info_hash"); + break; + } + if (!tokenMatch(token, from)) { + DHT_DEBUG("Incorrect token %s for announce_values.", to_hex(token.data(), token.size()).c_str()); + sendError(from, fromlen, tid, 401, "Announce_value with wrong token"); + break; + } + for (const auto& v : values) { + if (v->id == Value::INVALID_ID) { + DHT_DEBUG("Incorrect value id "); + sendError(from, fromlen, tid, 203, "Announce_value with invalid id"); + continue; + } + auto lv = getLocal(info_hash, v->id); + std::shared_ptr<Value> vc = v; + if (lv) { + const auto& type = getType(lv->type); + if (type.editPolicy(info_hash, lv, vc, id, from, fromlen)) { + DHT_DEBUG("Editing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); + storageStore(info_hash, vc); + } else { + DHT_WARN("Rejecting edition of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); + } + } else { + // Allow the value to be edited by the storage policy + const auto& type = getType(vc->type); + if (type.storePolicy(info_hash, vc, id, from, fromlen)) { + DHT_DEBUG("Storing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); + storageStore(info_hash, vc); + } else { + DHT_WARN("Rejecting storage of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); + } + } + + /* Note that if storage_store failed, we lie to the requestor. + This is to prevent them from backtracking, and hence + polluting the DHT. */ + DHT_DEBUG("Sending announceValue confirmation."); + sendValueAnnounced(from, fromlen, tid, v->id); + } + } +} + +void +Dht::periodic(const uint8_t *buf, size_t buflen, + const sockaddr *from, socklen_t fromlen, + time_t *tosleep) +{ + gettimeofday(&now, nullptr); + + processMessage(buf, buflen, from, fromlen); + + if (now.tv_sec >= rotate_secrets_time) + rotateSecrets(); + + if (now.tv_sec >= expire_stuff_time) { + expireBuckets(buckets); + expireBuckets(buckets6); + expireStorage(); + expireSearches(); + } + + if (search_time > 0 && now.tv_sec >= search_time) { + DHT_DEBUG("search_time"); + search_time = 0; + for (auto& sr : searches) { + time_t tm = sr.getNextStepTime(types, now.tv_sec); + if (tm == 0) continue; + if (tm <= now.tv_sec) { + searchStep(sr); + tm = sr.getNextStepTime(types, now.tv_sec); + } + if (tm != 0 && (search_time == 0 || search_time > tm)) + search_time = tm; + } + if (search_time == 0) + DHT_DEBUG("next search_time : (none)"); + else if (search_time < now.tv_sec) + DHT_DEBUG("next search_time : %lu (ASAP)"); + else + DHT_DEBUG("next search_time : %lu (in %lu s)", search_time, search_time-now.tv_sec); + } + + if (now.tv_sec >= confirm_nodes_time) { + bool soon = false; + + soon |= bucketMaintenance(buckets); + soon |= bucketMaintenance(buckets6); + + if (!soon) { + if (mybucket_grow_time >= now.tv_sec - 150) + soon |= neighbourhoodMaintenance(buckets); + if (mybucket6_grow_time >= now.tv_sec - 150) + soon |= neighbourhoodMaintenance(buckets6); + } + + /* In order to maintain all buckets' age within 600 seconds, worst + case is roughly 27 seconds, assuming the table is 22 bits deep. + We want to keep a margin for neighborhood maintenance, so keep + this within 25 seconds. */ + auto time_dis = soon ? + std::uniform_int_distribution<time_t> {5 , 25} + : std::uniform_int_distribution<time_t> {60, 180}; + confirm_nodes_time = now.tv_sec + time_dis(rd); + + dumpTables(); + } + + if (confirm_nodes_time > now.tv_sec) + *tosleep = confirm_nodes_time - now.tv_sec; + else + *tosleep = 0; + + if (search_time > 0) { + if (search_time <= now.tv_sec) + *tosleep = 0; + else if (*tosleep > search_time - now.tv_sec) + *tosleep = search_time - now.tv_sec; + } +} + +std::vector<Dht::ValuesExport> +Dht::exportValues() const +{ + std::vector<ValuesExport> e {}; + e.reserve(store.size()); + for (const auto& h : store) { + ValuesExport ve; + ve.first = h.id; + serialize<uint16_t>(h.values.size(), ve.second); + for (const auto& v : h.values) { + Blob vde; + serialize<time_t>(v.time, ve.second); + v.data->pack(ve.second); + } + e.push_back(std::move(ve)); + } + return e; +} + +void +Dht::importValues(const std::vector<ValuesExport>& import) +{ + for (const auto& h : import) { + if (h.second.empty()) + continue; + auto b = h.second.begin(), + e = h.second.end(); + try { + const size_t n_vals = deserialize<uint16_t>(b, e); + for (unsigned i = 0; i < n_vals; i++) { + time_t val_time; + Value tmp_val; + try { + val_time = deserialize<time_t>(b, e); + tmp_val.unpack(b, e); + } catch (const std::exception&) { + DHT_ERROR("Error reading value at %s", h.first.toString().c_str()); + continue; + } + auto st = storageStore(h.first, std::make_shared<Value>(std::move(tmp_val))); + st->time = val_time; + } + } catch (const std::exception&) { + DHT_ERROR("Error reading values at %s", h.first.toString().c_str()); + continue; + } + } +} + + +std::vector<Dht::NodeExport> +Dht::exportNodes() +{ + std::vector<NodeExport> nodes; + const auto b4 = buckets.findBucket(myid); + if (b4 != buckets.end()) { + for (auto& n : b4->nodes) + if (n.isGood(now.tv_sec)) + nodes.push_back(n.exportNode()); + } + const auto b6 = buckets6.findBucket(myid); + if (b6 != buckets6.end()) { + for (auto& n : b6->nodes) + if (n.isGood(now.tv_sec)) + nodes.push_back(n.exportNode()); + } + for (auto b = buckets.begin(); b != buckets.end(); ++b) { + if (b == b4) continue; + for (auto& n : b->nodes) + if (n.isGood(now.tv_sec)) + nodes.push_back(n.exportNode()); + } + for (auto b = buckets6.begin(); b != buckets6.end(); ++b) { + if (b == b6) continue; + for (auto& n : b->nodes) + if (n.isGood(now.tv_sec)) + nodes.push_back(n.exportNode()); + } + return nodes; +} + +bool +Dht::insertNode(const InfoHash& id, const sockaddr *sa, socklen_t salen) +{ + if (sa->sa_family != AF_INET && sa->sa_family != AF_INET6) { + errno = EAFNOSUPPORT; + return false; + } + Node *n = newNode(id, sa, salen, 0); + return !!n; +} + +int +Dht::pingNode(const sockaddr *sa, socklen_t salen) +{ + DHT_DEBUG("Sending ping."); + return sendPing(sa, salen, TransId {TransPrefix::PING}); +} + +/* We could use a proper bencoding printer and parser, but the format of + DHT messages is fairly stylised, so this seemed simpler. */ + +#define CHECK(offset, delta, size) \ + if (offset + delta > size) throw std::length_error("Provided buffer is not large enough."); + +#define INC(offset, delta, size) \ + if (delta < 0) throw std::length_error("Provided buffer is not large enough."); \ + CHECK(offset, (size_t)delta, size); \ + offset += delta + +#define COPY(buf, offset, src, delta, size) \ + CHECK(offset, delta, size); \ + memcpy(buf + offset, src, delta); \ + offset += delta; + +#define ADD_V(buf, offset, size) \ + COPY(buf, offset, my_v, sizeof(my_v), size); + +int +Dht::send(const void *buf, size_t len, int flags, const sockaddr *sa, socklen_t salen) +{ + if (salen == 0) + return -1; + + if (isNodeBlacklisted(sa, salen)) { + DHT_DEBUG("Attempting to send to blacklisted node."); + errno = EPERM; + return -1; + } + + int s; + if (sa->sa_family == AF_INET) + s = dht_socket; + else if (sa->sa_family == AF_INET6) + s = dht_socket6; + else + s = -1; + + if (s < 0) { + errno = EAFNOSUPPORT; + return -1; + } + return sendto(s, buf, len, flags, sa, salen); +} + +int +Dht::sendPing(const sockaddr *sa, socklen_t salen, TransId tid) +{ + char buf[512]; + int i = 0, rc; + rc = snprintf(buf + i, 512 - i, "d1:ad2:id20:"); INC(i, rc, 512); + COPY(buf, i, myid.data(), myid.size(), 512); + rc = snprintf(buf + i, 512 - i, "e1:q4:ping1:t%d:", tid.length); + INC(i, rc, 512); + COPY(buf, i, tid.data(), tid.length, 512); + ADD_V(buf, i, 512); + rc = snprintf(buf + i, 512 - i, "1:y1:qe"); INC(i, rc, 512); + return send(buf, i, 0, sa, salen); +} + +int +Dht::sendPong(const sockaddr *sa, socklen_t salen, TransId tid) +{ + char buf[512]; + int i = 0, rc; + rc = snprintf(buf + i, 512 - i, "d1:rd2:id20:"); INC(i, rc, 512); + COPY(buf, i, myid.data(), myid.size(), 512); + rc = snprintf(buf + i, 512 - i, "e1:t%d:", tid.length); INC(i, rc, 512); + COPY(buf, i, tid.data(), tid.length, 512); + ADD_V(buf, i, 512); + rc = snprintf(buf + i, 512 - i, "1:y1:re"); INC(i, rc, 512); + return send(buf, i, 0, sa, salen); +} + +int +Dht::sendFindNode(const sockaddr *sa, socklen_t salen, TransId tid, + const InfoHash& target, int want, int confirm) +{ + constexpr const size_t BUF_SZ = 512; + char buf[BUF_SZ]; + int i = 0, rc; + rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id20:"); INC(i, rc, BUF_SZ); + COPY(buf, i, myid.data(), myid.size(), BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "6:target20:"); INC(i, rc, BUF_SZ); + COPY(buf, i, target.data(), target.size(), BUF_SZ); + if (want > 0) { + rc = snprintf(buf + i, BUF_SZ - i, "4:wantl%s%se", + (want & WANT4) ? "2:n4" : "", + (want & WANT6) ? "2:n6" : ""); + INC(i, rc, BUF_SZ); + } + rc = snprintf(buf + i, BUF_SZ - i, "e1:q9:find_node1:t%d:", tid.length); + INC(i, rc, BUF_SZ); + COPY(buf, i, tid.data(), tid.length, BUF_SZ); + ADD_V(buf, i, BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); + return send(buf, i, confirm ? MSG_CONFIRM : 0, sa, salen); +} + +int +Dht::sendNodesValues(const sockaddr *sa, socklen_t salen, TransId tid, + const uint8_t *nodes, unsigned nodes_len, + const uint8_t *nodes6, unsigned nodes6_len, + Storage *st, const Blob& token) +{ + constexpr const size_t BUF_SZ = 2048 * 64; + char buf[BUF_SZ]; + int i = 0, rc; + + rc = snprintf(buf + i, BUF_SZ - i, "d1:rd2:id20:"); INC(i, rc, BUF_SZ); + COPY(buf, i, myid.data(), myid.size(), BUF_SZ); + if (nodes_len > 0) { + rc = snprintf(buf + i, BUF_SZ - i, "5:nodes%u:", nodes_len); + INC(i, rc, BUF_SZ); + COPY(buf, i, nodes, nodes_len, BUF_SZ); + } + if (nodes6_len > 0) { + rc = snprintf(buf + i, BUF_SZ - i, "6:nodes6%u:", nodes6_len); + INC(i, rc, BUF_SZ); + COPY(buf, i, nodes6, nodes6_len, BUF_SZ); + } + if (not token.empty()) { + rc = snprintf(buf + i, BUF_SZ - i, "5:token%lu:", token.size()); + INC(i, rc, BUF_SZ); + COPY(buf, i, token.data(), token.size(), BUF_SZ); + } + + if (st && st->values.size() > 0) { + /* We treat the storage as a circular list, and serve a randomly + chosen slice. In order to make sure we fit, + we limit ourselves to 50 values. */ + std::uniform_int_distribution<> pos_dis(0, st->values.size()-1); + unsigned j0 = pos_dis(rd); + unsigned j = j0; + unsigned k = 0; + + rc = snprintf(buf + i, BUF_SZ - i, "6:valuesl"); INC(i, rc, BUF_SZ); + do { + Blob packed_value; + st->values[j].data->pack(packed_value); + rc = snprintf(buf + i, BUF_SZ - i, "%lu:", packed_value.size()); INC(i, rc, BUF_SZ); + COPY(buf, i, packed_value.data(), packed_value.size(), BUF_SZ); + k++; + j = (j + 1) % st->values.size(); + } while (j != j0 && k < 50); + rc = snprintf(buf + i, BUF_SZ - i, "e"); INC(i, rc, BUF_SZ); + } + + rc = snprintf(buf + i, BUF_SZ - i, "e1:t%d:", tid.length); INC(i, rc, BUF_SZ); + COPY(buf, i, tid.data(), tid.length, BUF_SZ); + ADD_V(buf, i, BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "1:y1:re"); INC(i, rc, BUF_SZ); + + return send(buf, i, 0, sa, salen); +} + +unsigned +Dht::insertClosestNode(uint8_t *nodes, unsigned numnodes, const InfoHash& id, const Node& n) +{ + unsigned i, size; + + if (n.ss.ss_family == AF_INET) + size = HASH_LEN + sizeof(in_addr) + sizeof(in_port_t); // 26 + else if (n.ss.ss_family == AF_INET6) + size = HASH_LEN + sizeof(in6_addr) + sizeof(in_port_t); // 38 + else + return numnodes; + + for (i = 0; i < numnodes; i++) { + const InfoHash* nid = reinterpret_cast<const InfoHash*>(nodes + size * i); + if (InfoHash::cmp(n.id, *nid) == 0) + return numnodes; + if (id.xorCmp(n.id, *nid) < 0) + break; + } + + if (i >= 8) + return numnodes; + + if (numnodes < 8) + numnodes++; + + if (i < numnodes - 1) + memmove(nodes + size * (i + 1), nodes + size * i, size * (numnodes - i - 1)); + + if (n.ss.ss_family == AF_INET) { + sockaddr_in *sin = (sockaddr_in*)&n.ss; + memcpy(nodes + size * i, n.id.data(), HASH_LEN); + memcpy(nodes + size * i + HASH_LEN, &sin->sin_addr, sizeof(in_addr)); + memcpy(nodes + size * i + HASH_LEN + sizeof(in_addr), &sin->sin_port, 2); + } + else if (n.ss.ss_family == AF_INET6) { + sockaddr_in6 *sin6 = (sockaddr_in6*)&n.ss; + memcpy(nodes + size * i, n.id.data(), HASH_LEN); + memcpy(nodes + size * i + HASH_LEN, &sin6->sin6_addr, sizeof(in6_addr)); + memcpy(nodes + size * i + HASH_LEN + sizeof(in6_addr), &sin6->sin6_port, 2); + } + + return numnodes; +} + +unsigned +Dht::bufferClosestNodes(uint8_t *nodes, unsigned numnodes, const InfoHash& id, const Bucket& b) const +{ + for (auto& n : b.nodes) { + if (n.isGood(now.tv_sec)) + numnodes = insertClosestNode(nodes, numnodes, id, n); + } + return numnodes; +} + +int +Dht::sendClosestNodes(const sockaddr *sa, socklen_t salen, TransId tid, + const InfoHash& id, int want, const Blob& token, Storage *st) +{ + uint8_t nodes[8 * 26]; + uint8_t nodes6[8 * 38]; + unsigned numnodes = 0, numnodes6 = 0; + + if (want < 0) + want = sa->sa_family == AF_INET ? WANT4 : WANT6; + + if ((want & WANT4)) { + auto b = buckets.findBucket(id); + if (b != buckets.end()) { + numnodes = bufferClosestNodes(nodes, numnodes, id, *b); + if (std::next(b) != buckets.end()) + numnodes = bufferClosestNodes(nodes, numnodes, id, *std::next(b)); + if (b != buckets.begin()) + numnodes = bufferClosestNodes(nodes, numnodes, id, *std::prev(b)); + } + } + + if ((want & WANT6)) { + auto b = buckets6.findBucket(id); + if (b != buckets6.end()) { + numnodes6 = bufferClosestNodes(nodes6, numnodes6, id, *b); + if (std::next(b) != buckets6.end()) + numnodes6 = bufferClosestNodes(nodes6, numnodes6, id, *std::next(b)); + if (b != buckets6.begin()) + numnodes6 = bufferClosestNodes(nodes6, numnodes6, id, *std::prev(b)); + } + } + DHT_DEBUG("sending closest nodes (%d+%d nodes.)", numnodes, numnodes6); + + try { + return sendNodesValues(sa, salen, tid, + nodes, numnodes * 26, + nodes6, numnodes6 * 38, + st, token); + } catch (const std::overflow_error& e) { + DHT_ERROR("Can't send value: buffer not large enough !"); + return -1; + } +} + +int +Dht::sendGetValues(const sockaddr *sa, socklen_t salen, + TransId tid, const InfoHash& infohash, + int want, int confirm) +{ + const size_t BUF_SZ = 2048 * 4; + char buf[BUF_SZ]; + size_t i = 0; + int rc; + + rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id20:"); INC(i, rc, BUF_SZ); + COPY(buf, i, myid.data(), myid.size(), BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "9:info_hash20:"); INC(i, rc, BUF_SZ); + COPY(buf, i, infohash.data(), infohash.size(), BUF_SZ); + if (want > 0) { + rc = snprintf(buf + i, BUF_SZ - i, "4:wantl%s%se", + (want & WANT4) ? "2:n4" : "", + (want & WANT6) ? "2:n6" : ""); + INC(i, rc, BUF_SZ); + } + rc = snprintf(buf + i, BUF_SZ - i, "e1:q9:get_peers1:t%d:", tid.length); + INC(i, rc, BUF_SZ); + COPY(buf, i, tid.data(), tid.length, BUF_SZ); + ADD_V(buf, i, BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); + return send(buf, i, confirm ? MSG_CONFIRM : 0, sa, salen); +} + +int +Dht::sendAnnounceValue(const sockaddr *sa, socklen_t salen, TransId tid, + const InfoHash& infohash, const Value& value, + const Blob& token, int confirm) +{ + const size_t BUF_SZ = 2048 * 4; + char buf[BUF_SZ]; + size_t i = 0; + int rc; + + rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id%lu:", myid.size()); INC(i, rc, BUF_SZ); + COPY(buf, i, myid.data(), myid.size(), BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "9:info_hash%lu:", infohash.size()); INC(i, rc, BUF_SZ); + COPY(buf, i, infohash.data(), infohash.size(), BUF_SZ); + + Blob packed_value; + value.pack(packed_value); + rc = snprintf(buf + i, BUF_SZ - i, "6:valuesl%lu:", packed_value.size()); INC(i, rc, BUF_SZ); + COPY(buf, i, packed_value.data(), packed_value.size(), BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "e5:token%lu:", token.size()); INC(i, rc, BUF_SZ); + COPY(buf, i, token.data(), token.size(), BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "e1:q13:announce_peer1:t%u:", tid.length); INC(i, rc, BUF_SZ); + COPY(buf, i, tid.data(), tid.length, BUF_SZ); + ADD_V(buf, i, BUF_SZ); + rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); + + return send(buf, i, confirm ? 0 : MSG_CONFIRM, sa, salen); +} + +int +Dht::sendValueAnnounced(const sockaddr *sa, socklen_t salen, TransId tid, Value::Id vid) +{ + char buf[512]; + int i = 0, rc; + + rc = snprintf(buf + i, 512 - i, "d1:rd2:id20:"); INC(i, rc, 512); + COPY(buf, i, myid.data(), myid.size(), 512); + rc = snprintf(buf + i, 512 - i, "3:vid%lu:", sizeof(Value::Id)); INC(i, rc, 512); + COPY(buf, i, &vid, sizeof(Value::Id), 512); + rc = snprintf(buf + i, 512 - i, "e1:t%u:", tid.length); INC(i, rc, 512); + COPY(buf, i, tid.data(), tid.length, 512); + ADD_V(buf, i, 512); + rc = snprintf(buf + i, 512 - i, "1:y1:re"); INC(i, rc, 512); + return send(buf, i, 0, sa, salen); +} + +int +Dht::sendError(const sockaddr *sa, socklen_t salen, TransId tid, int code, const char *message) +{ + char buf[512]; + int i = 0, rc; + + rc = snprintf(buf + i, 512 - i, "d1:eli%de%d:", code, (int)strlen(message)); + INC(i, rc, 512); + COPY(buf, i, message, (int)strlen(message), 512); + rc = snprintf(buf + i, 512 - i, "e1:t%d:", tid.length); INC(i, rc, 512); + COPY(buf, i, tid.data(), tid.length, 512); + ADD_V(buf, i, 512); + rc = snprintf(buf + i, 512 - i, "1:y1:ee"); INC(i, rc, 512); + return send(buf, i, 0, sa, salen); +} + +#undef CHECK +#undef INC +#undef COPY +#undef ADD_V + +Dht::MessageType +Dht::parseMessage(const uint8_t *buf, size_t buflen, + TransId& tid_return, + InfoHash& id_return, InfoHash& info_hash_return, + InfoHash& target_return, in_port_t& port_return, + Blob& token, Value::Id& value_id, + uint8_t *nodes_return, unsigned *nodes_len, + uint8_t *nodes6_return, unsigned *nodes6_len, + std::vector<std::shared_ptr<Value>>& values_return, + int *want_return, uint16_t& error_code) +{ + const uint8_t *p; + + /* This code will happily crash if the buffer is not NUL-terminated. */ + if (buf[buflen] != '\0') + throw DhtException("Eek! parse_message with unterminated buffer."); + +#define CHECK(ptr, len) if (((uint8_t*)ptr) + (len) > (buf) + (buflen)) throw std::out_of_range("Truncated message."); + + p = (uint8_t*)dht_memmem(buf, buflen, "1:t", 3); + if (p) { + char *q; + size_t l = strtoul((char*)p + 3, &q, 10); + if (q && *q == ':') { + CHECK(q + 1, l); + tid_return = {q+1, l}; + } else + tid_return.length = 0; + } + + p = (uint8_t*)dht_memmem(buf, buflen, "2:id20:", 7); + if (p) { + CHECK(p + 7, HASH_LEN); + memcpy(id_return.data(), p + 7, HASH_LEN); + } else { + id_return = {}; + } + + p = (uint8_t*)dht_memmem(buf, buflen, "9:info_hash20:", 14); + if (p) { + CHECK(p + 14, HASH_LEN); + memcpy(info_hash_return.data(), p + 14, HASH_LEN); + } else { + info_hash_return = {}; + } + + p = (uint8_t*)dht_memmem(buf, buflen, "porti", 5); + if (p) { + char *q; + unsigned long l = strtoul((char*)p + 5, &q, 10); + if (q && *q == 'e' && l < 0x10000) + port_return = l; + else + port_return = 0; + } else + port_return = 0; + + p = (uint8_t*)dht_memmem(buf, buflen, "6:target20:", 11); + if (p) { + CHECK(p + 11, HASH_LEN); + memcpy(target_return.data(), p + 11, HASH_LEN); + } else { + target_return = {}; + } + + p = (uint8_t*)dht_memmem(buf, buflen, "5:token", 7); + if (p) { + char *q; + size_t l = strtoul((char*)p + 7, &q, 10); + if (q && *q == ':' && l > 0 && l <= 128) { + CHECK(q + 1, l); + token.clear(); + token.insert(token.begin(), q + 1, q + 1 + l); + } + } + + if (nodes_len) { + p = (uint8_t*)dht_memmem(buf, buflen, "5:nodes", 7); + if (p) { + char *q; + size_t l = strtoul((char*)p + 7, &q, 10); + if (q && *q == ':' && l > 0 && l < *nodes_len) { + CHECK(q + 1, l); + memcpy(nodes_return, q + 1, l); + *nodes_len = l; + } else + *nodes_len = 0; + } else + *nodes_len = 0; + } + + if (nodes6_len) { + p = (uint8_t*)dht_memmem(buf, buflen, "6:nodes6", 8); + if (p) { + char *q; + size_t l = strtoul((char*)p + 8, &q, 10); + if (q && *q == ':' && l > 0 && l < *nodes6_len) { + CHECK(q + 1, l); + memcpy(nodes6_return, q + 1, l); + *nodes6_len = l; + } else + *nodes6_len = 0; + } else + *nodes6_len = 0; + } + + p = (uint8_t*)dht_memmem(buf, buflen, "6:valuesl", 9); + if (p) { + unsigned i = p - buf + 9; + while (true) { + char *q; + size_t l = strtoul((char*)buf + i, &q, 10); + if (q && *q == ':' && l > 0) { + CHECK(q + 1, l); + i = q + 1 + l - (char*)buf; + Value v; + v.unpackBlob(Blob {q + 1, q + 1 + l}); + values_return.push_back(std::make_shared<Value>(std::move(v))); + } else + break; + } + if (i >= buflen || buf[i] != 'e') + DHT_DEBUG("eek... unexpected end for values."); + } + + p = (uint8_t*)dht_memmem(buf, buflen, "3:vid8:", 7); + if (p) { + CHECK(p + 7, sizeof(value_id)); + memcpy(&value_id, p + 7, sizeof(value_id)); + } else { + value_id = Value::INVALID_ID; + } + + if (want_return) { + p = (uint8_t*)dht_memmem(buf, buflen, "4:wantl", 7); + if (p) { + unsigned i = p - buf + 7; + *want_return = 0; + while (buf[i] > '0' && buf[i] <= '9' && buf[i + 1] == ':' && + i + 2 + buf[i] - '0' < buflen) { + CHECK(buf + i + 2, buf[i] - '0'); + if (buf[i] == '2' && memcmp(buf + i + 2, "n4", 2) == 0) + *want_return |= WANT4; + else if (buf[i] == '2' && memcmp(buf + i + 2, "n6", 2) == 0) + *want_return |= WANT6; + else + DHT_DEBUG("eek... unexpected want flag (%c)", buf[i]); + i += 2 + buf[i] - '0'; + } + if (i >= buflen || buf[i] != 'e') + DHT_DEBUG("eek... unexpected end for want."); + } else { + *want_return = -1; + } + } + + p = (uint8_t*)dht_memmem(buf, buflen, "1:eli", 5); + if (p) { + CHECK(p + 5, sizeof(error_code)); + memcpy(&error_code, p + 5, sizeof(error_code)); + } else { + error_code = 0; + } + +#undef CHECK + + if (dht_memmem(buf, buflen, "1:y1:r", 6)) + return MessageType::Reply; + if (dht_memmem(buf, buflen, "1:y1:e", 6)) + return MessageType::Error; + if (!dht_memmem(buf, buflen, "1:y1:q", 6)) + throw DhtException("Parse error"); + if (dht_memmem(buf, buflen, "1:q4:ping", 9)) + return MessageType::Ping; + if (dht_memmem(buf, buflen, "1:q9:find_node", 14)) + return MessageType::FindNode; + if (dht_memmem(buf, buflen, "1:q9:get_peers", 14)) + return MessageType::GetValues; + if (dht_memmem(buf, buflen, "1:q13:announce_peer", 19)) + return MessageType::AnnounceValue; + throw DhtException("Can't read message type."); +} + +#ifdef HAVE_MEMMEM + +void * +Dht::dht_memmem(const void *haystack, size_t haystacklen, const void *needle, size_t needlelen) +{ + return memmem(haystack, haystacklen, needle, needlelen); +} + +#else + +void * +Dht::dht_memmem(const void *haystack, size_t haystacklen, const void *needle, size_t needlelen) +{ + const char *h = (const char *)haystack; + const char *n = (const char *)needle; + size_t i; + + /* size_t is unsigned */ + if (needlelen > haystacklen) + return NULL; + + for (i = 0; i <= haystacklen - needlelen; i++) { + if (memcmp(h + i, n, needlelen) == 0) + return (void*)(h + i); + } + return NULL; +} + +#endif + +} diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp new file mode 100644 index 00000000..d4293c99 --- /dev/null +++ b/src/dhtrunner.cpp @@ -0,0 +1,329 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#include "dhtrunner.h" + +namespace dht { + +void +DhtRunner::run(in_port_t port, const crypto::Identity identity, bool threaded, StatusCallback cb) +{ + if (running) + return; + if (rcv_thread.joinable()) + rcv_thread.join(); + statusCb = cb; + running = true; + doRun(port, identity); + if (!threaded) + return; + dht_thread = std::thread([this]() { + while (running) { + std::unique_lock<std::mutex> lk(dht_mtx); + loop_(); + cv.wait_for(lk, std::chrono::seconds( tosleep ), [this]() { + if (!running) return true; + { + std::unique_lock<std::mutex> lck(sock_mtx); + if (!rcv.empty()) return true; + } + { + std::unique_lock<std::mutex> lck(storage_mtx); + if (!dht_gets.empty() || !dht_puts.empty() || !bootstrap_nodes.empty()) + return true; + } + return false; + }); + } + }); +} + +void +DhtRunner::join() +{ + running = false; + cv.notify_all(); + if (dht_thread.joinable()) + dht_thread.join(); + if (rcv_thread.joinable()) + rcv_thread.join(); + { + std::unique_lock<std::mutex> lck(dht_mtx); + dht.reset(); + status4 = Dht::Status::Disconnected; + status6 = Dht::Status::Disconnected; + } +} + +void +DhtRunner::loop_() +{ + if (!dht) return; + time_t tosl; + { + std::unique_lock<std::mutex> lck(sock_mtx); + if (!dht) return; + if (rcv.size()) { + for (const auto& pck : rcv) { + auto& buf = pck.first; + auto& from = pck.second; + dht->periodic(buf.data(), buf.size()-1, (sockaddr*)&from, from.ss_family == AF_INET ? sizeof(sockaddr_in) : sizeof(sockaddr_in6), &tosl); + } + rcv.clear(); + } else { + dht->periodic(nullptr, 0, nullptr, 0, &tosl); + } + } + tosleep = tosl; + { + std::unique_lock<std::mutex> lck(storage_mtx); + + for (auto& get : dht_gets) { + std::cout << "Processing get (" << std::get<0>(get) << ")" << std::endl; + dht->get(std::get<0>(get), std::get<1>(get), std::get<2>(get), std::move(std::get<3>(get))); + } + dht_gets.clear(); + + for (auto& put : dht_eputs) { + auto& id = std::get<0>(put); + auto& val = std::get<2>(put); + std::cout << "Processing encrypted put at " << id << " for " << std::get<1>(put) << " -> " << val << std::endl; + dht->putEncrypted(id, std::get<1>(put), std::move(val), std::get<3>(put)); + } + dht_eputs.clear(); + + for (auto& put : dht_puts) { + auto& id = std::get<0>(put); + auto& val = std::get<1>(put); + std::cout << "Processing put " << id << " -> " << val << std::endl; + dht->put(id, std::move(val), std::get<2>(put)); + } + dht_puts.clear(); + + for (auto& put : dht_sputs) { + auto& id = std::get<0>(put); + auto& val = std::get<1>(put); + std::cout << "Processing signed put " << id << " -> " << val << std::endl; + dht->putSigned(id, std::move(val), std::get<2>(put)); + } + dht_sputs.clear(); + + for (auto& node : bootstrap_nodes) + dht->insertNode(node); + bootstrap_nodes.clear(); + + for (auto& node : bootstrap_ips) { + dht->pingNode((sockaddr*)&node, sizeof(node)); + //std::this_thread::sleep_for( std::chrono::microseconds(/*rand_delay()*/ 10) ); + } + bootstrap_ips.clear(); + } + + if (statusCb) { + Dht::Status nstatus4 = dht->getStatus(AF_INET); + Dht::Status nstatus6 = dht->getStatus(AF_INET6); + if (nstatus4 != status4 || nstatus6 != status6) { + status4 = nstatus4; + status6 = nstatus6; + statusCb(status4, status6); + } + } +} + +void +DhtRunner::doRun(in_port_t port, const crypto::Identity identity) +{ + dht.reset(); + + int s = socket(PF_INET, SOCK_DGRAM, 0); + int s6 = socket(PF_INET6, SOCK_DGRAM, 0); + if(s >= 0) { + sockaddr_in sin { + .sin_family = AF_INET, + .sin_port = htons(port) + }; + int rc = bind(s, (sockaddr*)&sin, sizeof(sin)); + if(rc < 0) + throw DhtException("Can't bind IPv4 socket"); + } + if(s6 >= 0) { + int val = 1; + int rc = setsockopt(s6, IPPROTO_IPV6, IPV6_V6ONLY, (char *)&val, sizeof(val)); + if(rc < 0) { + throw DhtException("setsockopt(IPV6_V6ONLY)"); + } + + /* BEP-32 mandates that we should bind this socket to one of our + global IPv6 addresses. In this simple example, this only + happens if the user used the -b flag. */ + sockaddr_in6 sin6 { + .sin6_family = AF_INET6, + .sin6_port = htons(port) + }; + rc = bind(s6, (sockaddr*)&sin6, sizeof(sin6)); + if(rc < 0) + throw DhtException("Can't bind IPv6 socket"); + } + + dht = std::unique_ptr<SecureDht>(new SecureDht {s, s6, identity}); + + rcv_thread = std::thread([this,s,s6]() { + std::mt19937 engine(std::random_device{}()); + auto rand_delay = std::bind(std::uniform_int_distribution<uint32_t>(0, 1000000), engine); + try { + while (true) { + uint8_t buf[4096 * 64]; + sockaddr_storage from; + socklen_t fromlen; + + struct timeval tv; + fd_set readfds; + tv.tv_sec = tosleep / 5; + tv.tv_usec = rand_delay(); + //std::cout << "Dht::rcv_thread loop " << tv.tv_sec << "." << tv.tv_usec << std::endl; + + FD_ZERO(&readfds); + if(s >= 0) + FD_SET(s, &readfds); + if(s6 >= 0) + FD_SET(s6, &readfds); + int rc = select(s > s6 ? s + 1 : s6 + 1, &readfds, NULL, NULL, &tv); + if(rc < 0) { + if(errno != EINTR) { + perror("select"); + std::this_thread::sleep_for( std::chrono::seconds(1) ); + } + } + + if(!running) + break; + + if(rc > 0) { + fromlen = sizeof(from); + if(s >= 0 && FD_ISSET(s, &readfds)) + rc = recvfrom(s, buf, sizeof(buf) - 1, 0, (struct sockaddr*)&from, &fromlen); + else if(s6 >= 0 && FD_ISSET(s6, &readfds)) + rc = recvfrom(s6, buf, sizeof(buf) - 1, 0, (struct sockaddr*)&from, &fromlen); + else + break; + if (rc > 0) { + buf[rc] = 0; + { + std::unique_lock<std::mutex> lck(sock_mtx); + rcv.emplace_back(Blob {buf, buf+rc+1}, from); + } + cv.notify_all(); + } + } + } + } catch (const std::exception& e) { + std::cerr << "Error int DHT networking thread: " << e.what() << std::endl; + } + if (s >= 0) + close(s); + if (s6 >= 0) + close(s6); + }); +} + +void +DhtRunner::get(InfoHash hash, Dht::GetCallback vcb, Dht::DoneCallback dcb, Value::Filter f) +{ + std::unique_lock<std::mutex> lck(storage_mtx); + dht_gets.emplace_back(hash, vcb, dcb, f); + cv.notify_all(); +} + +void +DhtRunner::get(const std::string& key, Dht::GetCallback vcb, Dht::DoneCallback dcb, Value::Filter f) +{ + get(InfoHash::get(key), vcb, dcb, f); +} + +void +DhtRunner::put(InfoHash hash, Value&& value, Dht::DoneCallback cb) +{ + std::unique_lock<std::mutex> lck(storage_mtx); + dht_puts.emplace_back(hash, std::move(value), cb); + cv.notify_all(); +} + +void +DhtRunner::put(const std::string& key, Value&& value, Dht::DoneCallback cb) +{ + put(InfoHash::get(key), std::forward<Value>(value), cb); +} + +void +DhtRunner::putSigned(InfoHash hash, Value&& value, Dht::DoneCallback cb) +{ + std::unique_lock<std::mutex> lck(storage_mtx); + dht_sputs.emplace_back(hash, std::move(value), cb); + cv.notify_all(); +} + +void +DhtRunner::putSigned(const std::string& key, Value&& value, Dht::DoneCallback cb) +{ + putSigned(InfoHash::get(key), std::forward<Value>(value), cb); +} + +void +DhtRunner::putEncrypted(InfoHash hash, InfoHash to, Value&& value, Dht::DoneCallback cb) +{ + std::unique_lock<std::mutex> lck(storage_mtx); + dht_eputs.emplace_back(hash, to, std::move(value), cb); + cv.notify_all(); +} + +void +DhtRunner::putEncrypted(const std::string& key, InfoHash to, Value&& value, Dht::DoneCallback cb) +{ + putEncrypted(key, to, std::forward<Value>(value), cb); +} + +void +DhtRunner::bootstrap(const std::vector<sockaddr_storage>& nodes) +{ + std::unique_lock<std::mutex> lck(storage_mtx); + bootstrap_ips.insert(bootstrap_ips.end(), nodes.begin(), nodes.end()); + cv.notify_all(); +} + +void +DhtRunner::bootstrap(const std::vector<Dht::NodeExport>& nodes) +{ + std::unique_lock<std::mutex> lck(storage_mtx); + bootstrap_nodes.insert(bootstrap_nodes.end(), nodes.begin(), nodes.end()); + cv.notify_all(); +} + + +} diff --git a/src/infohash.cpp b/src/infohash.cpp new file mode 100644 index 00000000..aabfdc76 --- /dev/null +++ b/src/infohash.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#include "infohash.h" + +extern "C" { +#include <gnutls/gnutls.h> +} + +#include <sstream> + +namespace dht { + +InfoHash::InfoHash(const std::string& hex) { + unsigned in = std::min((size_t)HASH_LEN, hex.size()/2); + for (unsigned i = 0; i < in; i++) { + sscanf(hex.data() + 2*i, "%02x", (unsigned*)(&((*this)[i]))); + } +} + +InfoHash +InfoHash::get(const uint8_t* data, size_t data_len) +{ + InfoHash h; + size_t s = h.size(); + const gnutls_datum_t gnudata = {(uint8_t*)data, (unsigned)data_len}; + const gnutls_digest_algorithm_t algo = (HASH_LEN == 64) ? GNUTLS_DIG_SHA512 : ( + (HASH_LEN == 32) ? GNUTLS_DIG_SHA256 : ( + (HASH_LEN == 20) ? GNUTLS_DIG_SHA1 : + GNUTLS_DIG_NULL )); + static_assert(algo != GNUTLS_DIG_NULL, "Can't find hash function to use."); + int rc = gnutls_fingerprint(algo, &gnudata, h.data(), &s); + if (rc == 0 && s == HASH_LEN) + return h; + throw std::string("Error while hashing"); +} + +std::string +InfoHash::toString() const +{ + std::stringstream ss; + ss << *this; + return ss.str(); +} + +std::ostream& operator<< (std::ostream& s, const InfoHash& h) +{ + s << std::hex; + for (unsigned i=0; i<HASH_LEN; i++) + s << std::setfill('0') << std::setw(2) << (unsigned)h[i]; + s << std::dec; + return s; +} + +} diff --git a/src/securedht.cpp b/src/securedht.cpp new file mode 100644 index 00000000..d6e5745a --- /dev/null +++ b/src/securedht.cpp @@ -0,0 +1,292 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#include "securedht.h" + +extern "C" { +#include <gnutls/gnutls.h> +#include <gnutls/abstract.h> +#include <gnutls/x509.h> +} + +#include <random> + +namespace dht { + +SecureDht::SecureDht(int s, int s6, crypto::Identity id) +: Dht(s, s6, InfoHash::get("node:"+id.second->getPublicKey().getId().toString())), key_(id.first), certificate_(id.second) +{ + if (s < 0 && s6 < 0) + return; + + int rc = gnutls_global_init(); + if (rc != GNUTLS_E_SUCCESS) + throw DhtException(std::string("Error initializing GnuTLS: ")+gnutls_strerror(rc)); + + auto certId = certificate_->getPublicKey().getId(); + if (certId != key_->getPublicKey().getId()) + throw DhtException("SecureDht: provided certificate doesn't match private key."); + + registerType(ValueType::USER_DATA); + registerInsecureType(ServiceAnnouncement::TYPE); + registerInsecureType(CERTIFICATE_TYPE); + + Dht::put(certId, Value { + CERTIFICATE_TYPE, + *certificate_ + }, [this](bool ok) { + if (ok) + DHT_DEBUG("SecureDht: public key announced successfully"); + else + DHT_ERROR("SecureDht: error while announcing public key!"); + }); +} + +SecureDht::~SecureDht() +{ + gnutls_global_deinit(); +} + +ValueType +SecureDht::secureType(ValueType&& type) +{ + type.storePolicy = [this,type](InfoHash id, std::shared_ptr<Value>& v, InfoHash nid, const sockaddr* a, socklen_t al) { + if (v->isSigned()) { + if (!v->owner.checkSignature(v->getToSign(), v->signature)) { + DHT_WARN("Signature verification failed"); + return false; + } + else + DHT_WARN("Signature verification succeded"); + } + return type.storePolicy(id, v, nid, a, al); + }; + type.editPolicy = [this,type](InfoHash id, const std::shared_ptr<Value>& o, std::shared_ptr<Value>& n, InfoHash nid, const sockaddr* a, socklen_t al) { + if (!o->isSigned()) + return type.editPolicy(id, o, n, nid, a, al); + if (o->owner != n->owner) { + DHT_WARN("Edition forbidden: owner changed."); + return false; + } + if (!o->owner.checkSignature(n->getToSign(), n->signature)) { + DHT_WARN("Edition forbidden: signature verification failed."); + return false; + } + DHT_WARN("Edition old seq: %d, new seq: %d.", o->seq, n->seq); + if (o->seq == n->seq) { + // If the data is exactly the same, + // it can be reannounced, possibly by someone else. + if (o->getToSign() != n->getToSign()) + return false; + } + else if (n->seq < o->seq) + return false; + return true; + }; + return type; +} + +const std::shared_ptr<crypto::Certificate> +SecureDht::getCertificate(const InfoHash& node) const +{ + if (node == getId()) + return certificate_; + auto it = nodesCertificates_.find(node); + if (it == nodesCertificates_.end()) + return nullptr; + else + return it->second; +} + +const std::shared_ptr<crypto::Certificate> +SecureDht::registerCertificate(const InfoHash& node, const Blob& data) +{ + std::shared_ptr<crypto::Certificate> crt; + try { + crt = std::make_shared<crypto::Certificate>(data); + } catch (const std::exception& e) { + return nullptr; + } + InfoHash h = crt->getPublicKey().getId(); + if (node == h) { + DHT_DEBUG("Registering public key for %s", h.toString().c_str()); + nodesCertificates_[h] = crt; + } else { + DHT_DEBUG("Certificate %s for node %s does not match node id !", h.toString().c_str(), node.toString().c_str()); + return nullptr; + } + auto it = nodesCertificates_.find(h); + if (it == nodesCertificates_.end()) { + return nullptr; + } + return it->second; +} + +void +SecureDht::findCertificate(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::Certificate>)> cb) +{ + std::shared_ptr<crypto::Certificate> b = getCertificate(node); + if (b && *b) { + std::cout << "Using public key from cache for " << node << std::endl; + cb(b); + return; + } + auto found = std::make_shared<bool>(false); + Dht::get(node, [cb,node,found,this](const std::vector<std::shared_ptr<Value>>& vals) { + if (*found) + return false; + for (const auto& v : vals) { + if (auto cert = registerCertificate(node, v->data)) { + *found = true; + std::cout << "Found public key for " << node << std::endl; + cb(cert); + return false; + } + } + return true; + }, [cb,found](bool) { + if (!*found) + cb(nullptr); + }, Value::TypeFilter(CERTIFICATE_TYPE)); +} + + +void +SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter filter) +{ + Dht::get(id, + [=](const std::vector<std::shared_ptr<Value>>& values) { + std::vector<std::shared_ptr<Value>> tmpvals {}; + for (const auto& v : values) { + if (v->isEncrypted()) { + try { + Value decrypted_val = std::move(decrypt(*v)); + if (decrypted_val.recipient == getId()) { + auto dv = std::make_shared<Value>(std::move(decrypted_val)); + if (dv->owner.checkSignature(dv->getToSign(), dv->signature)) + tmpvals.push_back(v); + else + DHT_WARN("Signature verification failed for %s", id.toString().c_str()); + } + } catch (const std::exception& e) { + DHT_WARN("Could not decrypt value %s at infohash %s", v->toString().c_str(), id.toString().c_str()); + continue; + } + } else if (v->isSigned()) { + if (v->owner.checkSignature(v->getToSign(), v->signature)) + tmpvals.push_back(v); + else + DHT_WARN("Signature verification failed for %s", id.toString().c_str()); + } else { + tmpvals.push_back(v); + } + } + if (not tmpvals.empty()) + cb(tmpvals); + return true; + }, + donecb, + filter); +} + +void +SecureDht::putSigned(const InfoHash& hash, Value&& val, DoneCallback callback) +{ + if (val.id == Value::INVALID_ID) { + auto id = getId(); + static_assert(sizeof(Value::Id) <= sizeof(InfoHash), "Value::Id can't be larger than InfoHash"); + val.id = *reinterpret_cast<Value::Id*>(id.data()); + } + // TODO search the DHT instead of using the local value + auto p = getPut(hash, val.id); + if (p) { + DHT_WARN("Found previous value being announced."); + val.seq = p->seq + 1; + } + sign(val); + put(hash, std::move(val), callback); +} + +void +SecureDht::putEncrypted(const InfoHash& hash, const InfoHash& to, const std::shared_ptr<Value>& val, DoneCallback callback) +{ + findCertificate(to, [=](const std::shared_ptr<crypto::Certificate> crt) { + if(!crt || !*crt) { + if (callback) + callback(false); + return; + } + DHT_WARN("Encrypting data for PK: %s", crt->getPublicKey().getId().toString().c_str()); + try { + put(hash, encrypt(*val, crt->getPublicKey()), callback); + } catch (const std::exception& e) { + DHT_WARN("Error putting encrypted data: %s", e.what()); + if (callback) + callback(false); + } + }); +} + +void +SecureDht::sign(Value& v) const +{ + if (v.flags.isEncrypted()) + throw DhtException("Can't sign encrypted data."); + v.owner = key_->getPublicKey(); + v.flags = Value::ValueFlags(true, false); + v.signature = key_->sign(v.getToSign()); +} + +Value +SecureDht::encrypt(Value& v, const crypto::PublicKey& to) const +{ + if (v.flags.isEncrypted()) { + throw DhtException("Data is already encrypted."); + } + v.setRecipient(to.getId()); + sign(v); + Value nv {v.id}; + nv.setCypher(to.encrypt(v.getToEncrypt())); + return nv; +} + +Value +SecureDht::decrypt(const Value& v) +{ + if (not v.flags.isEncrypted()) + throw DhtException("Data is not encrypted."); + auto decrypted = key_->decrypt(v.cypher); + Value ret {v.id}; + auto pb = decrypted.cbegin(), pe = decrypted.cend(); + ret.unpackBody(pb, pe); + return ret; +} + +} diff --git a/src/value.cpp b/src/value.cpp new file mode 100644 index 00000000..0263d72c --- /dev/null +++ b/src/value.cpp @@ -0,0 +1,221 @@ +/* + * Copyright (C) 2014 Savoir-Faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Additional permission under GNU GPL version 3 section 7: + * + * If you modify this program, or any covered work, by linking or + * combining it with the OpenSSL project's OpenSSL library (or a + * modified version of that library), containing parts covered by the + * terms of the OpenSSL or SSLeay licenses, Savoir-Faire Linux Inc. + * grants you additional permission to convey the resulting work. + * Corresponding Source for a non-source form of such a combination + * shall include the source code for the parts of OpenSSL used as well + * as that of the covered work. + */ + +#include "value.h" +#include "securedht.h" // print certificate ID + +namespace dht { + +std::ostream& operator<< (std::ostream& s, const Value& v) +{ + s << "Value[id:" << std::hex << v.id << std::dec << " "; + if (v.flags.isSigned()) + s << "signed (v" << v.seq << ") "; + if (v.flags.isEncrypted()) + s << "encrypted "; + else { + if (v.type == ServiceAnnouncement::TYPE.id) { + s << ServiceAnnouncement(v.data); + } else if (v.type == CERTIFICATE_TYPE.id) { + s << "Certificate"; + try { + InfoHash h = crypto::Certificate(v.data).getPublicKey().getId(); + s << " with ID " << h; + } catch (const std::exception& e) { + s << " (invalid)"; + } + } else { + s << "Data (type: " << v.type << " ): "; + s << std::hex; + for (size_t i=0; i<v.data.size(); i++) + s << std::setfill('0') << std::setw(2) << (unsigned)v.data[i]; + s << std::dec; + } + } + s << "]"; + return s; +} + +const ValueType ValueType::USER_DATA = {0, "User Data"}; + +bool +ServiceAnnouncement::storePolicy(InfoHash, std::shared_ptr<Value>& v, InfoHash, const sockaddr* from, socklen_t fromlen) +{ + ServiceAnnouncement request {}; + request.unpackBlob(v->data); + if (request.getPort() == 0) + return false; + ServiceAnnouncement sa_addr {from, fromlen}; + sa_addr.setPort(request.getPort()); + // argument v is modified (not the value). + v = std::make_shared<Value>(ServiceAnnouncement::TYPE, sa_addr, v->id); + return true; +} + +const ValueType ServiceAnnouncement::TYPE = {1, "Service Announcement", 15 * 60, ServiceAnnouncement::storePolicy, ValueType::DEFAULT_EDIT_POLICY}; + +void +Value::packToSign(Blob& res) const +{ + res.push_back(flags.to_ulong()); + if (flags.isEncrypted()) { + res.insert(res.end(), cypher.begin(), cypher.end()); + } else { + if (flags.isSigned()) { + serialize<decltype(seq)>(seq, res); + owner.pack(res); + //res.insert(res.end(), owner.begin(), owner.end()); + if (flags.haveRecipient()) + res.insert(res.end(), recipient.begin(), recipient.end()); + } + serialize<ValueType::Id>(type, res); + serialize<Blob>(data, res); + } +} + +Blob +Value::getToSign() const +{ + Blob ret; + packToSign(ret); + return ret; +} + +/** + * Pack part of the data to be encrypted + */ +void +Value::packToEncrypt(Blob& res) const +{ + packToSign(res); + if (!flags.isEncrypted() && flags.isSigned()) + serialize<Blob>(signature, res); +} + +Blob +Value::getToEncrypt() const +{ + Blob ret; + packToEncrypt(ret); + return ret; +} + +void +Value::pack(Blob& res) const +{ + serialize<Id>(id, res); + packToEncrypt(res); +} + +void +Value::unpackBody(Blob::const_iterator& begin, Blob::const_iterator& end) +{ + // clear optional fields + owner = {}; + recipient = {}; + cypher.clear(); + signature.clear(); + data.clear(); + type = 0; + + flags = {deserialize<uint8_t>(begin, end)}; + if (flags.isEncrypted()) { + cypher = {begin, end}; + begin = end; + } else { + if(flags.isSigned()) { + seq = deserialize<decltype(seq)>(begin, end); + owner.unpack(begin, end); + if (flags.haveRecipient()) + recipient = deserialize<InfoHash>(begin, end); + } + type = deserialize<ValueType::Id>(begin, end); + data = deserialize<Blob>(begin, end); + if (flags.isSigned()) + signature = deserialize<Blob>(begin, end); + } +} + +void +Value::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) +{ + id = deserialize<Id>(begin, end); + unpackBody(begin, end); +} + +std::ostream& operator<< (std::ostream& s, const ServiceAnnouncement& v) +{ + s << "Peer: "; + s << "port " << v.getPort(); + + if (v.ss.ss_family == AF_INET || v.ss.ss_family == AF_INET6) { + char hbuf[NI_MAXHOST]; + if (getnameinfo((sockaddr*)&v.ss, sizeof(v.ss), hbuf, sizeof(hbuf), nullptr, 0, NI_NUMERICHOST) == 0) { + s << " addr " << std::string(hbuf, strlen(hbuf)); + } + } + return s; +} + +void +ServiceAnnouncement::pack(Blob& res) const +{ + serialize<in_port_t>(getPort(), res); + if (ss.ss_family == AF_INET) { + auto sa4 = reinterpret_cast<const sockaddr_in*>(&ss); + serialize<in_addr>(sa4->sin_addr, res); + } else if (ss.ss_family == AF_INET6) { + auto sa6 = reinterpret_cast<const sockaddr_in6*>(&ss); + serialize<in6_addr>(sa6->sin6_addr, res); + } +} + +void +ServiceAnnouncement::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) +{ + setPort(deserialize<in_port_t>(begin, end)); + size_t addr_size = end - begin; + if (addr_size < sizeof(in_addr)) { + ss.ss_family = 0; + } else if (addr_size == sizeof(in_addr)) { + auto sa4 = reinterpret_cast<sockaddr_in*>(&ss); + sa4->sin_family = AF_INET; + sa4->sin_addr = deserialize<in_addr>(begin, end); + } else if (addr_size == sizeof(in6_addr)) { + auto sa6 = reinterpret_cast<sockaddr_in6*>(&ss); + sa6->sin6_family = AF_INET6; + sa6->sin6_addr = deserialize<in6_addr>(begin, end); + } else { + throw std::runtime_error("ServiceAnnouncement parse error."); + } +} + + +} -- GitLab