From 612b55b151d20343e61403613011056b89bf8c87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com> Date: Mon, 29 May 2023 10:42:04 -0400 Subject: [PATCH] add initial project structure Change-Id: I6a3fb080ff623b312e42d71754480a7ce00b81a0 --- CMakeLists.txt | 37 + include/certstore.h | 198 +++ include/connectionmanager.h | 267 ++++ include/diffie-hellman.h | 73 + include/fileutils.h | 150 ++ include/generic_io.h | 123 ++ include/ice_options.h | 66 + include/ip_utils.h | 349 +++++ include/multiplexed_socket.h | 361 +++++ include/string_utils.h | 222 +++ include/tls_session.h | 165 +++ src/connectionmanager.cpp | 1656 ++++++++++++++++++++++ src/fileutils.cpp | 878 ++++++++++++ src/ice_socket.h | 58 + src/ice_transport.cpp | 1902 ++++++++++++++++++++++++++ src/ice_transport.h | 219 +++ src/ip_utils.cpp | 501 +++++++ src/multiplexed_socket.cpp | 1208 ++++++++++++++++ src/peer_connection.cpp | 452 ++++++ src/peer_connection.h | 142 ++ src/security/certstore.cpp | 673 +++++++++ src/security/diffie-hellman.cpp | 139 ++ src/security/security_const.h | 121 ++ src/security/threadloop.cpp | 135 ++ src/security/threadloop.h | 134 ++ src/security/tls_session.cpp | 1789 ++++++++++++++++++++++++ src/sip_utils.h | 173 +++ src/string_utils.cpp | 167 +++ src/tracepoint/trace-tools.h | 65 + src/tracepoint/tracepoint-def.h | 237 ++++ src/tracepoint/tracepoint.c | 3 + src/tracepoint/tracepoint.h | 60 + src/transport/peer_channel.h | 109 ++ src/upnp/protocol/igd.cpp | 76 + src/upnp/protocol/igd.h | 110 ++ src/upnp/protocol/mapping.cpp | 347 +++++ src/upnp/protocol/mapping.h | 146 ++ src/upnp/protocol/natpmp/nat_pmp.cpp | 775 +++++++++++ src/upnp/protocol/natpmp/nat_pmp.h | 174 +++ src/upnp/protocol/natpmp/pmp_igd.cpp | 63 + src/upnp/protocol/natpmp/pmp_igd.h | 54 + src/upnp/protocol/pupnp/pupnp.cpp | 1599 ++++++++++++++++++++++ src/upnp/protocol/pupnp/pupnp.h | 271 ++++ src/upnp/protocol/pupnp/upnp_igd.cpp | 74 + src/upnp/protocol/pupnp/upnp_igd.h | 106 ++ src/upnp/protocol/upnp_protocol.h | 126 ++ src/upnp/upnp_context.cpp | 1339 ++++++++++++++++++ src/upnp/upnp_context.h | 294 ++++ src/upnp/upnp_control.cpp | 150 ++ src/upnp/upnp_control.h | 78 ++ src/upnp/upnp_thread_util.h | 35 + 51 files changed, 18649 insertions(+) create mode 100644 CMakeLists.txt create mode 100644 include/certstore.h create mode 100644 include/connectionmanager.h create mode 100644 include/diffie-hellman.h create mode 100644 include/fileutils.h create mode 100644 include/generic_io.h create mode 100644 include/ice_options.h create mode 100644 include/ip_utils.h create mode 100644 include/multiplexed_socket.h create mode 100644 include/string_utils.h create mode 100644 include/tls_session.h create mode 100644 src/connectionmanager.cpp create mode 100644 src/fileutils.cpp create mode 100644 src/ice_socket.h create mode 100644 src/ice_transport.cpp create mode 100644 src/ice_transport.h create mode 100644 src/ip_utils.cpp create mode 100644 src/multiplexed_socket.cpp create mode 100644 src/peer_connection.cpp create mode 100644 src/peer_connection.h create mode 100644 src/security/certstore.cpp create mode 100644 src/security/diffie-hellman.cpp create mode 100644 src/security/security_const.h create mode 100644 src/security/threadloop.cpp create mode 100644 src/security/threadloop.h create mode 100644 src/security/tls_session.cpp create mode 100644 src/sip_utils.h create mode 100644 src/string_utils.cpp create mode 100644 src/tracepoint/trace-tools.h create mode 100644 src/tracepoint/tracepoint-def.h create mode 100644 src/tracepoint/tracepoint.c create mode 100644 src/tracepoint/tracepoint.h create mode 100644 src/transport/peer_channel.h create mode 100644 src/upnp/protocol/igd.cpp create mode 100644 src/upnp/protocol/igd.h create mode 100644 src/upnp/protocol/mapping.cpp create mode 100644 src/upnp/protocol/mapping.h create mode 100644 src/upnp/protocol/natpmp/nat_pmp.cpp create mode 100644 src/upnp/protocol/natpmp/nat_pmp.h create mode 100644 src/upnp/protocol/natpmp/pmp_igd.cpp create mode 100644 src/upnp/protocol/natpmp/pmp_igd.h create mode 100644 src/upnp/protocol/pupnp/pupnp.cpp create mode 100644 src/upnp/protocol/pupnp/pupnp.h create mode 100644 src/upnp/protocol/pupnp/upnp_igd.cpp create mode 100644 src/upnp/protocol/pupnp/upnp_igd.h create mode 100644 src/upnp/protocol/upnp_protocol.h create mode 100644 src/upnp/upnp_context.cpp create mode 100644 src/upnp/upnp_context.h create mode 100644 src/upnp/upnp_control.cpp create mode 100644 src/upnp/upnp_control.h create mode 100644 src/upnp/upnp_thread_util.h diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..227118f --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,37 @@ +cmake_minimum_required(VERSION 3.20) +project(dhtnet) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package (PkgConfig REQUIRED) +find_package(msgpack REQUIRED QUIET CONFIG NAMES msgpack msgpack-cxx) +pkg_check_modules (opendht REQUIRED IMPORTED_TARGET opendht>=2.6.0) +pkg_check_modules (pjproject REQUIRED IMPORTED_TARGET libpjproject) + +set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMSGPACK_NO_BOOST -DMSGPACK_DISABLE_LEGACY_NIL -DMSGPACK_DISABLE_LEGACY_CONVERT") + +# Sources +list (APPEND dhtnet_SOURCES + src/connectionmanager.cpp + src/ice_transport.cpp + src/multiplexed_socket.cpp + src/peer_connection.cpp + src/string_utils.cpp + src/fileutils.cpp + src/security/tls_session.cpp + src/security/certstore.cpp + src/security/threadloop.cpp +) + +list (APPEND dhtnet_HEADERS + include/connectionmanager.h + include/multiplexed_socket.h +) + +add_library(dhtnet ${dhtnet_SOURCES}) +target_link_libraries(dhtnet PUBLIC PkgConfig::opendht msgpack-cxx) +target_include_directories(dhtnet PUBLIC include) +target_compile_definitions(dhtnet PRIVATE + PJ_AUTOCONF=1 +) diff --git a/include/certstore.h b/include/certstore.h new file mode 100644 index 0000000..6a608ef --- /dev/null +++ b/include/certstore.h @@ -0,0 +1,198 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +//#include "security_const.h" +//#include "noncopyable.h" + +#include <opendht/crypto.h> + +#include <string> +#include <vector> +#include <map> +#include <set> +#include <future> +#include <mutex> + +namespace crypto = ::dht::crypto; + +namespace dht { +namespace log { +class Logger; +} +} + +namespace jami { + +using Logger = dht::log::Logger; +namespace tls { + +enum class TrustStatus { UNTRUSTED = 0, TRUSTED }; +TrustStatus trustStatusFromStr(const char* str); +const char* statusToStr(TrustStatus s); + +/** + * Global certificate store. + * Stores system root CAs and any other encountred certificate + */ +class CertificateStore +{ +public: + explicit CertificateStore(const std::string& accountId, std::shared_ptr<Logger> logger); + + std::vector<std::string> getPinnedCertificates() const; + /** + * Return certificate (with full chain) + */ + std::shared_ptr<crypto::Certificate> getCertificate(const std::string& cert_id); + std::shared_ptr<crypto::Certificate> getCertificateLegacy(const std::string& dataDir, const std::string& cert_id); + + std::shared_ptr<crypto::Certificate> findCertificateByName( + const std::string& name, crypto::NameType type = crypto::NameType::UNKNOWN) const; + std::shared_ptr<crypto::Certificate> findCertificateByUID(const std::string& uid) const; + std::shared_ptr<crypto::Certificate> findIssuer( + const std::shared_ptr<crypto::Certificate>& crt) const; + + std::vector<std::string> pinCertificate(const std::vector<uint8_t>& crt, + bool local = true) noexcept; + std::vector<std::string> pinCertificate(crypto::Certificate&& crt, bool local = true); + std::vector<std::string> pinCertificate(const std::shared_ptr<crypto::Certificate>& crt, + bool local = true); + bool unpinCertificate(const std::string&); + + void pinCertificatePath(const std::string& path, + std::function<void(const std::vector<std::string>&)> cb = {}); + unsigned unpinCertificatePath(const std::string&); + + bool setTrustedCertificate(const std::string& id, TrustStatus status); + std::vector<gnutls_x509_crt_t> getTrustedCertificates() const; + + void pinRevocationList(const std::string& id, + const std::shared_ptr<dht::crypto::RevocationList>& crl); + void pinRevocationList(const std::string& id, dht::crypto::RevocationList&& crl) + { + pinRevocationList(id, + std::make_shared<dht::crypto::RevocationList>( + std::forward<dht::crypto::RevocationList>(crl))); + } + void pinOcspResponse(const dht::crypto::Certificate& cert); + + void loadRevocations(crypto::Certificate& crt) const; + + const std::shared_ptr<Logger>& logger() const { + return logger_; + } + +private: + //NON_COPYABLE(CertificateStore); + + + unsigned loadLocalCertificates(); + void pinRevocationList(const std::string& id, const dht::crypto::RevocationList& crl); + std::shared_ptr<Logger> logger_; + + const std::string certPath_; + const std::string crlPath_; + const std::string ocspPath_; + + mutable std::mutex lock_; + std::map<std::string, std::shared_ptr<crypto::Certificate>> certs_; + std::map<std::string, std::vector<std::weak_ptr<crypto::Certificate>>> paths_; + + // globally trusted certificates (root CAs) + std::vector<std::shared_ptr<crypto::Certificate>> trustedCerts_; +}; + +/** + * Keeps track of the allowed and trust status of certificates + * Trusted is the status of top certificates we trust to build our + * certificate chain: root CAs and other configured CAs. + * + * Allowed is the status of certificates we accept for incoming + * connections. + */ +class TrustStore +{ +public: + explicit TrustStore(CertificateStore& certStore) + : certStore_(certStore) + {} + + enum class PermissionStatus { UNDEFINED = 0, ALLOWED, BANNED }; + + static PermissionStatus statusFromStr(const char* str); + static const char* statusToStr(PermissionStatus s); + + bool addRevocationList(dht::crypto::RevocationList&& crl); + + bool setCertificateStatus(const std::string& cert_id, const PermissionStatus status); + bool setCertificateStatus(const std::shared_ptr<crypto::Certificate>& cert, + PermissionStatus status, + bool local = true); + + PermissionStatus getCertificateStatus(const std::string& cert_id) const; + + std::vector<std::string> getCertificatesByStatus(PermissionStatus status) const; + + /** + * Check that the certificate is allowed (valid and permited) for contact. + * Valid means the certificate chain matches with our CA list, + * has valid signatures, expiration dates etc. + * Permited means at least one of the certificate in the chain is + * ALLOWED (if allowPublic is false), and none is BANNED. + * + * @param crt the end certificate of the chain to check + * @param allowPublic if false, requires at least one ALLOWED certificate. + * (not required otherwise). In any case a BANNED + * certificate means permission refusal. + * @return true if the certificate is valid and permitted. + */ + bool isAllowed(const crypto::Certificate& crt, bool allowPublic = false); + +private: + TrustStore(const TrustStore& o) = delete; + TrustStore& operator=(const TrustStore& o) = delete; + TrustStore(TrustStore&& o) = delete; + TrustStore& operator=(TrustStore&& o) = delete; + + void updateKnownCerts(); + bool setCertificateStatus(std::shared_ptr<crypto::Certificate> cert, + const std::string& cert_id, + const TrustStore::PermissionStatus status, + bool local); + void setStoreCertStatus(const crypto::Certificate& crt, bool status); + void rebuildTrust(); + + struct Status + { + bool allowed; + }; + + // unknown certificates with known status + mutable std::recursive_mutex mutex_; + std::map<std::string, Status> unknownCertStatus_; + std::map<std::string, std::pair<std::shared_ptr<crypto::Certificate>, Status>> certStatus_; + dht::crypto::TrustList allowed_; + CertificateStore& certStore_; +}; + +} // namespace tls +} // namespace jami diff --git a/include/connectionmanager.h b/include/connectionmanager.h new file mode 100644 index 0000000..5d2db50 --- /dev/null +++ b/include/connectionmanager.h @@ -0,0 +1,267 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ +#pragma once + +#include "ice_options.h" +#include "multiplexed_socket.h" + +#include <opendht/dhtrunner.h> +#include <opendht/infohash.h> +#include <opendht/value.h> +#include <opendht/default_types.h> +#include <opendht/sockaddr.h> +#include <opendht/logger.h> + +#include <memory> +#include <vector> +#include <string> + +namespace jami { + +class ChannelSocket; +class ConnectionManager; +namespace upnp { +class Controller; +} +namespace tls { +class CertificateStore; +} + +/** + * A PeerConnectionRequest is a request which ask for an initial connection + * It contains the ICE request an ID and if it's an answer + * Transmitted via the UDP DHT + */ +struct PeerConnectionRequest : public dht::EncryptedValue<PeerConnectionRequest> +{ + static const constexpr dht::ValueType& TYPE = dht::ValueType::USER_DATA; + static constexpr const char* key_prefix = "peer:"; ///< base to compute the DHT listen key + dht::Value::Id id = dht::Value::INVALID_ID; + std::string ice_msg {}; + bool isAnswer {false}; + std::string connType {}; // Used for push notifications to know why we open a new connection + MSGPACK_DEFINE_MAP(id, ice_msg, isAnswer, connType) +}; + +/** + * Used to accept or not an incoming ICE connection (default accept) + */ +using onICERequestCallback = std::function<bool(const DeviceId&)>; +/** + * Used to accept or decline an incoming channel request + */ +using ChannelRequestCallback = std::function<bool(const std::shared_ptr<dht::crypto::Certificate>&, + const std::string& /* name */)>; +/** + * Used by connectDevice, when the socket is ready + */ +using ConnectCallback = std::function<void(const std::shared_ptr<ChannelSocket>&, const DeviceId&)>; +/** + * Used when an incoming connection is ready + */ +using ConnectionReadyCallback = std::function< + void(const DeviceId&, const std::string& /* channel_name */, std::shared_ptr<ChannelSocket>)>; + +using iOSConnectedCallback + = std::function<bool(const std::string& /* connType */, dht::InfoHash /* peer_h */)>; + +/** + * Manages connections to other devices + * @note the account MUST be valid if ConnectionManager lives + */ +class ConnectionManager +{ +public: + class Config; + + ConnectionManager(std::shared_ptr<Config> config_); + ~ConnectionManager(); + + /** + * Open a new channel between the account's device and another device + * This method will send a message on the account's DHT, wait a reply + * and then, create a Tls socket with remote peer. + * @param deviceId Remote device + * @param name Name of the channel + * @param cb Callback called when socket is ready ready + * @param noNewSocket Do not negotiate a new socket if there is none + * @param forceNewSocket Negotiate a new socket even if there is one // todo group with previous + * (enum) + * @param connType Type of the connection + */ + void connectDevice(const DeviceId& deviceId, + const std::string& name, + ConnectCallback cb, + bool noNewSocket = false, + bool forceNewSocket = false, + const std::string& connType = ""); + void connectDevice(const std::shared_ptr<dht::crypto::Certificate>& cert, + const std::string& name, + ConnectCallback cb, + bool noNewSocket = false, + bool forceNewSocket = false, + const std::string& connType = ""); + + /** + * Check if we are already connecting to a device with a specific name + * @param deviceId Remote device + * @param name Name of the channel + * @return if connecting + * @note isConnecting is not true just after connectDevice() as connectDevice is full async + */ + bool isConnecting(const DeviceId& deviceId, const std::string& name) const; + + /** + * Close all connections with a current device + * @param peerUri Peer URI + */ + void closeConnectionsWith(const std::string& peerUri); + + /** + * Method to call to listen to incoming requests + * @param deviceId Account's device + */ + void onDhtConnected(const dht::crypto::PublicKey& devicePk); + + /** + * Add a callback to decline or accept incoming ICE connections + * @param cb Callback to trigger + */ + void onICERequest(onICERequestCallback&& cb); + + /** + * Trigger cb on incoming peer channel + * @param cb Callback to trigger + * @note The callback is used to validate + * if the incoming request is accepted or not. + */ + void onChannelRequest(ChannelRequestCallback&& cb); + + /** + * Trigger cb when connection with peer is ready + * @param cb Callback to trigger + */ + void onConnectionReady(ConnectionReadyCallback&& cb); + + /** + * Trigger cb when connection with peer is ready for iOS devices + * @param cb Callback to trigger + */ + void oniOSConnected(iOSConnectedCallback&& cb); + + /** + * @return the number of active sockets + */ + std::size_t activeSockets() const; + + /** + * Log informations for all sockets + */ + void monitor() const; + + /** + * Send beacon on peers supporting it + */ + void connectivityChanged(); + + /** + * Create and return ICE options. + */ + void getIceOptions(std::function<void(IceTransportOptions&&)> cb) noexcept; + IceTransportOptions getIceOptions() const noexcept; + + /** + * Get the published IP address, fallbacks to NAT if family is unspecified + * Prefers the usage of IPv4 if possible. + */ + IpAddr getPublishedIpAddress(uint16_t family = PF_UNSPEC) const; + + /** + * Set published IP address according to given family + */ + void setPublishedAddress(const IpAddr& ip_addr); + + /** + * Store the local/public addresses used to register + */ + void storeActiveIpAddress(std::function<void()>&& cb = {}); + + std::shared_ptr<Config> getConfig(); + +private: + ConnectionManager() = delete; + class Impl; + std::shared_ptr<Impl> pimpl_; +}; + +struct ConnectionManager::Config +{ + /** + * Determine if STUN public address resolution is required to register this account. In this + * case a STUN server hostname must be specified. + */ + bool stunEnabled {false}; + + /** + * The STUN server hostname (optional), used to provide the public IP address in case the + * softphone stay behind a NAT. + */ + std::string stunServer {}; + + /** + * Determine if TURN public address resolution is required to register this account. In this + * case a TURN server hostname must be specified. + */ + bool turnEnabled {false}; + + /** + * The TURN server hostname (optional), used to provide the public IP address in case the + * softphone stay behind a NAT. + */ + std::string turnServer; + std::string turnServerUserName; + std::string turnServerPwd; + std::string turnServerRealm; + + mutable std::mutex cachedTurnMutex {}; + dht::SockAddr cacheTurnV4 {}; + dht::SockAddr cacheTurnV6 {}; + + std::string cachePath {}; + + std::shared_ptr<asio::io_context> ioContext; + std::shared_ptr<dht::DhtRunner> dht; + dht::crypto::Identity id; + + tls::CertificateStore* certStore; + + /** + * UPnP IGD controller and the mutex to access it + */ + bool upnpEnabled; + std::shared_ptr<jami::upnp::Controller> upnpCtrl; + + std::shared_ptr<dht::log::Logger> logger; + + /** + * returns whether or not UPnP is enabled and active + * ie: if it is able to make port mappings + */ + bool getUPnPActive() const; +}; + +} // namespace jami \ No newline at end of file diff --git a/include/diffie-hellman.h b/include/diffie-hellman.h new file mode 100644 index 0000000..e7f8429 --- /dev/null +++ b/include/diffie-hellman.h @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include <gnutls/gnutls.h> + +#include <vector> +#include <memory> +#include <cstdint> +#include <string> + +namespace jami { +namespace tls { + +class DhParams +{ +public: + DhParams() = default; + DhParams(DhParams&&) = default; + DhParams(const DhParams& other) { *this = other; } + + DhParams& operator=(DhParams&& other) = default; + DhParams& operator=(const DhParams& other); + + /// \brief Construct by taking ownership of given gnutls DH params + /// + /// User should not call gnutls_dh_params_deinit on given \a raw_params. + /// The object is stolen and its live is manager by our object. + explicit DhParams(gnutls_dh_params_t p) + : params_ {p, gnutls_dh_params_deinit} + {} + + /** Deserialize DER or PEM encoded DH-params */ + DhParams(const std::vector<uint8_t>& data); + + gnutls_dh_params_t get() { return params_.get(); } + gnutls_dh_params_t get() const { return params_.get(); } + + explicit inline operator bool() const { return bool(params_); } + + /** Serialize data in PEM format */ + std::vector<uint8_t> serialize() const; + + static DhParams generate(); + + static DhParams loadDhParams(const std::string& path); + +private: + std::unique_ptr<gnutls_dh_params_int, decltype(gnutls_dh_params_deinit)*> + params_ {nullptr, gnutls_dh_params_deinit}; +}; + +} // namespace tls +} // namespace jami diff --git a/include/fileutils.h b/include/fileutils.h new file mode 100644 index 0000000..6264af6 --- /dev/null +++ b/include/fileutils.h @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Rafaël Carré <rafael.carre@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include <string> +#include <vector> +#include <chrono> +#include <mutex> +#include <cstdio> +#include <ios> + +#ifndef _WIN32 +#include <sys/stat.h> // mode_t +#define DIR_SEPARATOR_STR "/" // Directory separator string +#define DIR_SEPARATOR_CH '/' // Directory separator char +#define DIR_SEPARATOR_STR_ESC "\\/" // Escaped directory separator string +#else +#define mode_t unsigned +#define DIR_SEPARATOR_STR "\\" // Directory separator string +#define DIR_SEPARATOR_CH '\\' // Directory separator char +#define DIR_SEPARATOR_STR_ESC "//*" // Escaped directory separator string +#endif + +namespace jami { +namespace fileutils { + +/** + * Check directory existence and create it with given mode if it doesn't. + * @param path to check, relative or absolute + * @param dir last directory creation mode + * @param parents default mode for all created directories except the last + */ +bool check_dir(const char* path, mode_t dir = 0755, mode_t parents = 0755); +/*std::string expand_path(const std::string& path);*/ +bool isDirectoryWritable(const std::string& directory); + +bool recursive_mkdir(const std::string& path, mode_t mode = 0755); + +bool isPathRelative(const std::string& path); +/** + * If path is contained in base, return the suffix, otherwise return the full path. + * @param base must not finish with DIR_SEPARATOR_STR, can be empty + * @param path the path + */ +//std::string getCleanPath(const std::string& base, const std::string& path); +/** + * If path is relative, it is appended to base. + */ +//std::string getFullPath(const std::string& base, const std::string& path); + +bool isFile(const std::string& path, bool resolveSymlink = true); +bool isDirectory(const std::string& path); +bool isSymLink(const std::string& path); +bool hasHardLink(const std::string& path); + +std::chrono::system_clock::time_point writeTime(const std::string& path); + +/*void createFileLink(const std::string& src, const std::string& dest, bool hard = false); + +std::string_view getFileExtension(std::string_view filename);*/ + +/** + * Read content of the directory. + * The result is a list of relative (to @param dir) paths of all entries + * in the directory, without "." and "..". + */ +std::vector<std::string> readDirectory(const std::string& dir); + +/** + * Read the full content of a file at path. + * If path is relative, it is appended to default_dir. + */ +std::vector<uint8_t> loadFile(const std::string& path, const std::string& default_dir = {}); +std::string loadTextFile(const std::string& path, const std::string& default_dir = {}); + +void saveFile(const std::string& path, const uint8_t* data, size_t data_size, mode_t mode = 0644); +inline void +saveFile(const std::string& path, const std::vector<uint8_t>& data, mode_t mode = 0644) +{ + saveFile(path, data.data(), data.size(), mode); +} + +/*std::vector<uint8_t> loadCacheFile(const std::string& path, + std::chrono::system_clock::duration maxAge); +std::string loadCacheTextFile(const std::string& path, std::chrono::system_clock::duration maxAge); + +std::vector<uint8_t> readArchive(const std::string& path, const std::string& password = {}); +void writeArchive(const std::string& data, + const std::string& path, + const std::string& password = {});*/ + +std::mutex& getFileLock(const std::string& path); + +/** + * Remove a file with optional erasing of content. + * Return the same value as std::remove(). + */ +//int remove(const std::string& path, bool erase = false); + +/** + * Prune given directory's content and remove it, symlinks are not followed. + * Return 0 if succeed, -1 if directory is not removed (content can be removed partially). + */ +int removeAll(const std::string& path, bool erase = false); + +/** + * Wrappers for fstream opening that will convert paths to wstring + * on windows + */ +void openStream(std::ifstream& file, + const std::string& path, + std::ios_base::openmode mode = std::ios_base::in); +void openStream(std::ofstream& file, + const std::string& path, + std::ios_base::openmode mode = std::ios_base::out); +std::ifstream ifstream(const std::string& path, std::ios_base::openmode mode = std::ios_base::in); +std::ofstream ofstream(const std::string& path, std::ios_base::openmode mode = std::ios_base::out); + +int64_t size(const std::string& path); + +std::string sha3File(const std::string& path); +std::string sha3sum(const std::vector<uint8_t>& buffer); + +/** + * Windows compatibility wrapper for checking read-only attribute + */ +int accessFile(const std::string& file, int mode); + +uint64_t lastWriteTime(const std::string& p); + +} // namespace fileutils +} // namespace jami diff --git a/include/generic_io.h b/include/generic_io.h new file mode 100644 index 0000000..cd8d0e9 --- /dev/null +++ b/include/generic_io.h @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ +#pragma once + +#include "ip_utils.h" + +#include <functional> +#include <vector> +#include <chrono> +#include <system_error> +#include <cstdint> + +#if defined(_MSC_VER) +#include <BaseTsd.h> +using ssize_t = SSIZE_T; +#endif + +namespace jami { + +template<typename T> +class GenericSocket +{ +public: + using ValueType = T; + + virtual ~GenericSocket() { shutdown(); } + + using RecvCb = std::function<ssize_t(const ValueType* buf, std::size_t len)>; + + /// Close established connection + /// \note Terminate outstanding blocking read operations with an empty error code, but a 0 read size. + virtual void shutdown() {} + + /// Set Rx callback + /// \warning This method is here for backward compatibility + /// and because async IO are not implemented yet. + virtual void setOnRecv(RecvCb&& cb) = 0; + + virtual bool isReliable() const = 0; + + virtual bool isInitiator() const = 0; + + /// Return maximum application payload size. + /// This value is negative if the session is not ready to give a valid answer. + /// The value is 0 if such information is irrelevant for the session. + /// If stricly positive, the user must use send() with an input buffer size below or equals + /// to this value if it want to be sure that the transport sent it in an atomic way. + /// Example: in case of non-reliable transport using packet oriented IO, + /// this value gives the maximal size used to send one packet. + virtual int maxPayload() const = 0; + + /// Wait until data to read available, timeout or io error + /// \param ec error code set in case of error (if return value is < 0) + /// \return positive number if data ready for read, 0 in case of timeout or error. + /// \note error code is not set in case of timeout, but set only in case of io error + /// (i.e. socket deconnection). + /// \todo make a std::chrono version for the timeout + virtual int waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const = 0; + + /// Write a given amount of data. + /// \param buf data to write. + /// \param len number of bytes to write. + /// \param ec error code set in case of error. + /// \return number of bytes written, 0 is valid. + /// \warning error checking consists in checking if \a !ec is true, not if returned size is 0 + /// as a write of 0 could be considered a valid operation. + virtual std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) = 0; + + /// Read a given amount of data. + /// \param buf data to read. + /// \param len number of bytes to read. + /// \param ec error code set in case of error. + /// \return number of bytes read, 0 is valid. + /// \warning error checking consists in checking if \a !ec is true, not if returned size is 0 + /// as a read of 0 could be considered a valid operation (i.e. non-blocking IO). + virtual std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) = 0; + + /// write() adaptor for STL containers + template<typename U> + std::size_t write(const U& obj, std::error_code& ec) + { + return write(obj.data(), obj.size() * sizeof(typename U::value_type), ec); + } + + /// read() adaptor for STL containers + template<typename U> + std::size_t read(U& storage, std::error_code& ec) + { + auto res = read(storage.data(), storage.size() * sizeof(typename U::value_type), ec); + if (!ec) + storage.resize(res); + return res; + } + + /// Return the local IP address if known. + /// \note The address is not valid (addr.isUnspecified() returns true) if it's not known + /// or not available. + virtual IpAddr localAddr() const { return {}; } + + /// Return the remote IP address if known. + /// \note The address is not valid (addr.isUnspecified() returns true) if it's not known + /// or not available. + virtual IpAddr remoteAddr() const { return {}; } + +protected: + GenericSocket() = default; +}; + +} // namespace jami diff --git a/include/ice_options.h b/include/ice_options.h new file mode 100644 index 0000000..c26b83e --- /dev/null +++ b/include/ice_options.h @@ -0,0 +1,66 @@ +#pragma once + +#include <functional> +#include <vector> +#include <string> + +#include "ip_utils.h" + +namespace jami { + +class IceTransportFactory; +using IceTransportCompleteCb = std::function<void(bool)>; + +struct StunServerInfo +{ + inline StunServerInfo& setUri(const std::string& args) { + uri = args; + return *this; + } + + std::string uri; // server URI, mandatory +}; + +struct TurnServerInfo +{ + inline TurnServerInfo& setUri(const std::string& args) { + uri = args; + return *this; + } + inline TurnServerInfo& setUsername(const std::string& args) { + username = args; + return *this; + } + inline TurnServerInfo& setPassword(const std::string& args) { + password = args; + return *this; + } + inline TurnServerInfo& setRealm(const std::string& args) { + realm = args; + return *this; + } + + std::string uri; // server URI, mandatory + std::string username; // credentials username (optional, empty if not used) + std::string password; // credentials password (optional, empty if not used) + std::string realm; // credentials realm (optional, empty if not used) +}; + +struct IceTransportOptions +{ + IceTransportFactory* factory {nullptr}; + bool master {true}; + unsigned streamsCount {1}; + unsigned compCountPerStream {1}; + bool upnpEnable {false}; + IceTransportCompleteCb onInitDone {}; + IceTransportCompleteCb onNegoDone {}; + std::vector<StunServerInfo> stunServers; + std::vector<TurnServerInfo> turnServers; + bool tcpEnable {false}; + // Addresses used by the account owning the transport instance. + IpAddr accountLocalAddr {}; + IpAddr accountPublicAddr {}; +}; + +} diff --git a/include/ip_utils.h b/include/ip_utils.h new file mode 100644 index 0000000..c720aa1 --- /dev/null +++ b/include/ip_utils.h @@ -0,0 +1,349 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ +#pragma once + +#ifdef HAVE_CONFIG +#include <config.h> +#endif + +#include <sstream> // include before pjlib.h to fix macros issues with pjlib.h + +extern "C" { +#include <pjlib.h> +} + +#include <ciso646> // fix windows compiler bug + +#ifdef _WIN32 +#ifdef RING_UWP +#define _WIN32_WINNT 0x0A00 +#else +#define _WIN32_WINNT 0x0601 +#endif +#include <ws2tcpip.h> + +// define in mingw +#ifdef interface +#undef interface +#endif +#else +#include <sys/socket.h> +#include <netinet/in.h> +#include <arpa/inet.h> +#include <net/if.h> +#include <sys/ioctl.h> +#include <unistd.h> +#endif + +#include <string> +#include <vector> + +/* An IPv4 equivalent to IN6_IS_ADDR_UNSPECIFIED */ +#ifndef IN_IS_ADDR_UNSPECIFIED +#define IN_IS_ADDR_UNSPECIFIED(a) (((long int) (a)->s_addr) == 0x00000000) +#endif /* IN_IS_ADDR_UNSPECIFIED */ + +#define INVALID_SOCKET (-1) + +namespace jami { + +/** + * Binary representation of an IP address. + */ +class IpAddr +{ +public: + IpAddr() + : IpAddr(AF_UNSPEC) + {} + IpAddr(const IpAddr&) = default; + IpAddr(IpAddr&&) = default; + IpAddr& operator=(const IpAddr&) = default; + IpAddr& operator=(IpAddr&&) = default; + + explicit IpAddr(uint16_t family) + : addr() + { + addr.addr.sa_family = family; + } + + IpAddr(const pj_sockaddr& ip) + : addr(ip) + {} + + IpAddr(const pj_sockaddr& ip, socklen_t len) + : addr() + { + if (len > static_cast<socklen_t>(sizeof(addr))) + throw std::invalid_argument("IpAddr(): length overflows internal storage type"); + memcpy(&addr, &ip, len); + } + + IpAddr(const sockaddr& ip) + : addr() + { + memcpy(&addr, &ip, ip.sa_family == AF_INET6 ? sizeof addr.ipv6 : sizeof addr.ipv4); + } + + IpAddr(const sockaddr_in& ip) + : addr() + { + static_assert(sizeof(ip) <= sizeof(addr), "sizeof(sockaddr_in) too large"); + memcpy(&addr, &ip, sizeof(sockaddr_in)); + } + + IpAddr(const sockaddr_in6& ip) + : addr() + { + static_assert(sizeof(ip) <= sizeof(addr), "sizeof(sockaddr_in6) too large"); + memcpy(&addr, &ip, sizeof(sockaddr_in6)); + } + + IpAddr(const sockaddr_storage& ip) + : IpAddr(*reinterpret_cast<const sockaddr*>(&ip)) + {} + + IpAddr(const in_addr& ip) + : addr() + { + static_assert(sizeof(ip) <= sizeof(addr), "sizeof(in_addr) too large"); + addr.addr.sa_family = AF_INET; + memcpy(&addr.ipv4.sin_addr, &ip, sizeof(in_addr)); + } + + IpAddr(const in6_addr& ip) + : addr() + { + static_assert(sizeof(ip) <= sizeof(addr), "sizeof(in6_addr) too large"); + addr.addr.sa_family = AF_INET6; + memcpy(&addr.ipv6.sin6_addr, &ip, sizeof(in6_addr)); + } + + IpAddr(std::string_view str, pj_uint16_t family = AF_UNSPEC) + : addr() + { + if (str.empty()) { + addr.addr.sa_family = AF_UNSPEC; + return; + } + const pj_str_t pjstring {(char*) str.data(), (pj_ssize_t) str.size()}; + auto status = pj_sockaddr_parse(family, 0, &pjstring, &addr); + if (status != PJ_SUCCESS) + addr.addr.sa_family = AF_UNSPEC; + } + + // Is defined + inline explicit operator bool() const { return isIpv4() or isIpv6(); } + + inline explicit operator bool() { return isIpv4() or isIpv6(); } + + inline operator pj_sockaddr&() { return addr; } + + inline operator const pj_sockaddr&() const { return addr; } + + inline operator pj_sockaddr_in&() { return addr.ipv4; } + + inline operator const pj_sockaddr_in&() const + { + assert(addr.addr.sa_family != AF_INET6); + return addr.ipv4; + } + + inline operator pj_sockaddr_in6&() { return addr.ipv6; } + + inline operator const pj_sockaddr_in6&() const + { + assert(addr.addr.sa_family == AF_INET6); + return addr.ipv6; + } + + inline operator const sockaddr&() const { return reinterpret_cast<const sockaddr&>(addr); } + + inline operator const sockaddr*() const { return reinterpret_cast<const sockaddr*>(&addr); } + + inline const pj_sockaddr* pjPtr() const { return &addr; } + + inline pj_sockaddr* pjPtr() { return &addr; } + + inline operator std::string() const { return toString(); } + + std::string toString(bool include_port = false, bool force_ipv6_brackets = false) const + { + if (addr.addr.sa_family == AF_UNSPEC) + return {}; + std::string str(PJ_INET6_ADDRSTRLEN, (char) 0); + if (include_port) + force_ipv6_brackets = true; + pj_sockaddr_print(&addr, + &(*str.begin()), + PJ_INET6_ADDRSTRLEN, + (include_port ? 1 : 0) | (force_ipv6_brackets ? 2 : 0)); + str.resize(std::char_traits<char>::length(str.c_str())); + return str; + } + + void setPort(uint16_t port) { pj_sockaddr_set_port(&addr, port); } + + inline uint16_t getPort() const + { + if (not *this) + return 0; + return pj_sockaddr_get_port(&addr); + } + + inline socklen_t getLength() const + { + if (not *this) + return 0; + return pj_sockaddr_get_len(&addr); + } + + inline uint16_t getFamily() const { return addr.addr.sa_family; } + + inline bool isIpv4() const { return addr.addr.sa_family == AF_INET; } + + inline bool isIpv6() const { return addr.addr.sa_family == AF_INET6; } + + /** + * Return true if address is a loopback IP address. + */ + bool isLoopback() const; + + /** + * Return true if address is not a public IP address. + */ + bool isPrivate() const; + + bool isUnspecified() const; + + /** + * Return true if address is a valid IPv6. + */ + inline static bool isIpv6(std::string_view address) { return isValid(address, AF_INET6); } + + /** + * Return true if address is a valid IP address of specified family (if provided) or of any kind + * (default). Does not resolve hostnames. + */ + static bool isValid(std::string_view address, pj_uint16_t family = pj_AF_UNSPEC()); + +private: + pj_sockaddr addr {}; +}; + +// IpAddr helpers +inline bool +operator==(const IpAddr& lhs, const IpAddr& rhs) +{ + return !pj_sockaddr_cmp(&lhs, &rhs); +} +inline bool +operator!=(const IpAddr& lhs, const IpAddr& rhs) +{ + return !(lhs == rhs); +} +inline bool +operator<(const IpAddr& lhs, const IpAddr& rhs) +{ + return pj_sockaddr_cmp(&lhs, &rhs) < 0; +} +inline bool +operator>(const IpAddr& lhs, const IpAddr& rhs) +{ + return pj_sockaddr_cmp(&lhs, &rhs) > 0; +} +inline bool +operator<=(const IpAddr& lhs, const IpAddr& rhs) +{ + return pj_sockaddr_cmp(&lhs, &rhs) <= 0; +} +inline bool +operator>=(const IpAddr& lhs, const IpAddr& rhs) +{ + return pj_sockaddr_cmp(&lhs, &rhs) >= 0; +} + +namespace ip_utils { + +static const char* const DEFAULT_INTERFACE = "default"; + +static const unsigned int MAX_INTERFACE = 256; +static const unsigned int MIN_INTERFACE = 1; +enum class subnet_mask { prefix_8bit, prefix_16bit, prefix_24bit, prefix_32bit }; + +std::string getHostname(); + +int getHostName(char* out, size_t out_len); +std::string getGateway(char* localHost, ip_utils::subnet_mask prefix); +IpAddr getLocalGateway(); + +/** + * Return the generic "any host" IP address of the specified family. + * If family is unspecified, default to pj_AF_INET6() (IPv6). + */ +inline IpAddr +getAnyHostAddr(pj_uint16_t family) +{ + return IpAddr(family); +} + +/** + * Return the first host IP address of the specified family. + * If no address of the specified family is found, another family will + * be tried. + * Ex. : if family is pj_AF_INET6() (IPv6/default) and the system does not + * have an IPv6 address, an IPv4 address will be returned if available. + * + * If family is unspecified, default to pj_AF_INET6() if compiled + * with IPv6, or pj_AF_INET() otherwise. + */ +IpAddr getLocalAddr(pj_uint16_t family); + +/** + * Get the IP address of the network interface interface with the specified + * address family, or of any address family if unspecified (default). + */ +IpAddr getInterfaceAddr(const std::string& interface, pj_uint16_t family); + +/** + * List all the interfaces on the system and return + * a vector list containing their name (eth0, eth0:1 ...). + * @param void + * @return std::vector<std::string> A std::string vector + * of interface name available on all of the interfaces on + * the system. + */ +std::vector<std::string> getAllIpInterfaceByName(); + +/** + * List all the interfaces on the system and return + * a vector list containing their IP address. + * @param void + * @return std::vector<std::string> A std::string vector + * of IP address available on all of the interfaces on + * the system. + */ +std::vector<std::string> getAllIpInterface(); + +std::vector<IpAddr> getAddrList(std::string_view name, pj_uint16_t family = pj_AF_UNSPEC()); + +bool haveCommonAddr(const std::vector<IpAddr>& a, const std::vector<IpAddr>& b); + +std::vector<IpAddr> getLocalNameservers(); + +} // namespace ip_utils +} // namespace jami diff --git a/include/multiplexed_socket.h b/include/multiplexed_socket.h new file mode 100644 index 0000000..d8e6e16 --- /dev/null +++ b/include/multiplexed_socket.h @@ -0,0 +1,361 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ +#pragma once + +#include "ip_utils.h" +#include "generic_io.h" + +#include <opendht/default_types.h> +#include <condition_variable> + +#include <cstdint> + +namespace asio { +class io_context; +} + +namespace dht { +namespace log { +class Logger; +} +} + +namespace jami { + +using Logger = dht::log::Logger; +class IceTransport; +class ChannelSocket; +class TlsSocketEndpoint; + +using DeviceId = dht::PkId; +using OnConnectionRequestCb + = std::function<bool(const std::shared_ptr<dht::crypto::Certificate>& /* peer */, + const uint16_t& /* id */, + const std::string& /* name */)>; +using OnConnectionReadyCb + = std::function<void(const DeviceId& /* deviceId */, const std::shared_ptr<ChannelSocket>&)>; +using ChannelReadyCb = std::function<void(void)>; +using OnShutdownCb = std::function<void(void)>; + +static constexpr auto SEND_BEACON_TIMEOUT = std::chrono::milliseconds(3000); +static constexpr uint16_t CONTROL_CHANNEL {0}; +static constexpr uint16_t PROTOCOL_CHANNEL {0xffff}; + +enum class ChannelRequestState { + REQUEST, + ACCEPT, + DECLINE, +}; + +/** + * That msgpack structure is used to request a new channel (id, name) + * Transmitted over the TLS socket + */ +struct ChannelRequest +{ + std::string name {}; + uint16_t channel {0}; + ChannelRequestState state {ChannelRequestState::REQUEST}; + MSGPACK_DEFINE(name, channel, state) +}; + +/** + * A socket divided in channels over a TLS session + */ +class MultiplexedSocket : public std::enable_shared_from_this<MultiplexedSocket> +{ +public: + MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, std::unique_ptr<TlsSocketEndpoint> endpoint); + ~MultiplexedSocket(); + std::shared_ptr<ChannelSocket> addChannel(const std::string& name); + + std::shared_ptr<MultiplexedSocket> shared() + { + return std::static_pointer_cast<MultiplexedSocket>(shared_from_this()); + } + std::shared_ptr<MultiplexedSocket const> shared() const + { + return std::static_pointer_cast<MultiplexedSocket const>(shared_from_this()); + } + std::weak_ptr<MultiplexedSocket> weak() + { + return std::static_pointer_cast<MultiplexedSocket>(shared_from_this()); + } + std::weak_ptr<MultiplexedSocket const> weak() const + { + return std::static_pointer_cast<MultiplexedSocket const>(shared_from_this()); + } + + DeviceId deviceId() const; + bool isReliable() const; + bool isInitiator() const; + int maxPayload() const; + + /** + * Will be triggered when a new channel is ready + */ + void setOnReady(OnConnectionReadyCb&& cb); + /** + * Will be triggered when the peer asks for a new channel + */ + void setOnRequest(OnConnectionRequestCb&& cb); + + std::size_t write(const uint16_t& channel, + const uint8_t* buf, + std::size_t len, + std::error_code& ec); + + /** + * This will close all channels and send a TLS EOF on the main socket. + */ + void shutdown(); + + /** + * This will wait that eventLoop is stopped and stop it if necessary + */ + void join(); + + /** + * Will trigger that callback when shutdown() is called + */ + void onShutdown(OnShutdownCb&& cb); + + /** + * Get informations from socket (channels opened) + */ + void monitor() const; + + const std::shared_ptr<Logger>& logger(); + + /** + * Send a beacon on the socket and close if no response come + * @param timeout + */ + void sendBeacon(const std::chrono::milliseconds& timeout = SEND_BEACON_TIMEOUT); + + /** + * Get peer's certificate + */ + std::shared_ptr<dht::crypto::Certificate> peerCertificate() const; + + IpAddr getLocalAddress() const; + IpAddr getRemoteAddress() const; + + void eraseChannel(uint16_t channel); + +#ifdef LIBJAMI_TESTABLE + /** + * Check if we can send beacon on the socket + */ + bool canSendBeacon() const; + + /** + * Decide if yes or not we answer to beacon + * @param value New value + */ + void answerToBeacon(bool value); + + /** + * Change version sent to the peer + */ + void setVersion(int version); + + /** + * Set a callback to detect beacon messages + */ + void setOnBeaconCb(const std::function<void(bool)>& cb); + + /** + * Set a callback to detect version messages + */ + void setOnVersionCb(const std::function<void(int)>& cb); + + /** + * Send the version + */ + void sendVersion(); +#endif + +private: + class Impl; + std::unique_ptr<Impl> pimpl_; +}; + +class ChannelSocketInterface : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + + virtual DeviceId deviceId() const = 0; + virtual std::string name() const = 0; + virtual uint16_t channel() const = 0; + /** + * Triggered when a specific channel is ready + * Used by ConnectionManager::connectDevice() + */ + virtual void onReady(ChannelReadyCb&& cb) = 0; + /** + * Will trigger that callback when shutdown() is called + */ + virtual void onShutdown(OnShutdownCb&& cb) = 0; + + virtual void onRecv(std::vector<uint8_t>&& pkt) = 0; +}; + +class ChannelSocketTest : public ChannelSocketInterface +{ +public: + ChannelSocketTest(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, const std::string& name, const uint16_t& channel); + ~ChannelSocketTest(); + + static void link(const std::shared_ptr<ChannelSocketTest>& socket1, + const std::shared_ptr<ChannelSocketTest>& socket2); + + DeviceId deviceId() const override; + std::string name() const override; + uint16_t channel() const override; + + bool isReliable() const override { return true; }; + bool isInitiator() const override { return true; }; + int maxPayload() const override { return 0; }; + + void shutdown() override; + + std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override; + std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override; + int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override; + void setOnRecv(RecvCb&&) override; + void onRecv(std::vector<uint8_t>&& pkt) override; + + /** + * Triggered when a specific channel is ready + * Used by ConnectionManager::connectDevice() + */ + void onReady(ChannelReadyCb&& cb) override; + /** + * Will trigger that callback when shutdown() is called + */ + void onShutdown(OnShutdownCb&& cb) override; + + std::vector<uint8_t> rx_buf {}; + mutable std::mutex mutex {}; + mutable std::condition_variable cv {}; + GenericSocket<uint8_t>::RecvCb cb {}; + +private: + const DeviceId pimpl_deviceId; + const std::string pimpl_name; + const uint16_t pimpl_channel; + asio::io_context& ioCtx_; + std::weak_ptr<ChannelSocketTest> remote; + OnShutdownCb shutdownCb_ {[&] { + }}; + std::atomic_bool isShutdown_ {false}; +}; + +/** + * Represents a channel of the multiplexed socket (channel, name) + */ +class ChannelSocket : public ChannelSocketInterface +{ +public: + ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, + const std::string& name, + const uint16_t& channel, + bool isInitiator = false, + std::function<void()> rmFromMxSockCb = {}); + ~ChannelSocket(); + + DeviceId deviceId() const override; + std::string name() const override; + uint16_t channel() const override; + bool isReliable() const override; + bool isInitiator() const override; + int maxPayload() const override; + /** + * Like shutdown, but don't send any packet on the socket. + * Used by Multiplexed Socket when the TLS endpoint is already shutting down + */ + void stop(); + + /** + * This will send an empty buffer as a packet (equivalent to EOF) + * Will trigger onShutdown's callback + */ + void shutdown() override; + + void ready(); + /** + * Triggered when a specific channel is ready + * Used by ConnectionManager::connectDevice() + */ + void onReady(ChannelReadyCb&& cb) override; + /** + * Will trigger that callback when shutdown() is called + */ + void onShutdown(OnShutdownCb&& cb) override; + + std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override; + /** + * @note len should be < UINT8_MAX, else you will get ec = EMSGSIZE + */ + std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override; + int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override; + + /** + * set a callback when receiving data + * @note: this callback should take a little time and not block + * but you can move it in a thread + */ + void setOnRecv(RecvCb&&) override; + + void onRecv(std::vector<uint8_t>&& pkt) override; + + /** + * Send a beacon on the socket and close if no response come + * @param timeout + */ + void sendBeacon(const std::chrono::milliseconds& timeout = SEND_BEACON_TIMEOUT); + + /** + * Get peer's certificate + */ + std::shared_ptr<dht::crypto::Certificate> peerCertificate() const; + +#ifdef LIBJAMI_TESTABLE + std::shared_ptr<MultiplexedSocket> underlyingSocket() const; +#endif + + // Note: When a channel is accepted, it can receives data ASAP and when finished will be removed + // however, onAccept is it's own thread due to the callbacks. In this case, the channel must be + // deleted in the onAccept. + void answered(); + bool isAnswered() const; + void removable(); + bool isRemovable() const; + + IpAddr getLocalAddress() const; + IpAddr getRemoteAddress() const; + +private: + class Impl; + std::unique_ptr<Impl> pimpl_; +}; + +} // namespace jami + +MSGPACK_ADD_ENUM(jami::ChannelRequestState); diff --git a/include/string_utils.h b/include/string_utils.h new file mode 100644 index 0000000..75661a2 --- /dev/null +++ b/include/string_utils.h @@ -0,0 +1,222 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Tristan Matthews <tristan.matthews@savoirfairelinux.com> + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include <cstdint> +#include <string> +#include <vector> +#include <set> +#include <algorithm> +#include <regex> +#include <iterator> +#include <charconv> + +#ifdef _WIN32 +#include <WTypes.h> +#endif + +namespace jami { + +constexpr static const char TRUE_STR[] = "true"; +constexpr static const char FALSE_STR[] = "false"; + +constexpr static const char* +bool_to_str(bool b) noexcept +{ + return b ? TRUE_STR : FALSE_STR; +} + +std::string to_string(double value); + +#ifdef _WIN32 +std::wstring to_wstring(const std::string& str, int codePage = CP_UTF8); +std::string to_string(const std::wstring& wstr, int codePage = CP_UTF8); +#endif + +std::string to_hex_string(uint64_t id); +uint64_t from_hex_string(const std::string& str); + +template<typename T> +T +to_int(std::string_view str, T defaultValue) +{ + T result; + auto [p, ec] = std::from_chars(str.data(), str.data()+str.size(), result); + if (ec == std::errc()) + return result; + else + return defaultValue; +} + +template<typename T> +T +to_int(std::string_view str) +{ + T result; + auto [p, ec] = std::from_chars(str.data(), str.data()+str.size(), result); + if (ec == std::errc()) + return result; + if (ec == std::errc::invalid_argument) + throw std::invalid_argument("Can't parse integer: invalid_argument"); + else if (ec == std::errc::result_out_of_range) + throw std::out_of_range("Can't parse integer: out of range"); + throw std::system_error(std::make_error_code(ec)); +} + +static inline int +stoi(const std::string& str) +{ + return std::stoi(str); +} + +static inline double +stod(const std::string& str) +{ + return std::stod(str); +} + +template<typename... Args> +std::string concat(Args &&... args){ + static_assert((std::is_constructible_v<std::string_view, Args&&> && ...)); + std::string s; + s.reserve((std::string_view{ args }.size() + ...)); + (s.append(std::forward<Args>(args)), ...); + return s; +} + +std::string_view trim(std::string_view s); + +/** + * Split a string_view with an API similar to std::getline. + * @param str The input string stream to iterate on, trimed of line during iteration. + * @param line The output substring. + * @param delim The delimiter. + * @return True if line was set, false if the end of the input was reached. + */ +inline bool +getline_full(std::string_view& str, std::string_view& line, char delim = '\n') +{ + if (str.empty()) + return false; + auto pos = str.find(delim); + line = str.substr(0, pos); + str.remove_prefix(pos < str.size() ? pos + 1 : str.size()); + return true; +} + +/** + * Similar to @getline_full but skips empty results. + */ +inline bool +getline(std::string_view& str, std::string_view& line, char delim = '\n') +{ + do { + if (!getline_full(str, line, delim)) + return false; + } while (line.empty()); + return true; +} + +inline std::vector<std::string_view> +split_string(std::string_view str, char delim) +{ + std::vector<std::string_view> output; + for (auto first = str.data(), second = str.data(), last = first + str.size(); + second != last && first != last; + first = second + 1) { + second = std::find(first, last, delim); + if (first != second) + output.emplace_back(first, second - first); + } + return output; +} + +inline std::vector<std::string_view> +split_string(std::string_view str, std::string_view delims = " ") +{ + std::vector<std::string_view> output; + for (auto first = str.data(), second = str.data(), last = first + str.size(); + second != last && first != last; + first = second + 1) { + second = std::find_first_of(first, last, std::cbegin(delims), std::cend(delims)); + if (first != second) + output.emplace_back(first, second - first); + } + return output; +} + +std::vector<unsigned> split_string_to_unsigned(std::string_view s, char sep); + +void string_replace(std::string& str, const std::string& from, const std::string& to); + +std::string_view string_remove_suffix(std::string_view str, char separator); + +std::string string_join(const std::set<std::string>& set, std::string_view separator = "/"); + +std::set<std::string> string_split_set(std::string& str, std::string_view separator = "/"); + +} // namespace jami + +/* +// Add string operators missing from standard +// see https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/1RcShRhrmRc +namespace std { +inline string +operator+(const string& s, const string_view& sv) +{ + return jami::concat(s, sv); +} +inline string +operator+(const string_view& sv, const string& s) +{ + return jami::concat(sv, s); +} +using svmatch = match_results<string_view::const_iterator>; +using svsub_match = sub_match<string_view::const_iterator>; +constexpr string_view svsub_match_view(const svsub_match& submatch) noexcept { + return string_view(&*submatch.first, submatch.second - submatch.first); +} +inline bool +regex_match(string_view sv, + svmatch& m, + const regex& e, + regex_constants::match_flag_type flags = regex_constants::match_default) +{ + return regex_match(sv.begin(), sv.end(), m, e, flags); +} +inline bool +regex_match(string_view sv, + const regex& e, + regex_constants::match_flag_type flags = regex_constants::match_default) +{ + return regex_match(sv.begin(), sv.end(), e, flags); +} +inline bool +regex_search(string_view sv, + svmatch& m, + const regex& e, + regex_constants::match_flag_type flags = regex_constants::match_default) +{ + return regex_search(sv.begin(), sv.end(), m, e, flags); +} +} // namespace std +*/ diff --git a/include/tls_session.h b/include/tls_session.h new file mode 100644 index 0000000..4a4b994 --- /dev/null +++ b/include/tls_session.h @@ -0,0 +1,165 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com> + * Author: Vsevolod Ivanov <vsevolod.ivanov@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +//#include "noncopyable.h" +#include "generic_io.h" +#include "certstore.h" +#include "diffie-hellman.h" + +#include <gnutls/gnutls.h> +#include <asio/io_context.hpp> + +#include <string> +#include <functional> +#include <memory> +#include <future> +#include <chrono> +#include <vector> +#include <array> + +namespace dht { +namespace crypto { +struct Certificate; +struct PrivateKey; +} // namespace crypto +} // namespace dht + +namespace jami { +namespace tls { + +enum class TlsSessionState { + NONE, + SETUP, + COOKIE, // only used with non-initiator and non-reliable transport + HANDSHAKE, + MTU_DISCOVERY, // only used with non-reliable transport + ESTABLISHED, + SHUTDOWN +}; + +using clock = std::chrono::steady_clock; +using duration = clock::duration; + +struct TlsParams +{ + // User CA list for session credentials + std::string ca_list; + + std::shared_ptr<dht::crypto::Certificate> peer_ca; + + // User identity for credential + std::shared_ptr<dht::crypto::Certificate> cert; + std::shared_ptr<dht::crypto::PrivateKey> cert_key; + + // Diffie-Hellman computed by gnutls_dh_params_init/gnutls_dh_params_generateX + std::shared_future<DhParams> dh_params; + + tls::CertificateStore& certStore; + + // handshake timeout + duration timeout; + + // Callback for certificate checkings + std::function<int(unsigned status, const gnutls_datum_t* cert_list, unsigned cert_list_size)> + cert_check; + + std::shared_ptr<asio::io_context> io_context; + + std::shared_ptr<Logger> logger; +}; + +/// TlsSession +/// +/// Manages a TLS/DTLS data transport overlayed on a given generic socket. +/// +/// \note API is not thread-safe. +/// +class TlsSession : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + using OnStateChangeFunc = std::function<void(TlsSessionState)>; + using OnRxDataFunc = std::function<void(std::vector<uint8_t>&&)>; + using OnCertificatesUpdate + = std::function<void(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int)>; + using VerifyCertificate = std::function<int(gnutls_session_t)>; + + // ===> WARNINGS <=== + // Following callbacks are called into the FSM thread context + // Do not call blocking routines inside them. + using TlsSessionCallbacks = struct + { + OnStateChangeFunc onStateChange; + OnRxDataFunc onRxData; + OnCertificatesUpdate onCertificatesUpdate; + VerifyCertificate verifyCertificate; + }; + + TlsSession(std::unique_ptr<SocketType>&& transport, + const TlsParams& params, + const TlsSessionCallbacks& cbs, + bool anonymous = true); + ~TlsSession(); + + /// Request TLS thread to stop and quit. + /// \note IO operations return error after this call. + void shutdown() override; + + void setOnRecv(RecvCb&& cb) override + { + (void) cb; + throw std::logic_error("TlsSession::setOnRecv not implemented"); + } + + /// Return true if the TLS session type is a server. + bool isInitiator() const override; + + bool isReliable() const override; + + int maxPayload() const override; + + void waitForReady(const duration& timeout = {}); + + /// Synchronous writing. + /// Return a positive number for number of bytes write, or 0 and \a ec set in case of error. + std::size_t write(const ValueType* data, std::size_t size, std::error_code& ec) override; + + /// Synchronous reading. + /// Return a positive number for number of bytes read, or 0 and \a ec set in case of error. + std::size_t read(ValueType* data, std::size_t size, std::error_code& ec) override; + + int waitForData(std::chrono::milliseconds, std::error_code&) const override; + + std::shared_ptr<dht::crypto::Certificate> peerCertificate() const; + + const std::shared_ptr<dht::log::Logger>& logger() const; + +private: + class TlsSessionImpl; + std::unique_ptr<TlsSessionImpl> pimpl_; +}; + +} // namespace tls +} // namespace jami diff --git a/src/connectionmanager.cpp b/src/connectionmanager.cpp new file mode 100644 index 0000000..96bd6ba --- /dev/null +++ b/src/connectionmanager.cpp @@ -0,0 +1,1656 @@ +/* + * Copyright (C) 2019-2023 Savoir-faire Linux Inc. + * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ +#include "connectionmanager.h" +#include "peer_connection.h" +#include "upnp/upnp_control.h" +#include "certstore.h" +#include "fileutils.h" +#include "sip_utils.h" +#include "string_utils.h" + +#include <opendht/crypto.h> +#include <opendht/thread_pool.h> +#include <opendht/value.h> +#include <asio.hpp> + +#include <algorithm> +#include <mutex> +#include <map> +#include <condition_variable> +#include <set> +#include <charconv> + +namespace jami { +static constexpr std::chrono::seconds DHT_MSG_TIMEOUT {30}; +static constexpr uint64_t ID_MAX_VAL = 9007199254740992; + +using ValueIdDist = std::uniform_int_distribution<dht::Value::Id>; +using CallbackId = std::pair<jami::DeviceId, dht::Value::Id>; + +struct ConnectionInfo +{ + ~ConnectionInfo() + { + if (socket_) + socket_->join(); + } + + std::mutex mutex_ {}; + bool responseReceived_ {false}; + PeerConnectionRequest response_ {}; + std::unique_ptr<IceTransport> ice_ {nullptr}; + // Used to store currently non ready TLS Socket + std::unique_ptr<TlsSocketEndpoint> tls_ {nullptr}; + std::shared_ptr<MultiplexedSocket> socket_ {}; + std::set<CallbackId> cbIds_ {}; + + std::function<void(bool)> onConnected_; + std::unique_ptr<asio::steady_timer> waitForAnswer_ {}; +}; + +/** + * returns whether or not UPnP is enabled and active_ + * ie: if it is able to make port mappings + */ +bool +ConnectionManager::Config::getUPnPActive() const +{ + if (upnpCtrl) + return upnpCtrl->isReady(); + return false; +} + +class ConnectionManager::Impl : public std::enable_shared_from_this<ConnectionManager::Impl> +{ +public: + explicit Impl(std::shared_ptr<ConnectionManager::Config> config_) + : config_ {std::move(config_)} + {} + ~Impl() {} + + std::shared_ptr<dht::DhtRunner> dht() { return config_->dht; } + const dht::crypto::Identity& identity() const { return config_->id; } + + void removeUnusedConnections(const DeviceId& deviceId = {}) + { + std::vector<std::shared_ptr<ConnectionInfo>> unused {}; + + { + std::lock_guard<std::mutex> lk(infosMtx_); + for (auto it = infos_.begin(); it != infos_.end();) { + auto& [key, info] = *it; + if (info && (!deviceId || key.first == deviceId)) { + unused.emplace_back(std::move(info)); + it = infos_.erase(it); + } else { + ++it; + } + } + } + for (auto& info: unused) { + if (info->tls_) + info->tls_->shutdown(); + if (info->socket_) + info->socket_->shutdown(); + if (info->waitForAnswer_) + info->waitForAnswer_->cancel(); + } + if (!unused.empty()) + dht::ThreadPool::io().run([infos = std::move(unused)]() mutable { infos.clear(); }); + } + + void shutdown() + { + if (isDestroying_.exchange(true)) + return; + { + std::lock_guard<std::mutex> lk(connectCbsMtx_); + // Call all pending callbacks that channel is not ready + for (auto& [deviceId, pcbs] : pendingCbs_) + for (auto& pending : pcbs) + pending.cb(nullptr, deviceId); + pendingCbs_.clear(); + } + removeUnusedConnections(); + } + + struct PendingCb + { + std::string name; + ConnectCallback cb; + dht::Value::Id vid; + }; + + void connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk, + const dht::Value::Id& vid, + const std::string& connType, + std::function<void(bool)> onConnected); + void onResponse(const asio::error_code& ec, const DeviceId& deviceId, const dht::Value::Id& vid); + bool connectDeviceOnNegoDone(const DeviceId& deviceId, + const std::string& name, + const dht::Value::Id& vid, + const std::shared_ptr<dht::crypto::Certificate>& cert); + void connectDevice(const DeviceId& deviceId, + const std::string& uri, + ConnectCallback cb, + bool noNewSocket = false, + bool forceNewSocket = false, + const std::string& connType = ""); + void connectDevice(const std::shared_ptr<dht::crypto::Certificate>& cert, + const std::string& name, + ConnectCallback cb, + bool noNewSocket = false, + bool forceNewSocket = false, + const std::string& connType = ""); + /** + * Send a ChannelRequest on the TLS socket. Triggers cb when ready + * @param sock socket used to send the request + * @param name channel's name + * @param vid channel's id + * @param deviceId to identify the linked ConnectCallback + */ + void sendChannelRequest(std::shared_ptr<MultiplexedSocket>& sock, + const std::string& name, + const DeviceId& deviceId, + const dht::Value::Id& vid); + /** + * Triggered when a PeerConnectionRequest comes from the DHT + */ + void answerTo(IceTransport& ice, + const dht::Value::Id& id, + const std::shared_ptr<dht::crypto::PublicKey>& fromPk); + bool onRequestStartIce(const PeerConnectionRequest& req); + bool onRequestOnNegoDone(const PeerConnectionRequest& req); + void onDhtPeerRequest(const PeerConnectionRequest& req, + const std::shared_ptr<dht::crypto::Certificate>& cert); + + void addNewMultiplexedSocket(const CallbackId& id, const std::shared_ptr<ConnectionInfo>& info); + void onPeerResponse(const PeerConnectionRequest& req); + void onDhtConnected(const dht::crypto::PublicKey& devicePk); + + const std::shared_future<tls::DhParams> dhParams() const; + tls::CertificateStore& certStore() const { return *config_->certStore; } + + mutable std::mutex messageMutex_ {}; + std::set<std::string, std::less<>> treatedMessages_ {}; + + void loadTreatedMessages(); + void saveTreatedMessages() const; + + /// \return true if the given DHT message identifier has been treated + /// \note if message has not been treated yet this method st/ore this id and returns true at + /// further calls + bool isMessageTreated(std::string_view id); + + const std::shared_ptr<dht::log::Logger>& logger() const { return config_->logger; } + + /** + * Published IPv4/IPv6 addresses, used only if defined by the user in account + * configuration + * + */ + IpAddr publishedIp_[2] {}; + + // This will be stored in the configuration + std::string publishedIpAddress_ {}; + + /** + * Published port, used only if defined by the user + */ + pj_uint16_t publishedPort_ {sip_utils::DEFAULT_SIP_PORT}; + + /** + * interface name on which this account is bound + */ + std::string interface_ {"default"}; + + /** + * Get the local interface name on which this account is bound. + */ + const std::string& getLocalInterface() const { return interface_; } + + /** + * Get the published IP address, fallbacks to NAT if family is unspecified + * Prefers the usage of IPv4 if possible. + */ + IpAddr getPublishedIpAddress(uint16_t family = PF_UNSPEC) const; + + /** + * Set published IP address according to given family + */ + void setPublishedAddress(const IpAddr& ip_addr); + + /** + * Store the local/public addresses used to register + */ + void storeActiveIpAddress(std::function<void()>&& cb = {}); + + /** + * Create and return ICE options. + */ + void getIceOptions(std::function<void(IceTransportOptions&&)> cb) noexcept; + IceTransportOptions getIceOptions() const noexcept; + + /** + * Inform that a potential peer device have been found. + * Returns true only if the device certificate is a valid device certificate. + * In that case (true is returned) the account_id parameter is set to the peer account ID. + */ + static bool foundPeerDevice(const std::shared_ptr<dht::crypto::Certificate>& crt, + dht::InfoHash& account_id, const std::shared_ptr<Logger>& logger); + + bool findCertificate(const dht::PkId& id, + std::function<void(const std::shared_ptr<dht::crypto::Certificate>&)>&& cb); + + /** + * returns whether or not UPnP is enabled and active + * ie: if it is able to make port mappings + */ + bool getUPnPActive() const; + + /** + * Triggered when a new TLS socket is ready to use + * @param ok If succeed + * @param deviceId Related device + * @param vid vid of the connection request + * @param name non empty if TLS was created by connectDevice() + */ + void onTlsNegotiationDone(bool ok, + const DeviceId& deviceId, + const dht::Value::Id& vid, + const std::string& name = ""); + + std::shared_ptr<ConnectionManager::Config> config_; + + IceTransportFactory iceFactory_ {}; + + mutable std::mt19937_64 rand; + + iOSConnectedCallback iOSConnectedCb_ {}; + + std::mutex infosMtx_ {}; + // Note: Someone can ask multiple sockets, so to avoid any race condition, + // each device can have multiple multiplexed sockets. + std::map<CallbackId, std::shared_ptr<ConnectionInfo>> infos_ {}; + + std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, const dht::Value::Id& id) + { + std::lock_guard<std::mutex> lk(infosMtx_); + auto it = infos_.find({deviceId, id}); + if (it != infos_.end()) + return it->second; + return {}; + } + + std::shared_ptr<ConnectionInfo> getConnectedInfo(const DeviceId& deviceId) + { + std::lock_guard<std::mutex> lk(infosMtx_); + auto it = std::find_if(infos_.begin(), infos_.end(), [&](const auto& item) { + auto& [key, value] = item; + return key.first == deviceId && value && value->socket_; + }); + if (it != infos_.end()) + return it->second; + return {}; + } + + ChannelRequestCallback channelReqCb_ {}; + ConnectionReadyCallback connReadyCb_ {}; + onICERequestCallback iceReqCb_ {}; + + /** + * Stores callback from connectDevice + * @note: each device needs a vector because several connectDevice can + * be done in parallel and we only want one socket + */ + std::mutex connectCbsMtx_ {}; + std::map<DeviceId, std::vector<PendingCb>> pendingCbs_ {}; + + std::vector<PendingCb> extractPendingCallbacks(const DeviceId& deviceId, + const dht::Value::Id vid = 0) + { + std::vector<PendingCb> ret; + std::lock_guard<std::mutex> lk(connectCbsMtx_); + auto pendingIt = pendingCbs_.find(deviceId); + if (pendingIt == pendingCbs_.end()) + return ret; + auto& pendings = pendingIt->second; + if (vid == 0) { + ret = std::move(pendings); + } else { + for (auto it = pendings.begin(); it != pendings.end(); ++it) { + if (it->vid == vid) { + ret.emplace_back(std::move(*it)); + pendings.erase(it); + break; + } + } + } + if (pendings.empty()) + pendingCbs_.erase(pendingIt); + return ret; + } + + std::vector<PendingCb> getPendingCallbacks(const DeviceId& deviceId, + const dht::Value::Id vid = 0) + { + std::vector<PendingCb> ret; + std::lock_guard<std::mutex> lk(connectCbsMtx_); + auto pendingIt = pendingCbs_.find(deviceId); + if (pendingIt == pendingCbs_.end()) + return ret; + auto& pendings = pendingIt->second; + if (vid == 0) { + ret = pendings; + } else { + std::copy_if(pendings.begin(), + pendings.end(), + std::back_inserter(ret), + [&](auto pending) { return pending.vid == vid; }); + } + return ret; + } + + std::shared_ptr<ConnectionManager::Impl> shared() + { + return std::static_pointer_cast<ConnectionManager::Impl>(shared_from_this()); + } + std::shared_ptr<ConnectionManager::Impl const> shared() const + { + return std::static_pointer_cast<ConnectionManager::Impl const>(shared_from_this()); + } + std::weak_ptr<ConnectionManager::Impl> weak() + { + return std::static_pointer_cast<ConnectionManager::Impl>(shared_from_this()); + } + std::weak_ptr<ConnectionManager::Impl const> weak() const + { + return std::static_pointer_cast<ConnectionManager::Impl const>(shared_from_this()); + } + + std::atomic_bool isDestroying_ {false}; +}; + +void +ConnectionManager::Impl::connectDeviceStartIce( + const std::shared_ptr<dht::crypto::PublicKey>& devicePk, + const dht::Value::Id& vid, + const std::string& connType, + std::function<void(bool)> onConnected) +{ + auto deviceId = devicePk->getLongId(); + auto info = getInfo(deviceId, vid); + if (!info) { + onConnected(false); + return; + } + + std::unique_lock<std::mutex> lk(info->mutex_); + auto& ice = info->ice_; + + if (!ice) { + if (config_->logger) + config_->logger->error("No ICE detected"); + onConnected(false); + return; + } + + auto iceAttributes = ice->getLocalAttributes(); + std::ostringstream icemsg; + icemsg << iceAttributes.ufrag << "\n"; + icemsg << iceAttributes.pwd << "\n"; + for (const auto& addr : ice->getLocalCandidates(1)) { + icemsg << addr << "\n"; + if (config_->logger) + config_->logger->debug("Added local ICE candidate {}", addr); + } + + // Prepare connection request as a DHT message + PeerConnectionRequest val; + + val.id = vid; /* Random id for the message unicity */ + val.ice_msg = icemsg.str(); + val.connType = connType; + + auto value = std::make_shared<dht::Value>(std::move(val)); + value->user_type = "peer_request"; + + // Send connection request through DHT + if (config_->logger) + config_->logger->debug("Request connection to {}", deviceId); + dht()->putEncrypted(dht::InfoHash::get(PeerConnectionRequest::key_prefix + + devicePk->getId().toString()), + devicePk, + value, + [l=config_->logger,deviceId](bool ok) { + if (l) + l->debug("Sent connection request to {:s}. Put encrypted {:s}", + deviceId, + (ok ? "ok" : "failed")); + }); + // Wait for call to onResponse() operated by DHT + if (isDestroying_) { + onConnected(true); // This avoid to wait new negotiation when destroying + return; + } + + info->onConnected_ = std::move(onConnected); + info->waitForAnswer_ = std::make_unique<asio::steady_timer>(*config_->ioContext, + std::chrono::steady_clock::now() + + DHT_MSG_TIMEOUT); + info->waitForAnswer_->async_wait( + std::bind(&ConnectionManager::Impl::onResponse, this, std::placeholders::_1, deviceId, vid)); +} + +void +ConnectionManager::Impl::onResponse(const asio::error_code& ec, + const DeviceId& deviceId, + const dht::Value::Id& vid) +{ + if (ec == asio::error::operation_aborted) + return; + auto info = getInfo(deviceId, vid); + if (!info) + return; + + std::unique_lock<std::mutex> lk(info->mutex_); + auto& ice = info->ice_; + if (isDestroying_) { + info->onConnected_(true); // The destructor can wake a pending wait here. + return; + } + if (!info->responseReceived_) { + if (config_->logger) + config_->logger->error("no response from DHT to E2E request."); + info->onConnected_(false); + return; + } + + if (!info->ice_) { + info->onConnected_(false); + return; + } + + auto sdp = ice->parseIceCandidates(info->response_.ice_msg); + + if (not ice->startIce({sdp.rem_ufrag, sdp.rem_pwd}, std::move(sdp.rem_candidates))) { + if (config_->logger) + config_->logger->warn("start ICE failed"); + info->onConnected_(false); + return; + } + info->onConnected_(true); +} + +bool +ConnectionManager::Impl::connectDeviceOnNegoDone( + const DeviceId& deviceId, + const std::string& name, + const dht::Value::Id& vid, + const std::shared_ptr<dht::crypto::Certificate>& cert) +{ + auto info = getInfo(deviceId, vid); + if (!info) + return false; + + std::unique_lock<std::mutex> lk {info->mutex_}; + if (info->waitForAnswer_) { + // Negotiation is done and connected, go to handshake + // and avoid any cancellation at this point. + info->waitForAnswer_->cancel(); + } + auto& ice = info->ice_; + if (!ice || !ice->isRunning()) { + if (config_->logger) + config_->logger->error("No ICE detected or not running"); + return false; + } + + // Build socket + auto endpoint = std::make_unique<IceSocketEndpoint>(std::shared_ptr<IceTransport>( + std::move(ice)), + true); + + // Negotiate a TLS session + if (config_->logger) + config_->logger->debug("Start TLS session - Initied by connectDevice(). Launched by channel: {} - device: {} - vid: {}", name, deviceId, vid); + info->tls_ = std::make_unique<TlsSocketEndpoint>(std::move(endpoint), + certStore(), + identity(), + dhParams(), + *cert); + + info->tls_->setOnReady( + [w = weak(), deviceId = std::move(deviceId), vid = std::move(vid), name = std::move(name)]( + bool ok) { + if (auto shared = w.lock()) + shared->onTlsNegotiationDone(ok, deviceId, vid, name); + }); + return true; +} + +void +ConnectionManager::Impl::connectDevice(const DeviceId& deviceId, + const std::string& name, + ConnectCallback cb, + bool noNewSocket, + bool forceNewSocket, + const std::string& connType) +{ + if (!dht()) { + cb(nullptr, deviceId); + return; + } + if (deviceId.toString() == identity().second->getLongId().toString()) { + cb(nullptr, deviceId); + return; + } + findCertificate(deviceId, + [w = weak(), + deviceId, + name, + cb = std::move(cb), + noNewSocket, + forceNewSocket, + connType](const std::shared_ptr<dht::crypto::Certificate>& cert) { + if (!cert) { + if (auto shared = w.lock()) + if (shared->config_->logger) + shared->config_->logger->error( + "No valid certificate found for device {}", + deviceId); + cb(nullptr, deviceId); + return; + } + if (auto shared = w.lock()) { + shared->connectDevice(cert, + name, + std::move(cb), + noNewSocket, + forceNewSocket, + connType); + } else + cb(nullptr, deviceId); + }); +} + +void +ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certificate>& cert, + const std::string& name, + ConnectCallback cb, + bool noNewSocket, + bool forceNewSocket, + const std::string& connType) +{ + // Avoid dht operation in a DHT callback to avoid deadlocks + dht::ThreadPool::computation().run([w = weak(), + name = std::move(name), + cert = std::move(cert), + cb = std::move(cb), + noNewSocket, + forceNewSocket, + connType] { + auto devicePk = cert->getSharedPublicKey(); + auto deviceId = devicePk->getLongId(); + auto sthis = w.lock(); + if (!sthis || sthis->isDestroying_) { + cb(nullptr, deviceId); + return; + } + dht::Value::Id vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand); + auto isConnectingToDevice = false; + { + std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_); + auto pendingsIt = sthis->pendingCbs_.find(deviceId); + if (pendingsIt != sthis->pendingCbs_.end()) { + const auto& pendings = pendingsIt->second; + while (std::find_if(pendings.begin(), pendings.end(), [&](const auto& it){ return it.vid == vid; }) != pendings.end()) { + vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand); + } + } + // Check if already connecting + isConnectingToDevice = pendingsIt != sthis->pendingCbs_.end(); + // Save current request for sendChannelRequest. + // Note: do not return here, cause we can be in a state where first + // socket is negotiated and first channel is pending + // so return only after we checked the info + if (isConnectingToDevice) + pendingsIt->second.emplace_back(PendingCb {name, std::move(cb), vid}); + else + sthis->pendingCbs_[deviceId] = {{name, std::move(cb), vid}}; + } + + // Check if already negotiated + CallbackId cbId(deviceId, vid); + if (auto info = sthis->getConnectedInfo(deviceId)) { + std::lock_guard<std::mutex> lk(info->mutex_); + if (info->socket_) { + if (sthis->config_->logger) + sthis->config_->logger->debug("Peer already connected to {}. Add a new channel", deviceId); + info->cbIds_.emplace(cbId); + sthis->sendChannelRequest(info->socket_, name, deviceId, vid); + return; + } + } + + if (isConnectingToDevice && !forceNewSocket) { + if (sthis->config_->logger) + sthis->config_->logger->debug("Already connecting to {}, wait for the ICE negotiation", deviceId); + return; + } + if (noNewSocket) { + // If no new socket is specified, we don't try to generate a new socket + for (const auto& pending : sthis->extractPendingCallbacks(deviceId, vid)) + pending.cb(nullptr, deviceId); + return; + } + + // Note: used when the ice negotiation fails to erase + // all stored structures. + auto eraseInfo = [w, cbId] { + if (auto shared = w.lock()) { + // If no new socket is specified, we don't try to generate a new socket + for (const auto& pending : shared->extractPendingCallbacks(cbId.first, cbId.second)) + pending.cb(nullptr, cbId.first); + std::lock_guard<std::mutex> lk(shared->infosMtx_); + shared->infos_.erase(cbId); + } + }; + + // If no socket exists, we need to initiate an ICE connection. + sthis->getIceOptions([w, + deviceId = std::move(deviceId), + devicePk = std::move(devicePk), + name = std::move(name), + cert = std::move(cert), + vid, + connType, + eraseInfo](auto&& ice_config) { + auto sthis = w.lock(); + if (!sthis) { + dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); }); + return; + } + ice_config.tcpEnable = true; + ice_config.onInitDone = [w, + deviceId = std::move(deviceId), + devicePk = std::move(devicePk), + name = std::move(name), + cert = std::move(cert), + vid, + connType, + eraseInfo](bool ok) { + dht::ThreadPool::io().run([w = std::move(w), + devicePk = std::move(devicePk), + vid = std::move(vid), + eraseInfo, + connType, ok] { + auto sthis = w.lock(); + if (!ok && sthis && sthis->config_->logger) + sthis->config_->logger->error("Cannot initialize ICE session."); + if (!sthis || !ok) { + eraseInfo(); + return; + } + sthis->connectDeviceStartIce(devicePk, vid, connType, [=](bool ok) { + if (!ok) { + dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); }); + } + }); + }); + }; + ice_config.onNegoDone = [w, + deviceId, + name, + cert = std::move(cert), + vid, + eraseInfo](bool ok) { + dht::ThreadPool::io().run([w = std::move(w), + deviceId = std::move(deviceId), + name = std::move(name), + cert = std::move(cert), + vid = std::move(vid), + eraseInfo = std::move(eraseInfo), + ok] { + auto sthis = w.lock(); + if (!ok && sthis && sthis->config_->logger) + sthis->config_->logger->error("ICE negotiation failed."); + if (!sthis || !ok || !sthis->connectDeviceOnNegoDone(deviceId, name, vid, cert)) + eraseInfo(); + }); + }; + + auto info = std::make_shared<ConnectionInfo>(); + { + std::lock_guard<std::mutex> lk(sthis->infosMtx_); + sthis->infos_[{deviceId, vid}] = info; + } + std::unique_lock<std::mutex> lk {info->mutex_}; + ice_config.master = false; + ice_config.streamsCount = 1; + ice_config.compCountPerStream = 1; + info->ice_ = sthis->iceFactory_.createUTransport(""); + if (!info->ice_) { + if (sthis->config_->logger) + sthis->config_->logger->error("Cannot initialize ICE session."); + eraseInfo(); + return; + } + // We need to detect any shutdown if the ice session is destroyed before going to the + // TLS session; + info->ice_->setOnShutdown([eraseInfo]() { + dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); }); + }); + info->ice_->initIceInstance(ice_config); + }); + }); +} + +void +ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>& sock, + const std::string& name, + const DeviceId& deviceId, + const dht::Value::Id& vid) +{ + auto channelSock = sock->addChannel(name); + channelSock->onShutdown([name, deviceId, vid, w = weak()] { + auto shared = w.lock(); + if (shared) + for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid)) + pending.cb(nullptr, deviceId); + }); + channelSock->onReady( + [wSock = std::weak_ptr<ChannelSocket>(channelSock), name, deviceId, vid, w = weak()]() { + auto shared = w.lock(); + auto channelSock = wSock.lock(); + if (shared) + for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid)) + pending.cb(channelSock, deviceId); + }); + + ChannelRequest val; + val.name = channelSock->name(); + val.state = ChannelRequestState::REQUEST; + val.channel = channelSock->channel(); + msgpack::sbuffer buffer(256); + msgpack::pack(buffer, val); + + std::error_code ec; + int res = sock->write(CONTROL_CHANNEL, + reinterpret_cast<const uint8_t*>(buffer.data()), + buffer.size(), + ec); + if (res < 0) { + // TODO check if we should handle errors here + if (config_->logger) + config_->logger->error("sendChannelRequest failed - error: {}", ec.message()); + } +} + +void +ConnectionManager::Impl::onPeerResponse(const PeerConnectionRequest& req) +{ + auto device = req.owner->getLongId(); + if (config_->logger) + config_->logger->debug("New response received from {}", device); + if (auto info = getInfo(device, req.id)) { + std::lock_guard<std::mutex> lk {info->mutex_}; + info->responseReceived_ = true; + info->response_ = std::move(req); + info->waitForAnswer_->expires_at(std::chrono::steady_clock::now()); + info->waitForAnswer_->async_wait(std::bind(&ConnectionManager::Impl::onResponse, + this, + std::placeholders::_1, + device, + req.id)); + } else { + if (config_->logger) + config_->logger->warn("Respond received, but cannot find request"); + } +} + +void +ConnectionManager::Impl::onDhtConnected(const dht::crypto::PublicKey& devicePk) +{ + if (!dht()) + return; + dht()->listen<PeerConnectionRequest>( + dht::InfoHash::get(PeerConnectionRequest::key_prefix + devicePk.getId().toString()), + [w = weak()](PeerConnectionRequest&& req) { + auto shared = w.lock(); + if (!shared) + return false; + if (shared->isMessageTreated(to_hex_string(req.id))) { + // Message already treated. Just ignore + return true; + } + if (req.isAnswer) { + if (shared->config_->logger) + shared->config_->logger->debug("Received request answer from {}", req.owner->getLongId()); + } else { + if (shared->config_->logger) + shared->config_->logger->debug("Received request from {}", req.owner->getLongId()); + } + if (req.isAnswer) { + shared->onPeerResponse(req); + } else { + // Async certificate checking + shared->dht()->findCertificate( + req.from, + [w, req = std::move(req)]( + const std::shared_ptr<dht::crypto::Certificate>& cert) mutable { + auto shared = w.lock(); + if (!shared) + return; + dht::InfoHash peer_h; + if (foundPeerDevice(cert, peer_h, shared->config_->logger)) { +#if TARGET_OS_IOS + if (shared->iOSConnectedCb_(req.connType, peer_h)) + return; +#endif + shared->onDhtPeerRequest(req, cert); + } else { + if (shared->config_->logger) + shared->config_->logger->warn( + "Received request from untrusted peer {}", + req.owner->getLongId()); + } + }); + } + + return true; + }, + dht::Value::UserTypeFilter("peer_request")); +} + +void +ConnectionManager::Impl::onTlsNegotiationDone(bool ok, + const DeviceId& deviceId, + const dht::Value::Id& vid, + const std::string& name) +{ + if (isDestroying_) + return; + // Note: only handle pendingCallbacks here for TLS initied by connectDevice() + // Note: if not initied by connectDevice() the channel name will be empty (because no channel + // asked yet) + auto isDhtRequest = name.empty(); + if (!ok) { + if (isDhtRequest) { + if (config_->logger) + config_->logger->error("TLS connection failure for peer {} - Initied by DHT request. channel: {} - vid: {}", + deviceId, + name, + vid); + if (connReadyCb_) + connReadyCb_(deviceId, "", nullptr); + } else { + if (config_->logger) + config_->logger->error("TLS connection failure for peer {} - Initied by connectDevice. channel: {} - vid: {}", + deviceId, + name, + vid); + for (const auto& pending : extractPendingCallbacks(deviceId)) + pending.cb(nullptr, deviceId); + } + } else { + // The socket is ready, store it + if (isDhtRequest) { + if (config_->logger) + config_->logger->debug("Connection to {} is ready - Initied by DHT request. Vid: {}", + deviceId, + vid); + } else { + if (config_->logger) + config_->logger->debug("Connection to {} is ready - Initied by connectDevice(). channel: {} - vid: {}", + deviceId, + name, + vid); + } + + auto info = getInfo(deviceId, vid); + addNewMultiplexedSocket({deviceId, vid}, info); + // Finally, open the channel and launch pending callbacks + if (info->socket_) { + // Note: do not remove pending there it's done in sendChannelRequest + for (const auto& pending : getPendingCallbacks(deviceId)) { + if (config_->logger) + config_->logger->debug("Send request on TLS socket for channel {} to {}", + pending.name, + deviceId); + sendChannelRequest(info->socket_, pending.name, deviceId, pending.vid); + } + } + } +} + +void +ConnectionManager::Impl::answerTo(IceTransport& ice, + const dht::Value::Id& id, + const std::shared_ptr<dht::crypto::PublicKey>& from) +{ + // NOTE: This is a shortest version of a real SDP message to save some bits + auto iceAttributes = ice.getLocalAttributes(); + std::ostringstream icemsg; + icemsg << iceAttributes.ufrag << "\n"; + icemsg << iceAttributes.pwd << "\n"; + for (const auto& addr : ice.getLocalCandidates(1)) { + icemsg << addr << "\n"; + } + + // Send PeerConnection response + PeerConnectionRequest val; + val.id = id; + val.ice_msg = icemsg.str(); + val.isAnswer = true; + auto value = std::make_shared<dht::Value>(std::move(val)); + value->user_type = "peer_request"; + + if (config_->logger) + config_->logger->debug("Connection accepted, DHT reply to {}", from->getLongId()); + dht()->putEncrypted(dht::InfoHash::get(PeerConnectionRequest::key_prefix + + from->getId().toString()), + from, + value, + [from,l=config_->logger](bool ok) { + if (l) + l->debug("Answer to connection request from {:s}. Put encrypted {:s}", + from->getLongId(), + (ok ? "ok" : "failed")); + }); +} + +bool +ConnectionManager::Impl::onRequestStartIce(const PeerConnectionRequest& req) +{ + auto deviceId = req.owner->getLongId(); + auto info = getInfo(deviceId, req.id); + if (!info) + return false; + + std::unique_lock<std::mutex> lk {info->mutex_}; + auto& ice = info->ice_; + if (!ice) { + if (config_->logger) + config_->logger->error("No ICE detected"); + if (connReadyCb_) + connReadyCb_(deviceId, "", nullptr); + return false; + } + + auto sdp = ice->parseIceCandidates(req.ice_msg); + answerTo(*ice, req.id, req.owner); + if (not ice->startIce({sdp.rem_ufrag, sdp.rem_pwd}, std::move(sdp.rem_candidates))) { + if (config_->logger) + config_->logger->error("Start ICE failed - fallback to TURN"); + ice = nullptr; + if (connReadyCb_) + connReadyCb_(deviceId, "", nullptr); + return false; + } + return true; +} + +bool +ConnectionManager::Impl::onRequestOnNegoDone(const PeerConnectionRequest& req) +{ + auto deviceId = req.owner->getLongId(); + auto info = getInfo(deviceId, req.id); + if (!info) + return false; + + std::unique_lock<std::mutex> lk {info->mutex_}; + auto& ice = info->ice_; + if (!ice) { + if (config_->logger) + config_->logger->error("No ICE detected"); + return false; + } + + // Build socket + auto endpoint = std::make_unique<IceSocketEndpoint>(std::shared_ptr<IceTransport>( + std::move(ice)), + false); + + // init TLS session + auto ph = req.from; + if (config_->logger) + config_->logger->debug("Start TLS session - Initied by DHT request. Device: {} - vid: {}", + req.from, + req.id); + info->tls_ = std::make_unique<TlsSocketEndpoint>( + std::move(endpoint), + certStore(), + identity(), + dhParams(), + [ph, w = weak()](const dht::crypto::Certificate& cert) { + auto shared = w.lock(); + if (!shared) + return false; + auto crt = shared->certStore().getCertificate(cert.getLongId().toString()); + if (!crt) + return false; + return crt->getPacked() == cert.getPacked(); + }); + + info->tls_->setOnReady( + [w = weak(), deviceId = std::move(deviceId), vid = std::move(req.id)](bool ok) { + if (auto shared = w.lock()) + shared->onTlsNegotiationDone(ok, deviceId, vid); + }); + return true; +} + +void +ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req, + const std::shared_ptr<dht::crypto::Certificate>& /*cert*/) +{ + auto deviceId = req.owner->getLongId(); + if (config_->logger) + config_->logger->debug("New connection request from {}", deviceId); + if (!iceReqCb_ || !iceReqCb_(deviceId)) { + if (config_->logger) + config_->logger->debug("Refuse connection from {}", deviceId); + return; + } + + // Because the connection is accepted, create an ICE socket. + getIceOptions([w = weak(), req, deviceId](auto&& ice_config) { + auto shared = w.lock(); + if (!shared) + return; + // Note: used when the ice negotiation fails to erase + // all stored structures. + auto eraseInfo = [w, id = req.id, deviceId] { + if (auto shared = w.lock()) { + // If no new socket is specified, we don't try to generate a new socket + for (const auto& pending : shared->extractPendingCallbacks(deviceId, id)) + pending.cb(nullptr, deviceId); + if (shared->connReadyCb_) + shared->connReadyCb_(deviceId, "", nullptr); + std::lock_guard<std::mutex> lk(shared->infosMtx_); + shared->infos_.erase({deviceId, id}); + } + }; + + ice_config.tcpEnable = true; + ice_config.onInitDone = [w, req, eraseInfo](bool ok) { + auto shared = w.lock(); + if (!shared) + return; + if (!ok) { + if (shared->config_->logger) + shared->config_->logger->error("Cannot initialize ICE session."); + dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); }); + return; + } + + dht::ThreadPool::io().run( + [w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] { + auto shared = w.lock(); + if (!shared) + return; + if (!shared->onRequestStartIce(req)) + eraseInfo(); + }); + }; + + ice_config.onNegoDone = [w, req, eraseInfo](bool ok) { + auto shared = w.lock(); + if (!shared) + return; + if (!ok) { + if (shared->config_->logger) + shared->config_->logger->error("ICE negotiation failed."); + dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); }); + return; + } + + dht::ThreadPool::io().run( + [w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] { + if (auto shared = w.lock()) + if (!shared->onRequestOnNegoDone(req)) + eraseInfo(); + }); + }; + + // Negotiate a new ICE socket + auto info = std::make_shared<ConnectionInfo>(); + { + std::lock_guard<std::mutex> lk(shared->infosMtx_); + shared->infos_[{deviceId, req.id}] = info; + } + if (shared->config_->logger) + shared->config_->logger->debug("Accepting connection from {}", deviceId); + std::unique_lock<std::mutex> lk {info->mutex_}; + ice_config.streamsCount = 1; + ice_config.compCountPerStream = 1; // TCP + ice_config.master = true; + info->ice_ = shared->iceFactory_.createUTransport(""); + if (not info->ice_) { + if (shared->config_->logger) + shared->config_->logger->error("Cannot initialize ICE session"); + eraseInfo(); + return; + } + // We need to detect any shutdown if the ice session is destroyed before going to the TLS session; + info->ice_->setOnShutdown([eraseInfo]() { + dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); }); + }); + info->ice_->initIceInstance(ice_config); + }); +} + +void +ConnectionManager::Impl::addNewMultiplexedSocket(const CallbackId& id, const std::shared_ptr<ConnectionInfo>& info) +{ + info->socket_ = std::make_shared<MultiplexedSocket>(config_->ioContext, id.first, std::move(info->tls_)); + info->socket_->setOnReady( + [w = weak()](const DeviceId& deviceId, const std::shared_ptr<ChannelSocket>& socket) { + if (auto sthis = w.lock()) + if (sthis->connReadyCb_) + sthis->connReadyCb_(deviceId, socket->name(), socket); + }); + info->socket_->setOnRequest([w = weak()](const std::shared_ptr<dht::crypto::Certificate>& peer, + const uint16_t&, + const std::string& name) { + if (auto sthis = w.lock()) + if (sthis->channelReqCb_) + return sthis->channelReqCb_(peer, name); + return false; + }); + info->socket_->onShutdown([w = weak(), deviceId=id.first, vid=id.second]() { + // Cancel current outgoing connections + dht::ThreadPool::io().run([w, deviceId, vid] { + auto sthis = w.lock(); + if (!sthis) + return; + + std::set<CallbackId> ids; + if (auto info = sthis->getInfo(deviceId, vid)) { + std::lock_guard<std::mutex> lk(info->mutex_); + if (info->socket_) { + ids = std::move(info->cbIds_); + info->socket_->shutdown(); + } + } + for (const auto& cbId : ids) + for (const auto& pending : sthis->extractPendingCallbacks(cbId.first, cbId.second)) + pending.cb(nullptr, deviceId); + + std::lock_guard<std::mutex> lk(sthis->infosMtx_); + sthis->infos_.erase({deviceId, vid}); + }); + }); +} + +const std::shared_future<tls::DhParams> +ConnectionManager::Impl::dhParams() const +{ + return dht::ThreadPool::computation().get<tls::DhParams>( + std::bind(tls::DhParams::loadDhParams, config_->cachePath + DIR_SEPARATOR_STR "dhParams")); + ; +} + +template<typename ID = dht::Value::Id> +std::set<ID, std::less<>> +loadIdList(const std::string& path) +{ + std::set<ID, std::less<>> ids; + std::ifstream file = fileutils::ifstream(path); + if (!file.is_open()) { + //JAMI_DBG("Could not load %s", path.c_str()); + return ids; + } + std::string line; + while (std::getline(file, line)) { + if constexpr (std::is_same<ID, std::string>::value) { + ids.emplace(std::move(line)); + } else if constexpr (std::is_integral<ID>::value) { + ID vid; + if (auto [p, ec] = std::from_chars(line.data(), line.data() + line.size(), vid, 16); + ec == std::errc()) { + ids.emplace(vid); + } + } + } + return ids; +} + +template<typename List = std::set<dht::Value::Id>> +void +saveIdList(const std::string& path, const List& ids) +{ + std::ofstream file = fileutils::ofstream(path, std::ios::trunc | std::ios::binary); + if (!file.is_open()) { + //JAMI_ERR("Could not save to %s", path.c_str()); + return; + } + for (auto& c : ids) + file << std::hex << c << "\n"; +} + +void +ConnectionManager::Impl::loadTreatedMessages() +{ + std::lock_guard<std::mutex> lock(messageMutex_); + auto path = config_->cachePath + DIR_SEPARATOR_STR "treatedMessages"; + treatedMessages_ = loadIdList<std::string>(path); + if (treatedMessages_.empty()) { + auto messages = loadIdList(path); + for (const auto& m : messages) + treatedMessages_.emplace(to_hex_string(m)); + } +} + +void +ConnectionManager::Impl::saveTreatedMessages() const +{ + dht::ThreadPool::io().run([w = weak()]() { + if (auto sthis = w.lock()) { + auto& this_ = *sthis; + std::lock_guard<std::mutex> lock(this_.messageMutex_); + fileutils::check_dir(this_.config_->cachePath.c_str()); + saveIdList<decltype(this_.treatedMessages_)>(this_.config_->cachePath + + DIR_SEPARATOR_STR "treatedMessages", + this_.treatedMessages_); + } + }); +} + +bool +ConnectionManager::Impl::isMessageTreated(std::string_view id) +{ + std::lock_guard<std::mutex> lock(messageMutex_); + auto res = treatedMessages_.emplace(id); + if (res.second) { + saveTreatedMessages(); + return false; + } + return true; +} + +/** + * returns whether or not UPnP is enabled and active_ + * ie: if it is able to make port mappings + */ +bool +ConnectionManager::Impl::getUPnPActive() const +{ + return config_->getUPnPActive(); +} + +IpAddr +ConnectionManager::Impl::getPublishedIpAddress(uint16_t family) const +{ + if (family == AF_INET) + return publishedIp_[0]; + if (family == AF_INET6) + return publishedIp_[1]; + + assert(family == AF_UNSPEC); + + // If family is not set, prefere IPv4 if available. It's more + // likely to succeed behind NAT. + if (publishedIp_[0]) + return publishedIp_[0]; + if (publishedIp_[1]) + return publishedIp_[1]; + return {}; +} + +void +ConnectionManager::Impl::setPublishedAddress(const IpAddr& ip_addr) +{ + if (ip_addr.getFamily() == AF_INET) { + publishedIp_[0] = ip_addr; + } else { + publishedIp_[1] = ip_addr; + } +} + +void +ConnectionManager::Impl::storeActiveIpAddress(std::function<void()>&& cb) +{ + dht()->getPublicAddress([this, cb = std::move(cb)](std::vector<dht::SockAddr>&& results) { + bool hasIpv4 {false}, hasIpv6 {false}; + for (auto& result : results) { + auto family = result.getFamily(); + if (family == AF_INET) { + if (not hasIpv4) { + hasIpv4 = true; + if (config_->logger) + config_->logger->debug("Store DHT public IPv4 address: {}", result); + //JAMI_DBG("Store DHT public IPv4 address : %s", result.toString().c_str()); + setPublishedAddress(*result.get()); + if (config_->upnpCtrl) { + config_->upnpCtrl->setPublicAddress(*result.get()); + } + } + } else if (family == AF_INET6) { + if (not hasIpv6) { + hasIpv6 = true; + if (config_->logger) + config_->logger->debug("Store DHT public IPv6 address: {}", result); + setPublishedAddress(*result.get()); + } + } + if (hasIpv4 and hasIpv6) + break; + } + if (cb) + cb(); + }); +} + +void +ConnectionManager::Impl::getIceOptions(std::function<void(IceTransportOptions&&)> cb) noexcept +{ + storeActiveIpAddress([this, cb = std::move(cb)] { + IceTransportOptions opts = ConnectionManager::Impl::getIceOptions(); + auto publishedAddr = getPublishedIpAddress(); + + if (publishedAddr) { + auto interfaceAddr = ip_utils::getInterfaceAddr(getLocalInterface(), + publishedAddr.getFamily()); + if (interfaceAddr) { + opts.accountLocalAddr = interfaceAddr; + opts.accountPublicAddr = publishedAddr; + } + } + if (cb) + cb(std::move(opts)); + }); +} + +IceTransportOptions +ConnectionManager::Impl::getIceOptions() const noexcept +{ + IceTransportOptions opts; + opts.upnpEnable = getUPnPActive(); + + if (config_->stunEnabled) + opts.stunServers.emplace_back(StunServerInfo().setUri(config_->stunServer)); + if (config_->turnEnabled) { + auto cached = false; + std::lock_guard<std::mutex> lk(config_->cachedTurnMutex); + cached = config_->cacheTurnV4 || config_->cacheTurnV6; + if (config_->cacheTurnV4) { + opts.turnServers.emplace_back(TurnServerInfo() + .setUri(config_->cacheTurnV4.toString()) + .setUsername(config_->turnServerUserName) + .setPassword(config_->turnServerPwd) + .setRealm(config_->turnServerRealm)); + } + // NOTE: first test with ipv6 turn was not concluant and resulted in multiple + // co issues. So this needs some debug. for now just disable + // if (cacheTurnV6 && *cacheTurnV6) { + // opts.turnServers.emplace_back(TurnServerInfo() + // .setUri(cacheTurnV6->toString(true)) + // .setUsername(turnServerUserName_) + // .setPassword(turnServerPwd_) + // .setRealm(turnServerRealm_)); + //} + // Nothing cached, so do the resolution + if (!cached) { + opts.turnServers.emplace_back(TurnServerInfo() + .setUri(config_->turnServer) + .setUsername(config_->turnServerUserName) + .setPassword(config_->turnServerPwd) + .setRealm(config_->turnServerRealm)); + } + } + return opts; +} + +bool +ConnectionManager::Impl::foundPeerDevice(const std::shared_ptr<dht::crypto::Certificate>& crt, + dht::InfoHash& account_id, + const std::shared_ptr<Logger>& logger) +{ + if (not crt) + return false; + + auto top_issuer = crt; + while (top_issuer->issuer) + top_issuer = top_issuer->issuer; + + // Device certificate can't be self-signed + if (top_issuer == crt) { + if (logger) + logger->warn("Found invalid peer device: {}", crt->getLongId()); + return false; + } + + // Check peer certificate chain + // Trust store with top issuer as the only CA + dht::crypto::TrustList peer_trust; + peer_trust.add(*top_issuer); + if (not peer_trust.verify(*crt)) { + if (logger) + logger->warn("Found invalid peer device: {}", crt->getLongId()); + return false; + } + + // Check cached OCSP response + if (crt->ocspResponse and crt->ocspResponse->getCertificateStatus() != GNUTLS_OCSP_CERT_GOOD) { + if (logger) + logger->error("Certificate %s is disabled by cached OCSP response", crt->getLongId()); + return false; + } + + account_id = crt->issuer->getId(); + if (logger) + logger->warn("Found peer device: {} account:{} CA:{}", + crt->getLongId(), + account_id, + top_issuer->getId()); + return true; +} + +bool +ConnectionManager::Impl::findCertificate( + const dht::PkId& id, std::function<void(const std::shared_ptr<dht::crypto::Certificate>&)>&& cb) +{ + if (auto cert = certStore().getCertificate(id.toString())) { + if (cb) + cb(cert); + } else if (cb) + cb(nullptr); + return true; +} + +ConnectionManager::ConnectionManager(std::shared_ptr<ConnectionManager::Config> config_) + : pimpl_ {std::make_shared<Impl>(config_)} +{} + +ConnectionManager::~ConnectionManager() +{ + if (pimpl_) + pimpl_->shutdown(); +} + +void +ConnectionManager::connectDevice(const DeviceId& deviceId, + const std::string& name, + ConnectCallback cb, + bool noNewSocket, + bool forceNewSocket, + const std::string& connType) +{ + pimpl_->connectDevice(deviceId, name, std::move(cb), noNewSocket, forceNewSocket, connType); +} + +void +ConnectionManager::connectDevice(const std::shared_ptr<dht::crypto::Certificate>& cert, + const std::string& name, + ConnectCallback cb, + bool noNewSocket, + bool forceNewSocket, + const std::string& connType) +{ + pimpl_->connectDevice(cert, name, std::move(cb), noNewSocket, forceNewSocket, connType); +} + +bool +ConnectionManager::isConnecting(const DeviceId& deviceId, const std::string& name) const +{ + auto pending = pimpl_->getPendingCallbacks(deviceId); + return std::find_if(pending.begin(), pending.end(), [&](auto p) { return p.name == name; }) + != pending.end(); +} + +void +ConnectionManager::closeConnectionsWith(const std::string& peerUri) +{ + std::vector<std::shared_ptr<ConnectionInfo>> connInfos; + std::set<DeviceId> peersDevices; + { + std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); + for (auto iter = pimpl_->infos_.begin(); iter != pimpl_->infos_.end();) { + auto const& [key, value] = *iter; + auto deviceId = key.first; + auto cert = pimpl_->certStore().getCertificate(deviceId.toString()); + if (cert && cert->issuer && peerUri == cert->issuer->getId().toString()) { + connInfos.emplace_back(value); + peersDevices.emplace(deviceId); + iter = pimpl_->infos_.erase(iter); + } else { + iter++; + } + } + } + // Stop connections to all peers devices + for (const auto& deviceId : peersDevices) { + for (const auto& pending : pimpl_->extractPendingCallbacks(deviceId)) + pending.cb(nullptr, deviceId); + // This will close the TLS Session + pimpl_->removeUnusedConnections(deviceId); + } + for (auto& info : connInfos) { + if (info->socket_) + info->socket_->shutdown(); + if (info->waitForAnswer_) + info->waitForAnswer_->cancel(); + if (info->ice_) { + std::unique_lock<std::mutex> lk {info->mutex_}; + dht::ThreadPool::io().run( + [ice = std::shared_ptr<IceTransport>(std::move(info->ice_))] {}); + } + } +} + +void +ConnectionManager::onDhtConnected(const dht::crypto::PublicKey& devicePk) +{ + pimpl_->onDhtConnected(devicePk); +} + +void +ConnectionManager::onICERequest(onICERequestCallback&& cb) +{ + pimpl_->iceReqCb_ = std::move(cb); +} + +void +ConnectionManager::onChannelRequest(ChannelRequestCallback&& cb) +{ + pimpl_->channelReqCb_ = std::move(cb); +} + +void +ConnectionManager::onConnectionReady(ConnectionReadyCallback&& cb) +{ + pimpl_->connReadyCb_ = std::move(cb); +} + +void +ConnectionManager::oniOSConnected(iOSConnectedCallback&& cb) +{ + pimpl_->iOSConnectedCb_ = std::move(cb); +} + +std::size_t +ConnectionManager::activeSockets() const +{ + std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); + return pimpl_->infos_.size(); +} + +void +ConnectionManager::monitor() const +{ + std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); + auto logger = pimpl_->config_->logger; + if (!logger) + return; + logger->debug("ConnectionManager current status:"); + for (const auto& [_, ci] : pimpl_->infos_) { + if (ci->socket_) + ci->socket_->monitor(); + } + logger->debug("ConnectionManager end status."); +} + +void +ConnectionManager::connectivityChanged() +{ + std::lock_guard<std::mutex> lk(pimpl_->infosMtx_); + for (const auto& [_, ci] : pimpl_->infos_) { + if (ci->socket_) + ci->socket_->sendBeacon(); + } +} + +void +ConnectionManager::getIceOptions(std::function<void(IceTransportOptions&&)> cb) noexcept +{ + return pimpl_->getIceOptions(std::move(cb)); +} + +IceTransportOptions +ConnectionManager::getIceOptions() const noexcept +{ + return pimpl_->getIceOptions(); +} + +IpAddr +ConnectionManager::getPublishedIpAddress(uint16_t family) const +{ + return pimpl_->getPublishedIpAddress(family); +} + +void +ConnectionManager::setPublishedAddress(const IpAddr& ip_addr) +{ + return pimpl_->setPublishedAddress(ip_addr); +} + +void +ConnectionManager::storeActiveIpAddress(std::function<void()>&& cb) +{ + return pimpl_->storeActiveIpAddress(std::move(cb)); +} + +std::shared_ptr<ConnectionManager::Config> +ConnectionManager::getConfig() +{ + return pimpl_->config_; +} + +} // namespace jami diff --git a/src/fileutils.cpp b/src/fileutils.cpp new file mode 100644 index 0000000..be911a6 --- /dev/null +++ b/src/fileutils.cpp @@ -0,0 +1,878 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Rafaël Carré <rafael.carre@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +//#include "logger.h" +#include "fileutils.h" +//#include "archiver.h" +//#include "compiler_intrinsics.h" +#include <opendht/crypto.h> + +#ifdef RING_UWP +#include <io.h> // for access and close +#include "ring_signal.h" +#endif + +#ifdef __APPLE__ +#include <TargetConditionals.h> +#endif + +#if defined(__ANDROID__) || (defined(TARGET_OS_IOS) && TARGET_OS_IOS) +#include "client/ring_signal.h" +#endif + +#ifdef _WIN32 +#include <windows.h> +#include "string_utils.h" +#endif + +#include <sys/types.h> +#include <sys/stat.h> + +#ifndef _MSC_VER +#include <libgen.h> +#endif + +#ifdef _MSC_VER +#include "windirent.h" +#else +#include <dirent.h> +#endif + +#include <signal.h> +#include <unistd.h> +#include <fcntl.h> +#ifndef _WIN32 +#include <pwd.h> +#else +#include <shlobj.h> +#define NAME_MAX 255 +#endif +#if !defined __ANDROID__ && !defined _WIN32 +#include <wordexp.h> +#endif + +#include <nettle/sha3.h> + +#include <sstream> +#include <fstream> +#include <iostream> +#include <stdexcept> +#include <limits> +#include <array> + +#include <cstdlib> +#include <cstring> +#include <cerrno> +#include <cstddef> +#include <ciso646> + +#include <pj/ctype.h> +#include <pjlib-util/md5.h> + +#include <filesystem> + +#define PIDFILE ".ring.pid" +#define ERASE_BLOCK 4096 + +namespace jami { +namespace fileutils { + +// returns true if directory exists +bool +check_dir(const char* path, [[maybe_unused]] mode_t dirmode, mode_t parentmode) +{ + DIR* dir = opendir(path); + + if (!dir) { // doesn't exist + if (not recursive_mkdir(path, parentmode)) { + perror(path); + return false; + } +#ifndef _WIN32 + if (chmod(path, dirmode) < 0) { + //JAMI_ERR("fileutils::check_dir(): chmod() failed on '%s', %s", path, strerror(errno)); + return false; + } +#endif + } else + closedir(dir); + return true; +} + +std::string +expand_path(const std::string& path) +{ +#if defined __ANDROID__ || defined _MSC_VER || defined WIN32 || defined __APPLE__ + //JAMI_ERR("Path expansion not implemented, returning original"); + return path; +#else + + std::string result; + + wordexp_t p; + int ret = wordexp(path.c_str(), &p, 0); + + switch (ret) { + case WRDE_BADCHAR: + JAMI_ERR("Illegal occurrence of newline or one of |, &, ;, <, >, " + "(, ), {, }."); + return result; + case WRDE_BADVAL: + JAMI_ERR("An undefined shell variable was referenced"); + return result; + case WRDE_CMDSUB: + JAMI_ERR("Command substitution occurred"); + return result; + case WRDE_SYNTAX: + JAMI_ERR("Shell syntax error"); + return result; + case WRDE_NOSPACE: + JAMI_ERR("Out of memory."); + // This is the only error where we must call wordfree + break; + default: + if (p.we_wordc > 0) + result = std::string(p.we_wordv[0]); + break; + } + + wordfree(&p); + + return result; +#endif +} + +std::mutex& +getFileLock(const std::string& path) +{ + static std::mutex fileLockLock {}; + static std::map<std::string, std::mutex> fileLocks {}; + + std::lock_guard<std::mutex> l(fileLockLock); + return fileLocks[path]; +} + +bool +isFile(const std::string& path, bool resolveSymlink) +{ + if (path.empty()) + return false; +#ifdef _WIN32 + if (resolveSymlink) { + struct _stat64i32 s; + if (_wstat(jami::to_wstring(path).c_str(), &s) == 0) + return S_ISREG(s.st_mode); + } else { + DWORD attr = GetFileAttributes(jami::to_wstring(path).c_str()); + if ((attr != INVALID_FILE_ATTRIBUTES) && !(attr & FILE_ATTRIBUTE_DIRECTORY) + && !(attr & FILE_ATTRIBUTE_REPARSE_POINT)) + return true; + } +#else + if (resolveSymlink) { + struct stat s; + if (stat(path.c_str(), &s) == 0) + return S_ISREG(s.st_mode); + } else { + struct stat s; + if (lstat(path.c_str(), &s) == 0) + return S_ISREG(s.st_mode); + } +#endif + + return false; +} + +bool +isDirectory(const std::string& path) +{ + struct stat s; + if (stat(path.c_str(), &s) == 0) + return s.st_mode & S_IFDIR; + return false; +} + +bool +isDirectoryWritable(const std::string& directory) +{ + return accessFile(directory, W_OK) == 0; +} + +bool +hasHardLink(const std::string& path) +{ +#ifndef _WIN32 + struct stat s; + if (lstat(path.c_str(), &s) == 0) + return s.st_nlink > 1; +#endif + return false; +} + +bool +isSymLink(const std::string& path) +{ +#ifndef _WIN32 + struct stat s; + if (lstat(path.c_str(), &s) == 0) + return S_ISLNK(s.st_mode); +#elif !defined(_MSC_VER) + DWORD attr = GetFileAttributes(jami::to_wstring(path).c_str()); + if (attr & FILE_ATTRIBUTE_REPARSE_POINT) + return true; +#endif + return false; +} + +std::chrono::system_clock::time_point +writeTime(const std::string& path) +{ +#ifndef _WIN32 + struct stat s; + auto ret = stat(path.c_str(), &s); + if (ret) + throw std::runtime_error("Can't check write time for: " + path); + return std::chrono::system_clock::from_time_t(s.st_mtime); +#else +#if RING_UWP + _CREATEFILE2_EXTENDED_PARAMETERS ext_params = {0}; + ext_params.dwSize = sizeof(CREATEFILE2_EXTENDED_PARAMETERS); + ext_params.dwFileAttributes = FILE_ATTRIBUTE_NORMAL; + ext_params.dwFileFlags = FILE_FLAG_NO_BUFFERING; + ext_params.dwSecurityQosFlags = SECURITY_ANONYMOUS; + ext_params.lpSecurityAttributes = nullptr; + ext_params.hTemplateFile = nullptr; + HANDLE h = CreateFile2(jami::to_wstring(path).c_str(), + GENERIC_READ, + FILE_SHARE_READ, + OPEN_EXISTING, + &ext_params); +#elif _WIN32 + HANDLE h = CreateFileW(jami::to_wstring(path).c_str(), + GENERIC_READ, + FILE_SHARE_READ, + nullptr, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + nullptr); +#endif + if (h == INVALID_HANDLE_VALUE) + throw std::runtime_error("Can't open: " + path); + FILETIME lastWriteTime; + if (!GetFileTime(h, nullptr, nullptr, &lastWriteTime)) + throw std::runtime_error("Can't check write time for: " + path); + CloseHandle(h); + SYSTEMTIME sTime; + if (!FileTimeToSystemTime(&lastWriteTime, &sTime)) + throw std::runtime_error("Can't check write time for: " + path); + struct tm tm + {}; + tm.tm_year = sTime.wYear - 1900; + tm.tm_mon = sTime.wMonth - 1; + tm.tm_mday = sTime.wDay; + tm.tm_hour = sTime.wHour; + tm.tm_min = sTime.wMinute; + tm.tm_sec = sTime.wSecond; + tm.tm_isdst = -1; + return std::chrono::system_clock::from_time_t(mktime(&tm)); +#endif +} + +bool +createSymlink(const std::string& linkFile, const std::string& target) +{ + try { + std::filesystem::create_symlink(target, linkFile); + } catch (const std::exception& e) { + //JAMI_ERR("Couldn't create soft link: %s", e.what()); + return false; + } + return true; +} + +bool +createHardlink(const std::string& linkFile, const std::string& target) +{ + try { + std::filesystem::create_hard_link(target, linkFile); + } catch (const std::exception& e) { + //JAMI_ERR("Couldn't create hard link: %s", e.what()); + return false; + } + return true; +} + +void +createFileLink(const std::string& linkFile, const std::string& target, bool hard) +{ + if (not hard or not createHardlink(linkFile, target)) + createSymlink(linkFile, target); +} + +std::string_view +getFileExtension(std::string_view filename) +{ + std::string_view result; + auto sep = filename.find_last_of('.'); + if (sep != std::string_view::npos && sep != filename.size() - 1) + result = filename.substr(sep + 1); + if (result.size() >= 8) + return {}; + return result; +} + +bool +isPathRelative(const std::string& path) +{ +#ifndef _WIN32 + return not path.empty() and not(path[0] == '/'); +#else + return not path.empty() and path.find(":") == std::string::npos; +#endif +} + +std::string +getCleanPath(const std::string& base, const std::string& path) +{ + if (base.empty() or path.size() < base.size()) + return path; + auto base_sep = base + DIR_SEPARATOR_STR; + if (path.compare(0, base_sep.size(), base_sep) == 0) + return path.substr(base_sep.size()); + else + return path; +} + +std::string +getFullPath(const std::string& base, const std::string& path) +{ + bool isRelative {not base.empty() and isPathRelative(path)}; + return isRelative ? base + DIR_SEPARATOR_STR + path : path; +} + +std::vector<uint8_t> +loadFile(const std::string& path, const std::string& default_dir) +{ + std::vector<uint8_t> buffer; + std::ifstream file = ifstream(getFullPath(default_dir, path), std::ios::binary); + if (!file) + throw std::runtime_error("Can't read file: " + path); + file.seekg(0, std::ios::end); + auto size = file.tellg(); + if (size > std::numeric_limits<unsigned>::max()) + throw std::runtime_error("File is too big: " + path); + buffer.resize(size); + file.seekg(0, std::ios::beg); + if (!file.read((char*) buffer.data(), size)) + throw std::runtime_error("Can't load file: " + path); + return buffer; +} + +std::string +loadTextFile(const std::string& path, const std::string& default_dir) +{ + std::string buffer; + std::ifstream file = ifstream(getFullPath(default_dir, path)); + if (!file) + throw std::runtime_error("Can't read file: " + path); + file.seekg(0, std::ios::end); + auto size = file.tellg(); + if (size > std::numeric_limits<unsigned>::max()) + throw std::runtime_error("File is too big: " + path); + buffer.resize(size); + file.seekg(0, std::ios::beg); + if (!file.read((char*) buffer.data(), size)) + throw std::runtime_error("Can't load file: " + path); + return buffer; +} + +void +saveFile(const std::string& path, const uint8_t* data, size_t data_size, [[maybe_unused]] mode_t mode) +{ + std::ofstream file = fileutils::ofstream(path, std::ios::trunc | std::ios::binary); + if (!file.is_open()) { + //JAMI_ERR("Could not write data to %s", path.c_str()); + return; + } + file.write((char*) data, data_size); +#ifndef _WIN32 + if (chmod(path.c_str(), mode) < 0) + /*JAMI_WARN("fileutils::saveFile(): chmod() failed on '%s', %s", + path.c_str(), + strerror(errno))*/; +#endif +} + +std::vector<uint8_t> +loadCacheFile(const std::string& path, std::chrono::system_clock::duration maxAge) +{ + // writeTime throws exception if file doesn't exist + auto duration = std::chrono::system_clock::now() - writeTime(path); + if (duration > maxAge) + throw std::runtime_error("file too old"); + + //JAMI_DBG("Loading cache file '%.*s'", (int) path.size(), path.c_str()); + return loadFile(path); +} + +std::string +loadCacheTextFile(const std::string& path, std::chrono::system_clock::duration maxAge) +{ + // writeTime throws exception if file doesn't exist + auto duration = std::chrono::system_clock::now() - writeTime(path); + if (duration > maxAge) + throw std::runtime_error("file too old"); + + //JAMI_DBG("Loading cache file '%.*s'", (int) path.size(), path.c_str()); + return loadTextFile(path); +} + +static size_t +dirent_buf_size([[maybe_unused]] DIR* dirp) +{ + long name_max; +#if defined(HAVE_FPATHCONF) && defined(HAVE_DIRFD) && defined(_PC_NAME_MAX) + name_max = fpathconf(dirfd(dirp), _PC_NAME_MAX); + if (name_max == -1) +#if defined(NAME_MAX) + name_max = (NAME_MAX > 255) ? NAME_MAX : 255; +#else + return (size_t) (-1); +#endif +#else +#if defined(NAME_MAX) + name_max = (NAME_MAX > 255) ? NAME_MAX : 255; +#else +#error "buffer size for readdir_r cannot be determined" +#endif +#endif + size_t name_end = (size_t) offsetof(struct dirent, d_name) + name_max + 1; + return name_end > sizeof(struct dirent) ? name_end : sizeof(struct dirent); +} + +std::vector<std::string> +readDirectory(const std::string& dir) +{ + DIR* dp = opendir(dir.c_str()); + if (!dp) + return {}; + + size_t size = dirent_buf_size(dp); + if (size == (size_t) (-1)) + return {}; + std::vector<uint8_t> buf(size); + dirent* entry; + + std::vector<std::string> files; +#ifndef _WIN32 + while (!readdir_r(dp, reinterpret_cast<dirent*>(buf.data()), &entry) && entry) { +#else + while ((entry = readdir(dp)) != nullptr) { +#endif + std::string fname {entry->d_name}; + if (fname == "." || fname == "..") + continue; + files.emplace_back(std::move(fname)); + } + closedir(dp); + return files; +} // namespace fileutils + +/* +std::vector<uint8_t> +readArchive(const std::string& path, const std::string& pwd) +{ + JAMI_DBG("Reading archive from %s", path.c_str()); + + auto isUnencryptedGzip = [](const std::vector<uint8_t>& data) { + // NOTE: some webserver modify gzip files and this can end with a gunzip in a gunzip + // file. So, to make the readArchive more robust, we can support this case by detecting + // gzip header via 1f8b 08 + // We don't need to support more than 2 level, else somebody may be able to send + // gunzip in loops and abuse. + return data.size() > 3 && data[0] == 0x1f && data[1] == 0x8b && data[2] == 0x08; + }; + + auto decompress = [](std::vector<uint8_t>& data) { + try { + data = archiver::decompress(data); + } catch (const std::exception& e) { + JAMI_ERR("Error decrypting archive: %s", e.what()); + throw e; + } + }; + + std::vector<uint8_t> data; + // Read file + try { + data = loadFile(path); + } catch (const std::exception& e) { + JAMI_ERR("Error loading archive: %s", e.what()); + throw e; + } + + if (isUnencryptedGzip(data)) { + if (!pwd.empty()) + JAMI_WARN() << "A gunzip in a gunzip is detected. A webserver may have a bad config"; + + decompress(data); + } + + if (!pwd.empty()) { + // Decrypt + try { + data = dht::crypto::aesDecrypt(data, pwd); + } catch (const std::exception& e) { + JAMI_ERR("Error decrypting archive: %s", e.what()); + throw e; + } + decompress(data); + } else if (isUnencryptedGzip(data)) { + JAMI_WARN() << "A gunzip in a gunzip is detected. A webserver may have a bad config"; + decompress(data); + } + return data; +} + +void +writeArchive(const std::string& archive_str, const std::string& path, const std::string& password) +{ + JAMI_DBG("Writing archive to %s", path.c_str()); + + if (not password.empty()) { + // Encrypt using provided password + std::vector<uint8_t> data = dht::crypto::aesEncrypt(archiver::compress(archive_str), + password); + // Write + try { + saveFile(path, data); + } catch (const std::runtime_error& ex) { + JAMI_ERR("Export failed: %s", ex.what()); + return; + } + } else { + JAMI_WARN("Unsecured archiving (no password)"); + archiver::compressGzip(archive_str, path); + } +}*/ + +bool +recursive_mkdir(const std::string& path, mode_t mode) +{ +#ifndef _WIN32 + if (mkdir(path.data(), mode) != 0) { +#else + if (_wmkdir(jami::to_wstring(path.data()).c_str()) != 0) { +#endif + if (errno == ENOENT) { + recursive_mkdir(path.substr(0, path.find_last_of(DIR_SEPARATOR_CH)), mode); +#ifndef _WIN32 + if (mkdir(path.data(), mode) != 0) { +#else + if (_wmkdir(jami::to_wstring(path.data()).c_str()) != 0) { +#endif + //JAMI_ERR("Could not create directory."); + return false; + } + } + } // namespace jami + return true; +} + +#ifdef _WIN32 +bool +eraseFile_win32(const std::string& path, bool dosync) +{ + HANDLE h + = CreateFileA(path.c_str(), GENERIC_WRITE, 0, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + if (h == INVALID_HANDLE_VALUE) { + JAMI_WARN("Can not open file %s for erasing.", path.c_str()); + return false; + } + + LARGE_INTEGER size; + if (!GetFileSizeEx(h, &size)) { + JAMI_WARN("Can not erase file %s: GetFileSizeEx() failed.", path.c_str()); + CloseHandle(h); + return false; + } + if (size.QuadPart == 0) { + CloseHandle(h); + return false; + } + + uint64_t size_blocks = size.QuadPart / ERASE_BLOCK; + if (size.QuadPart % ERASE_BLOCK) + size_blocks++; + + char* buffer; + try { + buffer = new char[ERASE_BLOCK]; + } catch (std::bad_alloc& ba) { + JAMI_WARN("Can not allocate buffer for erasing %s.", path.c_str()); + CloseHandle(h); + return false; + } + memset(buffer, 0x00, ERASE_BLOCK); + + OVERLAPPED ovlp; + if (size.QuadPart < (1024 - 42)) { // a small file can be stored in the MFT record + ovlp.Offset = 0; + ovlp.OffsetHigh = 0; + WriteFile(h, buffer, (DWORD) size.QuadPart, 0, &ovlp); + FlushFileBuffers(h); + } + for (uint64_t i = 0; i < size_blocks; i++) { + uint64_t offset = i * ERASE_BLOCK; + ovlp.Offset = offset & 0x00000000FFFFFFFF; + ovlp.OffsetHigh = offset >> 32; + WriteFile(h, buffer, ERASE_BLOCK, 0, &ovlp); + } + + delete[] buffer; + + if (dosync) + FlushFileBuffers(h); + + CloseHandle(h); + return true; +} + +#else + +bool +eraseFile_posix(const std::string& path, bool dosync) +{ + struct stat st; + if (stat(path.c_str(), &st) == -1) { + //JAMI_WARN("Can not erase file %s: fstat() failed.", path.c_str()); + return false; + } + // Remove read-only flag if possible + chmod(path.c_str(), st.st_mode | (S_IWGRP+S_IWUSR) ); + + int fd = open(path.c_str(), O_WRONLY); + if (fd == -1) { + //JAMI_WARN("Can not open file %s for erasing.", path.c_str()); + return false; + } + + if (st.st_size == 0) { + close(fd); + return false; + } + + lseek(fd, 0, SEEK_SET); + + std::array<char, ERASE_BLOCK> buffer; + buffer.fill(0); + decltype(st.st_size) written(0); + while (written < st.st_size) { + auto ret = write(fd, buffer.data(), buffer.size()); + if (ret < 0) { + //JAMI_WARNING("Error while overriding file with zeros."); + break; + } else + written += ret; + } + + if (dosync) + fsync(fd); + + close(fd); + return written >= st.st_size; +} +#endif + +bool +eraseFile(const std::string& path, bool dosync) +{ +#ifdef _WIN32 + return eraseFile_win32(path, dosync); +#else + return eraseFile_posix(path, dosync); +#endif +} + +int +remove(const std::string& path, bool erase) +{ + if (erase and isFile(path, false) and !hasHardLink(path)) + eraseFile(path, true); + +#ifdef _WIN32 + // use Win32 api since std::remove will not unlink directory in use + if (isDirectory(path)) + return !RemoveDirectory(jami::to_wstring(path).c_str()); +#endif + + return std::remove(path.c_str()); +} + +int +removeAll(const std::string& path, bool erase) +{ + if (path.empty()) + return -1; + if (isDirectory(path) and !isSymLink(path)) { + auto dir = path; + if (dir.back() != DIR_SEPARATOR_CH) + dir += DIR_SEPARATOR_CH; + for (auto& entry : fileutils::readDirectory(dir)) + removeAll(dir + entry, erase); + } + return remove(path, erase); +} + +void +openStream(std::ifstream& file, const std::string& path, std::ios_base::openmode mode) +{ +#ifdef _WIN32 + file.open(jami::to_wstring(path), mode); +#else + file.open(path, mode); +#endif +} + +void +openStream(std::ofstream& file, const std::string& path, std::ios_base::openmode mode) +{ +#ifdef _WIN32 + file.open(jami::to_wstring(path), mode); +#else + file.open(path, mode); +#endif +} + +std::ifstream +ifstream(const std::string& path, std::ios_base::openmode mode) +{ +#ifdef _WIN32 + return std::ifstream(jami::to_wstring(path), mode); +#else + return std::ifstream(path, mode); +#endif +} + +std::ofstream +ofstream(const std::string& path, std::ios_base::openmode mode) +{ +#ifdef _WIN32 + return std::ofstream(jami::to_wstring(path), mode); +#else + return std::ofstream(path, mode); +#endif +} + +int64_t +size(const std::string& path) +{ + int64_t size = 0; + try { + std::ifstream file; + openStream(file, path, std::ios::binary | std::ios::in); + file.seekg(0, std::ios_base::end); + size = file.tellg(); + file.close(); + } catch (...) { + } + return size; +} + +std::string +sha3File(const std::string& path) +{ + sha3_512_ctx ctx; + sha3_512_init(&ctx); + + std::ifstream file; + try { + if (!fileutils::isFile(path)) + return {}; + openStream(file, path, std::ios::binary | std::ios::in); + if (!file) + return {}; + std::vector<char> buffer(8192, 0); + while (!file.eof()) { + file.read(buffer.data(), buffer.size()); + std::streamsize readSize = file.gcount(); + sha3_512_update(&ctx, readSize, (const uint8_t*) buffer.data()); + } + file.close(); + } catch (...) { + return {}; + } + + unsigned char digest[SHA3_512_DIGEST_SIZE]; + sha3_512_digest(&ctx, SHA3_512_DIGEST_SIZE, digest); + + char hash[SHA3_512_DIGEST_SIZE * 2]; + + for (int i = 0; i < SHA3_512_DIGEST_SIZE; ++i) + pj_val_to_hex_digit(digest[i], &hash[2 * i]); + + return {hash, SHA3_512_DIGEST_SIZE * 2}; +} + +std::string +sha3sum(const std::vector<uint8_t>& buffer) +{ + sha3_512_ctx ctx; + sha3_512_init(&ctx); + sha3_512_update(&ctx, buffer.size(), (const uint8_t*) buffer.data()); + + unsigned char digest[SHA3_512_DIGEST_SIZE]; + sha3_512_digest(&ctx, SHA3_512_DIGEST_SIZE, digest); + + char hash[SHA3_512_DIGEST_SIZE * 2]; + + for (int i = 0; i < SHA3_512_DIGEST_SIZE; ++i) + pj_val_to_hex_digit(digest[i], &hash[2 * i]); + + return {hash, SHA3_512_DIGEST_SIZE * 2}; +} + +int +accessFile(const std::string& file, int mode) +{ +#ifdef _WIN32 + return _waccess(jami::to_wstring(file).c_str(), mode); +#else + return access(file.c_str(), mode); +#endif +} + +uint64_t +lastWriteTime(const std::string& p) +{ +#if USE_STD_FILESYSTEM + return std::chrono::duration_cast<std::chrono::milliseconds>( + std::filesystem::last_write_time(std::filesystem::path(p)).time_since_epoch()) + .count(); +#else + struct stat result; + if (stat(p.c_str(), &result) == 0) + return result.st_mtime; + return 0; +#endif +} + +} // namespace fileutils +} // namespace jami diff --git a/src/ice_socket.h b/src/ice_socket.h new file mode 100644 index 0000000..795185d --- /dev/null +++ b/src/ice_socket.h @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include "generic_io.h" + +#include <memory> +#include <functional> + +#if defined(_MSC_VER) +#include <BaseTsd.h> +using ssize_t = SSIZE_T; +#endif + +namespace jami { + +class IceTransport; +using IceRecvCb = std::function<ssize_t(unsigned char* buf, size_t len)>; + +class IceSocket +{ +private: + std::shared_ptr<IceTransport> ice_transport_ {}; + int compId_ = -1; + +public: + IceSocket(std::shared_ptr<IceTransport> iceTransport, int compId) + : ice_transport_(std::move(iceTransport)) + , compId_(compId) + {} + + void close(); + ssize_t send(const unsigned char* buf, size_t len); + ssize_t waitForData(std::chrono::milliseconds timeout); + void setOnRecv(IceRecvCb cb); + uint16_t getTransportOverhead(); + void setDefaultRemoteAddress(const IpAddr& addr); + int getCompId() const { return compId_; }; +}; + +}; // namespace jami diff --git a/src/ice_transport.cpp b/src/ice_transport.cpp new file mode 100644 index 0000000..12c0122 --- /dev/null +++ b/src/ice_transport.cpp @@ -0,0 +1,1902 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "ice_transport.h" +#include "ice_socket.h" +#include "sip_utils.h" +#include "string_utils.h" +#include "upnp/upnp_control.h" +#include "transport/peer_channel.h" +#include "tracepoint/tracepoint.h" + +#include <opendht/logger.h> +#include <opendht/utils.h> + +#include <pjlib.h> + +#include <map> +#include <atomic> +#include <queue> +#include <mutex> +#include <condition_variable> +#include <thread> +#include <utility> +#include <tuple> +#include <algorithm> +#include <sstream> +#include <chrono> +#include <thread> +#include <cerrno> + +#include "pj/limits.h" + +#define TRY(ret) \ + do { \ + if ((ret) != PJ_SUCCESS) \ + throw std::runtime_error(#ret " failed"); \ + } while (0) + +// Validate that the component ID is within the expected range +#define ASSERT_COMP_ID(compId, compCount) \ + do { \ + if ((compId) == 0 or (compId) > (compCount)) \ + throw std::runtime_error("Invalid component ID " + (std::to_string(compId))); \ + } while (0) + +namespace jami { + +static constexpr unsigned STUN_MAX_PACKET_SIZE {8192}; +static constexpr uint16_t IPV6_HEADER_SIZE = 40; ///< Size in bytes of IPV6 packet header +static constexpr uint16_t IPV4_HEADER_SIZE = 20; ///< Size in bytes of IPV4 packet header +static constexpr int MAX_CANDIDATES {32}; +static constexpr int MAX_DESTRUCTION_TIMEOUT {3000}; +static constexpr int HANDLE_EVENT_DURATION {500}; + +//============================================================================== + +using MutexGuard = std::lock_guard<std::mutex>; +using MutexLock = std::unique_lock<std::mutex>; +using namespace upnp; + +//============================================================================== + +class IceLock +{ + pj_grp_lock_t* lk_; + +public: + IceLock(pj_ice_strans* strans) + : lk_(pj_ice_strans_get_grp_lock(strans)) + { + lock(); + } + + ~IceLock() { unlock(); } + + void lock() { if (lk_) pj_grp_lock_acquire(lk_); } + + void unlock() { if (lk_) pj_grp_lock_release(lk_); } +}; + +class IceTransport::Impl +{ +public: + Impl(std::string_view name); + ~Impl(); + + void initIceInstance(const IceTransportOptions& options); + + void onComplete(pj_ice_strans* ice_st, pj_ice_strans_op op, pj_status_t status); + + void onReceiveData(unsigned comp_id, void* pkt, pj_size_t size); + + /** + * Set/change transport role as initiator. + * Should be called before start method. + */ + bool setInitiatorSession(); + + /** + * Set/change transport role as slave. + * Should be called before start method. + */ + bool setSlaveSession(); + bool createIceSession(pj_ice_sess_role role); + + void getUFragPwd(); + + std::string link() const; + + bool _isInitialized() const; + bool _isStarted() const; + bool _isRunning() const; + bool _isFailed() const; + bool _waitForInitialization(std::chrono::milliseconds timeout); + + const pj_ice_sess_cand* getSelectedCandidate(unsigned comp_id, bool remote) const; + IpAddr getLocalAddress(unsigned comp_id) const; + IpAddr getRemoteAddress(unsigned comp_id) const; + static const char* getCandidateType(const pj_ice_sess_cand* cand); + bool isTcpEnabled() const { return config_.protocol == PJ_ICE_TP_TCP; } + bool addStunConfig(int af); + void requestUpnpMappings(); + bool hasUpnp() const; + // Take a list of address pairs (local/public) and add them as + // reflexive candidates using STUN config. + void addServerReflexiveCandidates(const std::vector<std::pair<IpAddr, IpAddr>>& addrList); + // Generate server reflexive candidates using the published (DHT/Account) address + std::vector<std::pair<IpAddr, IpAddr>> setupGenericReflexiveCandidates(); + // Generate server reflexive candidates using UPNP mappings. + std::vector<std::pair<IpAddr, IpAddr>> setupUpnpReflexiveCandidates(); + void setDefaultRemoteAddress(unsigned comp_id, const IpAddr& addr); + IpAddr getDefaultRemoteAddress(unsigned comp_id) const; + bool handleEvents(unsigned max_msec); + int flushTimerHeapAndIoQueue(); + int checkEventQueue(int maxEventToPoll); + + std::shared_ptr<dht::log::Logger> logger_ {}; + + std::condition_variable_any iceCV_ {}; + + std::string sessionName_ {}; + std::unique_ptr<pj_pool_t, decltype(&pj_pool_release)> pool_ {nullptr, pj_pool_release}; + bool isTcp_ {false}; + bool upnpEnabled_ {false}; + IceTransportCompleteCb on_initdone_cb_ {}; + IceTransportCompleteCb on_negodone_cb_ {}; + pj_ice_strans* icest_ {nullptr}; + unsigned streamsCount_ {0}; + unsigned compCountPerStream_ {0}; + unsigned compCount_ {0}; + std::string local_ufrag_ {}; + std::string local_pwd_ {}; + pj_sockaddr remoteAddr_ {}; + pj_ice_strans_cfg config_ {}; + //std::string last_errmsg_ {}; + + std::atomic_bool is_stopped_ {false}; + + struct Packet + { + Packet(void* pkt, pj_size_t size) + : data {reinterpret_cast<char*>(pkt), reinterpret_cast<char*>(pkt) + size} + {} + std::vector<char> data {}; + }; + + struct ComponentIO + { + std::mutex mutex; + std::condition_variable cv; + std::deque<Packet> queue; + IceRecvCb recvCb; + }; + + // NOTE: Component IDs start from 1, while these three vectors + // are indexed from 0. Conversion from ID to vector index must + // be done properly. + std::vector<ComponentIO> compIO_ {}; + std::vector<PeerChannel> peerChannels_ {}; + std::vector<IpAddr> iceDefaultRemoteAddr_; + + // ICE controlling role. True for controller agents and false for + // controlled agents + std::atomic_bool initiatorSession_ {true}; + + // Local/Public addresses used by the account owning the ICE instance. + IpAddr accountLocalAddr_ {}; + IpAddr accountPublicAddr_ {}; + + // STUN and TURN servers + std::vector<StunServerInfo> stunServers_; + std::vector<TurnServerInfo> turnServers_; + + /** + * Returns the IP of each candidate for a given component in the ICE session + */ + struct LocalCandidate + { + IpAddr addr; + pj_ice_cand_transport transport; + }; + + std::shared_ptr<upnp::Controller> upnp_ {}; + std::mutex upnpMutex_ {}; + std::map<Mapping::key_t, Mapping> upnpMappings_; + std::mutex upnpMappingsMutex_ {}; + + bool onlyIPv4Private_ {true}; + + // IO/Timer events are handled by following thread + std::thread thread_ {}; + std::atomic_bool threadTerminateFlags_ {false}; + + // Wait data on components + mutable std::mutex sendDataMutex_ {}; + std::condition_variable waitDataCv_ = {}; + pj_size_t lastSentLen_ {0}; + bool destroying_ {false}; + onShutdownCb scb {}; + + void cancelOperations() + { + for (auto& c : peerChannels_) + c.stop(); + std::lock_guard<std::mutex> lk(sendDataMutex_); + destroying_ = true; + waitDataCv_.notify_all(); + } +}; + +//============================================================================== + +/** + * Add stun/turn configuration or default host as candidates + */ + +static void +add_stun_server(pj_pool_t& pool, pj_ice_strans_cfg& cfg, const StunServerInfo& info) +{ + if (cfg.stun_tp_cnt >= PJ_ICE_MAX_STUN) + throw std::runtime_error("Too many STUN configurations"); + + IpAddr ip {info.uri}; + + // Given URI cannot be DNS resolved or not IPv4 or IPv6? + // This prevents a crash into PJSIP when ip.toString() is called. + if (ip.getFamily() == AF_UNSPEC) { + /*JAMI_DBG("[ice (%s)] STUN server '%s' not used, unresolvable address", + (cfg.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"), + info.uri.c_str());*/ + return; + } + + auto& stun = cfg.stun_tp[cfg.stun_tp_cnt++]; + pj_ice_strans_stun_cfg_default(&stun); + pj_strdup2_with_null(&pool, &stun.server, ip.toString().c_str()); + stun.af = ip.getFamily(); + if (!(stun.port = ip.getPort())) + stun.port = PJ_STUN_PORT; + stun.cfg.max_pkt_size = STUN_MAX_PACKET_SIZE; + stun.conn_type = cfg.stun.conn_type; + /*JAMI_DBG("[ice (%s)] added stun server '%s', port %u", + (cfg.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"), + pj_strbuf(&stun.server), + stun.port);*/ +} + +static void +add_turn_server(pj_pool_t& pool, pj_ice_strans_cfg& cfg, const TurnServerInfo& info) +{ + if (cfg.turn_tp_cnt >= PJ_ICE_MAX_TURN) + throw std::runtime_error("Too many TURN servers"); + + IpAddr ip {info.uri}; + + // Same comment as add_stun_server() + if (ip.getFamily() == AF_UNSPEC) { + /*JAMI_DBG("[ice (%s)] TURN server '%s' not used, unresolvable address", + (cfg.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"), + info.uri.c_str());*/ + return; + } + + auto& turn = cfg.turn_tp[cfg.turn_tp_cnt++]; + pj_ice_strans_turn_cfg_default(&turn); + pj_strdup2_with_null(&pool, &turn.server, ip.toString().c_str()); + turn.af = ip.getFamily(); + if (!(turn.port = ip.getPort())) + turn.port = PJ_STUN_PORT; + turn.cfg.max_pkt_size = STUN_MAX_PACKET_SIZE; + turn.conn_type = cfg.turn.conn_type; + + // Authorization (only static plain password supported yet) + if (not info.password.empty()) { + turn.auth_cred.type = PJ_STUN_AUTH_CRED_STATIC; + turn.auth_cred.data.static_cred.data_type = PJ_STUN_PASSWD_PLAIN; + pj_strset(&turn.auth_cred.data.static_cred.realm, + (char*) info.realm.c_str(), + info.realm.size()); + pj_strset(&turn.auth_cred.data.static_cred.username, + (char*) info.username.c_str(), + info.username.size()); + pj_strset(&turn.auth_cred.data.static_cred.data, + (char*) info.password.c_str(), + info.password.size()); + } + + /*JAMI_DBG("[ice (%s)] added turn server '%s', port %u", + (cfg.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"), + pj_strbuf(&turn.server), + turn.port);*/ +} + +//============================================================================== + +IceTransport::Impl::Impl(std::string_view name) + : sessionName_(name) +{ + if (logger_) + logger_->debug("[ice:{}] Creating IceTransport session for \"{:s}\"", fmt::ptr(this), name); +} + +IceTransport::Impl::~Impl() +{ + if (logger_) + logger_->debug("[ice:{}] destroying {}", fmt::ptr(this), fmt::ptr(icest_)); + + threadTerminateFlags_ = true; + + if (thread_.joinable()) { + thread_.join(); + } + + if (icest_) { + pj_ice_strans* strans = nullptr; + + std::swap(strans, icest_); + + // must be done before ioqueue/timer destruction + if (logger_) + logger_->debug("[ice:{}] Destroying ice_strans {}", pj_ice_strans_get_user_data(strans), fmt::ptr(strans)); + + pj_ice_strans_stop_ice(strans); + pj_ice_strans_destroy(strans); + + // NOTE: This last timer heap and IO queue polling is necessary to close + // TURN socket. + // Because when destroying the TURN session pjproject creates a pj_timer + // to postpone the TURN destruction. This timer is only called if we poll + // the event queue. + + int ret = flushTimerHeapAndIoQueue(); + + if (ret < 0) { + if (logger_) + logger_->error("[ice:{}] IO queue polling failed", fmt::ptr(this)); + } else if (ret > 0) { + if (logger_) + logger_->error("[ice:{}] Unexpected left timer in timer heap. " + "Please report the bug", + fmt::ptr(this)); + } + + if (checkEventQueue(1) > 0) { + if (logger_) + logger_->warn("[ice:{}] Unexpected left events in IO queue", fmt::ptr(this)); + } + + if (config_.stun_cfg.ioqueue) + pj_ioqueue_destroy(config_.stun_cfg.ioqueue); + + if (config_.stun_cfg.timer_heap) + pj_timer_heap_destroy(config_.stun_cfg.timer_heap); + } + + if (logger_) + logger_->debug("[ice:%p] done destroying", fmt::ptr(this)); + if (scb) + scb(); +} + +void +IceTransport::Impl::initIceInstance(const IceTransportOptions& options) +{ + isTcp_ = options.tcpEnable; + upnpEnabled_ = options.upnpEnable; + on_initdone_cb_ = options.onInitDone; + on_negodone_cb_ = options.onNegoDone; + streamsCount_ = options.streamsCount; + compCountPerStream_ = options.compCountPerStream; + compCount_ = streamsCount_ * compCountPerStream_; + compIO_ = std::vector<ComponentIO>(compCount_); + peerChannels_ = std::vector<PeerChannel>(compCount_); + iceDefaultRemoteAddr_.resize(compCount_); + initiatorSession_ = options.master; + accountLocalAddr_ = std::move(options.accountLocalAddr); + accountPublicAddr_ = std::move(options.accountPublicAddr); + stunServers_ = std::move(options.stunServers); + turnServers_ = std::move(options.turnServers); + + if (logger_) + logger_->debug("[ice:{}] Initializing the session - comp count {} - as a {}", + fmt::ptr(this), + compCount_, + initiatorSession_ ? "master" : "slave"); + + if (upnpEnabled_) + upnp_.reset(new upnp::Controller()); + + config_ = options.factory->getIceCfg(); // config copy + if (isTcp_) { + config_.protocol = PJ_ICE_TP_TCP; + config_.stun.conn_type = PJ_STUN_TP_TCP; + config_.turn.conn_type = PJ_TURN_TP_TCP; + } else { + config_.protocol = PJ_ICE_TP_UDP; + config_.stun.conn_type = PJ_STUN_TP_UDP; + config_.turn.conn_type = PJ_TURN_TP_UDP; + } + + pool_.reset( + pj_pool_create(options.factory->getPoolFactory(), "IceTransport.pool", 512, 512, NULL)); + if (not pool_) + throw std::runtime_error("pj_pool_create() failed"); + + // Note: For server reflexive candidates, UPNP mappings will + // be used if available. Then, the public address learnt during + // the account registration process will be added only if it + // differs from the UPNP public address. + // Also note that UPNP candidates should be added first in order + // to have a higher priority when performing the connectivity + // checks. + // STUN configs layout: + // - index 0 : host IPv4 + // - index 1 : host IPv6 + // - index 2 : upnp/generic srflx IPv4. + // - index 3 : generic srflx (if upnp exists and different) + + config_.stun_tp_cnt = 0; + + if (logger_) + logger_->debug("[ice:{}] Add host candidates", fmt::ptr(this)); + addStunConfig(pj_AF_INET()); + addStunConfig(pj_AF_INET6()); + + std::vector<std::pair<IpAddr, IpAddr>> upnpSrflxCand; + if (upnp_) { + requestUpnpMappings(); + upnpSrflxCand = setupUpnpReflexiveCandidates(); + if (not upnpSrflxCand.empty()) { + addServerReflexiveCandidates(upnpSrflxCand); + if (logger_) + logger_->debug("[ice:{}] Added UPNP srflx candidates:", fmt::ptr(this)); + } + } + + auto genericSrflxCand = setupGenericReflexiveCandidates(); + + if (not genericSrflxCand.empty()) { + // Generic srflx candidates will be added only if different + // from upnp candidates. + if (upnpSrflxCand.empty() + or (upnpSrflxCand[0].second.toString() != genericSrflxCand[0].second.toString())) { + addServerReflexiveCandidates(genericSrflxCand); + if (logger_) + logger_->debug("[ice:{}] Added generic srflx candidates:", fmt::ptr(this)); + } + } + + if (upnpSrflxCand.empty() and genericSrflxCand.empty()) { + if (logger_) + logger_->warn("[ice:{}] No server reflexive candidates added", fmt::ptr(this)); + } + + pj_ice_strans_cb icecb; + pj_bzero(&icecb, sizeof(icecb)); + + icecb.on_rx_data = [](pj_ice_strans* ice_st, + unsigned comp_id, + void* pkt, + pj_size_t size, + const pj_sockaddr_t* /*src_addr*/, + unsigned /*src_addr_len*/) { + if (auto* tr = static_cast<Impl*>(pj_ice_strans_get_user_data(ice_st))) + tr->onReceiveData(comp_id, pkt, size); + }; + + icecb.on_ice_complete = [](pj_ice_strans* ice_st, pj_ice_strans_op op, pj_status_t status) { + if (auto* tr = static_cast<Impl*>(pj_ice_strans_get_user_data(ice_st))) + tr->onComplete(ice_st, op, status); + }; + + if (isTcp_) { + icecb.on_data_sent = [](pj_ice_strans* ice_st, pj_ssize_t size) { + if (auto* tr = static_cast<Impl*>(pj_ice_strans_get_user_data(ice_st))) { + std::lock_guard lk(tr->sendDataMutex_); + tr->lastSentLen_ += size; + tr->waitDataCv_.notify_all(); + } + }; + } + + icecb.on_destroy = [](pj_ice_strans* ice_st) { + if (auto* tr = static_cast<Impl*>(pj_ice_strans_get_user_data(ice_st))) + tr->cancelOperations(); // Avoid upper layer to manage this ; Stop read operations + }; + + // Add STUN servers + for (auto& server : stunServers_) + add_stun_server(*pool_, config_, server); + + // Add TURN servers + for (auto& server : turnServers_) + add_turn_server(*pool_, config_, server); + + static constexpr auto IOQUEUE_MAX_HANDLES = std::min(PJ_IOQUEUE_MAX_HANDLES, 64); + TRY(pj_timer_heap_create(pool_.get(), 100, &config_.stun_cfg.timer_heap)); + TRY(pj_ioqueue_create(pool_.get(), IOQUEUE_MAX_HANDLES, &config_.stun_cfg.ioqueue)); + std::ostringstream sessionName {}; + // We use the instance pointer as the PJNATH session name in order + // to easily identify the logs reported by PJNATH. + sessionName << this; + pj_status_t status = pj_ice_strans_create(sessionName.str().c_str(), + &config_, + compCount_, + this, + &icecb, + &icest_); + + if (status != PJ_SUCCESS || icest_ == nullptr) { + throw std::runtime_error("pj_ice_strans_create() failed"); + } + + // Must be created after any potential failure + thread_ = std::thread([this] { + while (not threadTerminateFlags_) { + // NOTE: handleEvents can return false in this case + // but here we don't care if there is event or not. + handleEvents(HANDLE_EVENT_DURATION); + } + }); +} + +bool +IceTransport::Impl::_isInitialized() const +{ + if (auto *icest = icest_) { + auto state = pj_ice_strans_get_state(icest); + return state >= PJ_ICE_STRANS_STATE_SESS_READY and state != PJ_ICE_STRANS_STATE_FAILED; + } + return false; +} + +bool +IceTransport::Impl::_isStarted() const +{ + if (auto *icest = icest_) { + auto state = pj_ice_strans_get_state(icest); + return state >= PJ_ICE_STRANS_STATE_NEGO and state != PJ_ICE_STRANS_STATE_FAILED; + } + return false; +} + +bool +IceTransport::Impl::_isRunning() const +{ + if (auto *icest = icest_) { + auto state = pj_ice_strans_get_state(icest); + return state >= PJ_ICE_STRANS_STATE_RUNNING and state != PJ_ICE_STRANS_STATE_FAILED; + } + return false; +} + +bool +IceTransport::Impl::_isFailed() const +{ + if (auto *icest = icest_) + return pj_ice_strans_get_state(icest) == PJ_ICE_STRANS_STATE_FAILED; + return false; +} + +bool +IceTransport::Impl::handleEvents(unsigned max_msec) +{ + // By tests, never seen more than two events per 500ms + static constexpr auto MAX_NET_EVENTS = 2; + + pj_time_val max_timeout = {0, static_cast<long>(max_msec)}; + pj_time_val timeout = {0, 0}; + unsigned net_event_count = 0; + + pj_timer_heap_poll(config_.stun_cfg.timer_heap, &timeout); + auto hasActiveTimer = timeout.sec != PJ_MAXINT32 || timeout.msec != PJ_MAXINT32; + + // timeout limitation + if (hasActiveTimer) + pj_time_val_normalize(&timeout); + + if (PJ_TIME_VAL_GT(timeout, max_timeout)) { + timeout = max_timeout; + } + + do { + auto n_events = pj_ioqueue_poll(config_.stun_cfg.ioqueue, &timeout); + + // timeout + if (not n_events) + return hasActiveTimer; + + // error + if (n_events < 0) { + const auto err = pj_get_os_error(); + // Kept as debug as some errors are "normal" in regular context + if (logger_) + logger_->debug("[ice:{}] ioqueue error {:d}: {:s}", fmt::ptr(this), err, sip_utils::sip_strerror(err)); + std::this_thread::sleep_for(std::chrono::milliseconds(PJ_TIME_VAL_MSEC(timeout))); + return hasActiveTimer; + } + + net_event_count += n_events; + timeout.sec = timeout.msec = 0; + } while (net_event_count < MAX_NET_EVENTS); + return hasActiveTimer; +} + +int +IceTransport::Impl::flushTimerHeapAndIoQueue() +{ + pj_time_val timerTimeout = {0, 0}; + pj_time_val defaultWaitTime = {0, HANDLE_EVENT_DURATION}; + bool hasActiveTimer = false; + std::chrono::milliseconds totalWaitTime {0}; + auto const start = std::chrono::steady_clock::now(); + // We try to process pending events as fast as possible to + // speed-up the release. + int maxEventToProcess = 10; + + do { + if (checkEventQueue(maxEventToProcess) < 0) + return -1; + + pj_timer_heap_poll(config_.stun_cfg.timer_heap, &timerTimeout); + hasActiveTimer = !(timerTimeout.sec == PJ_MAXINT32 && timerTimeout.msec == PJ_MAXINT32); + + if (hasActiveTimer) { + pj_time_val_normalize(&timerTimeout); + auto waitTime = std::chrono::milliseconds( + std::min(PJ_TIME_VAL_MSEC(timerTimeout), PJ_TIME_VAL_MSEC(defaultWaitTime))); + std::this_thread::sleep_for(waitTime); + totalWaitTime += waitTime; + } + } while (hasActiveTimer && totalWaitTime < std::chrono::milliseconds(MAX_DESTRUCTION_TIMEOUT)); + + auto duration = std::chrono::steady_clock::now() - start; + if (logger_) + logger_->debug("[ice:{}] Timer heap flushed after {}", fmt::ptr(this), dht::print_duration(duration)); + + return static_cast<int>(pj_timer_heap_count(config_.stun_cfg.timer_heap)); +} + +int +IceTransport::Impl::checkEventQueue(int maxEventToPoll) +{ + pj_time_val timeout = {0, 0}; + int eventCount = 0; + int events = 0; + + do { + events = pj_ioqueue_poll(config_.stun_cfg.ioqueue, &timeout); + if (events < 0) { + const auto err = pj_get_os_error(); + if (logger_) + logger_->error("[ice:{}] ioqueue error {:d}: {:s}", fmt::ptr(this), err, sip_utils::sip_strerror(err)); + return events; + } + + eventCount += events; + + } while (events > 0 && eventCount < maxEventToPoll); + + return eventCount; +} + +void +IceTransport::Impl::onComplete(pj_ice_strans*, pj_ice_strans_op op, pj_status_t status) +{ + const char* opname = op == PJ_ICE_STRANS_OP_INIT ? "initialization" + : op == PJ_ICE_STRANS_OP_NEGOTIATION ? "negotiation" + : "unknown_op"; + + const bool done = status == PJ_SUCCESS; + if (done) { + if (logger_) + logger_->debug("[ice:{}] {:s} {:s} success", + fmt::ptr(this), + (config_.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"), + opname); + } else { + if (logger_) + logger_->error("[ice:{}] {:s} {:s} failed: {:s}", + fmt::ptr(this), + (config_.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"), + opname, + sip_utils::sip_strerror(status)); + } + + if (done and op == PJ_ICE_STRANS_OP_INIT) { + if (initiatorSession_) + setInitiatorSession(); + else + setSlaveSession(); + } + + if (op == PJ_ICE_STRANS_OP_INIT and on_initdone_cb_) + on_initdone_cb_(done); + else if (op == PJ_ICE_STRANS_OP_NEGOTIATION) { + if (done) { + // Dump of connection pairs + if (logger_) + logger_->debug("[ice:{}] {:s} connection pairs ([comp id] local [type] <-> remote [type]):\n{:s}", + fmt::ptr(this), + (config_.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"), + link()); + } + if (on_negodone_cb_) + on_negodone_cb_(done); + } + + iceCV_.notify_all(); +} + +std::string +IceTransport::Impl::link() const +{ + std::ostringstream out; + for (unsigned strm = 1; strm <= streamsCount_ * compCountPerStream_; strm++) { + auto absIdx = strm; + auto comp = (strm + 1) / compCountPerStream_; + auto laddr = getLocalAddress(absIdx); + auto raddr = getRemoteAddress(absIdx); + + if (laddr and laddr.getPort() != 0 and raddr and raddr.getPort() != 0) { + out << " [" << comp << "] " << laddr.toString(true, true) << " [" + << getCandidateType(getSelectedCandidate(absIdx, false)) << "] " + << " <-> " << raddr.toString(true, true) << " [" + << getCandidateType(getSelectedCandidate(absIdx, true)) << "] " << '\n'; + } else { + out << " [" << comp << "] disabled\n"; + } + } + return out.str(); +} + +bool +IceTransport::Impl::setInitiatorSession() +{ + if (logger_) + logger_->debug("[ice:{}] as master", fmt::ptr(this)); + initiatorSession_ = true; + if (_isInitialized()) { + auto status = pj_ice_strans_change_role(icest_, PJ_ICE_SESS_ROLE_CONTROLLING); + if (status != PJ_SUCCESS) { + if (logger_) + logger_->error("[ice:{}] role change failed: {:s}", fmt::ptr(this), sip_utils::sip_strerror(status)); + return false; + } + return true; + } + return createIceSession(PJ_ICE_SESS_ROLE_CONTROLLING); +} + +bool +IceTransport::Impl::setSlaveSession() +{ + if (logger_) + logger_->debug("[ice:{}] as slave", fmt::ptr(this)); + initiatorSession_ = false; + if (_isInitialized()) { + auto status = pj_ice_strans_change_role(icest_, PJ_ICE_SESS_ROLE_CONTROLLED); + if (status != PJ_SUCCESS) { + if (logger_) + logger_->error("[ice:{}] role change failed: {:s}", fmt::ptr(this), sip_utils::sip_strerror(status)); + return false; + } + return true; + } + return createIceSession(PJ_ICE_SESS_ROLE_CONTROLLED); +} + +const pj_ice_sess_cand* +IceTransport::Impl::getSelectedCandidate(unsigned comp_id, bool remote) const +{ + ASSERT_COMP_ID(comp_id, compCount_); + + // Return the selected candidate pair. Might not be the nominated pair if + // ICE has not concluded yet, but should be the nominated pair afterwards. + if (not _isRunning()) { + if (logger_) + logger_->error("[ice:{}] ICE transport is not running", fmt::ptr(this)); + return nullptr; + } + + const auto* sess = pj_ice_strans_get_valid_pair(icest_, comp_id); + if (sess == nullptr) { + if (logger_) + logger_->warn("[ice:{}] Component {} has no valid pair (disabled)", fmt::ptr(this), comp_id); + return nullptr; + } + + if (remote) + return sess->rcand; + else + return sess->lcand; +} + +IpAddr +IceTransport::Impl::getLocalAddress(unsigned comp_id) const +{ + ASSERT_COMP_ID(comp_id, compCount_); + + if (auto cand = getSelectedCandidate(comp_id, false)) + return cand->addr; + + return {}; +} + +IpAddr +IceTransport::Impl::getRemoteAddress(unsigned comp_id) const +{ + ASSERT_COMP_ID(comp_id, compCount_); + + if (auto cand = getSelectedCandidate(comp_id, true)) + return cand->addr; + + return {}; +} + +const char* +IceTransport::Impl::getCandidateType(const pj_ice_sess_cand* cand) +{ + auto name = cand ? pj_ice_get_cand_type_name(cand->type) : nullptr; + return name ? name : "?"; +} + +void +IceTransport::Impl::getUFragPwd() +{ + if (icest_) { + pj_str_t local_ufrag, local_pwd; + + pj_ice_strans_get_ufrag_pwd(icest_, &local_ufrag, &local_pwd, nullptr, nullptr); + local_ufrag_.assign(local_ufrag.ptr, local_ufrag.slen); + local_pwd_.assign(local_pwd.ptr, local_pwd.slen); + } +} + +bool +IceTransport::Impl::createIceSession(pj_ice_sess_role role) +{ + if (not icest_) { + return false; + } + + if (pj_ice_strans_init_ice(icest_, role, nullptr, nullptr) != PJ_SUCCESS) { + if (logger_) + logger_->error("[ice:{}] pj_ice_strans_init_ice() failed", fmt::ptr(this)); + return false; + } + + // Fetch some information on local configuration + getUFragPwd(); + + if (logger_) + logger_->debug("[ice:{}] (local) ufrag=%s, pwd=%s", fmt::ptr(this), local_ufrag_.c_str(), local_pwd_.c_str()); + + return true; +} + +bool +IceTransport::Impl::addStunConfig(int af) +{ + if (config_.stun_tp_cnt >= PJ_ICE_MAX_STUN) { + if (logger_) + logger_->error("Max number of STUN configurations reached (%i)", PJ_ICE_MAX_STUN); + return false; + } + + if (af != pj_AF_INET() and af != pj_AF_INET6()) { + if (logger_) + logger_->error("Invalid address familly (%i)", af); + return false; + } + + auto& stun = config_.stun_tp[config_.stun_tp_cnt++]; + + pj_ice_strans_stun_cfg_default(&stun); + stun.cfg.max_pkt_size = STUN_MAX_PACKET_SIZE; + stun.af = af; + stun.conn_type = config_.stun.conn_type; + + if (logger_) + logger_->debug("[ice:{}] added host stun config for {:s} transport", + fmt::ptr(this), + config_.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"); + + return true; +} + +void +IceTransport::Impl::requestUpnpMappings() +{ + // Must be called once ! + + std::lock_guard<std::mutex> lock(upnpMutex_); + + if (not upnp_) + return; + + auto transport = isTcpEnabled() ? PJ_CAND_TCP_PASSIVE : PJ_CAND_UDP; + auto portType = transport == PJ_CAND_UDP ? PortType::UDP : PortType::TCP; + + // Request upnp mapping for each component. + for (unsigned id = 1; id <= compCount_; id++) { + // Set port number to 0 to get any available port. + Mapping requestedMap(portType); + + // Request the mapping + Mapping::sharedPtr_t mapPtr = upnp_->reserveMapping(requestedMap); + + // To use a mapping, it must be valid, open and has valid host address. + if (mapPtr and mapPtr->getMapKey() and (mapPtr->getState() == MappingState::OPEN) + and mapPtr->hasValidHostAddress()) { + std::lock_guard<std::mutex> lock(upnpMappingsMutex_); + auto ret = upnpMappings_.emplace(mapPtr->getMapKey(), *mapPtr); + if (ret.second) { + if (logger_) + logger_->debug("[ice:{}] UPNP mapping {:s} successfully allocated", + fmt::ptr(this), + mapPtr->toString(true)); + } else { + if (logger_) + logger_->warn("[ice:{}] UPNP mapping {:s} already in the list!", + fmt::ptr(this), + mapPtr->toString()); + } + } else { + if (logger_) + logger_->warn("[ice:{}] UPNP mapping request failed!", fmt::ptr(this)); + upnp_->releaseMapping(requestedMap); + } + } +} + +bool +IceTransport::Impl::hasUpnp() const +{ + return upnp_ and upnpMappings_.size() == compCount_; +} + +void +IceTransport::Impl::addServerReflexiveCandidates( + const std::vector<std::pair<IpAddr, IpAddr>>& addrList) +{ + if (addrList.size() != compCount_) { + if (logger_) + logger_->warn("[ice:{}] Provided addr list size {} does not match component count {}", + fmt::ptr(this), + addrList.size(), + compCount_); + return; + } + if (compCount_ > PJ_ICE_MAX_COMP) { + if (logger_) + logger_->error("[ice:{}] Too many components", fmt::ptr(this)); + return; + } + + // Add config for server reflexive candidates (UPNP or from DHT). + if (not addStunConfig(pj_AF_INET())) + return; + + assert(config_.stun_tp_cnt > 0 && config_.stun_tp_cnt < PJ_ICE_MAX_STUN); + auto& stun = config_.stun_tp[config_.stun_tp_cnt - 1]; + + for (unsigned id = 1; id <= compCount_; id++) { + auto idx = id - 1; + auto& localAddr = addrList[idx].first; + auto& publicAddr = addrList[idx].second; + + if (logger_) + logger_->debug("[ice:{}] Add srflx reflexive candidates [{:s} : {:s}] for comp {:d}", + fmt::ptr(this), + localAddr.toString(true), + publicAddr.toString(true), + id); + + pj_sockaddr_cp(&stun.cfg.user_mapping[idx].local_addr, localAddr.pjPtr()); + pj_sockaddr_cp(&stun.cfg.user_mapping[idx].mapped_addr, publicAddr.pjPtr()); + + if (isTcpEnabled()) { + if (publicAddr.getPort() == 9) { + stun.cfg.user_mapping[idx].tp_type = PJ_CAND_TCP_ACTIVE; + } else { + stun.cfg.user_mapping[idx].tp_type = PJ_CAND_TCP_PASSIVE; + } + } else { + stun.cfg.user_mapping[idx].tp_type = PJ_CAND_UDP; + } + } + + stun.cfg.user_mapping_cnt = compCount_; +} + +std::vector<std::pair<IpAddr, IpAddr>> +IceTransport::Impl::setupGenericReflexiveCandidates() +{ + if (not accountLocalAddr_) { + if (logger_) + logger_->warn("[ice:{}] Missing local address, generic srflx candidates wont be generated!", + fmt::ptr(this)); + return {}; + } + + if (not accountPublicAddr_) { + if (logger_) + logger_->warn("[ice:{}] Missing public address, generic srflx candidates wont be generated!", + fmt::ptr(this)); + return {}; + } + + std::vector<std::pair<IpAddr, IpAddr>> addrList; + auto isTcp = isTcpEnabled(); + + addrList.reserve(compCount_); + for (unsigned id = 1; id <= compCount_; id++) { + // For TCP, the type is set to active, because most likely the incoming + // connection will be blocked by the NAT. + // For UDP use random port number. + uint16_t port = isTcp ? 9 + : upnp::Controller::generateRandomPort(isTcp ? PortType::TCP + : PortType::UDP); + + accountLocalAddr_.setPort(port); + accountPublicAddr_.setPort(port); + addrList.emplace_back(accountLocalAddr_, accountPublicAddr_); + } + + return addrList; +} + +std::vector<std::pair<IpAddr, IpAddr>> +IceTransport::Impl::setupUpnpReflexiveCandidates() +{ + // Add UPNP server reflexive candidates if available. + if (not hasUpnp()) + return {}; + + std::lock_guard<std::mutex> lock(upnpMappingsMutex_); + + if (upnpMappings_.size() < (size_t)compCount_) { + if (logger_) + logger_->warn("[ice:{}] Not enough mappings {:d}. Expected {:d}", + fmt::ptr(this), + upnpMappings_.size(), + compCount_); + return {}; + } + + std::vector<std::pair<IpAddr, IpAddr>> addrList; + + addrList.reserve(upnpMappings_.size()); + for (auto const& [_, map] : upnpMappings_) { + assert(map.getMapKey()); + IpAddr localAddr {map.getInternalAddress()}; + localAddr.setPort(map.getInternalPort()); + IpAddr publicAddr {map.getExternalAddress()}; + publicAddr.setPort(map.getExternalPort()); + addrList.emplace_back(localAddr, publicAddr); + } + + return addrList; +} + +void +IceTransport::Impl::setDefaultRemoteAddress(unsigned compId, const IpAddr& addr) +{ + ASSERT_COMP_ID(compId, compCount_); + + iceDefaultRemoteAddr_[compId - 1] = addr; + // The port does not matter. Set it 0 to avoid confusion. + iceDefaultRemoteAddr_[compId - 1].setPort(0); +} + +IpAddr +IceTransport::Impl::getDefaultRemoteAddress(unsigned compId) const +{ + ASSERT_COMP_ID(compId, compCount_); + return iceDefaultRemoteAddr_[compId - 1]; +} + +void +IceTransport::Impl::onReceiveData(unsigned comp_id, void* pkt, pj_size_t size) +{ + ASSERT_COMP_ID(comp_id, compCount_); + + jami_tracepoint_if_enabled(ice_transport_recv, + reinterpret_cast<uint64_t>(this), + comp_id, + size, + getRemoteAddress(comp_id).toString().c_str()); + if (size == 0) + return; + + { + auto& io = compIO_[comp_id - 1]; + std::lock_guard<std::mutex> lk(io.mutex); + + if (io.recvCb) { + io.recvCb((uint8_t*) pkt, size); + return; + } + } + + std::error_code ec; + auto err = peerChannels_.at(comp_id - 1).write((const char*) pkt, size, ec); + if (err < 0) { + if (logger_) + logger_->error("[ice:{}] rx: channel is closed", fmt::ptr(this)); + } +} + +bool +IceTransport::Impl::_waitForInitialization(std::chrono::milliseconds timeout) +{ + IceLock lk(icest_); + + if (not iceCV_.wait_for(lk, timeout, [this] { + return threadTerminateFlags_ or _isInitialized() or _isFailed(); + })) { + if (logger_) + logger_->warn("[ice:{}] waitForInitialization: timeout", fmt::ptr(this)); + return false; + } + + return _isInitialized(); +} + +//============================================================================== + +IceTransport::IceTransport(std::string_view name) + : pimpl_ {std::make_unique<Impl>(name)} +{} + +IceTransport::~IceTransport() +{ + cancelOperations(); +} + +const std::shared_ptr<dht::log::Logger>& +IceTransport::logger() const +{ + return pimpl_->logger_; +} + +void +IceTransport::initIceInstance(const IceTransportOptions& options) +{ + pimpl_->initIceInstance(options); + jami_tracepoint(ice_transport_context, reinterpret_cast<uint64_t>(this)); +} + +bool +IceTransport::isInitialized() const +{ + IceLock lk(pimpl_->icest_); + return pimpl_->_isInitialized(); +} + +bool +IceTransport::isStarted() const +{ + IceLock lk(pimpl_->icest_); + return pimpl_->_isStarted(); +} + +bool +IceTransport::isRunning() const +{ + if (!pimpl_->icest_) + return false; + IceLock lk(pimpl_->icest_); + return pimpl_->_isRunning(); +} + +bool +IceTransport::isFailed() const +{ + return pimpl_->_isFailed(); +} + +unsigned +IceTransport::getComponentCount() const +{ + return pimpl_->compCount_; +} + +bool +IceTransport::setSlaveSession() +{ + return pimpl_->setSlaveSession(); +} +bool +IceTransport::setInitiatorSession() +{ + return pimpl_->setInitiatorSession(); +} + +bool +IceTransport::isInitiator() const +{ + if (isInitialized()) { + return pj_ice_strans_get_role(pimpl_->icest_) == PJ_ICE_SESS_ROLE_CONTROLLING; + } + return pimpl_->initiatorSession_; +} + +bool +IceTransport::startIce(const Attribute& rem_attrs, std::vector<IceCandidate>&& rem_candidates) +{ + if (not isInitialized()) { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] not initialized transport", fmt::ptr(pimpl_.get())); + pimpl_->is_stopped_ = true; + return false; + } + + // pj_ice_strans_start_ice crashes if remote candidates array is empty + if (rem_candidates.empty()) { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] start failed: no remote candidates", fmt::ptr(pimpl_.get())); + pimpl_->is_stopped_ = true; + return false; + } + + auto comp_cnt = std::max(1u, getComponentCount()); + if (rem_candidates.size() / comp_cnt > PJ_ICE_ST_MAX_CAND - 1) { + std::vector<IceCandidate> rcands; + rcands.reserve(PJ_ICE_ST_MAX_CAND - 1); + if (pimpl_->logger_) + pimpl_->logger_->warn("[ice:{}] too much candidates detected, trim list.", fmt::ptr(pimpl_.get())); + // Just trim some candidates. To avoid to only take host candidates, iterate + // through the whole list and select some host, some turn and peer reflexives + // It should give at least enough infos to negotiate. + auto maxHosts = 8; + auto maxRelays = PJ_ICE_MAX_TURN; + for (auto& c : rem_candidates) { + if (c.type == PJ_ICE_CAND_TYPE_HOST) { + if (maxHosts == 0) + continue; + maxHosts -= 1; + } else if (c.type == PJ_ICE_CAND_TYPE_RELAYED) { + if (maxRelays == 0) + continue; + maxRelays -= 1; + } + if (rcands.size() == PJ_ICE_ST_MAX_CAND - 1) + break; + rcands.emplace_back(std::move(c)); + } + rem_candidates = std::move(rcands); + } + + pj_str_t ufrag, pwd; + if (pimpl_->logger_) + pimpl_->logger_->debug("[ice:{}] negotiation starting ({:d} remote candidates)", + fmt::ptr(pimpl_), + rem_candidates.size()); + + auto status = pj_ice_strans_start_ice(pimpl_->icest_, + pj_strset(&ufrag, + (char*) rem_attrs.ufrag.c_str(), + rem_attrs.ufrag.size()), + pj_strset(&pwd, + (char*) rem_attrs.pwd.c_str(), + rem_attrs.pwd.size()), + rem_candidates.size(), + rem_candidates.data()); + if (status != PJ_SUCCESS) { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] start failed: {:s}", fmt::ptr(pimpl_.get()), sip_utils::sip_strerror(status)); + pimpl_->is_stopped_ = true; + return false; + } + + return true; +} + +bool +IceTransport::startIce(const SDP& sdp) +{ + if (pimpl_->streamsCount_ != 1) { + if (pimpl_->logger_) + pimpl_->logger_->error(FMT_STRING("Expected exactly one stream per SDP (found {:u} streams)"), pimpl_->streamsCount_); + return false; + } + + if (not isInitialized()) { + if (pimpl_->logger_) + pimpl_->logger_->error(FMT_STRING("[ice:{}] not initialized transport"), fmt::ptr(pimpl_)); + pimpl_->is_stopped_ = true; + return false; + } + + for (unsigned id = 1; id <= getComponentCount(); id++) { + auto candVec = getLocalCandidates(id); + for (auto const& cand : candVec) { + if (pimpl_->logger_) + pimpl_->logger_->debug("[ice:{}] Using local candidate {:s} for comp {:d}", + fmt::ptr(pimpl_), cand, id); + } + } + + if (pimpl_->logger_) + pimpl_->logger_->debug("[ice:{}] negotiation starting ({:u} remote candidates)", + fmt::ptr(pimpl_), sdp.candidates.size()); + pj_str_t ufrag, pwd; + + std::vector<IceCandidate> rem_candidates; + rem_candidates.reserve(sdp.candidates.size()); + IceCandidate cand; + for (const auto& line : sdp.candidates) { + if (parseIceAttributeLine(0, line, cand)) + rem_candidates.emplace_back(cand); + } + + auto status = pj_ice_strans_start_ice(pimpl_->icest_, + pj_strset(&ufrag, + (char*) sdp.ufrag.c_str(), + sdp.ufrag.size()), + pj_strset(&pwd, (char*) sdp.pwd.c_str(), sdp.pwd.size()), + rem_candidates.size(), + rem_candidates.data()); + if (status != PJ_SUCCESS) { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] start failed: {:s}", fmt::ptr(pimpl_), sip_utils::sip_strerror(status)); + pimpl_->is_stopped_ = true; + return false; + } + + return true; +} + +void +IceTransport::cancelOperations() +{ + pimpl_->cancelOperations(); +} + +IpAddr +IceTransport::getLocalAddress(unsigned comp_id) const +{ + return pimpl_->getLocalAddress(comp_id); +} + +IpAddr +IceTransport::getRemoteAddress(unsigned comp_id) const +{ + // Return the default remote address if set. + // Note that the default remote addresses are the addresses + // set in the 'c=' and 'a=rtcp' lines of the received SDP. + // See pj_ice_strans_sendto2() for more details. + if (auto defaultAddr = pimpl_->getDefaultRemoteAddress(comp_id)) { + return defaultAddr; + } + + return pimpl_->getRemoteAddress(comp_id); +} + +const IceTransport::Attribute +IceTransport::getLocalAttributes() const +{ + return {pimpl_->local_ufrag_, pimpl_->local_pwd_}; +} + +std::vector<std::string> +IceTransport::getLocalCandidates(unsigned comp_id) const +{ + ASSERT_COMP_ID(comp_id, getComponentCount()); + std::vector<std::string> res; + pj_ice_sess_cand cand[MAX_CANDIDATES]; + unsigned cand_cnt = PJ_ARRAY_SIZE(cand); + + if (!isInitialized()) { + return res; + } + + if (pj_ice_strans_enum_cands(pimpl_->icest_, comp_id, &cand_cnt, cand) != PJ_SUCCESS) { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] pj_ice_strans_enum_cands() failed", fmt::ptr(pimpl_)); + return res; + } + + res.reserve(cand_cnt); + for (unsigned i = 0; i < cand_cnt; ++i) { + /** Section 4.5, RFC 6544 (https://tools.ietf.org/html/rfc6544) + * candidate-attribute = "candidate" ":" foundation SP component-id + * SP "TCP" SP priority SP connection-address SP port SP cand-type [SP + * rel-addr] [SP rel-port] SP tcp-type-ext + * *(SP extension-att-name SP + * extension-att-value) + * + * tcp-type-ext = "tcptype" SP tcp-type + * tcp-type = "active" / "passive" / "so" + */ + char ipaddr[PJ_INET6_ADDRSTRLEN]; + std::string tcp_type; + if (cand[i].transport != PJ_CAND_UDP) { + tcp_type += " tcptype"; + switch (cand[i].transport) { + case PJ_CAND_TCP_ACTIVE: + tcp_type += " active"; + break; + case PJ_CAND_TCP_PASSIVE: + tcp_type += " passive"; + break; + case PJ_CAND_TCP_SO: + default: + tcp_type += " so"; + break; + } + } + res.emplace_back( + fmt::format("{} {} {} {} {} {} typ {}{}", + sip_utils::as_view(cand[i].foundation), + cand[i].comp_id, + (cand[i].transport == PJ_CAND_UDP ? "UDP" : "TCP"), + cand[i].prio, + pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0), + pj_sockaddr_get_port(&cand[i].addr), + pj_ice_get_cand_type_name(cand[i].type), + tcp_type)); + } + + return res; +} +std::vector<std::string> +IceTransport::getLocalCandidates(unsigned streamIdx, unsigned compId) const +{ + ASSERT_COMP_ID(compId, getComponentCount()); + + std::vector<std::string> res; + pj_ice_sess_cand cand[MAX_CANDIDATES]; + unsigned cand_cnt = MAX_CANDIDATES; + + if (not isInitialized()) { + return res; + } + + // In the implementation, the component IDs are enumerated globally + // (per SDP: 1, 2, 3, 4, ...). This is simpler because we create + // only one pj_ice_strans instance. However, the component IDs are + // enumerated per stream in the generated SDP (1, 2, 1, 2, ...) in + // order to be compliant with the spec. + + auto globalCompId = streamIdx * 2 + compId; + if (pj_ice_strans_enum_cands(pimpl_->icest_, globalCompId, &cand_cnt, cand) != PJ_SUCCESS) { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] pj_ice_strans_enum_cands() failed", fmt::ptr(pimpl_)); + return res; + } + + res.reserve(cand_cnt); + // Build ICE attributes according to RFC 6544, section 4.5. + for (unsigned i = 0; i < cand_cnt; ++i) { + char ipaddr[PJ_INET6_ADDRSTRLEN]; + std::string tcp_type; + if (cand[i].transport != PJ_CAND_UDP) { + tcp_type += " tcptype"; + switch (cand[i].transport) { + case PJ_CAND_TCP_ACTIVE: + tcp_type += " active"; + break; + case PJ_CAND_TCP_PASSIVE: + tcp_type += " passive"; + break; + case PJ_CAND_TCP_SO: + default: + tcp_type += " so"; + break; + } + } + res.emplace_back( + fmt::format("{} {} {} {} {} {} typ {}{}", + sip_utils::as_view(cand[i].foundation), + compId, + (cand[i].transport == PJ_CAND_UDP ? "UDP" : "TCP"), + cand[i].prio, + pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0), + pj_sockaddr_get_port(&cand[i].addr), + pj_ice_get_cand_type_name(cand[i].type), + tcp_type)); + } + + return res; +} + +bool +IceTransport::parseIceAttributeLine(unsigned streamIdx, + const std::string& line, + IceCandidate& cand) const +{ + // Silently ignore empty lines + if (line.empty()) + return false; + + if (streamIdx >= pimpl_->streamsCount_) { + throw std::runtime_error(fmt::format("Stream index {:d} is invalid!", streamIdx)); + } + + int af, cnt; + char foundation[32], transport[12], ipaddr[80], type[32], tcp_type[32]; + pj_str_t tmpaddr; + unsigned comp_id, prio, port; + pj_status_t status; + pj_bool_t is_tcp = PJ_FALSE; + + // Parse ICE attribute line according to RFC-6544 section 4.5. + // TODO/WARNING: There is no fail-safe in case of malformed attributes. + cnt = sscanf(line.c_str(), + "%31s %u %11s %u %79s %u typ %31s tcptype %31s\n", + foundation, + &comp_id, + transport, + &prio, + ipaddr, + &port, + type, + tcp_type); + if (cnt != 7 && cnt != 8) { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] Invalid ICE candidate line: {:s}", fmt::ptr(pimpl_), line); + return false; + } + + if (strcmp(transport, "TCP") == 0) { + is_tcp = PJ_TRUE; + } + + pj_bzero(&cand, sizeof(IceCandidate)); + + if (strcmp(type, "host") == 0) + cand.type = PJ_ICE_CAND_TYPE_HOST; + else if (strcmp(type, "srflx") == 0) + cand.type = PJ_ICE_CAND_TYPE_SRFLX; + else if (strcmp(type, "prflx") == 0) + cand.type = PJ_ICE_CAND_TYPE_PRFLX; + else if (strcmp(type, "relay") == 0) + cand.type = PJ_ICE_CAND_TYPE_RELAYED; + else { + if (pimpl_->logger_) + pimpl_->logger_->warn("[ice:{}] invalid remote candidate type '{:s}'", fmt::ptr(pimpl_), type); + return false; + } + + if (is_tcp) { + if (strcmp(tcp_type, "active") == 0) + cand.transport = PJ_CAND_TCP_ACTIVE; + else if (strcmp(tcp_type, "passive") == 0) + cand.transport = PJ_CAND_TCP_PASSIVE; + else if (strcmp(tcp_type, "so") == 0) + cand.transport = PJ_CAND_TCP_SO; + else { + if (pimpl_->logger_) + pimpl_->logger_->warn("[ice:{}] invalid transport type type '{:s}'", fmt::ptr(pimpl_), tcp_type); + return false; + } + } else { + cand.transport = PJ_CAND_UDP; + } + + // If the component Id is enumerated relative to media, convert + // it to absolute enumeration. + if (comp_id <= pimpl_->compCountPerStream_) { + comp_id += pimpl_->compCountPerStream_ * streamIdx; + } + cand.comp_id = (pj_uint8_t) comp_id; + + cand.prio = prio; + + if (strchr(ipaddr, ':')) + af = pj_AF_INET6(); + else { + af = pj_AF_INET(); + pimpl_->onlyIPv4Private_ &= IpAddr(ipaddr).isPrivate(); + } + + tmpaddr = pj_str(ipaddr); + pj_sockaddr_init(af, &cand.addr, NULL, 0); + status = pj_sockaddr_set_str_addr(af, &cand.addr, &tmpaddr); + if (status != PJ_SUCCESS) { + if (pimpl_->logger_) + pimpl_->logger_->warn("[ice:{}] invalid IP address '{:s}'", fmt::ptr(pimpl_), ipaddr); + return false; + } + + pj_sockaddr_set_port(&cand.addr, (pj_uint16_t) port); + pj_strdup2(pimpl_->pool_.get(), &cand.foundation, foundation); + + return true; +} + +ssize_t +IceTransport::recv(unsigned compId, unsigned char* buf, size_t len, std::error_code& ec) +{ + ASSERT_COMP_ID(compId, getComponentCount()); + auto& io = pimpl_->compIO_[compId - 1]; + std::lock_guard<std::mutex> lk(io.mutex); + + if (io.queue.empty()) { + ec = std::make_error_code(std::errc::resource_unavailable_try_again); + return -1; + } + + auto& packet = io.queue.front(); + const auto count = std::min(len, packet.data.size()); + std::copy_n(packet.data.begin(), count, buf); + if (count == packet.data.size()) { + io.queue.pop_front(); + } else { + packet.data.erase(packet.data.begin(), packet.data.begin() + count); + } + + ec.clear(); + return count; +} + +ssize_t +IceTransport::recvfrom(unsigned compId, char* buf, size_t len, std::error_code& ec) +{ + ASSERT_COMP_ID(compId, getComponentCount()); + return pimpl_->peerChannels_.at(compId - 1).read(buf, len, ec); +} + +void +IceTransport::setOnRecv(unsigned compId, IceRecvCb cb) +{ + ASSERT_COMP_ID(compId, getComponentCount()); + + auto& io = pimpl_->compIO_[compId - 1]; + std::lock_guard<std::mutex> lk(io.mutex); + io.recvCb = std::move(cb); + + if (io.recvCb) { + // Flush existing queue using the callback + for (const auto& packet : io.queue) + io.recvCb((uint8_t*) packet.data.data(), packet.data.size()); + io.queue.clear(); + } +} + +void +IceTransport::setOnShutdown(onShutdownCb&& cb) +{ + pimpl_->scb = cb; +} + +ssize_t +IceTransport::send(unsigned compId, const unsigned char* buf, size_t len) +{ + ASSERT_COMP_ID(compId, getComponentCount()); + + auto remote = getRemoteAddress(compId); + + if (!remote) { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] can't find remote address for component {:d}", fmt::ptr(pimpl_), compId); + errno = EINVAL; + return -1; + } + + std::unique_lock dlk(pimpl_->sendDataMutex_, std::defer_lock); + if (isTCPEnabled()) + dlk.lock(); + + jami_tracepoint(ice_transport_send, + reinterpret_cast<uint64_t>(this), + compId, + len, + remote.toString().c_str()); + + auto status = pj_ice_strans_sendto2(pimpl_->icest_, + compId, + buf, + len, + remote.pjPtr(), + remote.getLength()); + + jami_tracepoint(ice_transport_send_status, status); + + if (status == PJ_EPENDING && isTCPEnabled()) { + // NOTE; because we are in TCP, the sent size will count the header (2 + // bytes length). + pimpl_->waitDataCv_.wait(dlk, [&] { + return pimpl_->lastSentLen_ >= static_cast<pj_size_t>(len) or pimpl_->destroying_; + }); + pimpl_->lastSentLen_ = 0; + } else if (status != PJ_SUCCESS && status != PJ_EPENDING) { + if (status == PJ_EBUSY) { + errno = EAGAIN; + } else { + if (pimpl_->logger_) + pimpl_->logger_->error("[ice:{}] ice send failed: {:s}", fmt::ptr(pimpl_), sip_utils::sip_strerror(status)); + errno = EIO; + } + return -1; + } + + return len; +} + +bool +IceTransport::waitForInitialization(std::chrono::milliseconds timeout) +{ + return pimpl_->_waitForInitialization(timeout); +} + +ssize_t +IceTransport::waitForData(unsigned compId, std::chrono::milliseconds timeout, std::error_code& ec) +{ + ASSERT_COMP_ID(compId, getComponentCount()); + return pimpl_->peerChannels_.at(compId - 1).wait(timeout, ec); +} + +bool +IceTransport::isTCPEnabled() +{ + return pimpl_->isTcpEnabled(); +} + +ICESDP +IceTransport::parseIceCandidates(std::string_view sdp_msg) +{ + if (pimpl_->streamsCount_ != 1) { + if (pimpl_->logger_) + pimpl_->logger_->error("Expected exactly one stream per SDP (found %u streams)", pimpl_->streamsCount_); + return {}; + } + + ICESDP res; + int nr = 0; + for (std::string_view line; jami::getline(sdp_msg, line); nr++) { + if (nr == 0) { + res.rem_ufrag = line; + } else if (nr == 1) { + res.rem_pwd = line; + } else { + IceCandidate cand; + if (parseIceAttributeLine(0, std::string(line), cand)) { + if (pimpl_->logger_) + pimpl_->logger_->debug("[ice:{}] Add remote candidate: {}", + fmt::ptr(pimpl_), + line); + res.rem_candidates.emplace_back(cand); + } + } + } + return res; +} + +void +IceTransport::setDefaultRemoteAddress(unsigned comp_id, const IpAddr& addr) +{ + pimpl_->setDefaultRemoteAddress(comp_id, addr); +} + +std::string +IceTransport::link() const +{ + return pimpl_->link(); +} + +//============================================================================== + +IceTransportFactory::IceTransportFactory() + : cp_(new pj_caching_pool(), + [](pj_caching_pool* p) { + pj_caching_pool_destroy(p); + delete p; + }) + , ice_cfg_() +{ + pj_caching_pool_init(cp_.get(), NULL, 0); + + pj_ice_strans_cfg_default(&ice_cfg_); + ice_cfg_.stun_cfg.pf = &cp_->factory; + + // v2.4.5 of PJNATH has a default of 100ms but RFC 5389 since version 14 requires + // a minimum of 500ms on fixed-line links. Our usual case is wireless links. + // This solves too long ICE exchange by DHT. + // Using 500ms with default PJ_STUN_MAX_TRANSMIT_COUNT (7) gives around 33s before timeout. + ice_cfg_.stun_cfg.rto_msec = 500; + + // See https://tools.ietf.org/html/rfc5245#section-8.1.1.2 + // If enabled, it may help speed-up the connectivity, but may cause + // the nomination of sub-optimal pairs. + ice_cfg_.opt.aggressive = PJ_FALSE; +} + +IceTransportFactory::~IceTransportFactory() {} + +std::shared_ptr<IceTransport> +IceTransportFactory::createTransport(std::string_view name) +{ + try { + return std::make_shared<IceTransport>(name); + } catch (const std::exception& e) { + //JAMI_ERR("%s", e.what()); + return nullptr; + } +} + +std::unique_ptr<IceTransport> +IceTransportFactory::createUTransport(std::string_view name) +{ + try { + return std::make_unique<IceTransport>(name); + } catch (const std::exception& e) { + //JAMI_ERR("%s", e.what()); + return nullptr; + } +} + +//============================================================================== + +void +IceSocket::close() +{ + if (ice_transport_) + ice_transport_->setOnRecv(compId_, {}); + ice_transport_.reset(); +} + +ssize_t +IceSocket::send(const unsigned char* buf, size_t len) +{ + if (not ice_transport_) + return -1; + return ice_transport_->send(compId_, buf, len); +} + +ssize_t +IceSocket::waitForData(std::chrono::milliseconds timeout) +{ + if (not ice_transport_) + return -1; + + std::error_code ec; + return ice_transport_->waitForData(compId_, timeout, ec); +} + +void +IceSocket::setOnRecv(IceRecvCb cb) +{ + if (ice_transport_) + ice_transport_->setOnRecv(compId_, cb); +} + +uint16_t +IceSocket::getTransportOverhead() +{ + if (not ice_transport_) + return 0; + + return (ice_transport_->getRemoteAddress(compId_).getFamily() == AF_INET) ? IPV4_HEADER_SIZE + : IPV6_HEADER_SIZE; +} + +void +IceSocket::setDefaultRemoteAddress(const IpAddr& addr) +{ + if (ice_transport_) + ice_transport_->setDefaultRemoteAddress(compId_, addr); +} + +} // namespace jami diff --git a/src/ice_transport.h b/src/ice_transport.h new file mode 100644 index 0000000..0bf6432 --- /dev/null +++ b/src/ice_transport.h @@ -0,0 +1,219 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "ice_options.h" +#include "ice_socket.h" +#include "ip_utils.h" + +#include <pjnath.h> +#include <pjlib.h> +#include <pjlib-util.h> + +#include <functional> +#include <memory> +#include <msgpack.hpp> +#include <vector> + +namespace dht { +namespace log { +class Logger; +} +} + +namespace jami { + +using Logger = dht::log::Logger; + +namespace upnp { +class Controller; +} + +class IceTransport; + +using IceRecvCb = std::function<ssize_t(unsigned char* buf, size_t len)>; +using IceCandidate = pj_ice_sess_cand; +using onShutdownCb = std::function<void(void)>; + +struct ICESDP +{ + std::vector<IceCandidate> rem_candidates; + std::string rem_ufrag; + std::string rem_pwd; +}; + +struct SDP +{ + std::string ufrag; + std::string pwd; + + std::vector<std::string> candidates; + MSGPACK_DEFINE(ufrag, pwd, candidates) +}; + +class IceTransport +{ +public: + using Attribute = struct + { + std::string ufrag; + std::string pwd; + }; + + /** + * Constructor + */ + IceTransport(std::string_view name); + ~IceTransport(); + + const std::shared_ptr<Logger>& logger() const; + + void initIceInstance(const IceTransportOptions& options); + + /** + * Get current state + */ + bool isInitiator() const; + + /** + * Start transport negotiation between local candidates and given remote + * to find the right candidate pair. + * This function doesn't block, the callback on_negodone_cb will be called + * with the negotiation result when operation is really done. + * Return false if negotiation cannot be started else true. + */ + bool startIce(const Attribute& rem_attrs, std::vector<IceCandidate>&& rem_candidates); + bool startIce(const SDP& sdp); + + /** + * Cancel operations + */ + void cancelOperations(); + + /** + * Returns true if ICE transport has been initialized + * [mutex protected] + */ + bool isInitialized() const; + + /** + * Returns true if ICE negotiation has been started + * [mutex protected] + */ + bool isStarted() const; + + /** + * Returns true if ICE negotiation has completed with success + * [mutex protected] + */ + bool isRunning() const; + + /** + * Returns true if ICE transport is in failure state + * [mutex protected] + */ + bool isFailed() const; + + IpAddr getLocalAddress(unsigned comp_id) const; + + IpAddr getRemoteAddress(unsigned comp_id) const; + + IpAddr getDefaultLocalAddress() const { return getLocalAddress(1); } + + /** + * Return ICE session attributes + */ + const Attribute getLocalAttributes() const; + + /** + * Return ICE session attributes + */ + std::vector<std::string> getLocalCandidates(unsigned comp_id) const; + + /** + * Return ICE session attributes + */ + std::vector<std::string> getLocalCandidates(unsigned streamIdx, unsigned compId) const; + + bool parseIceAttributeLine(unsigned streamIdx, + const std::string& line, + IceCandidate& cand) const; + + bool getCandidateFromSDP(const std::string& line, IceCandidate& cand) const; + + // I/O methods + + void setOnRecv(unsigned comp_id, IceRecvCb cb); + void setOnShutdown(onShutdownCb&& cb); + + ssize_t recv(unsigned comp_id, unsigned char* buf, size_t len, std::error_code& ec); + ssize_t recvfrom(unsigned comp_id, char* buf, size_t len, std::error_code& ec); + + ssize_t send(unsigned comp_id, const unsigned char* buf, size_t len); + + bool waitForInitialization(std::chrono::milliseconds timeout); + + int waitForNegotiation(std::chrono::milliseconds timeout); + + ssize_t waitForData(unsigned comp_id, std::chrono::milliseconds timeout, std::error_code& ec); + + unsigned getComponentCount() const; + + // Set session state + bool setSlaveSession(); + bool setInitiatorSession(); + + bool isTCPEnabled(); + + ICESDP parseIceCandidates(std::string_view sdp_msg); + + void setDefaultRemoteAddress(unsigned comp_id, const IpAddr& addr); + + std::string link() const; + +private: + class Impl; + std::unique_ptr<Impl> pimpl_; +}; + +class IceTransportFactory +{ +public: + IceTransportFactory(); + ~IceTransportFactory(); + + std::shared_ptr<IceTransport> createTransport(std::string_view name); + + std::unique_ptr<IceTransport> createUTransport(std::string_view name); + + /** + * PJSIP specifics + */ + pj_ice_strans_cfg getIceCfg() const { return ice_cfg_; } + pj_pool_factory* getPoolFactory() { return &cp_->factory; } + std::shared_ptr<pj_caching_pool> getPoolCaching() { return cp_; } + +private: + std::shared_ptr<pj_caching_pool> cp_; + pj_ice_strans_cfg ice_cfg_; +}; + +}; // namespace jami diff --git a/src/ip_utils.cpp b/src/ip_utils.cpp new file mode 100644 index 0000000..494dc9b --- /dev/null +++ b/src/ip_utils.cpp @@ -0,0 +1,501 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "ip_utils.h" +#include "logger.h" + +#include "connectivity/sip_utils.h" + +#include <sys/types.h> +#include <unistd.h> +#include <limits.h> + +#ifdef _WIN32 +#define InetPtonA inet_pton +WINSOCK_API_LINKAGE INT WSAAPI InetPtonA(INT Family, LPCSTR pStringBuf, PVOID pAddr); +#else +#include <arpa/inet.h> +#include <arpa/nameser.h> +#include <resolv.h> +#include <netdb.h> +#include <netinet/ip.h> +#include <net/if.h> +#include <ifaddrs.h> +#include <sys/ioctl.h> +#endif + +#ifndef HOST_NAME_MAX +#ifdef MAX_COMPUTERNAME_LENGTH +#define HOST_NAME_MAX MAX_COMPUTERNAME_LENGTH +#else +// Max 255 chars as per RFC 1035 +#define HOST_NAME_MAX 255 +#endif +#endif + +namespace jami { + +std::string_view +sip_strerror(pj_status_t code) +{ + thread_local char err_msg[PJ_ERR_MSG_SIZE]; + return as_view(pj_strerror(code, err_msg, sizeof err_msg)); +} + + +std::string +ip_utils::getHostname() +{ + char hostname[HOST_NAME_MAX]; + if (gethostname(hostname, HOST_NAME_MAX)) + return {}; + return hostname; +} + +int +ip_utils::getHostName(char* out, size_t out_len) +{ + char tempstr[INET_ADDRSTRLEN]; + const char* p = NULL; +#ifdef _WIN32 + struct hostent* h = NULL; + struct sockaddr_in localAddr; + memset(&localAddr, 0, sizeof(localAddr)); + gethostname(out, out_len); + h = gethostbyname(out); + if (h != NULL) { + memcpy(&localAddr.sin_addr, h->h_addr_list[0], 4); + p = inet_ntop(AF_INET, &localAddr.sin_addr, tempstr, sizeof(tempstr)); + if (p) + strncpy(out, p, out_len); + else + return -1; + } else { + return -1; + } +#elif (defined(BSD) && BSD >= 199306) || defined(__FreeBSD_kernel__) + int retVal = 0; + struct ifaddrs* ifap; + struct ifaddrs* ifa; + if (getifaddrs(&ifap) != 0) + return -1; + // Cycle through available interfaces. + for (ifa = ifap; ifa != NULL; ifa = ifa->ifa_next) { + // Skip loopback, point-to-point and down interfaces. + // except don't skip down interfaces if we're trying to get + // a list of configurable interfaces. + if ((ifa->ifa_flags & IFF_LOOPBACK) || (!(ifa->ifa_flags & IFF_UP))) + continue; + if (ifa->ifa_addr->sa_family == AF_INET) { + if (((struct sockaddr_in*) (ifa->ifa_addr))->sin_addr.s_addr == htonl(INADDR_LOOPBACK)) { + // We don't want the loopback interface. Go to next one. + continue; + } + p = inet_ntop(AF_INET, + &((struct sockaddr_in*) (ifa->ifa_addr))->sin_addr, + tempstr, + sizeof(tempstr)); + if (p) + strncpy(out, p, out_len); + else + retVal = -1; + break; + } + } + freeifaddrs(ifap); + retVal = ifa ? 0 : -1; + return retVal; +#else + struct ifconf ifConf; + struct ifreq ifReq; + struct sockaddr_in localAddr; + char szBuffer[MAX_INTERFACE * sizeof(struct ifreq)]; + int nResult; + int localSock; + memset(&ifConf, 0, sizeof(ifConf)); + memset(&ifReq, 0, sizeof(ifReq)); + memset(szBuffer, 0, sizeof(szBuffer)); + memset(&localAddr, 0, sizeof(localAddr)); + // Create an unbound datagram socket to do the SIOCGIFADDR ioctl on. + localSock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (localSock == INVALID_SOCKET) + return -1; + /* Get the interface configuration information... */ + ifConf.ifc_len = (int) sizeof szBuffer; + ifConf.ifc_ifcu.ifcu_buf = (caddr_t) szBuffer; + nResult = ioctl(localSock, SIOCGIFCONF, &ifConf); + if (nResult < 0) { + close(localSock); + return -1; + } + unsigned int i; + unsigned int j = 0; + // Cycle through the list of interfaces looking for IP addresses. + for (i = 0u; i < (unsigned int) ifConf.ifc_len && j < MIN_INTERFACE;) { + struct ifreq* pifReq = (struct ifreq*) ((caddr_t) ifConf.ifc_req + i); + i += sizeof *pifReq; + // See if this is the sort of interface we want to deal with. + memset(ifReq.ifr_name, 0, sizeof(ifReq.ifr_name)); + strncpy(ifReq.ifr_name, pifReq->ifr_name, sizeof(ifReq.ifr_name)); + ioctl(localSock, SIOCGIFFLAGS, &ifReq); + // Skip loopback, point-to-point and down interfaces. + // except don't skip down interfaces if we're trying to get + // a list of configurable interfaces. + if ((ifReq.ifr_flags & IFF_LOOPBACK) || (!(ifReq.ifr_flags & IFF_UP))) + continue; + if (pifReq->ifr_addr.sa_family == AF_INET) { + memcpy(&localAddr, &pifReq->ifr_addr, sizeof pifReq->ifr_addr); + if (localAddr.sin_addr.s_addr == htonl(INADDR_LOOPBACK)) { + // We don't want the loopback interface. Go to the next one. + continue; + } + } + j++; // Increment j if we found an address which is not loopback and is up. + } + close(localSock); + p = inet_ntop(AF_INET, &localAddr.sin_addr, tempstr, sizeof(tempstr)); + if (p) + strncpy(out, p, out_len); + else + return -1; +#endif + return 0; +} +std::string +ip_utils::getGateway(char* localHost, ip_utils::subnet_mask prefix) +{ + std::string_view localHostStr(localHost); + if (prefix == ip_utils::subnet_mask::prefix_32bit) + return std::string(localHostStr); + std::string defaultGw {}; + // Make a vector of each individual number in the ip address. + std::vector<std::string_view> tokens = split_string(localHostStr, '.'); + // Build a gateway address from the individual ip components. + for (unsigned i = 0; i <= (unsigned) prefix; i++) + defaultGw += tokens[i] + "."; + for (unsigned i = (unsigned) ip_utils::subnet_mask::prefix_32bit; + i > (unsigned) prefix + 1; + i--) + defaultGw += "0."; + defaultGw += "1"; + return defaultGw; +} + +IpAddr +ip_utils::getLocalGateway() +{ + char localHostBuf[INET_ADDRSTRLEN]; + if (ip_utils::getHostName(localHostBuf, INET_ADDRSTRLEN) < 0) { + JAMI_WARN("Couldn't find local host"); + return {}; + } else { + return IpAddr(ip_utils::getGateway(localHostBuf, ip_utils::subnet_mask::prefix_24bit)); + } +} + +std::vector<IpAddr> +ip_utils::getAddrList(std::string_view name, pj_uint16_t family) +{ + std::vector<IpAddr> ipList; + if (name.empty()) + return ipList; + if (IpAddr::isValid(name, family)) { + ipList.emplace_back(name); + return ipList; + } + + static constexpr unsigned MAX_ADDR_NUM = 128; + pj_addrinfo res[MAX_ADDR_NUM]; + unsigned addr_num = MAX_ADDR_NUM; + const pj_str_t pjname(sip_utils::CONST_PJ_STR(name)); + auto status = pj_getaddrinfo(family, &pjname, &addr_num, res); + if (status != PJ_SUCCESS) { + JAMI_ERR("Error resolving %.*s : %s", + (int) name.size(), + name.data(), + sip_utils::sip_strerror(status).c_str()); + return ipList; + } + + for (unsigned i = 0; i < addr_num; i++) { + bool found = false; + for (const auto& ip : ipList) + if (!pj_sockaddr_cmp(&ip, &res[i].ai_addr)) { + found = true; + break; + } + if (!found) + ipList.emplace_back(res[i].ai_addr); + } + + return ipList; +} + +bool +ip_utils::haveCommonAddr(const std::vector<IpAddr>& a, const std::vector<IpAddr>& b) +{ + for (const auto& i : a) { + for (const auto& j : b) { + if (i == j) + return true; + } + } + return false; +} + +IpAddr +ip_utils::getLocalAddr(pj_uint16_t family) +{ + IpAddr ip_addr {}; + pj_status_t status = pj_gethostip(family, ip_addr.pjPtr()); + if (status == PJ_SUCCESS) { + return ip_addr; + } + JAMI_WARN("Could not get preferred address familly (%s)", + (family == pj_AF_INET6()) ? "IPv6" : "IPv4"); + family = (family == pj_AF_INET()) ? pj_AF_INET6() : pj_AF_INET(); + status = pj_gethostip(family, ip_addr.pjPtr()); + if (status == PJ_SUCCESS) { + return ip_addr; + } + JAMI_ERR("Could not get local IP"); + return ip_addr; +} + +IpAddr +ip_utils::getInterfaceAddr(const std::string& interface, pj_uint16_t family) +{ + if (interface == DEFAULT_INTERFACE) + return getLocalAddr(family); + + IpAddr addr {}; + +#ifndef _WIN32 + const auto unix_family = family == pj_AF_INET() ? AF_INET : AF_INET6; + + int fd = socket(unix_family, SOCK_DGRAM, 0); + if (fd < 0) { + JAMI_ERR("Could not open socket: %m"); + return addr; + } + + if (unix_family == AF_INET6) { + int val = family != pj_AF_UNSPEC(); + if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, (void*) &val, sizeof(val)) < 0) { + JAMI_ERR("Could not setsockopt: %m"); + close(fd); + return addr; + } + } + + ifreq ifr; + strncpy(ifr.ifr_name, interface.c_str(), sizeof ifr.ifr_name); + // guarantee that ifr_name is NULL-terminated + ifr.ifr_name[sizeof(ifr.ifr_name) - 1] = '\0'; + + memset(&ifr.ifr_addr, 0, sizeof(ifr.ifr_addr)); + ifr.ifr_addr.sa_family = unix_family; + + ioctl(fd, SIOCGIFADDR, &ifr); + close(fd); + + addr = ifr.ifr_addr; + if (addr.isUnspecified()) + return getLocalAddr(addr.getFamily()); +#else // _WIN32 + struct addrinfo hints; + struct addrinfo* result = NULL; + struct sockaddr_in* sockaddr_ipv4; + struct sockaddr_in6* sockaddr_ipv6; + + ZeroMemory(&hints, sizeof(hints)); + + DWORD dwRetval = getaddrinfo(interface.c_str(), "0", &hints, &result); + if (dwRetval != 0) { + JAMI_ERR("getaddrinfo failed with error: %lu", dwRetval); + return addr; + } + + switch (result->ai_family) { + sockaddr_ipv4 = (struct sockaddr_in*) result->ai_addr; + addr = sockaddr_ipv4->sin_addr; + break; + case AF_INET6: + sockaddr_ipv6 = (struct sockaddr_in6*) result->ai_addr; + addr = sockaddr_ipv6->sin6_addr; + break; + default: + break; + } + + if (addr.isUnspecified()) + return getLocalAddr(addr.getFamily()); +#endif // !_WIN32 + + return addr; +} + +std::vector<std::string> +ip_utils::getAllIpInterfaceByName() +{ + std::vector<std::string> ifaceList; + ifaceList.push_back("default"); +#ifndef _WIN32 + static ifreq ifreqs[20]; + ifconf ifconf; + + ifconf.ifc_buf = (char*) (ifreqs); + ifconf.ifc_len = sizeof(ifreqs); + + int sock = socket(AF_INET6, SOCK_STREAM, 0); + + if (sock >= 0) { + if (ioctl(sock, SIOCGIFCONF, &ifconf) >= 0) + for (unsigned i = 0; i < ifconf.ifc_len / sizeof(ifreq); ++i) + ifaceList.push_back(std::string(ifreqs[i].ifr_name)); + + close(sock); + } + +#else + JAMI_ERR("Not implemented yet. (iphlpapi.h problem)"); +#endif + return ifaceList; +} + +std::vector<std::string> +ip_utils::getAllIpInterface() +{ + pj_sockaddr addrList[16]; + unsigned addrCnt = PJ_ARRAY_SIZE(addrList); + + std::vector<std::string> ifaceList; + + if (pj_enum_ip_interface(pj_AF_UNSPEC(), &addrCnt, addrList) == PJ_SUCCESS) { + for (unsigned i = 0; i < addrCnt; i++) { + char addr[PJ_INET6_ADDRSTRLEN]; + pj_sockaddr_print(&addrList[i], addr, sizeof(addr), 0); + ifaceList.push_back(std::string(addr)); + } + } + + return ifaceList; +} + +std::vector<IpAddr> +ip_utils::getLocalNameservers() +{ + std::vector<IpAddr> res; +#if defined __ANDROID__ || defined _WIN32 || TARGET_OS_IPHONE +#ifdef _MSC_VER +#pragma message(__FILE__ "(" STR2(__LINE__) ") : -NOTE- " \ + "Not implemented") +#else +#warning "Not implemented" +#endif +#else + if (not(_res.options & RES_INIT)) + res_init(); + res.insert(res.end(), _res.nsaddr_list, _res.nsaddr_list + _res.nscount); +#endif + return res; +} + +bool +IpAddr::isValid(std::string_view address, pj_uint16_t family) +{ + const pj_str_t pjstring(sip_utils::CONST_PJ_STR(address)); + pj_str_t ret_str; + pj_uint16_t ret_port; + int ret_family; + auto status = pj_sockaddr_parse2(pj_AF_UNSPEC(), 0, &pjstring, &ret_str, &ret_port, &ret_family); + if (status != PJ_SUCCESS || (family != pj_AF_UNSPEC() && ret_family != family)) + return false; + + char buf[PJ_INET6_ADDRSTRLEN]; + pj_str_t addr_with_null = {buf, 0}; + pj_strncpy_with_null(&addr_with_null, &ret_str, sizeof(buf)); + struct sockaddr sa; + return inet_pton(ret_family == pj_AF_INET6() ? AF_INET6 : AF_INET, buf, &(sa.sa_data)) == 1; +} + +bool +IpAddr::isUnspecified() const +{ + switch (addr.addr.sa_family) { + case AF_INET: + return IN_IS_ADDR_UNSPECIFIED(&addr.ipv4.sin_addr); + case AF_INET6: + return IN6_IS_ADDR_UNSPECIFIED(reinterpret_cast<const in6_addr*>(&addr.ipv6.sin6_addr)); + default: + return true; + } +} + +bool +IpAddr::isLoopback() const +{ + switch (addr.addr.sa_family) { + case AF_INET: { + auto addr_host = ntohl(addr.ipv4.sin_addr.s_addr); + uint8_t b1 = (uint8_t)(addr_host >> 24); + return b1 == 127; + } + case AF_INET6: + return IN6_IS_ADDR_LOOPBACK(reinterpret_cast<const in6_addr*>(&addr.ipv6.sin6_addr)); + default: + return false; + } +} + +bool +IpAddr::isPrivate() const +{ + if (isLoopback()) { + return true; + } + switch (addr.addr.sa_family) { + case AF_INET: { + auto addr_host = ntohl(addr.ipv4.sin_addr.s_addr); + uint8_t b1, b2; + b1 = (uint8_t)(addr_host >> 24); + b2 = (uint8_t)((addr_host >> 16) & 0x0ff); + // 10.x.y.z + if (b1 == 10) + return true; + // 172.16.0.0 - 172.31.255.255 + if ((b1 == 172) && (b2 >= 16) && (b2 <= 31)) + return true; + // 192.168.0.0 - 192.168.255.255 + if ((b1 == 192) && (b2 == 168)) + return true; + return false; + } + case AF_INET6: { + const pj_uint8_t* addr6 = reinterpret_cast<const pj_uint8_t*>(&addr.ipv6.sin6_addr); + if (addr6[0] == 0xfc) + return true; + return false; + } + default: + return false; + } +} +} // namespace jami diff --git a/src/multiplexed_socket.cpp b/src/multiplexed_socket.cpp new file mode 100644 index 0000000..5abf742 --- /dev/null +++ b/src/multiplexed_socket.cpp @@ -0,0 +1,1208 @@ +/* + * Copyright (C) 2019-2023 Savoir-faire Linux Inc. + * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +#include "multiplexed_socket.h" +#include "peer_connection.h" +#include "ice_transport.h" +#include "certstore.h" + +#include <opendht/logger.h> +#include <opendht/thread_pool.h> + +#include <asio/io_context.hpp> +#include <asio/steady_timer.hpp> + +#include <deque> + +static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations +static constexpr int MULTIPLEXED_SOCKET_VERSION {1}; + +struct ChanneledMessage +{ + uint16_t channel; + std::vector<uint8_t> data; + MSGPACK_DEFINE(channel, data) +}; + +struct BeaconMsg +{ + bool p; + MSGPACK_DEFINE_MAP(p) +}; + +struct VersionMsg +{ + int v; + MSGPACK_DEFINE_MAP(v) +}; + +namespace jami { + +using clock = std::chrono::steady_clock; +using time_point = clock::time_point; + +class MultiplexedSocket::Impl +{ +public: + Impl(MultiplexedSocket& parent, + std::shared_ptr<asio::io_context> ctx, + const DeviceId& deviceId, + std::unique_ptr<TlsSocketEndpoint> endpoint) + : parent_(parent) + , deviceId(deviceId) + , ctx_(std::move(ctx)) + , beaconTimer_(*ctx_) + , endpoint(std::move(endpoint)) + , eventLoopThread_ {[this] { + try { + eventLoop(); + } catch (const std::exception& e) { + if (logger_) + logger_->error("[CNX] peer connection event loop failure: {}", e.what()); + shutdown(); + } + }} + {} + + ~Impl() {} + + void join() + { + if (!isShutdown_) { + if (endpoint) + endpoint->setOnStateChange({}); + shutdown(); + } else { + clearSockets(); + } + if (eventLoopThread_.joinable()) + eventLoopThread_.join(); + } + + void clearSockets() + { + decltype(sockets) socks; + { + std::lock_guard<std::mutex> lkSockets(socketsMutex); + socks = std::move(sockets); + } + for (auto& socket : socks) { + // Just trigger onShutdown() to make client know + // No need to write the EOF for the channel, the write will fail because endpoint is + // already shutdown + if (socket.second) + socket.second->stop(); + } + } + + void shutdown() + { + if (isShutdown_) + return; + stop.store(true); + isShutdown_ = true; + beaconTimer_.cancel(); + if (onShutdown_) + onShutdown_(); + if (endpoint) { + std::unique_lock<std::mutex> lk(writeMtx); + endpoint->shutdown(); + } + clearSockets(); + } + + std::shared_ptr<ChannelSocket> makeSocket(const std::string& name, + uint16_t channel, + bool isInitiator = false) + { + auto& channelSocket = sockets[channel]; + if (not channelSocket) + channelSocket = std::make_shared<ChannelSocket>( + parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() { + // Remove socket in another thread to avoid any lock + dht::ThreadPool::io().run([w, channel]() { + if (auto shared = w.lock()) { + shared->eraseChannel(channel); + } + }); + }); + else { + if (logger_) + logger_->warn("A channel is already present on that socket, accepting " + "the request will close the previous one {}", name); + } + return channelSocket; + } + + /** + * Handle packets on the TLS endpoint and parse RTP + */ + void eventLoop(); + /** + * Triggered when a new control packet is received + */ + void handleControlPacket(std::vector<uint8_t>&& pkt); + void handleProtocolPacket(std::vector<uint8_t>&& pkt); + bool handleProtocolMsg(const msgpack::object& o); + /** + * Triggered when a new packet on a channel is received + */ + void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt); + void onRequest(const std::string& name, uint16_t channel); + void onAccept(const std::string& name, uint16_t channel); + + void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); } + void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); } + + // Beacon + void sendBeacon(const std::chrono::milliseconds& timeout); + void handleBeaconRequest(); + void handleBeaconResponse(); + std::atomic_int beaconCounter_ {0}; + + bool writeProtocolMessage(const msgpack::sbuffer& buffer); + + msgpack::unpacker pac_ {}; + + MultiplexedSocket& parent_; + + std::shared_ptr<Logger> logger_; + std::shared_ptr<asio::io_context> ctx_; + + OnConnectionReadyCb onChannelReady_ {}; + OnConnectionRequestCb onRequest_ {}; + OnShutdownCb onShutdown_ {}; + + DeviceId deviceId {}; + // Main socket + std::unique_ptr<TlsSocketEndpoint> endpoint {}; + + std::mutex socketsMutex {}; + std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {}; + + // Main loop to parse incoming packets + std::atomic_bool stop {false}; + std::thread eventLoopThread_ {}; + + std::atomic_bool isShutdown_ {false}; + + std::mutex writeMtx {}; + + time_point start_ {clock::now()}; + //std::shared_ptr<Task> beaconTask_ {}; + asio::steady_timer beaconTimer_; + + // version related stuff + void sendVersion(); + void onVersion(int version); + std::atomic_bool canSendBeacon_ {false}; + std::atomic_bool answerBeacon_ {true}; + int version_ {MULTIPLEXED_SOCKET_VERSION}; + std::function<void(bool)> onBeaconCb_ {}; + std::function<void(int)> onVersionCb_ {}; +}; + +void +MultiplexedSocket::Impl::eventLoop() +{ + endpoint->setOnStateChange([this](tls::TlsSessionState state) { + if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) { + if (logger_) + logger_->debug("Tls endpoint is down, shutdown multiplexed socket"); + shutdown(); + return false; + } + return true; + }); + sendVersion(); + std::error_code ec; + while (!stop) { + if (!endpoint) { + shutdown(); + return; + } + pac_.reserve_buffer(IO_BUFFER_SIZE); + int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec); + if (size < 0) { + if (ec && logger_) + logger_->error("Read error detected: {}", ec.message()); + break; + } + if (size == 0) { + // We can close the socket + shutdown(); + break; + } + + pac_.buffer_consumed(size); + msgpack::object_handle oh; + while (pac_.next(oh) && !stop) { + try { + auto msg = oh.get().as<ChanneledMessage>(); + if (msg.channel == CONTROL_CHANNEL) + handleControlPacket(std::move(msg.data)); + else if (msg.channel == PROTOCOL_CHANNEL) + handleProtocolPacket(std::move(msg.data)); + else + handleChannelPacket(msg.channel, std::move(msg.data)); + } catch (const std::exception& e) { + if (logger_) + logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what()); + } catch (...) { + if (logger_) + logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size); + } + } + } +} + +void +MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel) +{ + std::lock_guard<std::mutex> lkSockets(socketsMutex); + auto& socket = sockets[channel]; + if (!socket) { + if (logger_) + logger_->error("Receiving an answer for a non existing channel. This is a bug."); + return; + } + + onChannelReady_(deviceId, socket); + socket->ready(); + // Due to the callbacks that can take some time, onAccept can arrive after + // receiving all the data. In this case, the socket should be removed here + // as handle by onChannelReady_ + if (socket->isRemovable()) + sockets.erase(channel); + else + socket->answered(); +} + +void +MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout) +{ + if (!canSendBeacon_) + return; + beaconCounter_++; + if (logger_) + logger_->debug("Send beacon to peer {}", deviceId); + + msgpack::sbuffer buffer(8); + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack(BeaconMsg {true}); + if (!writeProtocolMessage(buffer)) + return; + beaconTimer_.expires_after(timeout); + beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) { + if (ec == asio::error::operation_aborted) + return; + if (auto shared = w.lock()) { + if (shared->pimpl_->beaconCounter_ != 0) { + if (shared->pimpl_->logger_) + shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket"); + shared->shutdown(); + } + } + }); +} + +void +MultiplexedSocket::Impl::handleBeaconRequest() +{ + if (!answerBeacon_) + return; + // Run this on dedicated thread because some callbacks can take time + dht::ThreadPool::io().run([w = parent_.weak()]() { + if (auto shared = w.lock()) { + msgpack::sbuffer buffer(8); + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack(BeaconMsg {false}); + if (shared->pimpl_->logger_) + shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId()); + shared->pimpl_->writeProtocolMessage(buffer); + } + }); +} + +void +MultiplexedSocket::Impl::handleBeaconResponse() +{ + if (logger_) + logger_->debug("Get beacon response from peer {}", deviceId); + beaconCounter_--; +} + +bool +MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer) +{ + std::error_code ec; + int wr = parent_.write(PROTOCOL_CHANNEL, + (const unsigned char*) buffer.data(), + buffer.size(), + ec); + return wr > 0; +} + +void +MultiplexedSocket::Impl::sendVersion() +{ + dht::ThreadPool::io().run([w = parent_.weak()]() { + if (auto shared = w.lock()) { + auto version = shared->pimpl_->version_; + msgpack::sbuffer buffer(8); + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack(VersionMsg {version}); + shared->pimpl_->writeProtocolMessage(buffer); + } + }); +} + +void +MultiplexedSocket::Impl::onVersion(int version) +{ + // Check if version > 1 + if (version >= 1) { + if (logger_) + logger_->debug("Peer {} supports beacon", deviceId); + canSendBeacon_ = true; + } else { + if (logger_) + logger_->warn("Peer {} uses version {:d} which doesn't support beacon", + deviceId, + version); + canSendBeacon_ = false; + } +} + +void +MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel) +{ + auto accept = onRequest_(endpoint->peerCertificate(), channel, name); + std::shared_ptr<ChannelSocket> channelSocket; + if (accept) { + std::lock_guard<std::mutex> lkSockets(socketsMutex); + channelSocket = makeSocket(name, channel); + } + + // Answer to ChannelRequest if accepted + ChannelRequest val; + val.channel = channel; + val.name = name; + val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE; + msgpack::sbuffer buffer(512); + msgpack::pack(buffer, val); + std::error_code ec; + int wr = parent_.write(CONTROL_CHANNEL, + reinterpret_cast<const uint8_t*>(buffer.data()), + buffer.size(), + ec); + if (wr < 0) { + if (ec && logger_) + logger_->error("The write operation failed with error: {:s}", ec.message()); + stop.store(true); + return; + } + + if (accept) { + onChannelReady_(deviceId, channelSocket); + channelSocket->ready(); + } +} + +void +MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt) +{ + // Run this on dedicated thread because some callbacks can take time + dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() { + auto shared = w.lock(); + if (!shared) + return; + auto& pimpl = *shared->pimpl_; + try { + size_t off = 0; + while (off != pkt.size()) { + msgpack::unpacked result; + msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off); + auto object = result.get(); + if (pimpl.handleProtocolMsg(object)) + continue; + auto req = object.as<ChannelRequest>(); + if (req.state == ChannelRequestState::ACCEPT) { + pimpl.onAccept(req.name, req.channel); + } else if (req.state == ChannelRequestState::DECLINE) { + std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex); + auto channel = pimpl.sockets.find(req.channel); + if (channel != pimpl.sockets.end()) { + channel->second->stop(); + pimpl.sockets.erase(channel); + } + } else if (pimpl.onRequest_) { + pimpl.onRequest(req.name, req.channel); + } + } + } catch (const std::exception& e) { + if (pimpl.logger_) + pimpl.logger_->error("Error on the control channel: {}", e.what()); + } + }); +} + +void +MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt) +{ + std::lock_guard<std::mutex> lkSockets(socketsMutex); + auto sockIt = sockets.find(channel); + if (channel > 0 && sockIt != sockets.end() && sockIt->second) { + if (pkt.size() == 0) { + sockIt->second->stop(); + if (sockIt->second->isAnswered()) + sockets.erase(sockIt); + else + sockIt->second->removable(); // This means that onAccept didn't happen yet, will be + // removed later. + } else { + sockIt->second->onRecv(std::move(pkt)); + } + } else if (pkt.size() != 0) { + if (logger_) + logger_->warn("Non existing channel: {}", channel); + } +} + +bool +MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o) +{ + try { + if (o.type == msgpack::type::MAP && o.via.map.size > 0) { + auto key = o.via.map.ptr[0].key.as<std::string_view>(); + if (key == "p") { + auto msg = o.as<BeaconMsg>(); + if (msg.p) + handleBeaconRequest(); + else + handleBeaconResponse(); + if (onBeaconCb_) + onBeaconCb_(msg.p); + return true; + } else if (key == "v") { + auto msg = o.as<VersionMsg>(); + onVersion(msg.v); + if (onVersionCb_) + onVersionCb_(msg.v); + return true; + } else { + if (logger_) + logger_->warn("Unknown message type"); + } + } + } catch (const std::exception& e) { + if (logger_) + logger_->error("Error on the protocol channel: {}", e.what()); + } + return false; +} + +void +MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt) +{ + // Run this on dedicated thread because some callbacks can take time + dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() { + auto shared = w.lock(); + if (!shared) + return; + try { + size_t off = 0; + while (off != pkt.size()) { + msgpack::unpacked result; + msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off); + auto object = result.get(); + if (shared->pimpl_->handleProtocolMsg(object)) + return; + } + } catch (const std::exception& e) { + if (shared->pimpl_->logger_) + shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what()); + } + }); +} + +MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, + std::unique_ptr<TlsSocketEndpoint> endpoint) + : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint))) +{} + +MultiplexedSocket::~MultiplexedSocket() {} + +std::shared_ptr<ChannelSocket> +MultiplexedSocket::addChannel(const std::string& name) +{ + // Note: because both sides can request the same channel number at the same time + // it's better to use a random channel number instead of just incrementing the request. + thread_local dht::crypto::random_device rd; + std::uniform_int_distribution<uint16_t> dist; + auto offset = dist(rd); + std::lock_guard<std::mutex> lk(pimpl_->socketsMutex); + for (int i = 1; i < UINT16_MAX; ++i) { + auto c = (offset + i) % UINT16_MAX; + if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL + || pimpl_->sockets.find(c) != pimpl_->sockets.end()) + continue; + auto channel = pimpl_->makeSocket(name, c, true); + return channel; + } + return {}; +} + +DeviceId +MultiplexedSocket::deviceId() const +{ + return pimpl_->deviceId; +} + +void +MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb) +{ + pimpl_->onChannelReady_ = std::move(cb); +} + +void +MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb) +{ + pimpl_->onRequest_ = std::move(cb); +} + +bool +MultiplexedSocket::isReliable() const +{ + return true; +} + +bool +MultiplexedSocket::isInitiator() const +{ + if (!pimpl_->endpoint) { + if (pimpl_->logger_) + pimpl_->logger_->warn("No endpoint found for socket"); + return false; + } + return pimpl_->endpoint->isInitiator(); +} + +int +MultiplexedSocket::maxPayload() const +{ + if (!pimpl_->endpoint) { + if (pimpl_->logger_) + pimpl_->logger_->warn("No endpoint found for socket"); + return 0; + } + return pimpl_->endpoint->maxPayload(); +} + +std::size_t +MultiplexedSocket::write(const uint16_t& channel, + const uint8_t* buf, + std::size_t len, + std::error_code& ec) +{ + assert(nullptr != buf); + + if (pimpl_->isShutdown_) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + if (len > UINT16_MAX) { + ec = std::make_error_code(std::errc::message_size); + return -1; + } + bool oneShot = len < 8192; + msgpack::sbuffer buffer(oneShot ? 16 + len : 16); + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_array(2); + pk.pack(channel); + pk.pack_bin(len); + if (oneShot) + pk.pack_bin_body((const char*) buf, len); + + std::unique_lock<std::mutex> lk(pimpl_->writeMtx); + if (!pimpl_->endpoint) { + if (pimpl_->logger_) + pimpl_->logger_->warn("No endpoint found for socket"); + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec); + if (not oneShot and res >= 0) + res = pimpl_->endpoint->write(buf, len, ec); + lk.unlock(); + if (res < 0) { + if (ec && pimpl_->logger_) + pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message()); + shutdown(); + } + return res; +} + +void +MultiplexedSocket::shutdown() +{ + pimpl_->shutdown(); +} + +void +MultiplexedSocket::join() +{ + pimpl_->join(); +} + +void +MultiplexedSocket::onShutdown(OnShutdownCb&& cb) +{ + pimpl_->onShutdown_ = std::move(cb); + if (pimpl_->isShutdown_) + pimpl_->onShutdown_(); +} + +const std::shared_ptr<Logger>& +MultiplexedSocket::logger() +{ + return pimpl_->logger_; +} + +void +MultiplexedSocket::monitor() const +{ + auto cert = peerCertificate(); + if (!cert || !cert->issuer) + return; + auto now = clock::now(); + if (!pimpl_->logger_) + return; + pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId()); + pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_)); + pimpl_->endpoint->monitor(); + std::lock_guard<std::mutex> lk(pimpl_->socketsMutex); + for (const auto& [_, channel] : pimpl_->sockets) { + if (channel) + pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}", + fmt::ptr(channel.get()), + channel.use_count(), + channel->name(), + channel->isInitiator()); + } +} + +void +MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout) +{ + pimpl_->sendBeacon(timeout); +} + +std::shared_ptr<dht::crypto::Certificate> +MultiplexedSocket::peerCertificate() const +{ + return pimpl_->endpoint->peerCertificate(); +} + +#ifdef LIBJAMI_TESTABLE +bool +MultiplexedSocket::canSendBeacon() const +{ + return pimpl_->canSendBeacon_; +} + +void +MultiplexedSocket::answerToBeacon(bool value) +{ + pimpl_->answerBeacon_ = value; +} + +void +MultiplexedSocket::setVersion(int version) +{ + pimpl_->version_ = version; +} + +void +MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb) +{ + pimpl_->onBeaconCb_ = cb; +} + +void +MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb) +{ + pimpl_->onVersionCb_ = cb; +} + +void +MultiplexedSocket::sendVersion() +{ + pimpl_->sendVersion(); +} + +IpAddr +MultiplexedSocket::getLocalAddress() const +{ + return pimpl_->endpoint->getLocalAddress(); +} + +IpAddr +MultiplexedSocket::getRemoteAddress() const +{ + return pimpl_->endpoint->getRemoteAddress(); +} + +#endif + +void +MultiplexedSocket::eraseChannel(uint16_t channel) +{ + std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex); + auto itSocket = pimpl_->sockets.find(channel); + if (pimpl_->sockets.find(channel) != pimpl_->sockets.end()) + pimpl_->sockets.erase(itSocket); +} + +//////////////////////////////////////////////////////////////// + +class ChannelSocket::Impl +{ +public: + Impl(std::weak_ptr<MultiplexedSocket> endpoint, + const std::string& name, + const uint16_t& channel, + bool isInitiator, + std::function<void()> rmFromMxSockCb) + : name(name) + , channel(channel) + , endpoint(std::move(endpoint)) + , isInitiator_(isInitiator) + , rmFromMxSockCb_(std::move(rmFromMxSockCb)) + {} + + ~Impl() {} + + ChannelReadyCb readyCb_ {}; + OnShutdownCb shutdownCb_ {}; + std::atomic_bool isShutdown_ {false}; + std::string name {}; + uint16_t channel {}; + std::weak_ptr<MultiplexedSocket> endpoint {}; + bool isInitiator_ {false}; + std::function<void()> rmFromMxSockCb_; + + bool isAnswered_ {false}; + bool isRemovable_ {false}; + + std::vector<uint8_t> buf {}; + std::mutex mutex {}; + std::condition_variable cv {}; + GenericSocket<uint8_t>::RecvCb cb {}; +}; + +ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx, + const DeviceId& deviceId, + const std::string& name, + const uint16_t& channel) + : pimpl_deviceId(deviceId) + , pimpl_name(name) + , pimpl_channel(channel) + , ioCtx_(*ctx) +{} + +ChannelSocketTest::~ChannelSocketTest() {} + +void +ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1, + const std::shared_ptr<ChannelSocketTest>& socket2) +{ + socket1->remote = socket2; + socket2->remote = socket1; +} + +DeviceId +ChannelSocketTest::deviceId() const +{ + return pimpl_deviceId; +} + +std::string +ChannelSocketTest::name() const +{ + return pimpl_name; +} + +uint16_t +ChannelSocketTest::channel() const +{ + return pimpl_channel; +} + +void +ChannelSocketTest::shutdown() +{ + { + std::unique_lock<std::mutex> lk {mutex}; + if (!isShutdown_.exchange(true)) { + lk.unlock(); + shutdownCb_(); + } + cv.notify_all(); + } + + if (auto peer = remote.lock()) { + if (!peer->isShutdown_.exchange(true)) { + peer->shutdownCb_(); + } + peer->cv.notify_all(); + } +} + +std::size_t +ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec) +{ + std::size_t size = std::min(len, this->rx_buf.size()); + + for (std::size_t i = 0; i < size; ++i) + buf[i] = this->rx_buf[i]; + + if (size == this->rx_buf.size()) { + this->rx_buf.clear(); + } else + this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size); + return size; +} + +std::size_t +ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec) +{ + if (isShutdown_) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + ec = {}; + dht::ThreadPool::computation().run( + [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable { + if (auto peer = r.lock()) + peer->onRecv(std::move(data)); + }); + return len; +} + +int +ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const +{ + std::unique_lock<std::mutex> lk {mutex}; + cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; }); + return rx_buf.size(); +} + +void +ChannelSocketTest::setOnRecv(RecvCb&& cb) +{ + std::lock_guard<std::mutex> lkSockets(mutex); + this->cb = std::move(cb); + if (!rx_buf.empty() && this->cb) { + this->cb(rx_buf.data(), rx_buf.size()); + rx_buf.clear(); + } +} + +void +ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt) +{ + std::lock_guard<std::mutex> lkSockets(mutex); + if (cb) { + cb(pkt.data(), pkt.size()); + return; + } + rx_buf.insert(rx_buf.end(), + std::make_move_iterator(pkt.begin()), + std::make_move_iterator(pkt.end())); + cv.notify_all(); +} + +void +ChannelSocketTest::onReady(ChannelReadyCb&& cb) +{} + +void +ChannelSocketTest::onShutdown(OnShutdownCb&& cb) +{ + std::unique_lock<std::mutex> lk {mutex}; + shutdownCb_ = std::move(cb); + + if (isShutdown_) { + lk.unlock(); + shutdownCb_(); + } +} + +ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint, + const std::string& name, + const uint16_t& channel, + bool isInitiator, + std::function<void()> rmFromMxSockCb) + : pimpl_ { + std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))} +{} + +ChannelSocket::~ChannelSocket() {} + +DeviceId +ChannelSocket::deviceId() const +{ + if (auto ep = pimpl_->endpoint.lock()) { + return ep->deviceId(); + } + return {}; +} + +std::string +ChannelSocket::name() const +{ + return pimpl_->name; +} + +uint16_t +ChannelSocket::channel() const +{ + return pimpl_->channel; +} + +bool +ChannelSocket::isReliable() const +{ + if (auto ep = pimpl_->endpoint.lock()) { + return ep->isReliable(); + } + return false; +} + +bool +ChannelSocket::isInitiator() const +{ + // Note. Is initiator here as not the same meaning of MultiplexedSocket. + // because a multiplexed socket can have sockets from accepted requests + // or made via connectDevice(). Here, isInitiator_ return if the socket + // is from connectDevice. + return pimpl_->isInitiator_; +} + +int +ChannelSocket::maxPayload() const +{ + if (auto ep = pimpl_->endpoint.lock()) { + return ep->maxPayload(); + } + return -1; +} + +void +ChannelSocket::setOnRecv(RecvCb&& cb) +{ + std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); + pimpl_->cb = std::move(cb); + if (!pimpl_->buf.empty() && pimpl_->cb) { + pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size()); + pimpl_->buf.clear(); + } +} + +void +ChannelSocket::onRecv(std::vector<uint8_t>&& pkt) +{ + std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); + if (pimpl_->cb) { + pimpl_->cb(&pkt[0], pkt.size()); + return; + } + pimpl_->buf.insert(pimpl_->buf.end(), + std::make_move_iterator(pkt.begin()), + std::make_move_iterator(pkt.end())); + pimpl_->cv.notify_all(); +} + +#ifdef LIBJAMI_TESTABLE +std::shared_ptr<MultiplexedSocket> +ChannelSocket::underlyingSocket() const +{ + if (auto mtx = pimpl_->endpoint.lock()) + return mtx; + return {}; +} +#endif + +void +ChannelSocket::answered() +{ + pimpl_->isAnswered_ = true; +} + +void +ChannelSocket::removable() +{ + pimpl_->isRemovable_ = true; +} + +bool +ChannelSocket::isRemovable() const +{ + return pimpl_->isRemovable_; +} + +bool +ChannelSocket::isAnswered() const +{ + return pimpl_->isAnswered_; +} + +void +ChannelSocket::ready() +{ + if (pimpl_->readyCb_) + pimpl_->readyCb_(); +} + +void +ChannelSocket::stop() +{ + if (pimpl_->isShutdown_) + return; + pimpl_->isShutdown_ = true; + if (pimpl_->shutdownCb_) + pimpl_->shutdownCb_(); + pimpl_->cv.notify_all(); + // stop() can be called by ChannelSocket::shutdown() + // In this case, the eventLoop is not used, but MxSock + // must remove the channel from its list (so that the + // channel can be destroyed and its shared_ptr invalidated). + if (pimpl_->rmFromMxSockCb_) + pimpl_->rmFromMxSockCb_(); +} + +void +ChannelSocket::shutdown() +{ + if (pimpl_->isShutdown_) + return; + stop(); + if (auto ep = pimpl_->endpoint.lock()) { + std::error_code ec; + const uint8_t dummy = '\0'; + ep->write(pimpl_->channel, &dummy, 0, ec); + } +} + +std::size_t +ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec) +{ + std::lock_guard<std::mutex> lkSockets(pimpl_->mutex); + std::size_t size = std::min(len, pimpl_->buf.size()); + + for (std::size_t i = 0; i < size; ++i) + outBuf[i] = pimpl_->buf[i]; + + pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size); + return size; +} + +std::size_t +ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec) +{ + if (pimpl_->isShutdown_) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + if (auto ep = pimpl_->endpoint.lock()) { + std::size_t sent = 0; + do { + std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent); + auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec); + if (ec) { + if (ep->logger()) + ep->logger()->error("Error when writing on channel: {}", ec.message()); + return res; + } + sent += toSend; + } while (sent < len); + return sent; + } + ec = std::make_error_code(std::errc::broken_pipe); + return -1; +} + +int +ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const +{ + std::unique_lock<std::mutex> lk {pimpl_->mutex}; + pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; }); + return pimpl_->buf.size(); +} + +void +ChannelSocket::onShutdown(OnShutdownCb&& cb) +{ + pimpl_->shutdownCb_ = std::move(cb); + if (pimpl_->isShutdown_) { + pimpl_->shutdownCb_(); + } +} + +void +ChannelSocket::onReady(ChannelReadyCb&& cb) +{ + pimpl_->readyCb_ = std::move(cb); +} + +void +ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout) +{ + if (auto ep = pimpl_->endpoint.lock()) { + ep->sendBeacon(timeout); + } else { + shutdown(); + } +} + +std::shared_ptr<dht::crypto::Certificate> +ChannelSocket::peerCertificate() const +{ + if (auto ep = pimpl_->endpoint.lock()) + return ep->peerCertificate(); + return {}; +} + +IpAddr +ChannelSocket::getLocalAddress() const +{ + if (auto ep = pimpl_->endpoint.lock()) + return ep->getLocalAddress(); + return {}; +} + +IpAddr +ChannelSocket::getRemoteAddress() const +{ + if (auto ep = pimpl_->endpoint.lock()) + return ep->getRemoteAddress(); + return {}; +} + +} // namespace jami diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp new file mode 100644 index 0000000..0b4ede5 --- /dev/null +++ b/src/peer_connection.cpp @@ -0,0 +1,452 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "peer_connection.h" +#include "tls_session.h" + +#include <opendht/thread_pool.h> +#include <opendht/logger.h> + +#include <algorithm> +#include <chrono> +#include <future> +#include <vector> +#include <atomic> +#include <stdexcept> +#include <istream> +#include <ostream> +#include <unistd.h> +#include <cstdio> + +#ifdef _WIN32 +#include <winsock2.h> +#include <ws2tcpip.h> +#else +#include <sys/select.h> +#endif + +#ifndef _MSC_VER +#include <sys/time.h> +#endif + +static constexpr int ICE_COMP_ID_SIP_TRANSPORT {1}; + +namespace jami { + +int +init_crt(gnutls_session_t session, dht::crypto::Certificate& crt) +{ + // Support only x509 format + if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) { + return GNUTLS_E_CERTIFICATE_ERROR; + } + + // Store verification status + unsigned int status = 0; + auto ret = gnutls_certificate_verify_peers2(session, &status); + if (ret < 0 or (status & GNUTLS_CERT_SIGNATURE_FAILURE) != 0) { + return GNUTLS_E_CERTIFICATE_ERROR; + } + + unsigned int cert_list_size = 0; + auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size); + if (cert_list == nullptr) { + return GNUTLS_E_CERTIFICATE_ERROR; + } + + // Check if received peer certificate is awaited + std::vector<std::pair<uint8_t*, uint8_t*>> crt_data; + crt_data.reserve(cert_list_size); + for (unsigned i = 0; i < cert_list_size; i++) + crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size); + crt = dht::crypto::Certificate {crt_data}; + + return GNUTLS_E_SUCCESS; +} + +using lock = std::lock_guard<std::mutex>; + +//============================================================================== + +IceSocketEndpoint::IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender) + : ice_(std::move(ice)) + , iceIsSender(isSender) +{} + +IceSocketEndpoint::~IceSocketEndpoint() +{ + shutdown(); + if (ice_) + dht::ThreadPool::io().run([ice = std::move(ice_)] {}); +} + +void +IceSocketEndpoint::shutdown() +{ + // Sometimes the other peer never send any packet + // So, we cancel pending read to avoid to have + // any blocking operation. + if (ice_) + ice_->cancelOperations(); +} + +int +IceSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const +{ + if (ice_) { + if (!ice_->isRunning()) + return -1; + return ice_->waitForData(compId_, timeout, ec); + } + return -1; +} + +std::size_t +IceSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec) +{ + if (ice_) { + if (!ice_->isRunning()) + return 0; + try { + auto res = ice_->recvfrom(compId_, reinterpret_cast<char*>(buf), len, ec); + if (res < 0) + shutdown(); + return res; + } catch (const std::exception& e) { + if (auto logger = ice_->logger()) + logger->error("IceSocketEndpoint::read exception: %s", e.what()); + } + return 0; + } + return -1; +} + +std::size_t +IceSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec) +{ + if (ice_) { + if (!ice_->isRunning()) + return 0; + auto res = 0; + res = ice_->send(compId_, reinterpret_cast<const unsigned char*>(buf), len); + if (res < 0) { + ec.assign(errno, std::generic_category()); + shutdown(); + } else { + ec.clear(); + } + return res; + } + return -1; +} + +//============================================================================== + +class TlsSocketEndpoint::Impl +{ +public: + static constexpr auto TLS_TIMEOUT = std::chrono::seconds(40); + + Impl(std::unique_ptr<IceSocketEndpoint>&& ep, + tls::CertificateStore& certStore, + const dht::crypto::Certificate& peer_cert, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params) + : peerCertificate {peer_cert} + , ep_ {ep.get()} + { + tls::TlsSession::TlsSessionCallbacks tls_cbs + = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); }, + /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); }, + /*.onCertificatesUpdate = */ + [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) { + onTlsCertificatesUpdate(l, r, n); + }, + /*.verifyCertificate = */ + [this](gnutls_session_t session) { + return verifyCertificate(session); + }}; + tls::TlsParams tls_param = { + /*.ca_list = */ "", + /*.peer_ca = */ nullptr, + /*.cert = */ local_identity.second, + /*.cert_key = */ local_identity.first, + /*.dh_params = */ dh_params, + /*.certStore = */ certStore, + /*.timeout = */ TLS_TIMEOUT, + /*.cert_check = */ nullptr, + }; + tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs); + } + + Impl(std::unique_ptr<IceSocketEndpoint>&& ep, + tls::CertificateStore& certStore, + std::function<bool(const dht::crypto::Certificate&)>&& cert_check, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params) + : peerCertificateCheckFunc {std::move(cert_check)} + , peerCertificate {null_cert} + , ep_ {ep.get()} + { + tls::TlsSession::TlsSessionCallbacks tls_cbs + = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); }, + /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); }, + /*.onCertificatesUpdate = */ + [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) { + onTlsCertificatesUpdate(l, r, n); + }, + /*.verifyCertificate = */ + [this](gnutls_session_t session) { + return verifyCertificate(session); + }}; + tls::TlsParams tls_param = { + /*.ca_list = */ "", + /*.peer_ca = */ nullptr, + /*.cert = */ local_identity.second, + /*.cert_key = */ local_identity.first, + /*.dh_params = */ dh_params, + /*.certStore = */ certStore, + /*.timeout = */ std::chrono::duration_cast<decltype(tls::TlsParams::timeout)>(TLS_TIMEOUT), + /*.cert_check = */ nullptr, + }; + tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs); + } + + ~Impl() + { + { + std::lock_guard<std::mutex> lk(cbMtx_); + onStateChangeCb_ = {}; + onReadyCb_ = {}; + } + tls.reset(); + } + + std::shared_ptr<IceTransport> underlyingICE() const + { + if (ep_) + if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_)) + return iceSocket->underlyingICE(); + return {}; + } + + // TLS callbacks + int verifyCertificate(gnutls_session_t); + void onTlsStateChange(tls::TlsSessionState); + void onTlsRxData(std::vector<uint8_t>&&); + void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int); + + std::mutex cbMtx_ {}; + OnStateChangeCb onStateChangeCb_; + dht::crypto::Certificate null_cert; + std::function<bool(const dht::crypto::Certificate&)> peerCertificateCheckFunc; + const dht::crypto::Certificate& peerCertificate; + std::atomic_bool isReady_ {false}; + OnReadyCb onReadyCb_; + std::unique_ptr<tls::TlsSession> tls; + const IceSocketEndpoint* ep_; +}; + +int +TlsSocketEndpoint::Impl::verifyCertificate(gnutls_session_t session) +{ + dht::crypto::Certificate crt; + auto verified = init_crt(session, crt); + if (verified != GNUTLS_E_SUCCESS) + return verified; + if (peerCertificateCheckFunc) { + if (!peerCertificateCheckFunc(crt)) { + if (const auto& logger = tls->logger()) + logger->error("[TLS-SOCKET] Refusing peer certificate"); + return GNUTLS_E_CERTIFICATE_ERROR; + } + + null_cert = std::move(crt); + } else { + if (crt.getPacked() != peerCertificate.getPacked()) { + if (const auto& logger = tls->logger()) + logger->error("[TLS-SOCKET] Unexpected peer certificate"); + return GNUTLS_E_CERTIFICATE_ERROR; + } + } + + return GNUTLS_E_SUCCESS; +} + +void +TlsSocketEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state) +{ + std::lock_guard<std::mutex> lk(cbMtx_); + if ((state == tls::TlsSessionState::SHUTDOWN || state == tls::TlsSessionState::ESTABLISHED) + && !isReady_) { + isReady_ = true; + if (onReadyCb_) + onReadyCb_(state == tls::TlsSessionState::ESTABLISHED); + } + if (onStateChangeCb_ && !onStateChangeCb_(state)) + onStateChangeCb_ = {}; +} + +void +TlsSocketEndpoint::Impl::onTlsRxData([[maybe_unused]] std::vector<uint8_t>&& buf) +{} + +void +TlsSocketEndpoint::Impl::onTlsCertificatesUpdate([[maybe_unused]] const gnutls_datum_t* local_raw, + [[maybe_unused]] const gnutls_datum_t* remote_raw, + [[maybe_unused]] unsigned int remote_count) +{} + +TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr, + tls::CertificateStore& certStore, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params, + const dht::crypto::Certificate& peer_cert) + : pimpl_ {std::make_unique<Impl>(std::move(tr), certStore, peer_cert, local_identity, dh_params)} +{} + +TlsSocketEndpoint::TlsSocketEndpoint( + std::unique_ptr<IceSocketEndpoint>&& tr, + tls::CertificateStore& certStore, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params, + std::function<bool(const dht::crypto::Certificate&)>&& cert_check) + : pimpl_ { + std::make_unique<Impl>(std::move(tr), certStore, std::move(cert_check), local_identity, dh_params)} +{} + +TlsSocketEndpoint::~TlsSocketEndpoint() {} + +bool +TlsSocketEndpoint::isInitiator() const +{ + if (!pimpl_->tls) { + return false; + } + return pimpl_->tls->isInitiator(); +} + +int +TlsSocketEndpoint::maxPayload() const +{ + if (!pimpl_->tls) { + return -1; + } + return pimpl_->tls->maxPayload(); +} + +std::size_t +TlsSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec) +{ + if (!pimpl_->tls) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + return pimpl_->tls->read(buf, len, ec); +} + +std::size_t +TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec) +{ + if (!pimpl_->tls) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + return pimpl_->tls->write(buf, len, ec); +} + +std::shared_ptr<dht::crypto::Certificate> +TlsSocketEndpoint::peerCertificate() const +{ + if (!pimpl_->tls) + return {}; + return pimpl_->tls->peerCertificate(); +} + +void +TlsSocketEndpoint::waitForReady(const std::chrono::milliseconds& timeout) +{ + if (!pimpl_->tls) { + return; + } + pimpl_->tls->waitForReady(timeout); +} + +int +TlsSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const +{ + if (!pimpl_->tls) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + return pimpl_->tls->waitForData(timeout, ec); +} + +void +TlsSocketEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb) +{ + std::lock_guard<std::mutex> lk(pimpl_->cbMtx_); + pimpl_->onStateChangeCb_ = std::move(cb); +} + +void +TlsSocketEndpoint::setOnReady(std::function<void(bool ok)>&& cb) +{ + std::lock_guard<std::mutex> lk(pimpl_->cbMtx_); + pimpl_->onReadyCb_ = std::move(cb); +} + +void +TlsSocketEndpoint::shutdown() +{ + pimpl_->tls->shutdown(); + if (pimpl_->ep_) { + const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_); + if (iceSocket && iceSocket->underlyingICE()) + iceSocket->underlyingICE()->cancelOperations(); + } +} + +void +TlsSocketEndpoint::monitor() const +{ + if (auto ice = pimpl_->underlyingICE()) + if (auto logger = ice->logger()) + logger->debug("\t- Ice connection: {}", ice->link()); +} + +IpAddr +TlsSocketEndpoint::getLocalAddress() const +{ + if (auto ice = pimpl_->underlyingICE()) + return ice->getLocalAddress(ICE_COMP_ID_SIP_TRANSPORT); + return {}; +} + +IpAddr +TlsSocketEndpoint::getRemoteAddress() const +{ + if (auto ice = pimpl_->underlyingICE()) + return ice->getRemoteAddress(ICE_COMP_ID_SIP_TRANSPORT); + return {}; +} + +} // namespace jami diff --git a/src/peer_connection.h b/src/peer_connection.h new file mode 100644 index 0000000..3798f0c --- /dev/null +++ b/src/peer_connection.h @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "ip_utils.h" +#include "certstore.h" +#include "opendht/crypto.h" +#include "ice_transport.h" +#include "tls_session.h" + +#include <functional> +#include <future> +#include <limits> +#include <map> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +namespace dht { +namespace crypto { +struct PrivateKey; +struct Certificate; +} // namespace crypto +} // namespace dht + +namespace jami { +namespace tls { +class DhParams; +} + +using OnStateChangeCb = std::function<bool(tls::TlsSessionState state)>; +using OnReadyCb = std::function<void(bool ok)>; +using onShutdownCb = std::function<void(void)>; + +//============================================================================== + +class IceSocketEndpoint : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + explicit IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender); + ~IceSocketEndpoint(); + + void shutdown() override; + bool isReliable() const override { return ice_ ? ice_->isRunning() : false; } + bool isInitiator() const override { return ice_ ? ice_->isInitiator() : true; } + int maxPayload() const override + { + return 65536 /* The max for a RTP packet used to wrap data here */; + } + int waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const override; + std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override; + std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override; + + std::shared_ptr<IceTransport> underlyingICE() const { return ice_; } + + void setOnRecv(RecvCb&& cb) override + { + if (ice_) + ice_->setOnRecv(compId_, cb); + } + +private: + std::shared_ptr<IceTransport> ice_ {nullptr}; + std::atomic_bool iceStopped {false}; + std::atomic_bool iceIsSender {false}; + uint8_t compId_ {1}; +}; + +//============================================================================== + +/// Implement a TLS session IO over a system socket +class TlsSocketEndpoint : public GenericSocket<uint8_t> +{ +public: + using SocketType = GenericSocket<uint8_t>; + using Identity = std::pair<std::shared_ptr<dht::crypto::PrivateKey>, + std::shared_ptr<dht::crypto::Certificate>>; + + TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr, + tls::CertificateStore& certStore, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params, + const dht::crypto::Certificate& peer_cert); + TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr, + tls::CertificateStore& certStore, + const Identity& local_identity, + const std::shared_future<tls::DhParams>& dh_params, + std::function<bool(const dht::crypto::Certificate&)>&& cert_check); + ~TlsSocketEndpoint(); + + bool isReliable() const override { return true; } + bool isInitiator() const override; + int maxPayload() const override; + void shutdown() override; + std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override; + std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override; + + std::shared_ptr<dht::crypto::Certificate> peerCertificate() const; + + void setOnRecv(RecvCb&&) override + { + throw std::logic_error("TlsSocketEndpoint::setOnRecv not implemented"); + } + int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override; + + void waitForReady(const std::chrono::milliseconds& timeout = {}); + + void setOnStateChange(OnStateChangeCb&& cb); + void setOnReady(OnReadyCb&& cb); + + IpAddr getLocalAddress() const; + IpAddr getRemoteAddress() const; + + void monitor() const; + +private: + class Impl; + std::unique_ptr<Impl> pimpl_; +}; + +} // namespace jami diff --git a/src/security/certstore.cpp b/src/security/certstore.cpp new file mode 100644 index 0000000..acaa07d --- /dev/null +++ b/src/security/certstore.cpp @@ -0,0 +1,673 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * Author: Vsevolod Ivanov <vsevolod.ivanov@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "certstore.h" +#include "security_const.h" + +#include "fileutils.h" + +#include <opendht/thread_pool.h> +#include <opendht/logger.h> + +#include <gnutls/ocsp.h> + +#include <thread> +#include <sstream> +#include <fmt/format.h> + +namespace jami { +namespace tls { + +CertificateStore::CertificateStore(const std::string& path, std::shared_ptr<Logger> logger) + : logger_(std::move(logger)) + , certPath_(fmt::format("{}/certificates", path)) + , crlPath_(fmt::format("{}/crls", path)) + , ocspPath_(fmt::format("{}/oscp", path)) +{ + fileutils::check_dir(certPath_.c_str()); + fileutils::check_dir(crlPath_.c_str()); + fileutils::check_dir(ocspPath_.c_str()); + loadLocalCertificates(); +} + +unsigned +CertificateStore::loadLocalCertificates() +{ + std::lock_guard<std::mutex> l(lock_); + + auto dir_content = fileutils::readDirectory(certPath_); + unsigned n = 0; + for (const auto& f : dir_content) { + try { + auto crt = std::make_shared<crypto::Certificate>( + fileutils::loadFile(certPath_ + DIR_SEPARATOR_CH + f)); + auto id = crt->getId().toString(); + auto longId = crt->getLongId().toString(); + if (id != f && longId != f) + throw std::logic_error("Certificate id mismatch"); + while (crt) { + id = crt->getId().toString(); + longId = crt->getLongId().toString(); + certs_.emplace(std::move(id), crt); + certs_.emplace(std::move(longId), crt); + loadRevocations(*crt); + crt = crt->issuer; + ++n; + } + } catch (const std::exception& e) { + if (logger_) + logger_->warn("Remove cert. {}", e.what()); + remove(fmt::format("{}/{}", certPath_, f).c_str()); + } + } + if (logger_) + logger_->debug("CertificateStore: loaded {} local certificates.", n); + return n; +} + +void +CertificateStore::loadRevocations(crypto::Certificate& crt) const +{ + auto dir = fmt::format("{:s}/{:s}", crlPath_, crt.getId().toString()); + for (const auto& crl : fileutils::readDirectory(dir)) { + try { + crt.addRevocationList(std::make_shared<crypto::RevocationList>( + fileutils::loadFile(fmt::format("{}/{}", dir, crl)))); + } catch (const std::exception& e) { + if (logger_) + logger_->warn("Can't load revocation list: %s", e.what()); + } + } + auto ocsp_dir = ocspPath_ + DIR_SEPARATOR_CH + crt.getId().toString(); + for (const auto& ocsp : fileutils::readDirectory(ocsp_dir)) { + try { + auto ocsp_filepath = fmt::format("{}/{}", ocsp_dir, ocsp); + if (logger_) logger_->debug("Found {:s}", ocsp_filepath); + auto serial = crt.getSerialNumber(); + if (dht::toHex(serial.data(), serial.size()) != ocsp) + continue; + // Save the response + auto ocspBlob = fileutils::loadFile(ocsp_filepath); + crt.ocspResponse = std::make_shared<dht::crypto::OcspResponse>(ocspBlob.data(), + ocspBlob.size()); + unsigned int status = crt.ocspResponse->getCertificateStatus(); + if (status == GNUTLS_OCSP_CERT_GOOD) { + if (logger_) logger_->debug("Certificate {:s} has good OCSP status", crt.getId()); + } else if (status == GNUTLS_OCSP_CERT_REVOKED) { + if (logger_) logger_->error("Certificate {:s} has revoked OCSP status", crt.getId()); + } else if (status == GNUTLS_OCSP_CERT_UNKNOWN) { + if (logger_) logger_->error("Certificate {:s} has unknown OCSP status", crt.getId()); + } else { + if (logger_) logger_->error("Certificate {:s} has invalid OCSP status", crt.getId()); + } + } catch (const std::exception& e) { + if (logger_) + logger_->warn("Can't load OCSP revocation status: {:s}", e.what()); + } + } +} + +std::vector<std::string> +CertificateStore::getPinnedCertificates() const +{ + std::lock_guard<std::mutex> l(lock_); + + std::vector<std::string> certIds; + certIds.reserve(certs_.size()); + for (const auto& crt : certs_) + certIds.emplace_back(crt.first); + return certIds; +} + +std::shared_ptr<crypto::Certificate> +CertificateStore::getCertificate(const std::string& k) +{ + auto getCertificate_ = [this](const std::string& k) -> std::shared_ptr<crypto::Certificate> { + auto cit = certs_.find(k); + if (cit == certs_.cend()) + return {}; + return cit->second; + }; + std::unique_lock<std::mutex> l(lock_); + auto crt = getCertificate_(k); + // Check if certificate is complete + // If the certificate has been splitted, reconstruct it + auto top_issuer = crt; + while (top_issuer && top_issuer->getUID() != top_issuer->getIssuerUID()) { + if (top_issuer->issuer) { + top_issuer = top_issuer->issuer; + } else if (auto cert = getCertificate_(top_issuer->getIssuerUID())) { + top_issuer->issuer = cert; + top_issuer = cert; + } else { + // In this case, a certificate was not found + if (logger_) + logger_->warn("Incomplete certificate detected {:s}", k); + break; + } + } + return crt; +} + +std::shared_ptr<crypto::Certificate> +CertificateStore::getCertificateLegacy(const std::string& dataDir, const std::string& k) +{ + auto oldPath = fmt::format("{}/certificates/{}", dataDir, k); + if (fileutils::isFile(oldPath)) { + auto crt = std::make_shared<crypto::Certificate>(oldPath); + pinCertificate(crt, true); + return crt; + } + return {}; +} + +std::shared_ptr<crypto::Certificate> +CertificateStore::findCertificateByName(const std::string& name, crypto::NameType type) const +{ + std::unique_lock<std::mutex> l(lock_); + for (auto& i : certs_) { + if (i.second->getName() == name) + return i.second; + if (type != crypto::NameType::UNKNOWN) { + for (const auto& alt : i.second->getAltNames()) + if (alt.first == type and alt.second == name) + return i.second; + } + } + return {}; +} + +std::shared_ptr<crypto::Certificate> +CertificateStore::findCertificateByUID(const std::string& uid) const +{ + std::unique_lock<std::mutex> l(lock_); + for (auto& i : certs_) { + if (i.second->getUID() == uid) + return i.second; + } + return {}; +} + +std::shared_ptr<crypto::Certificate> +CertificateStore::findIssuer(const std::shared_ptr<crypto::Certificate>& crt) const +{ + std::shared_ptr<crypto::Certificate> ret {}; + auto n = crt->getIssuerUID(); + if (not n.empty()) { + if (crt->issuer and crt->issuer->getUID() == n) + ret = crt->issuer; + else + ret = findCertificateByUID(n); + } + if (not ret) { + n = crt->getIssuerName(); + if (not n.empty()) + ret = findCertificateByName(n); + } + if (not ret) + return ret; + unsigned verify_out = 0; + int err = gnutls_x509_crt_verify(crt->cert, &ret->cert, 1, 0, &verify_out); + if (err != GNUTLS_E_SUCCESS) { + if (logger_) + logger_->warn("gnutls_x509_crt_verify failed: {:s}", gnutls_strerror(err)); + return {}; + } + if (verify_out & GNUTLS_CERT_INVALID) + return {}; + return ret; +} + +static std::vector<crypto::Certificate> +readCertificates(const std::string& path, const std::string& crl_path) +{ + std::vector<crypto::Certificate> ret; + if (fileutils::isDirectory(path)) { + auto files = fileutils::readDirectory(path); + for (const auto& file : files) { + auto certs = readCertificates(fmt::format("{}/{}", path, file), crl_path); + ret.insert(std::end(ret), + std::make_move_iterator(std::begin(certs)), + std::make_move_iterator(std::end(certs))); + } + } else { + try { + auto data = fileutils::loadFile(path); + const gnutls_datum_t dt {data.data(), (unsigned) data.size()}; + gnutls_x509_crt_t* certs {nullptr}; + unsigned cert_num {0}; + gnutls_x509_crt_list_import2(&certs, &cert_num, &dt, GNUTLS_X509_FMT_PEM, 0); + for (unsigned i = 0; i < cert_num; i++) + ret.emplace_back(certs[i]); + gnutls_free(certs); + } catch (const std::exception& e) { + }; + } + return ret; +} + +void +CertificateStore::pinCertificatePath(const std::string& path, + std::function<void(const std::vector<std::string>&)> cb) +{ + dht::ThreadPool::computation().run([&, path, cb]() { + auto certs = readCertificates(path, crlPath_); + std::vector<std::string> ids; + std::vector<std::weak_ptr<crypto::Certificate>> scerts; + ids.reserve(certs.size()); + scerts.reserve(certs.size()); + { + std::lock_guard<std::mutex> l(lock_); + + for (auto& cert : certs) { + auto shared = std::make_shared<crypto::Certificate>(std::move(cert)); + scerts.emplace_back(shared); + auto e = certs_.emplace(shared->getId().toString(), shared); + ids.emplace_back(e.first->first); + e = certs_.emplace(shared->getLongId().toString(), shared); + ids.emplace_back(e.first->first); + } + paths_.emplace(path, std::move(scerts)); + } + if (logger_) logger_->d("CertificateStore: loaded %zu certificates from %s.", certs.size(), path.c_str()); + if (cb) + cb(ids); + //emitSignal<libjami::ConfigurationSignal::CertificatePathPinned>(path, ids); + }); +} + +unsigned +CertificateStore::unpinCertificatePath(const std::string& path) +{ + std::lock_guard<std::mutex> l(lock_); + + auto certs = paths_.find(path); + if (certs == std::end(paths_)) + return 0; + unsigned n = 0; + for (const auto& wcert : certs->second) { + if (auto cert = wcert.lock()) { + certs_.erase(cert->getId().toString()); + ++n; + } + } + paths_.erase(certs); + return n; +} + +std::vector<std::string> +CertificateStore::pinCertificate(const std::vector<uint8_t>& cert, bool local) noexcept +{ + try { + return pinCertificate(crypto::Certificate(cert), local); + } catch (const std::exception& e) { + } + return {}; +} + +std::vector<std::string> +CertificateStore::pinCertificate(crypto::Certificate&& cert, bool local) +{ + return pinCertificate(std::make_shared<crypto::Certificate>(std::move(cert)), local); +} + +std::vector<std::string> +CertificateStore::pinCertificate(const std::shared_ptr<crypto::Certificate>& cert, bool local) +{ + bool sig {false}; + std::vector<std::string> ids {}; + { + auto c = cert; + std::lock_guard<std::mutex> l(lock_); + while (c) { + bool inserted; + auto id = c->getId().toString(); + auto longId = c->getLongId().toString(); + decltype(certs_)::iterator it; + std::tie(it, inserted) = certs_.emplace(id, c); + if (not inserted) + it->second = c; + std::tie(it, inserted) = certs_.emplace(longId, c); + if (not inserted) + it->second = c; + if (local) { + for (const auto& crl : c->getRevocationLists()) + pinRevocationList(id, *crl); + } + ids.emplace_back(longId); + ids.emplace_back(id); + c = c->issuer; + sig |= inserted; + } + if (local) { + if (sig) + fileutils::saveFile(certPath_ + DIR_SEPARATOR_CH + ids.front(), cert->getPacked()); + } + } + //for (const auto& id : ids) + // emitSignal<libjami::ConfigurationSignal::CertificatePinned>(id); + return ids; +} + +bool +CertificateStore::unpinCertificate(const std::string& id) +{ + std::lock_guard<std::mutex> l(lock_); + + certs_.erase(id); + return remove((certPath_ + DIR_SEPARATOR_CH + id).c_str()) == 0; +} + +bool +CertificateStore::setTrustedCertificate(const std::string& id, TrustStatus status) +{ + if (status == TrustStatus::TRUSTED) { + if (auto crt = getCertificate(id)) { + trustedCerts_.emplace_back(crt); + return true; + } + } else { + auto tc = std::find_if(trustedCerts_.begin(), + trustedCerts_.end(), + [&](const std::shared_ptr<crypto::Certificate>& crt) { + return crt->getId().toString() == id; + }); + if (tc != trustedCerts_.end()) { + trustedCerts_.erase(tc); + return true; + } + } + return false; +} + +std::vector<gnutls_x509_crt_t> +CertificateStore::getTrustedCertificates() const +{ + std::vector<gnutls_x509_crt_t> crts; + crts.reserve(trustedCerts_.size()); + for (auto& crt : trustedCerts_) + crts.emplace_back(crt->getCopy()); + return crts; +} + +void +CertificateStore::pinRevocationList(const std::string& id, + const std::shared_ptr<dht::crypto::RevocationList>& crl) +{ + try { + if (auto c = getCertificate(id)) + c->addRevocationList(crl); + pinRevocationList(id, *crl); + } catch (...) { + if (logger_) + logger_->warn("Can't add revocation list"); + } +} + +void +CertificateStore::pinRevocationList(const std::string& id, const dht::crypto::RevocationList& crl) +{ + fileutils::check_dir((crlPath_ + DIR_SEPARATOR_CH + id).c_str()); + fileutils::saveFile(crlPath_ + DIR_SEPARATOR_CH + id + DIR_SEPARATOR_CH + + dht::toHex(crl.getNumber()), + crl.getPacked()); +} + +void +CertificateStore::pinOcspResponse(const dht::crypto::Certificate& cert) +{ + if (not cert.ocspResponse) + return; + try { + cert.ocspResponse->getCertificateStatus(); + } catch (dht::crypto::CryptoException& e) { + if (logger_) logger_->error("Failed to read certificate status of OCSP response: {:s}", e.what()); + return; + } + auto id = cert.getId().toString(); + auto serial = cert.getSerialNumber(); + auto serialhex = dht::toHex(serial); + auto dir = ocspPath_ + DIR_SEPARATOR_CH + id; + + if (auto localCert = getCertificate(id)) { + // Update certificate in the local store if relevant + if (localCert.get() != &cert && serial == localCert->getSerialNumber()) { + if (logger_) logger_->d("Updating OCSP for certificate %s in the local store", id.c_str()); + localCert->ocspResponse = cert.ocspResponse; + } + } + + dht::ThreadPool::io().run([l=logger_, + path = dir + DIR_SEPARATOR_CH + serialhex, + dir = std::move(dir), + id = std::move(id), + serialhex = std::move(serialhex), + ocspResponse = cert.ocspResponse] { + if (l) l->d("Saving OCSP Response of device %s with serial %s", id.c_str(), serialhex.c_str()); + std::lock_guard<std::mutex> lock(fileutils::getFileLock(path)); + fileutils::check_dir(dir.c_str()); + fileutils::saveFile(path, ocspResponse->pack()); + }); +} + +TrustStore::PermissionStatus +TrustStore::statusFromStr(const char* str) +{ + if (!std::strcmp(str, libjami::Certificate::Status::ALLOWED)) + return PermissionStatus::ALLOWED; + if (!std::strcmp(str, libjami::Certificate::Status::BANNED)) + return PermissionStatus::BANNED; + return PermissionStatus::UNDEFINED; +} + +const char* +TrustStore::statusToStr(TrustStore::PermissionStatus s) +{ + switch (s) { + case PermissionStatus::ALLOWED: + return libjami::Certificate::Status::ALLOWED; + case PermissionStatus::BANNED: + return libjami::Certificate::Status::BANNED; + case PermissionStatus::UNDEFINED: + default: + return libjami::Certificate::Status::UNDEFINED; + } +} + +TrustStatus +trustStatusFromStr(const char* str) +{ + if (!std::strcmp(str, libjami::Certificate::TrustStatus::TRUSTED)) + return TrustStatus::TRUSTED; + return TrustStatus::UNTRUSTED; +} + +const char* +statusToStr(TrustStatus s) +{ + switch (s) { + case TrustStatus::TRUSTED: + return libjami::Certificate::TrustStatus::TRUSTED; + case TrustStatus::UNTRUSTED: + default: + return libjami::Certificate::TrustStatus::UNTRUSTED; + } +} + +bool +TrustStore::addRevocationList(dht::crypto::RevocationList&& crl) +{ + allowed_.add(crl); + return true; +} + +bool +TrustStore::setCertificateStatus(const std::string& cert_id, + const TrustStore::PermissionStatus status) +{ + return setCertificateStatus(nullptr, cert_id, status, false); +} + +bool +TrustStore::setCertificateStatus(const std::shared_ptr<crypto::Certificate>& cert, + const TrustStore::PermissionStatus status, + bool local) +{ + return setCertificateStatus(cert, cert->getId().toString(), status, local); +} + +bool +TrustStore::setCertificateStatus(std::shared_ptr<crypto::Certificate> cert, + const std::string& cert_id, + const TrustStore::PermissionStatus status, + bool local) +{ + if (cert) + certStore_.pinCertificate(cert, local); + std::lock_guard<std::recursive_mutex> lk(mutex_); + updateKnownCerts(); + bool dirty {false}; + if (status == PermissionStatus::UNDEFINED) { + unknownCertStatus_.erase(cert_id); + dirty = certStatus_.erase(cert_id); + } else { + bool allowed = (status == PermissionStatus::ALLOWED); + auto s = certStatus_.find(cert_id); + if (s == std::end(certStatus_)) { + // Certificate state is currently undefined + if (not cert) + cert = certStore_.getCertificate(cert_id); + if (cert) { + unknownCertStatus_.erase(cert_id); + auto& crt_status = certStatus_[cert_id]; + if (not crt_status.first) + crt_status.first = cert; + crt_status.second.allowed = allowed; + setStoreCertStatus(*cert, allowed); + } else { + // Can't find certificate + unknownCertStatus_[cert_id].allowed = allowed; + } + } else { + // Certificate is already allowed or banned + if (s->second.second.allowed != allowed) { + s->second.second.allowed = allowed; + if (allowed) // Certificate is re-added after ban, rebuld needed + dirty = true; + else + allowed_.remove(*s->second.first, false); + } + } + } + if (dirty) + rebuildTrust(); + return true; +} + +TrustStore::PermissionStatus +TrustStore::getCertificateStatus(const std::string& cert_id) const +{ + std::lock_guard<std::recursive_mutex> lk(mutex_); + auto s = certStatus_.find(cert_id); + if (s == std::end(certStatus_)) { + auto us = unknownCertStatus_.find(cert_id); + if (us == std::end(unknownCertStatus_)) + return PermissionStatus::UNDEFINED; + return us->second.allowed ? PermissionStatus::ALLOWED : PermissionStatus::BANNED; + } + return s->second.second.allowed ? PermissionStatus::ALLOWED : PermissionStatus::BANNED; +} + +std::vector<std::string> +TrustStore::getCertificatesByStatus(TrustStore::PermissionStatus status) const +{ + std::lock_guard<std::recursive_mutex> lk(mutex_); + std::vector<std::string> ret; + for (const auto& i : certStatus_) + if (i.second.second.allowed == (status == TrustStore::PermissionStatus::ALLOWED)) + ret.emplace_back(i.first); + for (const auto& i : unknownCertStatus_) + if (i.second.allowed == (status == TrustStore::PermissionStatus::ALLOWED)) + ret.emplace_back(i.first); + return ret; +} + +bool +TrustStore::isAllowed(const crypto::Certificate& crt, bool allowPublic) +{ + // Match by certificate pinning + std::lock_guard<std::recursive_mutex> lk(mutex_); + bool allowed {allowPublic}; + for (auto c = &crt; c; c = c->issuer.get()) { + auto status = getCertificateStatus(c->getId().toString()); // lock mutex_ + if (status == PermissionStatus::ALLOWED) + allowed = true; + else if (status == PermissionStatus::BANNED) + return false; + } + + // Match by certificate chain + updateKnownCerts(); + auto ret = allowed_.verify(crt); + // Unknown issuer (only that) are accepted if allowPublic is true + if (not ret + and !(allowPublic and ret.result == (GNUTLS_CERT_INVALID | GNUTLS_CERT_SIGNER_NOT_FOUND))) { + if (certStore_.logger()) + certStore_.logger()->warn("%s", ret.toString().c_str()); + return false; + } + + return allowed; +} + +void +TrustStore::updateKnownCerts() +{ + auto i = std::begin(unknownCertStatus_); + while (i != std::end(unknownCertStatus_)) { + if (auto crt = certStore_.getCertificate(i->first)) { + certStatus_.emplace(i->first, std::make_pair(crt, i->second)); + setStoreCertStatus(*crt, i->second.allowed); + i = unknownCertStatus_.erase(i); + } else + ++i; + } +} + +void +TrustStore::setStoreCertStatus(const crypto::Certificate& crt, bool status) +{ + if (status) + allowed_.add(crt); + else + allowed_.remove(crt, false); +} + +void +TrustStore::rebuildTrust() +{ + allowed_ = {}; + for (const auto& c : certStatus_) + setStoreCertStatus(*c.second.first, c.second.second.allowed); +} + +} // namespace tls +} // namespace jami diff --git a/src/security/diffie-hellman.cpp b/src/security/diffie-hellman.cpp new file mode 100644 index 0000000..bc0a854 --- /dev/null +++ b/src/security/diffie-hellman.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "diffie-hellman.h" +#include "logger.h" +#include "fileutils.h" + +#include <chrono> +#include <ciso646> + +namespace jami { +namespace tls { + +DhParams::DhParams(const std::vector<uint8_t>& data) +{ + gnutls_dh_params_t new_params_; + int ret = gnutls_dh_params_init(&new_params_); + if (ret) + throw std::runtime_error(std::string("Error initializing DH params: ") + + gnutls_strerror(ret)); + params_.reset(new_params_); + const gnutls_datum_t dat {(uint8_t*) data.data(), (unsigned) data.size()}; + if (int ret_pem = gnutls_dh_params_import_pkcs3(params_.get(), &dat, GNUTLS_X509_FMT_PEM)) + if (int ret_der = gnutls_dh_params_import_pkcs3(params_.get(), &dat, GNUTLS_X509_FMT_DER)) + throw std::runtime_error(std::string("Error importing DH params: ") + + gnutls_strerror(ret_pem) + " " + gnutls_strerror(ret_der)); +} + +DhParams& +DhParams::operator=(const DhParams& other) +{ + if (not params_) { + // We need a valid DH params pointer for the copy + gnutls_dh_params_t new_params_; + auto err = gnutls_dh_params_init(&new_params_); + if (err != GNUTLS_E_SUCCESS) + throw std::runtime_error(std::string("Error initializing DH params: ") + + gnutls_strerror(err)); + params_.reset(new_params_); + } + + auto err = gnutls_dh_params_cpy(params_.get(), other.get()); + if (err != GNUTLS_E_SUCCESS) + throw std::runtime_error(std::string("Error copying DH params: ") + gnutls_strerror(err)); + + return *this; +} + +std::vector<uint8_t> +DhParams::serialize() const +{ + if (!params_) { + JAMI_WARN("serialize() called on an empty DhParams"); + return {}; + } + gnutls_datum_t out; + if (gnutls_dh_params_export2_pkcs3(params_.get(), GNUTLS_X509_FMT_PEM, &out)) + return {}; + std::vector<uint8_t> ret {out.data, out.data + out.size}; + gnutls_free(out.data); + return ret; +} + +DhParams +DhParams::generate() +{ + using clock = std::chrono::high_resolution_clock; + + auto bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH, + /* GNUTLS_SEC_PARAM_HIGH */ GNUTLS_SEC_PARAM_HIGH); + JAMI_DBG("Generating DH params with %u bits", bits); + auto start = clock::now(); + + gnutls_dh_params_t new_params_; + int ret = gnutls_dh_params_init(&new_params_); + if (ret != GNUTLS_E_SUCCESS) { + JAMI_ERR("Error initializing DH params: %s", gnutls_strerror(ret)); + return {}; + } + DhParams params {new_params_}; + + ret = gnutls_dh_params_generate2(params.get(), bits); + if (ret != GNUTLS_E_SUCCESS) { + JAMI_ERR("Error generating DH params: %s", gnutls_strerror(ret)); + return {}; + } + + std::chrono::duration<double> time_span = clock::now() - start; + JAMI_DBG("Generated DH params with %u bits in %lfs", bits, time_span.count()); + return params; +} + +DhParams +DhParams::loadDhParams(const std::string& path) +{ + std::lock_guard<std::mutex> l(fileutils::getFileLock(path)); + try { + // writeTime throw exception if file doesn't exist + auto duration = std::chrono::system_clock::now() - fileutils::writeTime(path); + if (duration >= std::chrono::hours(24 * 3)) // file is valid only 3 days + throw std::runtime_error("file too old"); + + JAMI_DBG("Loading DhParams from file '%s'", path.c_str()); + return {fileutils::loadFile(path)}; + } catch (const std::exception& e) { + JAMI_DBG("Failed to load DhParams file '%s': %s", path.c_str(), e.what()); + if (auto params = tls::DhParams::generate()) { + try { + fileutils::saveFile(path, params.serialize(), 0600); + JAMI_DBG("Saved DhParams to file '%s'", path.c_str()); + } catch (const std::exception& ex) { + JAMI_WARN("Failed to save DhParams in file '%s': %s", path.c_str(), ex.what()); + } + return params; + } + JAMI_ERR("Can't generate DH params."); + return {}; + } +} + +} // namespace tls +} // namespace jami diff --git a/src/security/security_const.h b/src/security/security_const.h new file mode 100644 index 0000000..fb9541b --- /dev/null +++ b/src/security/security_const.h @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Philippe Proulx <philippe.proulx@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +namespace libjami { + +namespace Certificate { + +namespace Status { +constexpr static char UNDEFINED[] = "UNDEFINED"; +constexpr static char ALLOWED[] = "ALLOWED"; +constexpr static char BANNED[] = "BANNED"; +} // namespace Status + +namespace TrustStatus { +constexpr static char UNTRUSTED[] = "UNTRUSTED"; +constexpr static char TRUSTED[] = "TRUSTED"; +} // namespace TrustStatus + +/** + * Those constantes are used by the ConfigurationManager.validateCertificate method + */ +namespace ChecksNames { +constexpr static char HAS_PRIVATE_KEY[] = "HAS_PRIVATE_KEY"; +constexpr static char EXPIRED[] = "EXPIRED"; +constexpr static char STRONG_SIGNING[] = "STRONG_SIGNING"; +constexpr static char NOT_SELF_SIGNED[] = "NOT_SELF_SIGNED"; +constexpr static char KEY_MATCH[] = "KEY_MATCH"; +constexpr static char PRIVATE_KEY_STORAGE_PERMISSION[] = "PRIVATE_KEY_STORAGE_PERMISSION"; +constexpr static char PUBLIC_KEY_STORAGE_PERMISSION[] = "PUBLIC_KEY_STORAGE_PERMISSION"; +constexpr static char PRIVATE_KEY_DIRECTORY_PERMISSIONS[] = "PRIVATEKEY_DIRECTORY_PERMISSIONS"; +constexpr static char PUBLIC_KEY_DIRECTORY_PERMISSIONS[] = "PUBLICKEY_DIRECTORY_PERMISSIONS"; +constexpr static char PRIVATE_KEY_STORAGE_LOCATION[] = "PRIVATE_KEY_STORAGE_LOCATION"; +constexpr static char PUBLIC_KEY_STORAGE_LOCATION[] = "PUBLIC_KEY_STORAGE_LOCATION"; +constexpr static char PRIVATE_KEY_SELINUX_ATTRIBUTES[] = "PRIVATE_KEY_SELINUX_ATTRIBUTES"; +constexpr static char PUBLIC_KEY_SELINUX_ATTRIBUTES[] = "PUBLIC_KEY_SELINUX_ATTRIBUTES"; +constexpr static char EXIST[] = "EXIST"; +constexpr static char VALID[] = "VALID"; +constexpr static char VALID_AUTHORITY[] = "VALID_AUTHORITY"; +constexpr static char KNOWN_AUTHORITY[] = "KNOWN_AUTHORITY"; +constexpr static char NOT_REVOKED[] = "NOT_REVOKED"; +constexpr static char AUTHORITY_MISMATCH[] = "AUTHORITY_MISMATCH"; +constexpr static char UNEXPECTED_OWNER[] = "UNEXPECTED_OWNER"; +constexpr static char NOT_ACTIVATED[] = "NOT_ACTIVATED"; +} // namespace ChecksNames + +/** + * Those constants are used by the ConfigurationManager.getCertificateDetails method + */ +namespace DetailsNames { +constexpr static char EXPIRATION_DATE[] = "EXPIRATION_DATE"; +constexpr static char ACTIVATION_DATE[] = "ACTIVATION_DATE"; +constexpr static char REQUIRE_PRIVATE_KEY_PASSWORD[] = "REQUIRE_PRIVATE_KEY_PASSWORD"; +constexpr static char PUBLIC_SIGNATURE[] = "PUBLIC_SIGNATURE"; +constexpr static char VERSION_NUMBER[] = "VERSION_NUMBER"; +constexpr static char SERIAL_NUMBER[] = "SERIAL_NUMBER"; +constexpr static char ISSUER[] = "ISSUER"; +constexpr static char SUBJECT_KEY_ALGORITHM[] = "SUBJECT_KEY_ALGORITHM"; +constexpr static char CN[] = "CN"; +constexpr static char N[] = "N"; +constexpr static char O[] = "O"; +constexpr static char SIGNATURE_ALGORITHM[] = "SIGNATURE_ALGORITHM"; +constexpr static char MD5_FINGERPRINT[] = "MD5_FINGERPRINT"; +constexpr static char SHA1_FINGERPRINT[] = "SHA1_FINGERPRINT"; +constexpr static char PUBLIC_KEY_ID[] = "PUBLIC_KEY_ID"; +constexpr static char ISSUER_DN[] = "ISSUER_DN"; +constexpr static char NEXT_EXPECTED_UPDATE_DATE[] = "NEXT_EXPECTED_UPDATE_DATE"; +constexpr static char OUTGOING_SERVER[] = "OUTGOING_SERVER"; +constexpr static char IS_CA[] = "IS_CA"; +} // namespace DetailsNames + +/** + * Those constants are used by the ConfigurationManager.getCertificateDetails and + * ConfigurationManager.validateCertificate methods + */ +namespace ChecksValuesTypesNames { +constexpr static char BOOLEAN[] = "BOOLEAN"; +constexpr static char ISO_DATE[] = "ISO_DATE"; +constexpr static char CUSTOM[] = "CUSTOM"; +constexpr static char NUMBER[] = "NUMBER"; +} // namespace ChecksValuesTypesNames + +/** + * Those constantes are used by the ConfigurationManager.validateCertificate method + */ +namespace CheckValuesNames { +constexpr static char PASSED[] = "PASSED"; +constexpr static char FAILED[] = "FAILED"; +constexpr static char UNSUPPORTED[] = "UNSUPPORTED"; +constexpr static char ISO_DATE[] = "ISO_DATE"; +constexpr static char CUSTOM[] = "CUSTOM"; +constexpr static char DATE[] = "DATE"; +} // namespace CheckValuesNames + +} // namespace Certificate + +namespace TlsTransport { +constexpr static char TLS_PEER_CERT[] = "TLS_PEER_CERT"; +constexpr static char TLS_PEER_CA_NUM[] = "TLS_PEER_CA_NUM"; +constexpr static char TLS_PEER_CA_[] = "TLS_PEER_CA_"; +constexpr static char TLS_CIPHER[] = "TLS_CIPHER"; +} // namespace TlsTransport + +} // namespace libjami diff --git a/src/security/threadloop.cpp b/src/security/threadloop.cpp new file mode 100644 index 0000000..88db725 --- /dev/null +++ b/src/security/threadloop.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <Guillaume.Roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "threadloop.h" + +#include <ciso646> // fix windows compiler bug + +namespace jami { + +void +ThreadLoop::mainloop(std::thread::id& tid, + const std::function<bool()> setup, + const std::function<void()> process, + const std::function<void()> cleanup) +{ + tid = std::this_thread::get_id(); + try { + if (setup()) { + while (state_ == ThreadState::RUNNING) + process(); + cleanup(); + } else { + throw std::runtime_error("setup failed"); + } + } catch (const ThreadLoopException& e) { + if (logger_) logger_->e("[threadloop:{}] ThreadLoopException: {}", fmt::ptr(this), e.what()); + } catch (const std::exception& e) { + if (logger_) logger_->e("[threadloop:{}] Unwaited exception: {}", fmt::ptr(this), e.what()); + } + stop(); +} + +ThreadLoop::ThreadLoop(std::shared_ptr<dht::log::Logger> logger, + const std::function<bool()>& setup, + const std::function<void()>& process, + const std::function<void()>& cleanup) + : setup_(setup) + , process_(process) + , cleanup_(cleanup) + , thread_() + , logger_(std::move(logger)) +{} + +ThreadLoop::~ThreadLoop() +{ + if (isRunning()) { + if (logger_) logger_->error("join() should be explicitly called in owner's destructor"); + join(); + } +} + +void +ThreadLoop::start() +{ + const auto s = state_.load(); + + if (s == ThreadState::RUNNING) { + if (logger_) logger_->error("already started"); + return; + } + + // stop pending but not processed by thread yet? + if (s == ThreadState::STOPPING and thread_.joinable()) { + if (logger_) logger_->debug("stop pending"); + thread_.join(); + } + + state_ = ThreadState::RUNNING; + thread_ = std::thread(&ThreadLoop::mainloop, this, std::ref(threadId_), setup_, process_, cleanup_); + threadId_ = thread_.get_id(); +} + +void +ThreadLoop::stop() +{ + if (state_ == ThreadState::RUNNING) + state_ = ThreadState::STOPPING; +} + +void +ThreadLoop::join() +{ + stop(); + if (thread_.joinable()) + thread_.join(); +} + +void +ThreadLoop::waitForCompletion() +{ + if (thread_.joinable()) + thread_.join(); +} + +void +ThreadLoop::exit() +{ + stop(); + throw ThreadLoopException(); +} + +bool +ThreadLoop::isRunning() const noexcept +{ +#ifdef _WIN32 + return state_ == ThreadState::RUNNING; +#else + return thread_.joinable() and state_ == ThreadState::RUNNING; +#endif +} + +void +InterruptedThreadLoop::stop() +{ + ThreadLoop::stop(); + cv_.notify_one(); +} +} // namespace jami diff --git a/src/security/threadloop.h b/src/security/threadloop.h new file mode 100644 index 0000000..8a7a0c6 --- /dev/null +++ b/src/security/threadloop.h @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Guillaume Roguez <Guillaume.Roguez@savoirfairelinux.com> + * Author: Eloi Bail <Eloi.Bail@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include <atomic> +#include <thread> +#include <functional> +#include <stdexcept> +#include <condition_variable> +#include <mutex> + +#include <opendht/logger.h> + +namespace jami { + +struct ThreadLoopException : public std::runtime_error +{ + ThreadLoopException() + : std::runtime_error("ThreadLoopException") + {} +}; + +class ThreadLoop +{ +public: + enum class ThreadState { READY, RUNNING, STOPPING }; + + ThreadLoop(std::shared_ptr<dht::log::Logger> logger, + const std::function<bool()>& setup, + const std::function<void()>& process, + const std::function<void()>& cleanup); + virtual ~ThreadLoop(); + + void start(); + void exit(); + virtual void stop(); + void join(); + void waitForCompletion(); // thread will stop itself + + bool isRunning() const noexcept; + bool isStopping() const noexcept { return state_ == ThreadState::STOPPING; } + std::thread::id get_id() const noexcept { return threadId_; } + +private: + ThreadLoop(const ThreadLoop&) = delete; + ThreadLoop(ThreadLoop&&) noexcept = delete; + ThreadLoop& operator=(const ThreadLoop&) = delete; + ThreadLoop& operator=(ThreadLoop&&) noexcept = delete; + + // These must be provided by users of ThreadLoop + std::function<bool()> setup_; + std::function<void()> process_; + std::function<void()> cleanup_; + + void mainloop(std::thread::id& tid, + const std::function<bool()> setup, + const std::function<void()> process, + const std::function<void()> cleanup); + + std::atomic<ThreadState> state_ {ThreadState::READY}; + std::thread::id threadId_; + std::thread thread_; + std::shared_ptr<dht::log::Logger> logger_; +}; + +class InterruptedThreadLoop : public ThreadLoop +{ +public: + InterruptedThreadLoop(std::shared_ptr<dht::log::Logger> logger, + const std::function<bool()>& setup, + const std::function<void()>& process, + const std::function<void()>& cleanup) + : ThreadLoop::ThreadLoop(logger, setup, process, cleanup) + {} + + void stop() override; + + void interrupt() noexcept { cv_.notify_one(); } + + template<typename Rep, typename Period> + void wait_for(const std::chrono::duration<Rep, Period>& rel_time) + { + if (std::this_thread::get_id() != get_id()) + throw std::runtime_error("can not call wait_for outside thread context"); + + std::unique_lock<std::mutex> lk(mutex_); + cv_.wait_for(lk, rel_time, [this]() { return isStopping(); }); + } + + template<typename Rep, typename Period, typename Pred> + bool wait_for(const std::chrono::duration<Rep, Period>& rel_time, Pred&& pred) + { + if (std::this_thread::get_id() != get_id()) + throw std::runtime_error("can not call wait_for outside thread context"); + + std::unique_lock<std::mutex> lk(mutex_); + return cv_.wait_for(lk, rel_time, [this, pred] { return isStopping() || pred(); }); + } + + template<typename Pred> + void wait(Pred&& pred) + { + if (std::this_thread::get_id() != get_id()) + throw std::runtime_error("Can not call wait outside thread context"); + + std::unique_lock<std::mutex> lk(mutex_); + cv_.wait(lk, [this, p = std::forward<Pred>(pred)] { return isStopping() || p(); }); + } + +private: + std::mutex mutex_; + std::condition_variable cv_; +}; + +} // namespace jami diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp new file mode 100644 index 0000000..43f623d --- /dev/null +++ b/src/security/tls_session.cpp @@ -0,0 +1,1789 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com> + * Author: Vsevolod Ivanov <vsevolod.ivanov@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "tls_session.h" +#include "threadloop.h" +#include "certstore.h" + +#include <gnutls/gnutls.h> +#include <gnutls/dtls.h> +#include <gnutls/abstract.h> + +#include <gnutls/crypto.h> +#include <gnutls/ocsp.h> +#include <opendht/http.h> +#include <opendht/logger.h> + +#include <list> +#include <mutex> +#include <condition_variable> +#include <utility> +#include <map> +#include <atomic> +#include <iterator> +#include <stdexcept> +#include <algorithm> +#include <cstring> // std::memset + +#include <cstdlib> +#include <unistd.h> + +namespace jami { +namespace tls { + +static constexpr const char* DTLS_CERT_PRIORITY_STRING { + "SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* DTLS_FULL_PRIORITY_STRING { + "SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_" + "PRECEDENCE:%SAFE_RENEGOTIATION"}; +// Note: -GROUP-FFDHE4096:-GROUP-FFDHE6144:-GROUP-FFDHE8192:+GROUP-X25519: +// is added after gnutls 3.6.7, because some safety checks were introduced for FFDHE resulting in a +// performance drop for our usage (2/3s of delay) This performance drop is visible on mobiles devices. + +// Benchmark result (on a computer) +// $gnutls-cli --benchmark-tls-kx +// (TLS1.3)-(DHE-FFDHE3072)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM) 20.48 transactions/sec +// (avg. handshake time: 48.45 ms, sample variance: 0.68) +// (TLS1.3)-(ECDHE-SECP256R1)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM) 208.14 transactions/sec +// (avg. handshake time: 4.01 ms, sample variance: 0.01) +// (TLS1.3)-(ECDHE-X25519)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM) 240.93 transactions/sec +// (avg. handshake time: 4.00 ms, sample variance: 0.00) +static constexpr const char* TLS_CERT_PRIORITY_STRING { + "SECURE192:-RSA:-GROUP-FFDHE4096:-GROUP-FFDHE6144:-GROUP-FFDHE8192:+GROUP-X25519:%SERVER_" + "PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr const char* TLS_FULL_PRIORITY_STRING { + "SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-RSA:-GROUP-FFDHE4096:-GROUP-FFDHE6144:-" + "GROUP-FFDHE8192:+GROUP-X25519:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"}; +static constexpr uint32_t RX_MAX_SIZE {64 * 1024}; // 64k = max size of a UDP packet +static constexpr std::size_t INPUT_MAX_SIZE { + 1000}; // Maximum number of packets to store before dropping (pkt size = DTLS_MTU) +static constexpr ssize_t FLOOD_THRESHOLD {4 * 1024}; +static constexpr auto FLOOD_PAUSE = std::chrono::milliseconds( + 100); // Time to wait after an invalid cookie packet (anti flood attack) +static constexpr size_t HANDSHAKE_MAX_RETRY {64}; +static constexpr auto DTLS_RETRANSMIT_TIMEOUT = std::chrono::milliseconds( + 1000); // Delay between two handshake request on DTLS +static constexpr auto COOKIE_TIMEOUT = std::chrono::seconds( + 10); // Time to wait for a cookie packet from client +static constexpr int MIN_MTU { + 512 - 20 - 8}; // minimal payload size of a DTLS packet carried by an IPv4 packet +static constexpr uint8_t HEARTBEAT_TRIES = 1; // Number of tries at each heartbeat ping send +static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds( + 700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds) +static constexpr auto HEARTBEAT_TOTAL_TIMEOUT + = HEARTBEAT_RETRANS_TIMEOUT + * HEARTBEAT_TRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds) +static constexpr int MISS_ORDERING_LIMIT + = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type) +static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(1500); +static constexpr int ASYMETRIC_TRANSPORT_MTU_OFFSET + = 20; // when client, if your local IP is IPV4 and server is IPV6; you must reduce your MTU to + // avoid packet too big error on server side. the offset is the difference in size of IP headers +static constexpr auto OCSP_REQUEST_TIMEOUT = std::chrono::seconds( + 2); // Time to wait for an ocsp-request + +// Helper to cast any duration into an integer number of milliseconds +template<class Rep, class Period> +static std::chrono::milliseconds::rep +duration2ms(std::chrono::duration<Rep, Period> d) +{ + return std::chrono::duration_cast<std::chrono::milliseconds>(d).count(); +} + +static inline uint64_t +array2uint(const std::array<uint8_t, 8>& a) +{ + uint64_t res = 0; + for (int i = 0; i < 8; ++i) + res = (res << 8) + a[i]; + return res; +} + +//============================================================================== + +namespace { + +class TlsCertificateCredendials +{ + using T = gnutls_certificate_credentials_t; + +public: + TlsCertificateCredendials() + { + int ret = gnutls_certificate_allocate_credentials(&creds_); + if (ret < 0) { + //if (params_.logger) + // params_.logger->e("gnutls_certificate_allocate_credentials() failed with ret=%d", ret); + throw std::bad_alloc(); + } + } + + ~TlsCertificateCredendials() { gnutls_certificate_free_credentials(creds_); } + + operator T() { return creds_; } + +private: + TlsCertificateCredendials(const TlsCertificateCredendials&) = delete; + TlsCertificateCredendials& operator=(const TlsCertificateCredendials&) = delete; + T creds_; +}; + +class TlsAnonymousClientCredendials +{ + using T = gnutls_anon_client_credentials_t; + +public: + TlsAnonymousClientCredendials() + { + int ret = gnutls_anon_allocate_client_credentials(&creds_); + if (ret < 0) { + //if (params_.logger) + // params_.logger->e("gnutls_anon_allocate_client_credentials() failed with ret=%d", ret); + throw std::bad_alloc(); + } + } + + ~TlsAnonymousClientCredendials() { gnutls_anon_free_client_credentials(creds_); } + + operator T() { return creds_; } + +private: + TlsAnonymousClientCredendials(const TlsAnonymousClientCredendials&) = delete; + TlsAnonymousClientCredendials& operator=(const TlsAnonymousClientCredendials&) = delete; + T creds_; +}; + +class TlsAnonymousServerCredendials +{ + using T = gnutls_anon_server_credentials_t; + +public: + TlsAnonymousServerCredendials() + { + int ret = gnutls_anon_allocate_server_credentials(&creds_); + if (ret < 0) { + //if (params_.logger) + // params_.logger->e("gnutls_anon_allocate_server_credentials() failed with ret=%d", ret); + throw std::bad_alloc(); + } + } + + ~TlsAnonymousServerCredendials() { gnutls_anon_free_server_credentials(creds_); } + + operator T() { return creds_; } + +private: + TlsAnonymousServerCredendials(const TlsAnonymousServerCredendials&) = delete; + TlsAnonymousServerCredendials& operator=(const TlsAnonymousServerCredendials&) = delete; + T creds_; +}; + +} // namespace + +//============================================================================== + +class TlsSession::TlsSessionImpl +{ +public: + using clock = std::chrono::steady_clock; + using StateHandler = std::function<TlsSessionState(TlsSessionState state)>; + using OcspVerification = std::function<void(const int status)>; + using HttpResponse = std::function<void(const dht::http::Response& response)>; + + // Constants (ctor init.) + const bool isServer_; + const TlsParams params_; + const TlsSessionCallbacks callbacks_; + const bool anonymous_; + + TlsSessionImpl(std::unique_ptr<SocketType>&& transport, + const TlsParams& params, + const TlsSessionCallbacks& cbs, + bool anonymous); + + ~TlsSessionImpl(); + + const char* typeName() const; + + std::unique_ptr<SocketType> transport_; + + // State protectors + std::mutex stateMutex_; + std::condition_variable stateCondition_; + + // State machine + TlsSessionState handleStateSetup(TlsSessionState state); + TlsSessionState handleStateCookie(TlsSessionState state); + TlsSessionState handleStateHandshake(TlsSessionState state); + TlsSessionState handleStateMtuDiscovery(TlsSessionState state); + TlsSessionState handleStateEstablished(TlsSessionState state); + TlsSessionState handleStateShutdown(TlsSessionState state); + std::map<TlsSessionState, StateHandler> fsmHandlers_ {}; + std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP}; + std::atomic<TlsSessionState> newState_ {TlsSessionState::NONE}; + std::atomic<int> maxPayload_ {-1}; + + // IO GnuTLS <-> ICE + std::mutex rxMutex_ {}; + std::condition_variable rxCv_ {}; + std::list<std::vector<ValueType>> rxQueue_ {}; + + bool flushProcessing_ {false}; ///< protect against recursive call to flushRxQueue + std::vector<ValueType> rawPktBuf_; ///< gnutls incoming packet buffer + uint64_t baseSeq_ {0}; ///< sequence number of first application data packet received + uint64_t lastRxSeq_ {0}; ///< last received and valid packet sequence number + uint64_t gapOffset_ {0}; ///< offset of first byte not received yet + clock::time_point lastReadTime_; + std::map<uint64_t, std::vector<ValueType>> reorderBuffer_ {}; + std::list<clock::time_point> nextFlush_ {}; + + std::size_t send(const ValueType*, std::size_t, std::error_code&); + ssize_t sendRaw(const void*, size_t); + ssize_t sendRawVec(const giovec_t*, int); + ssize_t recvRaw(void*, size_t); + int waitForRawData(std::chrono::milliseconds); + + bool initFromRecordState(int offset = 0); + void handleDataPacket(std::vector<ValueType>&&, uint64_t); + void flushRxQueue(std::unique_lock<std::mutex>&); + + // Statistics + std::atomic<std::size_t> stRxRawPacketCnt_ {0}; + std::atomic<std::size_t> stRxRawBytesCnt_ {0}; + std::atomic<std::size_t> stRxRawPacketDropCnt_ {0}; + std::atomic<std::size_t> stTxRawPacketCnt_ {0}; + std::atomic<std::size_t> stTxRawBytesCnt_ {0}; + void dump_io_stats() const; + + std::unique_ptr<TlsAnonymousClientCredendials> cacred_; // ctor init. + std::unique_ptr<TlsAnonymousServerCredendials> sacred_; // ctor init. + std::unique_ptr<TlsCertificateCredendials> xcred_; // ctor init. + std::mutex sessionReadMutex_; + std::mutex sessionWriteMutex_; + gnutls_session_t session_ {nullptr}; + gnutls_datum_t cookie_key_ {nullptr, 0}; + gnutls_dtls_prestate_st prestate_ {}; + ssize_t cookie_count_ {0}; + + TlsSessionState setupClient(); + TlsSessionState setupServer(); + void initAnonymous(); + void initCredentials(); + bool commonSessionInit(); + + std::shared_ptr<dht::crypto::Certificate> peerCertificate(gnutls_session_t session) const; + + /* + * Implicit certificate validations. + */ + int verifyCertificateWrapper(gnutls_session_t session); + /* + * Verify OCSP (Online Certificate Service Protocol): + */ + void verifyOcsp(const std::string& url, + dht::crypto::Certificate& cert, + gnutls_x509_crt_t issuer, + OcspVerification cb); + /* + * Send OCSP Request to the specified URI. + */ + void sendOcspRequest(const std::string& uri, + std::string body, + std::chrono::seconds timeout, + HttpResponse cb = {}); + + // FSM thread (TLS states) + ThreadLoop thread_; // ctor init. + bool setup(); + void process(); + void cleanup(); + + // Path mtu discovery + std::array<int, 3> MTUS_; + int mtuProbe_; + int hbPingRecved_ {0}; + bool pmtudOver_ {false}; + void pathMtuHeartbeat(); + + std::mutex requestsMtx_; + std::set<std::shared_ptr<dht::http::Request>> requests_; + std::shared_ptr<dht::crypto::Certificate> pCert_ {}; +}; + +TlsSession::TlsSessionImpl::TlsSessionImpl(std::unique_ptr<SocketType>&& transport, + const TlsParams& params, + const TlsSessionCallbacks& cbs, + bool anonymous) + : isServer_(not transport->isInitiator()) + , params_(params) + , callbacks_(cbs) + , anonymous_(anonymous) + , transport_ {std::move(transport)} + , cacred_(nullptr) + , sacred_(nullptr) + , xcred_(nullptr) + , thread_(params.logger, [this] { return setup(); }, [this] { process(); }, [this] { cleanup(); }) +{ + if (not transport_->isReliable()) { + transport_->setOnRecv([this](const ValueType* buf, size_t len) { + std::lock_guard<std::mutex> lk {rxMutex_}; + if (rxQueue_.size() == INPUT_MAX_SIZE) { + rxQueue_.pop_front(); // drop oldest packet if input buffer is full + ++stRxRawPacketDropCnt_; + } + rxQueue_.emplace_back(buf, buf + len); + ++stRxRawPacketCnt_; + stRxRawBytesCnt_ += len; + rxCv_.notify_one(); + return len; + }); + } + + // Run FSM into dedicated thread + thread_.start(); +} + +TlsSession::TlsSessionImpl::~TlsSessionImpl() +{ + state_ = TlsSessionState::SHUTDOWN; + stateCondition_.notify_all(); + rxCv_.notify_all(); + { + std::lock_guard<std::mutex> lock(requestsMtx_); + // requests_ store a shared_ptr, so we need to cancel requests + // to not be stuck in verifyCertificateWrapper + for (auto& request : requests_) + request->cancel(); + requests_.clear(); + } + thread_.join(); + if (not transport_->isReliable()) + transport_->setOnRecv(nullptr); +} + +const char* +TlsSession::TlsSessionImpl::typeName() const +{ + return isServer_ ? "server" : "client"; +} + +void +TlsSession::TlsSessionImpl::dump_io_stats() const +{ + if (params_.logger) + params_.logger->debug("[TLS] RxRawPkt={:d} ({:d} bytes) - TxRawPkt={:d} ({:d} bytes)", + stRxRawPacketCnt_.load(), + stRxRawBytesCnt_.load(), + stTxRawPacketCnt_.load(), + stTxRawBytesCnt_.load()); +} + +TlsSessionState +TlsSession::TlsSessionImpl::setupClient() +{ + int ret; + + if (not transport_->isReliable()) { + ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM); + // uncoment to reactivate PMTUD + // if (params_.logger) + params_.logger->d("[TLS] set heartbeat reception for retrocompatibility check on server"); + // gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND); + } else { + ret = gnutls_init(&session_, GNUTLS_CLIENT); + } + + if (ret != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("[TLS] session init failed: %s", gnutls_strerror(ret)); + return TlsSessionState::SHUTDOWN; + } + + if (not commonSessionInit()) { + return TlsSessionState::SHUTDOWN; + } + + return TlsSessionState::HANDSHAKE; +} + +TlsSessionState +TlsSession::TlsSessionImpl::setupServer() +{ + int ret; + + if (not transport_->isReliable()) { + ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM); + + // uncoment to reactivate PMTUD + // if (params_.logger) + params_.logger->d("[TLS] set heartbeat reception"); + // gnutls_heartbeat_enable(session_, GNUTLS_HB_PEER_ALLOWED_TO_SEND); + + gnutls_dtls_prestate_set(session_, &prestate_); + } else { + ret = gnutls_init(&session_, GNUTLS_SERVER); + } + + if (ret != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("[TLS] session init failed: %s", gnutls_strerror(ret)); + return TlsSessionState::SHUTDOWN; + } + + gnutls_certificate_server_set_request(session_, GNUTLS_CERT_REQUIRE); + + if (not commonSessionInit()) + return TlsSessionState::SHUTDOWN; + + return TlsSessionState::HANDSHAKE; +} + +void +TlsSession::TlsSessionImpl::initAnonymous() +{ + // credentials for handshaking and transmission + if (isServer_) + sacred_.reset(new TlsAnonymousServerCredendials()); + else + cacred_.reset(new TlsAnonymousClientCredendials()); + + // Setup DH-params for anonymous authentification + if (isServer_) { + if (const auto& dh_params = params_.dh_params.get().get()) + gnutls_anon_set_server_dh_params(*sacred_, dh_params); + else + if (params_.logger) + params_.logger->w("[TLS] DH params unavailable"); + } +} + +void +TlsSession::TlsSessionImpl::initCredentials() +{ + int ret; + + // credentials for handshaking and transmission + xcred_.reset(new TlsCertificateCredendials()); + + gnutls_certificate_set_verify_function(*xcred_, [](gnutls_session_t session) -> int { + auto this_ = reinterpret_cast<TlsSessionImpl*>(gnutls_session_get_ptr(session)); + return this_->verifyCertificateWrapper(session); + }); + + // Load user-given CA list + if (not params_.ca_list.empty()) { + // Try PEM format first + ret = gnutls_certificate_set_x509_trust_file(*xcred_, + params_.ca_list.c_str(), + GNUTLS_X509_FMT_PEM); + + // Then DER format + if (ret < 0) + ret = gnutls_certificate_set_x509_trust_file(*xcred_, + params_.ca_list.c_str(), + GNUTLS_X509_FMT_DER); + if (ret < 0) + throw std::runtime_error("can't load CA " + params_.ca_list + ": " + + std::string(gnutls_strerror(ret))); + + if (params_.logger) + params_.logger->d("[TLS] CA list %s loadev", params_.ca_list.c_str()); + } + if (params_.peer_ca) { + auto chain = params_.peer_ca->getChainWithRevocations(); + auto ret = gnutls_certificate_set_x509_trust(*xcred_, + chain.first.data(), + chain.first.size()); + if (not chain.second.empty()) + gnutls_certificate_set_x509_crl(*xcred_, chain.second.data(), chain.second.size()); + if (params_.logger) + params_.logger->debug("[TLS] Peer CA list {:d} ({:d} CRLs): {:d}", + chain.first.size(), + chain.second.size(), + ret); + } + + // Load user-given identity (key and passwd) + if (params_.cert) { + std::vector<gnutls_x509_crt_t> certs; + certs.reserve(3); + auto crt = params_.cert; + while (crt) { + certs.emplace_back(crt->cert); + crt = crt->issuer; + } + + ret = gnutls_certificate_set_x509_key(*xcred_, + certs.data(), + certs.size(), + params_.cert_key->x509_key); + if (ret < 0) + throw std::runtime_error("can't load certificate: " + std::string(gnutls_strerror(ret))); + + if (params_.logger) + params_.logger->d("[TLS] User identity loaded"); + } + + // Setup DH-params (server only, may block on dh_params.get()) + if (isServer_) { + if (const auto& dh_params = params_.dh_params.get().get()) + gnutls_certificate_set_dh_params(*xcred_, dh_params); + else + if (params_.logger) + params_.logger->w("[TLS] DH params unavailable"); // YOMGUI: need to stop? + } +} + +bool +TlsSession::TlsSessionImpl::commonSessionInit() +{ + int ret; + + if (anonymous_) { + // Force anonymous connection, see handleStateHandshake how we handle failures + ret = gnutls_priority_set_direct(session_, + transport_->isReliable() ? TLS_FULL_PRIORITY_STRING + : DTLS_FULL_PRIORITY_STRING, + nullptr); + if (ret != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("[TLS] TLS priority set failed: %s", gnutls_strerror(ret)); + return false; + } + + // Add anonymous credentials + if (isServer_) + ret = gnutls_credentials_set(session_, GNUTLS_CRD_ANON, *sacred_); + else + ret = gnutls_credentials_set(session_, GNUTLS_CRD_ANON, *cacred_); + + if (ret != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("[TLS] anonymous credential set failed: %s", gnutls_strerror(ret)); + return false; + } + } else { + // Use a classic non-encrypted CERTIFICATE exchange method (less anonymous) + ret = gnutls_priority_set_direct(session_, + transport_->isReliable() ? TLS_CERT_PRIORITY_STRING + : DTLS_CERT_PRIORITY_STRING, + nullptr); + if (ret != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("[TLS] TLS priority set failed: %s", gnutls_strerror(ret)); + return false; + } + } + + // Add certificate credentials + ret = gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, *xcred_); + if (ret != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("[TLS] certificate credential set failed: %s", gnutls_strerror(ret)); + return false; + } + gnutls_certificate_send_x509_rdn_sequence(session_, 0); + + if (not transport_->isReliable()) { + // DTLS hanshake timeouts + auto re_tx_timeout = duration2ms(DTLS_RETRANSMIT_TIMEOUT); + gnutls_dtls_set_timeouts(session_, + re_tx_timeout, + std::max(duration2ms(params_.timeout), re_tx_timeout)); + + // gnutls DTLS mtu = maximum payload size given by transport + gnutls_dtls_set_mtu(session_, transport_->maxPayload()); + } + + // Stuff for transport callbacks + gnutls_session_set_ptr(session_, this); + gnutls_transport_set_ptr(session_, this); + gnutls_transport_set_vec_push_function(session_, + [](gnutls_transport_ptr_t t, + const giovec_t* iov, + int iovcnt) -> ssize_t { + auto this_ = reinterpret_cast<TlsSessionImpl*>(t); + return this_->sendRawVec(iov, iovcnt); + }); + gnutls_transport_set_pull_function(session_, + [](gnutls_transport_ptr_t t, void* d, size_t s) -> ssize_t { + auto this_ = reinterpret_cast<TlsSessionImpl*>(t); + return this_->recvRaw(d, s); + }); + gnutls_transport_set_pull_timeout_function(session_, + [](gnutls_transport_ptr_t t, unsigned ms) -> int { + auto this_ = reinterpret_cast<TlsSessionImpl*>(t); + return this_->waitForRawData( + std::chrono::milliseconds(ms)); + }); + // TODO -1 = default else set value + if (transport_->isReliable()) + gnutls_handshake_set_timeout(session_, duration2ms(params_.timeout)); + return true; +} + +std::string +getOcspUrl(gnutls_x509_crt_t cert) +{ + int ret; + gnutls_datum_t aia; + unsigned int seq = 0; + do { + // Extracts the Authority Information Access (AIA) extension, see RFC 5280 section 4.2.2.1 + ret = gnutls_x509_crt_get_authority_info_access(cert, seq++, GNUTLS_IA_OCSP_URI, &aia, NULL); + } while (ret < 0 && ret != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE); + // could also try the issuer if we include ocsp uri into there + if (ret < 0) { + return {}; + } + std::string url((const char*) aia.data, (size_t) aia.size); + gnutls_free(aia.data); + return url; +} + +int +TlsSession::TlsSessionImpl::verifyCertificateWrapper(gnutls_session_t session) +{ + // Perform user-set verification first to avoid flooding with ocsp-requests if peer is denied + int verified; + if (callbacks_.verifyCertificate) { + auto this_ = reinterpret_cast<TlsSessionImpl*>(gnutls_session_get_ptr(session)); + verified = this_->callbacks_.verifyCertificate(session); + if (verified != GNUTLS_E_SUCCESS) + return verified; + } else { + verified = GNUTLS_E_SUCCESS; + } + /* + * Support only x509 format + */ + if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) + return GNUTLS_E_CERTIFICATE_ERROR; + + pCert_ = peerCertificate(session); + if (!pCert_) + return GNUTLS_E_CERTIFICATE_ERROR; + + std::string ocspUrl = getOcspUrl(pCert_->cert); + if (ocspUrl.empty()) { + // Skipping OCSP verification: AIA not found + return verified; + } + + // OCSP (Online Certificate Service Protocol) { + std::promise<int> v; + std::future<int> f = v.get_future(); + + gnutls_x509_crt_t issuer_crt = pCert_->issuer ? pCert_->issuer->cert : nullptr; + verifyOcsp(ocspUrl, *pCert_, issuer_crt, [&](const int status) { + if (status == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) { + // OCSP URI is absent, don't fail the verification by overwritting the user-set one. + if (params_.logger) + params_.logger->w("Skipping OCSP verification %s: request failed", pCert_->getUID().c_str()); + v.set_value(verified); + } else { + if (status != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("OCSP verification failed for %s: %s (%i)", + pCert_->getUID().c_str(), + gnutls_strerror(status), + status); + } + v.set_value(status); + } + }); + f.wait(); + + return f.get(); +} + +void +TlsSession::TlsSessionImpl::verifyOcsp(const std::string& aia_uri, + dht::crypto::Certificate& cert, + gnutls_x509_crt_t issuer, + OcspVerification cb) +{ + if (params_.logger) + params_.logger->d("Certificate's AIA URI: %s", aia_uri.c_str()); + + // Generate OCSP request + std::pair<std::string, dht::Blob> ocsp_req; + try { + ocsp_req = cert.generateOcspRequest(issuer); + } catch (dht::crypto::CryptoException& e) { + if (params_.logger) + params_.logger->e("Failed to generate OCSP request: %s", e.what()); + if (cb) + cb(GNUTLS_E_INVALID_REQUEST); + return; + } + + sendOcspRequest(aia_uri, + std::move(ocsp_req.first), + OCSP_REQUEST_TIMEOUT, + [cb = std::move(cb), &cert, nonce = std::move(ocsp_req.second), this]( + const dht::http::Response& r) { + // Prepare response data + // Verify response validity + if (r.status_code != 200) { + if (params_.logger) + params_.logger->w("HTTP OCSP Request Failed with code %i", r.status_code); + if (cb) + cb(GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE); + return; + } + if (params_.logger) + params_.logger->d("HTTP OCSP Request done!"); + gnutls_ocsp_cert_status_t verify = GNUTLS_OCSP_CERT_UNKNOWN; + try { + cert.ocspResponse = std::make_shared<dht::crypto::OcspResponse>( + (const uint8_t*) r.body.data(), r.body.size()); + if (params_.logger) + params_.logger->d("%s", cert.ocspResponse->toString().c_str()); + verify = cert.ocspResponse->verifyDirect(cert, nonce); + } catch (dht::crypto::CryptoException& e) { + if (params_.logger) + params_.logger->e("Failed to verify OCSP response: %s", e.what()); + } + if (verify == GNUTLS_OCSP_CERT_UNKNOWN) { + // Soft-fail + if (cb) + cb(GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE); + return; + } + int status = GNUTLS_E_SUCCESS; + if (verify == GNUTLS_OCSP_CERT_GOOD) { + if (params_.logger) + params_.logger->d("OCSP verification success!"); + } else { + status = GNUTLS_E_CERTIFICATE_ERROR; + if (params_.logger) + params_.logger->e("OCSP verification: certificate is revoked!"); + } + // Save response into the certificate store + try { + params_.certStore.pinOcspResponse(cert); + } catch (std::exception& e) { + if (params_.logger) + params_.logger->error("{}", e.what()); + } + if (cb) + cb(status); + }); +} + +void +TlsSession::TlsSessionImpl::sendOcspRequest(const std::string& uri, + std::string body, + std::chrono::seconds timeout, + HttpResponse cb) +{ + using namespace dht; + auto request = std::make_shared<http::Request>(*params_.io_context, + uri); //, logger); + request->set_method(restinio::http_method_post()); + request->set_header_field(restinio::http_field_t::user_agent, "Jami"); + request->set_header_field(restinio::http_field_t::accept, "*/*"); + request->set_header_field(restinio::http_field_t::content_type, "application/ocsp-request"); + request->set_body(std::move(body)); + request->set_connection_type(restinio::http_connection_header_t::close); + request->timeout(timeout, [request,l=params_.logger](const asio::error_code& ec) { + if (ec and ec != asio::error::operation_aborted) + if (l) l->error("HTTP OCSP Request timeout with error: {:s}", ec.message()); + request->cancel(); + }); + request->add_on_state_change_callback([this, cb = std::move(cb)](const http::Request::State state, + const http::Response response) { + if (params_.logger) + params_.logger->d("HTTP OCSP Request state=%i status_code=%i", + (unsigned int) state, + response.status_code); + if (state != http::Request::State::DONE) + return; + if (cb) + cb(response); + if (auto request = response.request.lock()) { + std::lock_guard<std::mutex> lock(requestsMtx_); + requests_.erase(request); + } + }); + { + std::lock_guard<std::mutex> lock(requestsMtx_); + requests_.emplace(request); + } + request->send(); +} + +std::shared_ptr<dht::crypto::Certificate> +TlsSession::TlsSessionImpl::peerCertificate(gnutls_session_t session) const +{ + if (!session) + return {}; + /* + * Get the peer's raw certificate (chain) as sent by the peer. + * The first certificate in the list is the peer's certificate, following the issuer's cert. etc. + */ + unsigned int cert_list_size = 0; + auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size); + + if (cert_list == nullptr) + return {}; + std::vector<std::pair<uint8_t*, uint8_t*>> crt_data; + crt_data.reserve(cert_list_size); + for (unsigned i = 0; i < cert_list_size; i++) + crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size); + return std::make_shared<dht::crypto::Certificate>(crt_data); +} + +std::size_t +TlsSession::TlsSessionImpl::send(const ValueType* tx_data, std::size_t tx_size, std::error_code& ec) +{ + std::lock_guard<std::mutex> lk(sessionWriteMutex_); + if (state_ != TlsSessionState::ESTABLISHED) { + ec = std::error_code(GNUTLS_E_INVALID_SESSION, std::system_category()); + return 0; + } + + std::size_t total_written = 0; + std::size_t max_tx_sz; + + if (transport_->isReliable()) + max_tx_sz = tx_size; + else + max_tx_sz = gnutls_dtls_get_data_mtu(session_); + + // Split incoming data into chunck suitable for the underlying transport + while (total_written < tx_size) { + auto chunck_sz = std::min(max_tx_sz, tx_size - total_written); + auto data_seq = tx_data + total_written; + ssize_t nwritten; + do { + nwritten = gnutls_record_send(session_, data_seq, chunck_sz); + } while ((nwritten == GNUTLS_E_INTERRUPTED and state_ != TlsSessionState::SHUTDOWN) + or nwritten == GNUTLS_E_AGAIN); + if (nwritten < 0) { + /* Normally we would have to retry record_send but our internal + * state has not changed, so we have to ask for more data first. + * We will just try again later, although this should never happen. + */ + if (params_.logger) + params_.logger->error("[TLS] send failed (only {} bytes sent): {}", total_written, gnutls_strerror(nwritten)); + ec = std::error_code(nwritten, std::system_category()); + return 0; + } + + total_written += nwritten; + } + + ec.clear(); + return total_written; +} + +// Called by GNUTLS to send encrypted packet to low-level transport. +// Should return a positive number indicating the bytes sent, and -1 on error. +ssize_t +TlsSession::TlsSessionImpl::sendRaw(const void* buf, size_t size) +{ + std::error_code ec; + unsigned retry_count = 0; + do { + auto n = transport_->write(reinterpret_cast<const ValueType*>(buf), size, ec); + if (!ec) { + // log only on success + ++stTxRawPacketCnt_; + stTxRawBytesCnt_ += n; + return n; + } + + if (ec.value() == EAGAIN) { + if (params_.logger) + params_.logger->w("[TLS] EAGAIN from transport, retry#", ++retry_count); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + if (retry_count == 100) { + if (params_.logger) + params_.logger->e("[TLS] excessive retry detected, aborting"); + ec.assign(EIO, std::system_category()); + } + } + } while (ec.value() == EAGAIN); + + // Must be called to pass errno value to GnuTLS on Windows (cf. GnuTLS doc) + gnutls_transport_set_errno(session_, ec.value()); + if (params_.logger) + params_.logger->e("[TLS] transport failure on tx: errno = {}", ec.value()); + return -1; +} + +// Called by GNUTLS to send encrypted packet to low-level transport. +// Should return a positive number indicating the bytes sent, and -1 on error. +ssize_t +TlsSession::TlsSessionImpl::sendRawVec(const giovec_t* iov, int iovcnt) +{ + ssize_t sent = 0; + for (int i = 0; i < iovcnt; ++i) { + const giovec_t& dat = iov[i]; + ssize_t ret = sendRaw(dat.iov_base, dat.iov_len); + if (ret < 0) + return -1; + sent += ret; + } + return sent; +} + +// Called by GNUTLS to receive encrypted packet from low-level transport. +// Should return 0 on connection termination, +// a positive number indicating the number of bytes received, +// and -1 on error. +ssize_t +TlsSession::TlsSessionImpl::recvRaw(void* buf, size_t size) +{ + if (transport_->isReliable()) { + std::error_code ec; + auto count = transport_->read(reinterpret_cast<ValueType*>(buf), size, ec); + if (!ec) + return count; + gnutls_transport_set_errno(session_, ec.value()); + return -1; + } + + std::lock_guard<std::mutex> lk {rxMutex_}; + if (rxQueue_.empty()) { + gnutls_transport_set_errno(session_, EAGAIN); + return -1; + } + + const auto& pkt = rxQueue_.front(); + const std::size_t count = std::min(pkt.size(), size); + std::copy_n(pkt.begin(), count, reinterpret_cast<ValueType*>(buf)); + rxQueue_.pop_front(); + return count; +} + +// Called by GNUTLS to wait for encrypted packet from low-level transport. +// 'timeout' is in milliseconds. +// Should return 0 on timeout, a positive number if data are available for read, or -1 on error. +int +TlsSession::TlsSessionImpl::waitForRawData(std::chrono::milliseconds timeout) +{ + if (transport_->isReliable()) { + std::error_code ec; + auto err = transport_->waitForData(timeout, ec); + if (err <= 0) { + // shutdown? + if (state_ == TlsSessionState::SHUTDOWN) { + gnutls_transport_set_errno(session_, EINTR); + return -1; + } + if (ec) { + gnutls_transport_set_errno(session_, ec.value()); + return -1; + } + return 0; + } + return 1; + } + + // non-reliable uses callback installed with setOnRecv() + std::unique_lock<std::mutex> lk {rxMutex_}; + rxCv_.wait_for(lk, timeout, [this] { + return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN; + }); + if (state_ == TlsSessionState::SHUTDOWN) { + gnutls_transport_set_errno(session_, EINTR); + return -1; + } + if (rxQueue_.empty()) { + if (params_.logger) + params_.logger->error("[TLS] waitForRawData: timeout after {}", timeout); + return 0; + } + return 1; +} + +bool +TlsSession::TlsSessionImpl::initFromRecordState(int offset) +{ + std::array<uint8_t, 8> seq; + if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0]) + != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("[TLS] Fatal-error Unable to read initial state"); + return false; + } + + baseSeq_ = array2uint(seq) + offset; + gapOffset_ = baseSeq_; + lastRxSeq_ = baseSeq_ - 1; + if (params_.logger) + params_.logger->debug("[TLS] Initial sequence number: {:d}", baseSeq_); + return true; +} + +bool +TlsSession::TlsSessionImpl::setup() +{ + // Setup FSM + fsmHandlers_[TlsSessionState::SETUP] = [this](TlsSessionState s) { + return handleStateSetup(s); + }; + fsmHandlers_[TlsSessionState::COOKIE] = [this](TlsSessionState s) { + return handleStateCookie(s); + }; + fsmHandlers_[TlsSessionState::HANDSHAKE] = [this](TlsSessionState s) { + return handleStateHandshake(s); + }; + fsmHandlers_[TlsSessionState::MTU_DISCOVERY] = [this](TlsSessionState s) { + return handleStateMtuDiscovery(s); + }; + fsmHandlers_[TlsSessionState::ESTABLISHED] = [this](TlsSessionState s) { + return handleStateEstablished(s); + }; + fsmHandlers_[TlsSessionState::SHUTDOWN] = [this](TlsSessionState s) { + return handleStateShutdown(s); + }; + + return true; +} + +void +TlsSession::TlsSessionImpl::cleanup() +{ + state_ = TlsSessionState::SHUTDOWN; // be sure to block any user operations + stateCondition_.notify_all(); + + { + std::lock_guard<std::mutex> lk1(sessionReadMutex_); + std::lock_guard<std::mutex> lk2(sessionWriteMutex_); + if (session_) { + if (transport_->isReliable()) + gnutls_bye(session_, GNUTLS_SHUT_RDWR); + else + gnutls_bye(session_, GNUTLS_SHUT_WR); // not wait for a peer answer + gnutls_deinit(session_); + session_ = nullptr; + } + } + + if (cookie_key_.data) + gnutls_free(cookie_key_.data); + + transport_->shutdown(); +} + +TlsSessionState +TlsSession::TlsSessionImpl::handleStateSetup([[maybe_unused]] TlsSessionState state) +{ + if (params_.logger) + params_.logger->d("[TLS] Start %s session", typeName()); + + try { + if (anonymous_) + initAnonymous(); + initCredentials(); + } catch (const std::exception& e) { + if (params_.logger) + params_.logger->e("[TLS] authentifications init failed: %s", e.what()); + return TlsSessionState::SHUTDOWN; + } + + if (not isServer_) + return setupClient(); + + // Extra step for DTLS-like transports + if (transport_ and not transport_->isReliable()) { + gnutls_key_generate(&cookie_key_, GNUTLS_COOKIE_KEY_SIZE); + return TlsSessionState::COOKIE; + } + return setupServer(); +} + +TlsSessionState +TlsSession::TlsSessionImpl::handleStateCookie(TlsSessionState state) +{ + if (params_.logger) + params_.logger->d("[TLS] SYN cookie"); + + std::size_t count; + { + // block until rx packet or shutdown + std::unique_lock<std::mutex> lk {rxMutex_}; + if (!rxCv_.wait_for(lk, COOKIE_TIMEOUT, [this] { + return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN; + })) { + if (params_.logger) + params_.logger->e("[TLS] SYN cookie failed: timeout"); + return TlsSessionState::SHUTDOWN; + } + // Shutdown state? + if (rxQueue_.empty()) + return TlsSessionState::SHUTDOWN; + count = rxQueue_.front().size(); + } + + // Total bytes rx during cookie checking (see flood protection below) + cookie_count_ += count; + + int ret; + + // Peek and verify front packet + { + std::lock_guard<std::mutex> lk {rxMutex_}; + auto& pkt = rxQueue_.front(); + std::memset(&prestate_, 0, sizeof(prestate_)); + ret = gnutls_dtls_cookie_verify(&cookie_key_, nullptr, 0, pkt.data(), pkt.size(), &prestate_); + } + + if (ret < 0) { + gnutls_dtls_cookie_send(&cookie_key_, + nullptr, + 0, + &prestate_, + this, + [](gnutls_transport_ptr_t t, const void* d, size_t s) -> ssize_t { + auto this_ = reinterpret_cast<TlsSessionImpl*>(t); + return this_->sendRaw(d, s); + }); + + // Drop front packet + { + std::lock_guard<std::mutex> lk {rxMutex_}; + rxQueue_.pop_front(); + } + + // Cookie may be sent on multiple network packets + // So we retry until we get a valid cookie. + // To protect against a flood attack we delay each retry after FLOOD_THRESHOLD rx bytes. + if (cookie_count_ >= FLOOD_THRESHOLD) { + if (params_.logger) + params_.logger->warn("[TLS] flood threshold reach (retry in {})", FLOOD_PAUSE); + dump_io_stats(); + std::this_thread::sleep_for(FLOOD_PAUSE); // flood attack protection + } + return state; + } + + if (params_.logger) + params_.logger->d("[TLS] cookie ok"); + + return setupServer(); +} + +TlsSessionState +TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state) +{ + int ret; + size_t retry_count = 0; + if (params_.logger) + params_.logger->debug("[TLS] handshake"); + do { + ret = gnutls_handshake(session_); + } while ((ret == GNUTLS_E_INTERRUPTED or ret == GNUTLS_E_AGAIN) + and ++retry_count < HANDSHAKE_MAX_RETRY + and state_.load() != TlsSessionState::SHUTDOWN); + if (retry_count > 0) { + if (params_.logger) + params_.logger->error("[TLS] handshake retried count: {}", retry_count); + } + + // Stop on fatal error + if (gnutls_error_is_fatal(ret) || state_.load() == TlsSessionState::SHUTDOWN) { + if (params_.logger) + params_.logger->error("[TLS] handshake failed: {:s}", gnutls_strerror(ret)); + return TlsSessionState::SHUTDOWN; + } + + // Continue handshaking on non-fatal error + if (ret != GNUTLS_E_SUCCESS) { + // TODO: handle GNUTLS_E_LARGE_PACKET (MTU must be lowered) + if (ret != GNUTLS_E_AGAIN) + if (params_.logger) + params_.logger->debug("[TLS] non-fatal handshake error: {:s}", gnutls_strerror(ret)); + return state; + } + + // Safe-Renegotiation status shall always be true to prevent MiM attack + // Following https://www.gnutls.org/manual/html_node/Safe-renegotiation.html + // "Unlike TLS 1.2, the server is not allowed to change identities" + // So, we don't have to check the status if we are the client + bool isTLS1_3 = gnutls_protocol_get_version(session_) == GNUTLS_TLS1_3; + if (!isTLS1_3 || (isTLS1_3 && isServer_)) { + if (!gnutls_safe_renegotiation_status(session_)) { + if (params_.logger) + params_.logger->error("[TLS] server identity changed! MiM attack?"); + return TlsSessionState::SHUTDOWN; + } + } + + auto desc = gnutls_session_get_desc(session_); + if (params_.logger) + params_.logger->debug("[TLS] session established: {:s}", desc); + gnutls_free(desc); + + // Anonymous connection? rehandshake immediately with certificate authentification forced + auto cred = gnutls_auth_get_type(session_); + if (cred == GNUTLS_CRD_ANON) { + if (params_.logger) + params_.logger->debug("[TLS] renogotiate with certificate authentification"); + + // Re-setup TLS algorithms priority list with only certificate based cipher suites + ret = gnutls_priority_set_direct(session_, + transport_ and transport_->isReliable() + ? TLS_CERT_PRIORITY_STRING + : DTLS_CERT_PRIORITY_STRING, + nullptr); + if (ret != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->error("[TLS] session TLS cert-only priority set failed: {:s}", gnutls_strerror(ret)); + return TlsSessionState::SHUTDOWN; + } + + // remove anon credentials and re-enable certificate ones + gnutls_credentials_clear(session_); + ret = gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, *xcred_); + if (ret != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->error("[TLS] session credential set failed: {:s}", gnutls_strerror(ret)); + return TlsSessionState::SHUTDOWN; + } + + return state; // handshake + + } else if (cred != GNUTLS_CRD_CERTIFICATE) { + if (params_.logger) + params_.logger->error("[TLS] spurious session credential ({})", cred); + return TlsSessionState::SHUTDOWN; + } + + // Aware about certificates updates + if (callbacks_.onCertificatesUpdate) { + unsigned int remote_count; + auto local = gnutls_certificate_get_ours(session_); + auto remote = gnutls_certificate_get_peers(session_, &remote_count); + callbacks_.onCertificatesUpdate(local, remote, remote_count); + } + + return transport_ and transport_->isReliable() ? TlsSessionState::ESTABLISHED + : TlsSessionState::MTU_DISCOVERY; +} + +TlsSessionState +TlsSession::TlsSessionImpl::handleStateMtuDiscovery([[maybe_unused]] TlsSessionState state) +{ + if (!transport_) { + if (params_.logger) + params_.logger->w("No transport available when discovering the MTU"); + return TlsSessionState::SHUTDOWN; + } + mtuProbe_ = transport_->maxPayload(); + assert(mtuProbe_ >= MIN_MTU); + MTUS_ = {MIN_MTU, std::max((mtuProbe_ + MIN_MTU) / 2, MIN_MTU), mtuProbe_}; + + // retrocompatibility check + if (gnutls_heartbeat_allowed(session_, GNUTLS_HB_LOCAL_ALLOWED_TO_SEND) == 1) { + if (!isServer_) { + pathMtuHeartbeat(); + if (state_ == TlsSessionState::SHUTDOWN) { + if (params_.logger) + params_.logger->e("[TLS] session destroyed while performing PMTUD, shuting down"); + return TlsSessionState::SHUTDOWN; + } + pmtudOver_ = true; + } + } else { + if (params_.logger) + params_.logger->e("[TLS] PEER HEARTBEAT DISABLED: using transport MTU value ", mtuProbe_); + pmtudOver_ = true; + } + + gnutls_dtls_set_mtu(session_, mtuProbe_); + maxPayload_ = gnutls_dtls_get_data_mtu(session_); + + if (pmtudOver_) { + if (params_.logger) + params_.logger->d("[TLS] maxPayload: ", maxPayload_.load()); + if (!initFromRecordState()) + return TlsSessionState::SHUTDOWN; + } + + return TlsSessionState::ESTABLISHED; +} + +/* + * Path MTU discovery heuristic + * heuristic description: + * The two members of the current tls connection will exchange dtls heartbeat messages + * of increasing size until the heartbeat times out which will be considered as a packet + * drop from the network due to the size of the packet. (one retry to test for a buffer issue) + * when timeout happens or all the values have been tested, the mtu will be returned. + * In case of unexpected error the first (and minimal) value of the mtu array + */ +void +TlsSession::TlsSessionImpl::pathMtuHeartbeat() +{ + if (params_.logger) + params_.logger->debug("[TLS] PMTUD: starting probing with {} of retransmission timeout", HEARTBEAT_RETRANS_TIMEOUT); + + gnutls_heartbeat_set_timeouts(session_, + HEARTBEAT_RETRANS_TIMEOUT.count(), + HEARTBEAT_TOTAL_TIMEOUT.count()); + + int errno_send = GNUTLS_E_SUCCESS; + int mtuOffset = 0; + + // when the remote (server) has a IPV6 interface selected by ICE, and local (client) has a IPV4 + // selected, the path MTU discovery triggers errors for packets too big on server side because + // of different IP headers overhead. Hence we have to signal to the TLS session to reduce the + // MTU on client size accordingly. + if (transport_ and transport_->localAddr().isIpv4() and transport_->remoteAddr().isIpv6()) { + mtuOffset = ASYMETRIC_TRANSPORT_MTU_OFFSET; + if (params_.logger) + params_.logger->w("[TLS] local/remote IP protocol version not alike, use an MTU offset of {} bytes to compensate", ASYMETRIC_TRANSPORT_MTU_OFFSET); + } + + mtuProbe_ = MTUS_[0]; + + for (auto mtu : MTUS_) { + gnutls_dtls_set_mtu(session_, mtu); + auto data_mtu = gnutls_dtls_get_data_mtu(session_); + if (params_.logger) + params_.logger->debug("[TLS] PMTUD: mtu {}, payload {}", mtu, data_mtu); + auto bytesToSend = data_mtu - mtuOffset - 3; // want to know why -3? ask gnutls! + + do { + errno_send = gnutls_heartbeat_ping(session_, + bytesToSend, + HEARTBEAT_TRIES, + GNUTLS_HEARTBEAT_WAIT); + } while (errno_send == GNUTLS_E_AGAIN + || (errno_send == GNUTLS_E_INTERRUPTED && state_ != TlsSessionState::SHUTDOWN)); + + if (errno_send != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->debug("[TLS] PMTUD: mtu {} [FAILED]", mtu); + break; + } + + mtuProbe_ = mtu; + if (params_.logger) + params_.logger->debug("[TLS] PMTUD: mtu {} [OK]", mtu); + } + + if (errno_send == GNUTLS_E_TIMEDOUT) { // timeout is considered as a packet loss, then the good + // mtu is the precedent + if (mtuProbe_ == MTUS_[0]) { + if (params_.logger) + params_.logger->warn("[TLS] PMTUD: no response on first ping, using minimal MTU value {}", mtuProbe_); + } else { + if (params_.logger) + params_.logger->warn("[TLS] PMTUD: timed out, using last working mtu {}", mtuProbe_); + } + } else if (errno_send != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->error("[TLS] PMTUD: failed with gnutls error '{}'", gnutls_strerror(errno_send)); + } else { + if (params_.logger) + params_.logger->debug("[TLS] PMTUD: reached maximal value"); + } +} + +void +TlsSession::TlsSessionImpl::handleDataPacket(std::vector<ValueType>&& buf, uint64_t pkt_seq) +{ + // Check for a valid seq. num. delta + int64_t seq_delta = pkt_seq - lastRxSeq_; + if (seq_delta > 0) { + lastRxSeq_ = pkt_seq; + } else { + // too old? + if (seq_delta <= -MISS_ORDERING_LIMIT) { + if (params_.logger) + params_.logger->warn("[TLS] drop old pkt: 0x{:x}", pkt_seq); + return; + } + + // No duplicate check as DTLS prevents that for us (replay protection) + + // accept Out-Of-Order pkt - will be reordered by queue flush operation + if (params_.logger) + params_.logger->warn("[TLS] OOO pkt: 0x{:x}", pkt_seq); + } + + std::unique_lock<std::mutex> lk {rxMutex_}; + auto now = clock::now(); + if (reorderBuffer_.empty()) + lastReadTime_ = now; + reorderBuffer_.emplace(pkt_seq, std::move(buf)); + nextFlush_.emplace_back(now + RX_OOO_TIMEOUT); + rxCv_.notify_one(); + // Try to flush right now as a new packet is available + flushRxQueue(lk); +} + +/// +/// Reorder and push received packet to upper layer +/// +/// \note This method must be called continuously, faster than RX_OOO_TIMEOUT +/// +void +TlsSession::TlsSessionImpl::flushRxQueue(std::unique_lock<std::mutex>& lk) +{ + // RAII bool swap + class GuardedBoolSwap + { + public: + explicit GuardedBoolSwap(bool& var) + : var_ {var} + { + var_ = !var_; + } + ~GuardedBoolSwap() { var_ = !var_; } + + private: + bool& var_; + }; + + if (reorderBuffer_.empty()) + return; + + // Prevent re-entrant access as the callbacks_.onRxData() is called in unprotected region + if (flushProcessing_) + return; + + GuardedBoolSwap swap_flush_processing {flushProcessing_}; + + auto now = clock::now(); + + auto item = std::begin(reorderBuffer_); + auto next_offset = item->first; + + // Wait for next continuous packet until timeout + if ((now - lastReadTime_) >= RX_OOO_TIMEOUT) { + // OOO packet timeout - consider waited packets as lost + if (auto lost = next_offset - gapOffset_) { + if (params_.logger) + params_.logger->warn("[TLS] {:d} lost since 0x{:x}", lost, gapOffset_); + } else if (params_.logger) + params_.logger->warn("[TLS] slow flush"); + } else if (next_offset != gapOffset_) + return; + + // Loop on offset-ordered received packet until a discontinuity in sequence number + while (item != std::end(reorderBuffer_) and item->first <= next_offset) { + auto pkt_offset = item->first; + auto pkt = std::move(item->second); + + // Remove item before unlocking to not trash the item' relationship + next_offset = pkt_offset + 1; + item = reorderBuffer_.erase(item); + + if (callbacks_.onRxData) { + lk.unlock(); + callbacks_.onRxData(std::move(pkt)); + lk.lock(); + } + } + + gapOffset_ = std::max(gapOffset_, next_offset); + lastReadTime_ = now; +} + +TlsSessionState +TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state) +{ + // Nothing to do in reliable mode, so just wait for state change + if (transport_ and transport_->isReliable()) { + auto disconnected = [this]() -> bool { + return state_.load() != TlsSessionState::ESTABLISHED + or newState_.load() != TlsSessionState::NONE; + }; + std::unique_lock<std::mutex> lk(stateMutex_); + stateCondition_.wait(lk, disconnected); + auto oldState = state_.load(); + if (oldState == TlsSessionState::ESTABLISHED) { + auto newState = newState_.load(); + if (newState != TlsSessionState::NONE) { + newState_ = TlsSessionState::NONE; + return newState; + } + } + return oldState; + } + + // block until rx packet or state change + { + std::unique_lock<std::mutex> lk {rxMutex_}; + if (nextFlush_.empty()) + rxCv_.wait(lk, [this] { + return state_ != TlsSessionState::ESTABLISHED or not rxQueue_.empty() + or not nextFlush_.empty(); + }); + else + rxCv_.wait_until(lk, nextFlush_.front(), [this] { + return state_ != TlsSessionState::ESTABLISHED or !rxQueue_.empty(); + }); + state = state_.load(); + if (state != TlsSessionState::ESTABLISHED) + return state; + + if (not nextFlush_.empty()) { + auto now = clock::now(); + if (nextFlush_.front() <= now) { + while (nextFlush_.front() <= now) + nextFlush_.pop_front(); + flushRxQueue(lk); + return state; + } + } + } + + std::array<uint8_t, 8> seq; + rawPktBuf_.resize(RX_MAX_SIZE); + auto ret = gnutls_record_recv_seq(session_, rawPktBuf_.data(), rawPktBuf_.size(), &seq[0]); + + if (ret > 0) { + // Are we in PMTUD phase? + if (!pmtudOver_) { + mtuProbe_ = MTUS_[std::max(0, hbPingRecved_ - 1)]; + gnutls_dtls_set_mtu(session_, mtuProbe_); + maxPayload_ = gnutls_dtls_get_data_mtu(session_); + pmtudOver_ = true; + if (params_.logger) + params_.logger->debug("[TLS] maxPayload: {}", maxPayload_.load()); + + if (!initFromRecordState(-1)) + return TlsSessionState::SHUTDOWN; + } + + rawPktBuf_.resize(ret); + handleDataPacket(std::move(rawPktBuf_), array2uint(seq)); + // no state change + } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) { + if (params_.logger) + params_.logger->d("[TLS] PMTUD: ping received sending pong"); + auto errno_send = gnutls_heartbeat_pong(session_, 0); + + if (errno_send != GNUTLS_E_SUCCESS) { + if (params_.logger) + params_.logger->e("[TLS] PMTUD: failed on pong with error %d: %s", + errno_send, + gnutls_strerror(errno_send)); + } else { + ++hbPingRecved_; + } + // no state change + } else if (ret == 0) { + if (params_.logger) + params_.logger->d("[TLS] eof"); + state = TlsSessionState::SHUTDOWN; + } else if (ret == GNUTLS_E_REHANDSHAKE) { + if (params_.logger) + params_.logger->d("[TLS] re-handshake"); + state = TlsSessionState::HANDSHAKE; + } else if (gnutls_error_is_fatal(ret)) { + if (params_.logger) + params_.logger->e("[TLS] fatal error in recv: %s", gnutls_strerror(ret)); + state = TlsSessionState::SHUTDOWN; + } // else non-fatal error... let's continue + + return state; +} + +TlsSessionState +TlsSession::TlsSessionImpl::handleStateShutdown(TlsSessionState state) +{ + if (params_.logger) + params_.logger->d("[TLS] shutdown"); + + // Stop ourself + thread_.stop(); + return state; +} + +void +TlsSession::TlsSessionImpl::process() +{ + auto old_state = state_.load(); + auto new_state = fsmHandlers_[old_state](old_state); + + // update state_ with taking care for external state change + if (not std::atomic_compare_exchange_strong(&state_, &old_state, new_state)) + new_state = old_state; + + if (old_state != new_state) + stateCondition_.notify_all(); + + if (old_state != new_state and callbacks_.onStateChange) + callbacks_.onStateChange(new_state); +} + +//============================================================================== + +TlsSession::TlsSession(std::unique_ptr<SocketType>&& transport, + const TlsParams& params, + const TlsSessionCallbacks& cbs, + bool anonymous) + + : pimpl_ {std::make_unique<TlsSessionImpl>(std::move(transport), params, cbs, anonymous)} +{} + +TlsSession::~TlsSession() {} + +bool +TlsSession::isInitiator() const +{ + return !pimpl_->isServer_; +} + +bool +TlsSession::isReliable() const +{ + if (!pimpl_->transport_) + return false; + return pimpl_->transport_->isReliable(); +} + +int +TlsSession::maxPayload() const +{ + if (pimpl_->state_ == TlsSessionState::SHUTDOWN) + throw std::runtime_error("Getting maxPayload from non-valid TLS session"); + if (!pimpl_->transport_) + return 0; + return pimpl_->transport_->maxPayload(); +} + +// Called by anyone to stop the connection and the FSM thread +void +TlsSession::shutdown() +{ + pimpl_->newState_ = TlsSessionState::SHUTDOWN; + pimpl_->stateCondition_.notify_all(); + pimpl_->rxCv_.notify_one(); // unblock waiting FSM +} + +std::size_t +TlsSession::write(const ValueType* data, std::size_t size, std::error_code& ec) +{ + return pimpl_->send(data, size, ec); +} + +std::size_t +TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec) +{ + std::errc error; + + if (pimpl_->state_ != TlsSessionState::ESTABLISHED) { + ec = std::make_error_code(std::errc::broken_pipe); + return 0; + } + + while (true) { + ssize_t ret; + { + std::lock_guard<std::mutex> lk(pimpl_->sessionReadMutex_); + if (!pimpl_->session_) + return 0; + ret = gnutls_record_recv(pimpl_->session_, data, size); + } + if (ret > 0) { + ec.clear(); + return ret; + } + + std::lock_guard<std::mutex> lk(pimpl_->stateMutex_); + if (ret == 0) { + if (pimpl_) { + if (pimpl_->params_.logger) + pimpl_->params_.logger->d("[TLS] eof"); + pimpl_->newState_ = TlsSessionState::SHUTDOWN; + pimpl_->stateCondition_.notify_all(); + pimpl_->rxCv_.notify_one(); // unblock waiting FSM + } + error = std::errc::broken_pipe; + break; + } else if (ret == GNUTLS_E_REHANDSHAKE) { + if (pimpl_->params_.logger) + pimpl_->params_.logger->d("[TLS] re-handshake"); + pimpl_->newState_ = TlsSessionState::HANDSHAKE; + pimpl_->rxCv_.notify_one(); // unblock waiting FSM + pimpl_->stateCondition_.notify_all(); + } else if (gnutls_error_is_fatal(ret)) { + if (pimpl_ && pimpl_->state_ != TlsSessionState::SHUTDOWN) { + if (pimpl_->params_.logger) + pimpl_->params_.logger->e("[TLS] fatal error in recv: %s", gnutls_strerror(ret)); + pimpl_->newState_ = TlsSessionState::SHUTDOWN; + pimpl_->stateCondition_.notify_all(); + pimpl_->rxCv_.notify_one(); // unblock waiting FSM + } + error = std::errc::io_error; + break; + } + } + + ec = std::make_error_code(error); + return 0; +} + +void +TlsSession::waitForReady(const duration& timeout) +{ + auto ready = [this]() -> bool { + auto state = pimpl_->state_.load(); + return state == TlsSessionState::ESTABLISHED or state == TlsSessionState::SHUTDOWN; + }; + std::unique_lock<std::mutex> lk(pimpl_->stateMutex_); + if (timeout == duration::zero()) + pimpl_->stateCondition_.wait(lk, ready); + else + pimpl_->stateCondition_.wait_for(lk, timeout, ready); + + if (!ready()) + throw std::logic_error("Invalid state in TlsSession::waitForReady: " + + std::to_string((int) pimpl_->state_.load())); +} + +int +TlsSession::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const +{ + if (!pimpl_->transport_) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + if (!pimpl_->transport_->waitForData(timeout, ec)) + return 0; + return 1; +} + +std::shared_ptr<dht::crypto::Certificate> +TlsSession::peerCertificate() const +{ + return pimpl_->pCert_; +} + +const std::shared_ptr<dht::log::Logger>& +TlsSession::logger() const +{ + return pimpl_->params_.logger; +} + +} // namespace tls +} // namespace jami diff --git a/src/sip_utils.h b/src/sip_utils.h new file mode 100644 index 0000000..6460b70 --- /dev/null +++ b/src/sip_utils.h @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Tristan Matthews <tristan.matthews@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "ip_utils.h" + +#include <utility> +#include <string> +#include <vector> +#include <cstring> // strcmp + +#include <pjsip/sip_msg.h> +#include <pjlib.h> +#include <pj/pool.h> +#include <pjsip/sip_endpoint.h> +#include <pjsip/sip_dialog.h> + +namespace jami { +namespace sip_utils { + +using namespace std::literals; + +// SIP methods. Only list methods that need to be explicitly +// handled + +namespace SIP_METHODS { +constexpr std::string_view MESSAGE = "MESSAGE"sv; +constexpr std::string_view INFO = "INFO"sv; +constexpr std::string_view OPTIONS = "OPTIONS"sv; +constexpr std::string_view PUBLISH = "PUBLISH"sv; +constexpr std::string_view REFER = "REFER"sv; +constexpr std::string_view NOTIFY = "NOTIFY"sv; +} // namespace SIP_METHODS + +static constexpr int DEFAULT_SIP_PORT {5060}; +static constexpr int DEFAULT_SIP_TLS_PORT {5061}; +static constexpr int DEFAULT_AUTO_SELECT_PORT {0}; + +/// PjsipErrorCategory - a PJSIP error category for std::error_code +class PjsipErrorCategory final : public std::error_category +{ +public: + const char* name() const noexcept override { return "pjsip"; } + std::string message(int condition) const override; +}; + +/// PJSIP related exception +/// Based on std::system_error with code() returning std::error_code with PjsipErrorCategory category +class PjsipFailure : public std::system_error +{ +private: + static constexpr const char* what_ = "PJSIP call failed"; + +public: + PjsipFailure() + : std::system_error(std::error_code(PJ_EUNKNOWN, PjsipErrorCategory()), what_) + {} + + explicit PjsipFailure(pj_status_t status) + : std::system_error(std::error_code(status, PjsipErrorCategory()), what_) + {} +}; + + +/** + * Helper function to parser header from incoming sip messages + * @return Header from SIP message + */ +/*std::string fetchHeaderValue(pjsip_msg* msg, const std::string& field); + +pjsip_route_hdr* createRouteSet(const std::string& route, pj_pool_t* hdr_pool); + +std::string_view stripSipUriPrefix(std::string_view sipUri); + +std::string parseDisplayName(const pjsip_name_addr* sip_name_addr); +std::string parseDisplayName(const pjsip_from_hdr* header); +std::string parseDisplayName(const pjsip_contact_hdr* header); + +std::string_view getHostFromUri(std::string_view sipUri); + +void addContactHeader(const std::string& contact, pjsip_tx_data* tdata); +void addUserAgentHeader(const std::string& userAgent, pjsip_tx_data* tdata); +std::string_view getPeerUserAgent(const pjsip_rx_data* rdata); +std::vector<std::string> getPeerAllowMethods(const pjsip_rx_data* rdata); +void logMessageHeaders(const pjsip_hdr* hdr_list);*/ + +std::string_view sip_strerror(pj_status_t code); + +// Helper function that return a constant pj_str_t from an array of any types +// that may be statically casted into char pointer. +// Per convention, the input array is supposed to be null terminated. +template<typename T, std::size_t N> +constexpr const pj_str_t +CONST_PJ_STR(T (&a)[N]) noexcept +{ + return {const_cast<char*>(a), N - 1}; +} + +inline const pj_str_t +CONST_PJ_STR(const std::string& str) noexcept +{ + return {const_cast<char*>(str.c_str()), (pj_ssize_t) str.size()}; +} + +inline constexpr pj_str_t +CONST_PJ_STR(const std::string_view& str) noexcept +{ + return {const_cast<char*>(str.data()), (pj_ssize_t) str.size()}; +} + +inline constexpr std::string_view +as_view(const pj_str_t& str) noexcept +{ + return {str.ptr, (size_t) str.slen}; +} + +// PJSIP dialog locking in RAII way +// Usage: declare local variable like this: sip_utils::PJDialogLock lock {dialog}; +// The lock is kept until the local variable is deleted +class PJDialogLock +{ +public: + explicit PJDialogLock(pjsip_dialog* dialog) + : dialog_(dialog) + { + pjsip_dlg_inc_lock(dialog_); + } + + ~PJDialogLock() { pjsip_dlg_dec_lock(dialog_); } + +private: + PJDialogLock(const PJDialogLock&) = delete; + PJDialogLock& operator=(const PJDialogLock&) = delete; + pjsip_dialog* dialog_ {nullptr}; +}; + +// Helper on PJSIP memory pool allocation from endpoint +// This encapsulate the allocated memory pool inside a unique_ptr +static inline std::unique_ptr<pj_pool_t, decltype(pj_pool_release)&> +smart_alloc_pool(pjsip_endpoint* endpt, const char* const name, pj_size_t initial, pj_size_t inc) +{ + auto pool = pjsip_endpt_create_pool(endpt, name, initial, inc); + if (not pool) + throw std::bad_alloc(); + return std::unique_ptr<pj_pool_t, decltype(pj_pool_release)&>(pool, pj_pool_release); +} + +void sockaddr_to_host_port(pj_pool_t* pool, pjsip_host_port* host_port, const pj_sockaddr* addr); + +static constexpr int POOL_TP_INIT {512}; +static constexpr int POOL_TP_INC {512}; +static constexpr int TRANSPORT_INFO_LENGTH {64}; + +} // namespace sip_utils +} // namespace jami diff --git a/src/string_utils.cpp b/src/string_utils.cpp new file mode 100644 index 0000000..934ff23 --- /dev/null +++ b/src/string_utils.cpp @@ -0,0 +1,167 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Tristan Matthews <tristan.matthews@savoirfairelinux.com> + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "string_utils.h" + +#include <fmt/core.h> +#include <fmt/ranges.h> + +#include <sstream> +#include <cctype> +#include <algorithm> +#include <ostream> +#include <iomanip> +#include <stdexcept> +#include <ios> +#include <charconv> +#include <string_view> +#ifdef _WIN32 +#include <windows.h> +#include <oleauto.h> +#endif + +#include <ciso646> // fix windows compiler bug + +namespace jami { + +#ifdef _WIN32 +std::wstring +to_wstring(const std::string& str, int codePage) +{ + int srcLength = (int) str.length(); + int requiredSize = MultiByteToWideChar(codePage, 0, str.c_str(), srcLength, nullptr, 0); + if (!requiredSize) { + throw std::runtime_error("Can't convert string to wstring"); + } + std::wstring result((size_t) requiredSize, 0); + if (!MultiByteToWideChar(codePage, 0, str.c_str(), srcLength, &(*result.begin()), requiredSize)) { + throw std::runtime_error("Can't convert string to wstring"); + } + return result; +} + +std::string +to_string(const std::wstring& wstr, int codePage) +{ + int srcLength = (int) wstr.length(); + int requiredSize = WideCharToMultiByte(codePage, 0, wstr.c_str(), srcLength, nullptr, 0, 0, 0); + if (!requiredSize) { + throw std::runtime_error("Can't convert wstring to string"); + } + std::string result((size_t) requiredSize, 0); + if (!WideCharToMultiByte( + codePage, 0, wstr.c_str(), srcLength, &(*result.begin()), requiredSize, 0, 0)) { + throw std::runtime_error("Can't convert wstring to string"); + } + return result; +} +#endif + +std::string +to_string(double value) +{ + char buf[64]; + int len = snprintf(buf, sizeof(buf), "%-.*G", 16, value); + if (len <= 0) + throw std::invalid_argument {"can't parse double"}; + return {buf, (size_t) len}; +} + +std::string +to_hex_string(uint64_t id) +{ + return fmt::format("{:016x}", id); +} + +uint64_t +from_hex_string(const std::string& str) +{ + uint64_t id; + if (auto [p, ec] = std::from_chars(str.data(), str.data()+str.size(), id, 16); ec != std::errc()) { + throw std::invalid_argument("Can't parse id: " + str); + } + return id; +} + +std::string_view +trim(std::string_view s) +{ + auto wsfront = std::find_if_not(s.cbegin(), s.cend(), [](int c) { return std::isspace(c); }); + return std::string_view(&*wsfront, std::find_if_not(s.rbegin(), + std::string_view::const_reverse_iterator(wsfront), + [](int c) { return std::isspace(c); }) + .base() - wsfront); +} + +std::vector<unsigned> +split_string_to_unsigned(std::string_view str, char delim) +{ + std::vector<unsigned> output; + for (auto first = str.data(), second = str.data(), last = first + str.size(); second != last && first != last; first = second + 1) { + second = std::find(first, last, delim); + if (first != second) { + unsigned result; + auto [p, ec] = std::from_chars(first, second, result); + if (ec == std::errc()) + output.emplace_back(result); + } + } + return output; +} + +void +string_replace(std::string& str, const std::string& from, const std::string& to) +{ + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); // Handles case where 'to' is a substring of 'from' + } +} + +std::string_view +string_remove_suffix(std::string_view str, char separator) +{ + auto it = str.find(separator); + if (it != std::string_view::npos) + str = str.substr(0, it); + return str; +} + +std::string +string_join(const std::set<std::string>& set, std::string_view separator) +{ + return fmt::format("{}", fmt::join(set, separator)); +} + +std::set<std::string> +string_split_set(std::string& str, std::string_view separator) +{ + std::set<std::string> output; + for (auto first = str.data(), second = str.data(), last = first + str.size(); second != last && first != last; first = second + 1) { + second = std::find_first_of(first, last, std::cbegin(separator), std::cend(separator)); + if (first != second) + output.emplace(first, second - first); + } + return output; +} + +} // namespace jami diff --git a/src/tracepoint/trace-tools.h b/src/tracepoint/trace-tools.h new file mode 100644 index 0000000..ccd65cd --- /dev/null +++ b/src/tracepoint/trace-tools.h @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2022-2023 Savoir-faire Linux Inc. + * + * Author: Olivier Dion <olivier.dion@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#ifdef ENABLE_TRACEPOINTS +/* + * GCC Only. We use these instead of classic __FILE__ and __LINE__ because + * these are evaluated where invoked and not at expansion time. See GCC manual. + */ +# define CURRENT_FILENAME() __builtin_FILE() +# define CURRENT_LINE() __builtin_LINE() +#else +# define CURRENT_FILENAME() "" +# define CURRENT_LINE() 0 +#endif + +#ifdef HAVE_CXXABI_H +#include <cxxabi.h> +#include <string> + +template<typename T> +std::string demangle() +{ + int err; + char *raw; + std::string ret; + + raw = abi::__cxa_demangle(typeid(T).name(), 0, 0, &err); + + if (0 == err) { + ret = raw; + } else { + ret = typeid(T).name(); + } + + std::free(raw); + + return ret; +} + +#else +template<typename T> +std::string demangle() +{ + return typeid(T).name(); +} +#endif diff --git a/src/tracepoint/tracepoint-def.h b/src/tracepoint/tracepoint-def.h new file mode 100644 index 0000000..ed584c2 --- /dev/null +++ b/src/tracepoint/tracepoint-def.h @@ -0,0 +1,237 @@ +#ifdef ENABLE_TRACEPOINTS +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#undef LTTNG_UST_TRACEPOINT_PROVIDER +#define LTTNG_UST_TRACEPOINT_PROVIDER jami + +#undef LTTNG_UST_TRACEPOINT_INCLUDE +#define LTTNG_UST_TRACEPOINT_INCLUDE "src/jami/tracepoint-def.h" + +#if !defined(TRACEPOINT_DEF_H) || defined(LTTNG_UST_TRACEPOINT_HEADER_MULTI_READ) +#define TRACEPOINT_DEF_H + +#include <lttng/tracepoint.h> + + +/* + * Use LTTNG_UST_TRACEPOINT_EVENT(), LTTNG_UST_TRACEPOINT_EVENT_CLASS(), + * LTTNG_UST_TRACEPOINT_EVENT_INSTANCE(), and LTTNG_UST_TRACEPOINT_LOGLEVEL() + * here. + */ + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + scheduled_executor_task_begin, + LTTNG_UST_TP_ARGS( + const char *, executor_name, + const char *, filename, + uint32_t, linum, + uint64_t, cookie + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_string(executor, executor_name) + lttng_ust_field_string(source_filename, filename) + lttng_ust_field_integer(uint32_t, source_line, linum) + lttng_ust_field_integer(uint64_t, cookie, cookie) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + scheduled_executor_task_end, + LTTNG_UST_TP_ARGS(uint64_t, cookie), + LTTNG_UST_TP_FIELDS(lttng_ust_field_integer(uint64_t, cookie, cookie)) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + ice_transport_context, + LTTNG_UST_TP_ARGS( + uint64_t, context + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, ice_context, context) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + ice_transport_send, + LTTNG_UST_TP_ARGS( + uint64_t, context, + unsigned, component, + size_t, len, + const char*, remote_addr + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, ice_context, context) + lttng_ust_field_integer(unsigned, component, component) + lttng_ust_field_integer(size_t, packet_length, len) + lttng_ust_field_string(remote_addr, remote_addr) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + ice_transport_send_status, + LTTNG_UST_TP_ARGS( + int, status + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(int, pj_status, status) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + ice_transport_recv, + LTTNG_UST_TP_ARGS( + uint64_t, context, + unsigned, component, + size_t, len, + const char*, remote_addr + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, ice_context, context) + lttng_ust_field_integer(unsigned, component, component) + lttng_ust_field_integer(size_t, packet_length, len) + lttng_ust_field_string(remote_addr, remote_addr) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + emit_signal, + LTTNG_UST_TP_ARGS( + const char*, signal_type + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_string(signal_type, signal_type) + + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + emit_signal_end, + LTTNG_UST_TP_ARGS( + ), + LTTNG_UST_TP_FIELDS( + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + emit_signal_begin_callback, + LTTNG_UST_TP_ARGS( + const char*, filename, + uint32_t, linum + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_string(source_filename, filename) + lttng_ust_field_integer(uint32_t, source_line, linum) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + emit_signal_end_callback, + LTTNG_UST_TP_ARGS( + ), + LTTNG_UST_TP_FIELDS( + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + audio_input_read_from_device_end, + LTTNG_UST_TP_ARGS( + const char*, id + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16)) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + audio_layer_put_recorded_end, + LTTNG_UST_TP_ARGS( + ), + LTTNG_UST_TP_FIELDS( + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + audio_layer_get_to_play_end, + LTTNG_UST_TP_ARGS( + ), + LTTNG_UST_TP_FIELDS( + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + call_start, + LTTNG_UST_TP_ARGS( + const char*, id + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16)) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + call_end, + LTTNG_UST_TP_ARGS( + const char*, id + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16)) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + conference_begin, + LTTNG_UST_TP_ARGS( + const char*, id + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16)) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + conference_end, + LTTNG_UST_TP_ARGS( + const char*, id + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16)) + ) +) + +LTTNG_UST_TRACEPOINT_EVENT( + jami, + conference_add_participant, + LTTNG_UST_TP_ARGS( + const char*, conference_id, + const char*, participant_id + ), + LTTNG_UST_TP_FIELDS( + lttng_ust_field_integer(uint64_t, id, strtoull(conference_id, NULL, 16)) + lttng_ust_field_integer(uint64_t, participant_id, strtoull(participant_id, NULL, 16)) + ) +) + +#endif /* TRACEPOINT_DEF_H */ + +#include <lttng/tracepoint-event.h> + +#endif diff --git a/src/tracepoint/tracepoint.c b/src/tracepoint/tracepoint.c new file mode 100644 index 0000000..392fb0e --- /dev/null +++ b/src/tracepoint/tracepoint.c @@ -0,0 +1,3 @@ +#define LTTNG_UST_TRACEPOINT_CREATE_PROBES +#define LTTNG_UST_TRACEPOINT_DEFINE +#include "./tracepoint.h" diff --git a/src/tracepoint/tracepoint.h b/src/tracepoint/tracepoint.h new file mode 100644 index 0000000..1e7f9a3 --- /dev/null +++ b/src/tracepoint/tracepoint.h @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2022-2023 Savoir-faire Linux Inc. + * + * Author: Olivier Dion <olivier.dion@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#include "tracepoint-def.h" +#pragma GCC diagnostic pop + +#ifdef ENABLE_TRACEPOINTS + +# ifndef lttng_ust_tracepoint +# define lttng_ust_tracepoint(...) tracepoint(__VA_ARGS__) +# endif + +# ifndef lttng_ust_do_tracepoint +# define lttng_ust_do_tracepoint(...) do_tracepoint(__VA_ARGS__) +# endif + +# ifndef lttng_ust_tracepoint_enabled +# define lttng_ust_tracepoint_enabled(...) tracepoint_enabled(__VA_ARGS__) +# endif + +# define jami_tracepoint(tp_name, ...) \ + lttng_ust_tracepoint(jami, tp_name __VA_OPT__(,) __VA_ARGS__) + +# define jami_tracepoint_if_enabled(tp_name, ...) \ + do { \ + if (lttng_ust_tracepoint_enabled(jami, tp_name)) { \ + lttng_ust_do_tracepoint(jami, \ + tp_name \ + __VA_OPT__(,) \ + __VA_ARGS__); \ + } \ + } \ + while (0) + +#else + +# define jami_tracepoint(...) static_assert(true) +# define jami_tracepoint_if_enabled(...) static_assert(true) + +#endif diff --git a/src/transport/peer_channel.h b/src/transport/peer_channel.h new file mode 100644 index 0000000..5f25123 --- /dev/null +++ b/src/transport/peer_channel.h @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * Authors: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * Guillaume Roguez <guillaume.roguez@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ +#pragma once + +#include <mutex> +#include <condition_variable> +#include <deque> +#include <algorithm> + +namespace jami { + +class PeerChannel +{ +public: + PeerChannel() {} + ~PeerChannel() { stop(); } + PeerChannel(PeerChannel&& o) + { + std::lock_guard<std::mutex> lk(o.mutex_); + stream_ = std::move(o.stream_); + stop_ = o.stop_; + o.cv_.notify_all(); + } + + template<typename Duration> + ssize_t wait(Duration timeout, std::error_code& ec) + { + std::unique_lock<std::mutex> lk {mutex_}; + cv_.wait_for(lk, timeout, [this] { return stop_ or not stream_.empty(); }); + if (stop_) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + ec.clear(); + return stream_.size(); + } + + ssize_t read(char* output, std::size_t size, std::error_code& ec) + { + std::unique_lock<std::mutex> lk {mutex_}; + cv_.wait(lk, [this] { return stop_ or not stream_.empty(); }); + if (stream_.size()) { + auto toRead = std::min(size, stream_.size()); + if (toRead) { + auto endIt = stream_.begin() + toRead; + std::copy(stream_.begin(), endIt, output); + stream_.erase(stream_.begin(), endIt); + } + ec.clear(); + return toRead; + } + if (stop_) { + ec.clear(); + return 0; + } + ec = std::make_error_code(std::errc::resource_unavailable_try_again); + return -1; + } + + ssize_t write(const char* data, std::size_t size, std::error_code& ec) + { + std::lock_guard<std::mutex> lk {mutex_}; + if (stop_) { + ec = std::make_error_code(std::errc::broken_pipe); + return -1; + } + stream_.insert(stream_.end(), data, data + size); + cv_.notify_all(); + ec.clear(); + return size; + } + + void stop() noexcept + { + std::lock_guard<std::mutex> lk {mutex_}; + if (stop_) + return; + stop_ = true; + cv_.notify_all(); + } + +private: + PeerChannel(const PeerChannel& o) = delete; + PeerChannel& operator=(const PeerChannel& o) = delete; + PeerChannel& operator=(PeerChannel&& o) = delete; + + std::mutex mutex_ {}; + std::condition_variable cv_ {}; + std::deque<char> stream_; + bool stop_ {false}; +}; + +} // namespace jami diff --git a/src/upnp/protocol/igd.cpp b/src/upnp/protocol/igd.cpp new file mode 100644 index 0000000..0e2ac90 --- /dev/null +++ b/src/upnp/protocol/igd.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "igd.h" +#include "logger.h" + +namespace jami { +namespace upnp { + +IGD::IGD(NatProtocolType proto) + : protocol_(proto) +{} + +bool +IGD::operator==(IGD& other) const +{ + return localIp_ == other.localIp_ and publicIp_ == other.publicIp_ and uid_ == other.uid_; +} + +void +IGD::setValid(bool valid) +{ + valid_ = valid; + + if (valid) { + // Reset errors counter. + errorsCounter_ = 0; + } else { + JAMI_WARN("IGD %s [%s] was disabled", toString().c_str(), getProtocolName()); + } +} + +bool +IGD::incrementErrorsCounter() +{ + if (not valid_) + return false; + + if (++errorsCounter_ >= MAX_ERRORS_COUNT) { + JAMI_WARN("IGD %s [%s] has too many errors, it will be disabled", + toString().c_str(), + getProtocolName()); + setValid(false); + return false; + } + + return true; +} + +int +IGD::getErrorsCount() const +{ + return errorsCounter_.load(); +} + +} // namespace upnp +} // namespace jami \ No newline at end of file diff --git a/src/upnp/protocol/igd.h b/src/upnp/protocol/igd.h new file mode 100644 index 0000000..33810f8 --- /dev/null +++ b/src/upnp/protocol/igd.h @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include <mutex> + +#include "ip_utils.h" +#include "mapping.h" + +#ifdef _MSC_VER +typedef uint16_t in_port_t; +#endif + +namespace jami { +namespace upnp { + +enum class NatProtocolType { UNKNOWN, PUPNP, NAT_PMP }; + +class IGD +{ +public: + // Max error before moving the IGD to invalid state. + constexpr static int MAX_ERRORS_COUNT = 10; + + IGD(NatProtocolType prot); + virtual ~IGD() = default; + bool operator==(IGD& other) const; + + NatProtocolType getProtocol() const { return protocol_; } + + char const* getProtocolName() const + { + return protocol_ == NatProtocolType::NAT_PMP ? "NAT-PMP" : "UPNP"; + }; + + IpAddr getLocalIp() const + { + std::lock_guard<std::mutex> lock(mutex_); + return localIp_; + } + IpAddr getPublicIp() const + { + std::lock_guard<std::mutex> lock(mutex_); + return publicIp_; + } + void setLocalIp(const IpAddr& addr) + { + std::lock_guard<std::mutex> lock(mutex_); + localIp_ = addr; + } + void setPublicIp(const IpAddr& addr) + { + std::lock_guard<std::mutex> lock(mutex_); + publicIp_ = addr; + } + void setUID(const std::string& uid) + { + std::lock_guard<std::mutex> lock(mutex_); + uid_ = uid; + } + std::string getUID() const + { + std::lock_guard<std::mutex> lock(mutex_); + return uid_; + } + + void setValid(bool valid); + bool isValid() const { return valid_; } + bool incrementErrorsCounter(); + int getErrorsCount() const; + + virtual const std::string toString() const = 0; + +protected: + const NatProtocolType protocol_ {NatProtocolType::UNKNOWN}; + std::atomic_bool valid_ {false}; + std::atomic<int> errorsCounter_ {0}; + + mutable std::mutex mutex_; + IpAddr localIp_ {}; // Local IP of the IGD (typically the same as the gateway). + IpAddr publicIp_ {}; // External/public IP of IGD. + std::string uid_ {}; + +private: + IGD(IGD&& other) = delete; + IGD(IGD& other) = delete; + IGD& operator=(IGD&& other) = delete; + IGD& operator=(IGD& other) = delete; +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/mapping.cpp b/src/upnp/protocol/mapping.cpp new file mode 100644 index 0000000..9b38831 --- /dev/null +++ b/src/upnp/protocol/mapping.cpp @@ -0,0 +1,347 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "mapping.h" +#include "logger.h" + +namespace jami { +namespace upnp { + +Mapping::Mapping(PortType type, uint16_t portExternal, uint16_t portInternal, bool available) + : type_(type) + , externalPort_(portExternal) + , internalPort_(portInternal) + , internalAddr_() + , igd_() + , available_(available) + , state_(MappingState::PENDING) + , notifyCb_(nullptr) + , autoUpdate_(false) +#if HAVE_LIBNATPMP + , renewalTime_(sys_clock::now()) +#endif +{} + +Mapping::Mapping(const Mapping& other) +{ + std::lock_guard<std::mutex> lock(other.mutex_); + + internalAddr_ = other.internalAddr_; + internalPort_ = other.internalPort_; + externalPort_ = other.externalPort_; + type_ = other.type_; + igd_ = other.igd_; + available_ = other.available_; + state_ = other.state_; + notifyCb_ = other.notifyCb_; + autoUpdate_ = other.autoUpdate_; +#if HAVE_LIBNATPMP + renewalTime_ = other.renewalTime_; +#endif +} + +void +Mapping::updateFrom(const Mapping::sharedPtr_t& other) +{ + updateFrom(*other); +} + +void +Mapping::updateFrom(const Mapping& other) +{ + if (type_ != other.type_) { + JAMI_ERR("The source and destination types must match"); + return; + } + + internalAddr_ = std::move(other.internalAddr_); + internalPort_ = other.internalPort_; + externalPort_ = other.externalPort_; + igd_ = other.igd_; + state_ = other.state_; +} + +void +Mapping::setAvailable(bool val) +{ + JAMI_DBG("Changing mapping %s state from %s to %s", + toString().c_str(), + available_ ? "AVAILABLE" : "UNAVAILABLE", + val ? "AVAILABLE" : "UNAVAILABLE"); + + std::lock_guard<std::mutex> lock(mutex_); + available_ = val; +} + +void +Mapping::setState(const MappingState& state) +{ + std::lock_guard<std::mutex> lock(mutex_); + state_ = state; +} + +const char* +Mapping::getStateStr() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return getStateStr(state_); +} + +std::string +Mapping::toString(bool extraInfo) const +{ + std::lock_guard<std::mutex> lock(mutex_); + std::ostringstream descr; + descr << UPNP_MAPPING_DESCRIPTION_PREFIX << "-" << getTypeStr(type_); + descr << ":" << std::to_string(internalPort_); + + if (extraInfo) { + descr << " (state=" << getStateStr(state_) + << ", auto-update=" << (autoUpdate_ ? "YES" : "NO") << ")"; + } + + return descr.str(); +} + +bool +Mapping::isValid() const +{ + std::lock_guard<std::mutex> lock(mutex_); + if (state_ == MappingState::FAILED) + return false; + if (internalPort_ == 0) + return false; + if (externalPort_ == 0) + return false; + if (not igd_ or not igd_->isValid()) + return false; + IpAddr intAddr(internalAddr_); + return intAddr and not intAddr.isLoopback(); +} + +bool +Mapping::hasValidHostAddress() const +{ + std::lock_guard<std::mutex> lock(mutex_); + + IpAddr intAddr(internalAddr_); + return intAddr and not intAddr.isLoopback(); +} + +bool +Mapping::hasPublicAddress() const +{ + std::lock_guard<std::mutex> lock(mutex_); + + return igd_ and igd_->getPublicIp() and not igd_->getPublicIp().isPrivate(); +} + +Mapping::key_t +Mapping::getMapKey() const +{ + std::lock_guard<std::mutex> lock(mutex_); + + key_t mapKey = internalPort_; + if (type_ == PortType::UDP) + mapKey |= 1 << (sizeof(uint16_t) * 8); + return mapKey; +} + +PortType +Mapping::getTypeFromMapKey(key_t key) +{ + return (key >> (sizeof(uint16_t) * 8)) ? PortType::UDP : PortType::TCP; +} + +std::string +Mapping::getExternalAddress() const +{ + std::lock_guard<std::mutex> lock(mutex_); + if (igd_) + return igd_->getPublicIp().toString(); + return {}; +} + +void +Mapping::setExternalPort(uint16_t port) +{ + std::lock_guard<std::mutex> lock(mutex_); + externalPort_ = port; +} + +uint16_t +Mapping::getExternalPort() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return externalPort_; +} + +std::string +Mapping::getExternalPortStr() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return std::to_string(externalPort_); +} + +void +Mapping::setInternalAddress(const std::string& addr) +{ + std::lock_guard<std::mutex> lock(mutex_); + internalAddr_ = addr; +} + +std::string +Mapping::getInternalAddress() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return internalAddr_; +} + +void +Mapping::setInternalPort(uint16_t port) +{ + std::lock_guard<std::mutex> lock(mutex_); + internalPort_ = port; +} + +uint16_t +Mapping::getInternalPort() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return internalPort_; +} + +std::string +Mapping::getInternalPortStr() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return std::to_string(internalPort_); +} + +PortType +Mapping::getType() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return type_; +} + +const char* +Mapping::getTypeStr() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return getTypeStr(type_); +} + +bool +Mapping::isAvailable() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return available_; +} + +std::shared_ptr<IGD> +Mapping::getIgd() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return igd_; +} + +NatProtocolType +Mapping::getProtocol() const +{ + std::lock_guard<std::mutex> lock(mutex_); + if (igd_) + return igd_->getProtocol(); + return NatProtocolType::UNKNOWN; +} +const char* +Mapping::getProtocolName() const +{ + if (igd_) { + if (igd_->getProtocol() == NatProtocolType::NAT_PMP) + return "NAT-PMP"; + if (igd_->getProtocol() == NatProtocolType::PUPNP) + return "PUPNP"; + } + return "UNKNOWN"; +} + +void +Mapping::setIgd(const std::shared_ptr<IGD>& igd) +{ + std::lock_guard<std::mutex> lock(mutex_); + igd_ = igd; +} + +MappingState +Mapping::getState() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return state_; +} + +Mapping::NotifyCallback +Mapping::getNotifyCallback() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return notifyCb_; +} + +void +Mapping::setNotifyCallback(NotifyCallback cb) +{ + std::lock_guard<std::mutex> lock(mutex_); + notifyCb_ = std::move(cb); +} + +void +Mapping::enableAutoUpdate(bool enable) +{ + std::lock_guard<std::mutex> lock(mutex_); + autoUpdate_ = enable; +} + +bool +Mapping::getAutoUpdate() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return autoUpdate_; +} + +#if HAVE_LIBNATPMP +sys_clock::time_point +Mapping::getRenewalTime() const +{ + std::lock_guard<std::mutex> lock(mutex_); + return renewalTime_; +} + +void +Mapping::setRenewalTime(sys_clock::time_point time) +{ + std::lock_guard<std::mutex> lock(mutex_); + renewalTime_ = time; +} +#endif + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/mapping.h b/src/upnp/protocol/mapping.h new file mode 100644 index 0000000..89e46b0 --- /dev/null +++ b/src/upnp/protocol/mapping.h @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "ip_utils.h" +#include "igd.h" + +#include <map> +#include <string> +#include <chrono> +#include <functional> +#include <mutex> + +namespace jami { +namespace upnp { + +using sys_clock = std::chrono::system_clock; + +enum class PortType { TCP, UDP }; +enum class MappingState { PENDING, IN_PROGRESS, FAILED, OPEN }; + +enum class NatProtocolType; +class IGD; + +class Mapping +{ + friend class UPnPContext; + friend class NatPmp; + friend class PUPnP; + +public: + using key_t = uint64_t; + using sharedPtr_t = std::shared_ptr<Mapping>; + using NotifyCallback = std::function<void(sharedPtr_t)>; + + static constexpr char const* MAPPING_STATE_STR[4] {"PENDING", "IN_PROGRESS", "FAILED", "OPEN"}; + static constexpr char const* UPNP_MAPPING_DESCRIPTION_PREFIX {"JAMI"}; + + Mapping(PortType type, + uint16_t portExternal = 0, + uint16_t portInternal = 0, + bool available = true); + Mapping(const Mapping& other); + Mapping(Mapping&& other) = delete; + ~Mapping() = default; + + // Delete operators with confusing semantic. + Mapping& operator=(Mapping&& other) = delete; + bool operator==(const Mapping& other) = delete; + bool operator!=(const Mapping& other) = delete; + bool operator<(const Mapping& other) = delete; + bool operator>(const Mapping& other) = delete; + bool operator<=(const Mapping& other) = delete; + bool operator>=(const Mapping& other) = delete; + + inline explicit operator bool() const { return isValid(); } + + void updateFrom(const Mapping& other); + void updateFrom(const Mapping::sharedPtr_t& other); + std::string getExternalAddress() const; + uint16_t getExternalPort() const; + std::string getExternalPortStr() const; + std::string getInternalAddress() const; + uint16_t getInternalPort() const; + std::string getInternalPortStr() const; + PortType getType() const; + const char* getTypeStr() const; + static const char* getTypeStr(PortType type) { return type == PortType::UDP ? "UDP" : "TCP"; } + std::shared_ptr<IGD> getIgd() const; + NatProtocolType getProtocol() const; + const char* getProtocolName() const; + bool isAvailable() const; + MappingState getState() const; + const char* getStateStr() const; + static const char* getStateStr(MappingState state) + { + return MAPPING_STATE_STR[static_cast<int>(state)]; + } + std::string toString(bool extraInfo = false) const; + bool isValid() const; + bool hasValidHostAddress() const; + bool hasPublicAddress() const; + void setNotifyCallback(NotifyCallback cb); + void enableAutoUpdate(bool enable); + bool getAutoUpdate() const; + key_t getMapKey() const; + static PortType getTypeFromMapKey(key_t key); +#if HAVE_LIBNATPMP + sys_clock::time_point getRenewalTime() const; +#endif + +private: + NotifyCallback getNotifyCallback() const; + void setInternalAddress(const std::string& addr); + void setExternalPort(uint16_t port); + void setInternalPort(uint16_t port); + + void setIgd(const std::shared_ptr<IGD>& igd); + void setAvailable(bool val); + void setState(const MappingState& state); + void updateDescription(); +#if HAVE_LIBNATPMP + void setRenewalTime(sys_clock::time_point time); +#endif + + mutable std::mutex mutex_; + PortType type_ {PortType::UDP}; + uint16_t externalPort_ {0}; + uint16_t internalPort_ {0}; + std::string internalAddr_; + // Protocol and + std::shared_ptr<IGD> igd_; + // Track if the mapping is available to use. + bool available_; + // Track the state of the mapping + MappingState state_; + NotifyCallback notifyCb_; + // If true, a new mapping will be requested on behave of the mapping + // owner when the mapping state changes from "OPEN" to "FAILED". + bool autoUpdate_; +#if HAVE_LIBNATPMP + sys_clock::time_point renewalTime_; +#endif +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/natpmp/nat_pmp.cpp b/src/upnp/protocol/natpmp/nat_pmp.cpp new file mode 100644 index 0000000..21f11ee --- /dev/null +++ b/src/upnp/protocol/natpmp/nat_pmp.cpp @@ -0,0 +1,775 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "nat_pmp.h" + +#if HAVE_LIBNATPMP + +namespace jami { +namespace upnp { + +NatPmp::NatPmp() +{ + JAMI_DBG("NAT-PMP: Instance [%p] created", this); + runOnNatPmpQueue([this] { + threadId_ = getCurrentThread(); + igd_ = std::make_shared<PMPIGD>(); + }); +} + +NatPmp::~NatPmp() +{ + JAMI_DBG("NAT-PMP: Instance [%p] destroyed", this); +} + +void +NatPmp::initNatPmp() +{ + if (not isValidThread()) { + runOnNatPmpQueue([w = weak()] { + if (auto pmpThis = w.lock()) { + pmpThis->initNatPmp(); + } + }); + return; + } + + initialized_ = false; + + { + std::lock_guard<std::mutex> lock(natpmpMutex_); + hostAddress_ = ip_utils::getLocalAddr(AF_INET); + } + + // Local address must be valid. + if (not getHostAddress() or getHostAddress().isLoopback()) { + JAMI_WARN("NAT-PMP: Does not have a valid local address!"); + return; + } + + assert(igd_); + if (igd_->isValid()) { + igd_->setValid(false); + processIgdUpdate(UpnpIgdEvent::REMOVED); + } + + igd_->setLocalIp(IpAddr()); + igd_->setPublicIp(IpAddr()); + igd_->setUID(""); + + JAMI_DBG("NAT-PMP: Trying to initialize IGD"); + + int err = initnatpmp(&natpmpHdl_, 0, 0); + + if (err < 0) { + JAMI_WARN("NAT-PMP: Initializing IGD using default gateway failed!"); + const auto& localGw = ip_utils::getLocalGateway(); + if (not localGw) { + JAMI_WARN("NAT-PMP: Couldn't find valid gateway on local host"); + err = NATPMP_ERR_CANNOTGETGATEWAY; + } else { + JAMI_WARN("NAT-PMP: Trying to initialize using detected gateway %s", + localGw.toString().c_str()); + + struct in_addr inaddr; + inet_pton(AF_INET, localGw.toString().c_str(), &inaddr); + err = initnatpmp(&natpmpHdl_, 1, inaddr.s_addr); + } + } + + if (err < 0) { + JAMI_ERR("NAT-PMP: Can't initialize libnatpmp -> %s", getNatPmpErrorStr(err)); + return; + } + + char addrbuf[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &natpmpHdl_.gateway, addrbuf, sizeof(addrbuf)); + IpAddr igdAddr(addrbuf); + JAMI_DBG("NAT-PMP: Initialized on gateway %s", igdAddr.toString().c_str()); + + // Set the local (gateway) address. + igd_->setLocalIp(igdAddr); + // NAT-PMP protocol does not have UID, but we will set generic + // one debugging purposes. + igd_->setUID("NAT-PMP Gateway"); + + // Search and set the public address. + getIgdPublicAddress(); + + // Update and notify. + if (igd_->isValid()) { + initialized_ = true; + processIgdUpdate(UpnpIgdEvent::ADDED); + }; +} + +void +NatPmp::setObserver(UpnpMappingObserver* obs) +{ + if (not isValidThread()) { + runOnNatPmpQueue([w = weak(), obs] { + if (auto pmpThis = w.lock()) { + pmpThis->setObserver(obs); + } + }); + return; + } + + JAMI_DBG("NAT-PMP: Setting observer to %p", obs); + + observer_ = obs; +} + +void +NatPmp::terminate(std::condition_variable& cv) +{ + initialized_ = false; + observer_ = nullptr; + + { + std::lock_guard<std::mutex> lock(natpmpMutex_); + shutdownComplete_ = true; + cv.notify_one(); + } +} + +void +NatPmp::terminate() +{ + std::unique_lock<std::mutex> lk(natpmpMutex_); + std::condition_variable cv {}; + + runOnNatPmpQueue([w = weak(), &cv = cv] { + if (auto pmpThis = w.lock()) { + pmpThis->terminate(cv); + } + }); + + if (cv.wait_for(lk, std::chrono::seconds(10), [this] { return shutdownComplete_; })) { + JAMI_DBG("NAT-PMP: Shutdown completed"); + } else { + JAMI_ERR("NAT-PMP: Shutdown timed-out"); + } +} + +const IpAddr +NatPmp::getHostAddress() const +{ + std::lock_guard<std::mutex> lock(natpmpMutex_); + return hostAddress_; +} + +void +NatPmp::clearIgds() +{ + if (not isValidThread()) { + runOnNatPmpQueue([w = weak()] { + if (auto pmpThis = w.lock()) { + pmpThis->clearIgds(); + } + }); + return; + } + + bool do_close = false; + + if (igd_) { + if (igd_->isValid()) { + do_close = true; + } + igd_->setValid(false); + } + + initialized_ = false; + if (searchForIgdTimer_) + searchForIgdTimer_->cancel(); + + igdSearchCounter_ = 0; + + if (do_close) { + closenatpmp(&natpmpHdl_); + memset(&natpmpHdl_, 0, sizeof(natpmpHdl_)); + } +} + +void +NatPmp::searchForIgd() +{ + if (not isValidThread()) { + runOnNatPmpQueue([w = weak()] { + if (auto pmpThis = w.lock()) { + pmpThis->searchForIgd(); + } + }); + return; + } + + if (not initialized_) { + initNatPmp(); + } + + // Schedule a retry in case init failed. + if (not initialized_) { + if (igdSearchCounter_++ < MAX_RESTART_SEARCH_RETRIES) { + JAMI_DBG("NAT-PMP: Start search for IGDs. Attempt %i", igdSearchCounter_); + + // Cancel the current timer (if any) and re-schedule. + if (searchForIgdTimer_) + searchForIgdTimer_->cancel(); + + searchForIgdTimer_ = getNatpmpScheduler()->scheduleIn([this] { searchForIgd(); }, + NATPMP_SEARCH_RETRY_UNIT + * igdSearchCounter_); + } else { + JAMI_WARN("NAT-PMP: Setup failed after %u trials. NAT-PMP will be disabled!", + MAX_RESTART_SEARCH_RETRIES); + } + } +} + +std::list<std::shared_ptr<IGD>> +NatPmp::getIgdList() const +{ + std::lock_guard<std::mutex> lock(natpmpMutex_); + std::list<std::shared_ptr<IGD>> igdList; + if (igd_->isValid()) + igdList.emplace_back(igd_); + return igdList; +} + +bool +NatPmp::isReady() const +{ + if (observer_ == nullptr) { + JAMI_ERR("NAT-PMP: the observer is not set!"); + return false; + } + + // Must at least have a valid local address. + if (not getHostAddress() or getHostAddress().isLoopback()) + return false; + + return igd_ and igd_->isValid(); +} + +void +NatPmp::incrementErrorsCounter(const std::shared_ptr<IGD>& igdIn) +{ + if (not validIgdInstance(igdIn)) { + return; + } + + if (not igd_->isValid()) { + // Already invalid. Nothing to do. + return; + } + + if (not igd_->incrementErrorsCounter()) { + // Disable this IGD. + igd_->setValid(false); + // Notify the listener. + JAMI_WARN("NAT-PMP: No more valid IGD!"); + + processIgdUpdate(UpnpIgdEvent::INVALID_STATE); + } +} + +void +NatPmp::requestMappingAdd(const Mapping& mapping) +{ + // Process on nat-pmp thread. + if (not isValidThread()) { + runOnNatPmpQueue([w = weak(), mapping] { + if (auto pmpThis = w.lock()) { + pmpThis->requestMappingAdd(mapping); + } + }); + return; + } + + Mapping map(mapping); + assert(map.getIgd()); + auto err = addPortMapping(map); + if (err < 0) { + JAMI_WARN("NAT-PMP: Request for mapping %s on %s failed with error %i: %s", + map.toString().c_str(), + igd_->toString().c_str(), + err, + getNatPmpErrorStr(err)); + + if (isErrorFatal(err)) { + // Fatal error, increment the counter. + incrementErrorsCounter(igd_); + } + // Notify the listener. + processMappingRequestFailed(std::move(map)); + } else { + JAMI_DBG("NAT-PMP: Request for mapping %s on %s succeeded", + map.toString().c_str(), + igd_->toString().c_str()); + // Notify the listener. + processMappingAdded(std::move(map)); + } +} + +void +NatPmp::requestMappingRenew(const Mapping& mapping) +{ + // Process on nat-pmp thread. + if (not isValidThread()) { + runOnNatPmpQueue([w = weak(), mapping] { + if (auto pmpThis = w.lock()) { + pmpThis->requestMappingRenew(mapping); + } + }); + return; + } + + Mapping map(mapping); + auto err = addPortMapping(map); + if (err < 0) { + JAMI_WARN("NAT-PMP: Renewal request for mapping %s on %s failed with error %i: %s", + map.toString().c_str(), + igd_->toString().c_str(), + err, + getNatPmpErrorStr(err)); + // Notify the listener. + processMappingRequestFailed(std::move(map)); + + if (isErrorFatal(err)) { + // Fatal error, increment the counter. + incrementErrorsCounter(igd_); + } + } else { + JAMI_DBG("NAT-PMP: Renewal request for mapping %s on %s succeeded", + map.toString().c_str(), + igd_->toString().c_str()); + // Notify the listener. + processMappingRenewed(map); + } +} + +int +NatPmp::readResponse(natpmp_t& handle, natpmpresp_t& response) +{ + int err = 0; + unsigned readRetriesCounter = 0; + + while (true) { + if (readRetriesCounter++ > MAX_READ_RETRIES) { + err = NATPMP_ERR_SOCKETERROR; + break; + } + + fd_set fds; + struct timeval timeout; + FD_ZERO(&fds); + FD_SET(handle.s, &fds); + getnatpmprequesttimeout(&handle, &timeout); + // Wait for data. + if (select(FD_SETSIZE, &fds, NULL, NULL, &timeout) == -1) { + err = NATPMP_ERR_SOCKETERROR; + break; + } + + // Read the data. + err = readnatpmpresponseorretry(&handle, &response); + + if (err == NATPMP_TRYAGAIN) { + std::this_thread::sleep_for(std::chrono::milliseconds(TIMEOUT_BEFORE_READ_RETRY)); + } else { + break; + } + } + + return err; +} + +int +NatPmp::sendMappingRequest(const Mapping& mapping, uint32_t& lifetime) +{ + CHECK_VALID_THREAD(); + + int err = sendnewportmappingrequest(&natpmpHdl_, + mapping.getType() == PortType::UDP ? NATPMP_PROTOCOL_UDP + : NATPMP_PROTOCOL_TCP, + mapping.getInternalPort(), + mapping.getExternalPort(), + lifetime); + + if (err < 0) { + JAMI_ERR("NAT-PMP: Send mapping request failed with error %s %i", + getNatPmpErrorStr(err), + errno); + return err; + } + + unsigned readRetriesCounter = 0; + + while (readRetriesCounter++ < MAX_READ_RETRIES) { + // Read the response + natpmpresp_t response; + err = readResponse(natpmpHdl_, response); + + if (err < 0) { + JAMI_WARN("NAT-PMP: Read response on IGD %s failed with error %s", + igd_->toString().c_str(), + getNatPmpErrorStr(err)); + } else if (response.type != NATPMP_RESPTYPE_TCPPORTMAPPING + and response.type != NATPMP_RESPTYPE_UDPPORTMAPPING) { + JAMI_ERR("NAT-PMP: Unexpected response type (%i) for mapping %s from IGD %s.", + response.type, + mapping.toString().c_str(), + igd_->toString().c_str()); + // Try to read again. + continue; + } + + lifetime = response.pnu.newportmapping.lifetime; + // Done. + break; + } + + return err; +} + +int +NatPmp::addPortMapping(Mapping& mapping) +{ + auto const& igdIn = mapping.getIgd(); + assert(igdIn); + assert(igdIn->getProtocol() == NatProtocolType::NAT_PMP); + + if (not igdIn->isValid() or not validIgdInstance(igdIn)) { + mapping.setState(MappingState::FAILED); + return NATPMP_ERR_INVALIDARGS; + } + + mapping.setInternalAddress(getHostAddress().toString()); + + uint32_t lifetime = MAPPING_ALLOCATION_LIFETIME; + int err = sendMappingRequest(mapping, lifetime); + + if (err < 0) { + mapping.setState(MappingState::FAILED); + return err; + } + + // Set the renewal time and update. + mapping.setRenewalTime(sys_clock::now() + std::chrono::seconds(lifetime * 4 / 5)); + mapping.setState(MappingState::OPEN); + + return 0; +} + +void +NatPmp::requestMappingRemove(const Mapping& mapping) +{ + // Process on nat-pmp thread. + if (not isValidThread()) { + runOnNatPmpQueue([w = weak(), mapping] { + if (auto pmpThis = w.lock()) { + Mapping map {mapping}; + pmpThis->removePortMapping(map); + } + }); + return; + } +} + +void +NatPmp::removePortMapping(Mapping& mapping) +{ + auto igdIn = mapping.getIgd(); + assert(igdIn); + if (not igdIn->isValid()) { + return; + } + + if (not validIgdInstance(igdIn)) { + return; + } + + Mapping mapToRemove(mapping); + + uint32_t lifetime = 0; + int err = sendMappingRequest(mapping, lifetime); + + if (err < 0) { + // Nothing to do if the request fails, just log the error. + JAMI_WARN("NAT-PMP: Send remove request failed with error %s. Ignoring", + getNatPmpErrorStr(err)); + } + + // Update and notify the listener. + mapToRemove.setState(MappingState::FAILED); + processMappingRemoved(std::move(mapToRemove)); +} + +void +NatPmp::getIgdPublicAddress() +{ + CHECK_VALID_THREAD(); + + // Set the public address for this IGD if it does not + // have one already. + if (igd_->getPublicIp()) { + JAMI_WARN("NAT-PMP: IGD %s already have a public address (%s)", + igd_->toString().c_str(), + igd_->getPublicIp().toString().c_str()); + return; + } + assert(igd_->getProtocol() == NatProtocolType::NAT_PMP); + + int err = sendpublicaddressrequest(&natpmpHdl_); + + if (err < 0) { + JAMI_ERR("NAT-PMP: send public address request on IGD %s failed with error: %s", + igd_->toString().c_str(), + getNatPmpErrorStr(err)); + + if (isErrorFatal(err)) { + // Fatal error, increment the counter. + incrementErrorsCounter(igd_); + } + return; + } + + natpmpresp_t response; + err = readResponse(natpmpHdl_, response); + + if (err < 0) { + JAMI_WARN("NAT-PMP: Read response on IGD %s failed - %s", + igd_->toString().c_str(), + getNatPmpErrorStr(err)); + return; + } + + if (response.type != NATPMP_RESPTYPE_PUBLICADDRESS) { + JAMI_ERR("NAT-PMP: Unexpected response type (%i) for public address request from IGD %s.", + response.type, + igd_->toString().c_str()); + return; + } + + IpAddr publicAddr(response.pnu.publicaddress.addr); + + if (not publicAddr) { + JAMI_ERR("NAT-PMP: IGD %s returned an invalid public address %s", + igd_->toString().c_str(), + publicAddr.toString().c_str()); + } + + // Update. + igd_->setPublicIp(publicAddr); + igd_->setValid(true); + + JAMI_DBG("NAT-PMP: Setting IGD %s public address to %s", + igd_->toString().c_str(), + igd_->getPublicIp().toString().c_str()); +} + +void +NatPmp::removeAllMappings() +{ + CHECK_VALID_THREAD(); + + JAMI_WARN("NAT-PMP: Send request to close all existing mappings to IGD %s", + igd_->toString().c_str()); + + int err = sendnewportmappingrequest(&natpmpHdl_, NATPMP_PROTOCOL_TCP, 0, 0, 0); + if (err < 0) { + JAMI_WARN("NAT-PMP: Send close all TCP mappings request failed with error %s", + getNatPmpErrorStr(err)); + } + err = sendnewportmappingrequest(&natpmpHdl_, NATPMP_PROTOCOL_UDP, 0, 0, 0); + if (err < 0) { + JAMI_WARN("NAT-PMP: Send close all UDP mappings request failed with error %s", + getNatPmpErrorStr(err)); + } +} + +const char* +NatPmp::getNatPmpErrorStr(int errorCode) const +{ +#ifdef ENABLE_STRNATPMPERR + return strnatpmperr(errorCode); +#else + switch (errorCode) { + case NATPMP_ERR_INVALIDARGS: + return "INVALIDARGS"; + break; + case NATPMP_ERR_SOCKETERROR: + return "SOCKETERROR"; + break; + case NATPMP_ERR_CANNOTGETGATEWAY: + return "CANNOTGETGATEWAY"; + break; + case NATPMP_ERR_CLOSEERR: + return "CLOSEERR"; + break; + case NATPMP_ERR_RECVFROM: + return "RECVFROM"; + break; + case NATPMP_ERR_NOPENDINGREQ: + return "NOPENDINGREQ"; + break; + case NATPMP_ERR_NOGATEWAYSUPPORT: + return "NOGATEWAYSUPPORT"; + break; + case NATPMP_ERR_CONNECTERR: + return "CONNECTERR"; + break; + case NATPMP_ERR_WRONGPACKETSOURCE: + return "WRONGPACKETSOURCE"; + break; + case NATPMP_ERR_SENDERR: + return "SENDERR"; + break; + case NATPMP_ERR_FCNTLERROR: + return "FCNTLERROR"; + break; + case NATPMP_ERR_GETTIMEOFDAYERR: + return "GETTIMEOFDAYERR"; + break; + case NATPMP_ERR_UNSUPPORTEDVERSION: + return "UNSUPPORTEDVERSION"; + break; + case NATPMP_ERR_UNSUPPORTEDOPCODE: + return "UNSUPPORTEDOPCODE"; + break; + case NATPMP_ERR_UNDEFINEDERROR: + return "UNDEFINEDERROR"; + break; + case NATPMP_ERR_NOTAUTHORIZED: + return "NOTAUTHORIZED"; + break; + case NATPMP_ERR_NETWORKFAILURE: + return "NETWORKFAILURE"; + break; + case NATPMP_ERR_OUTOFRESOURCES: + return "OUTOFRESOURCES"; + break; + case NATPMP_TRYAGAIN: + return "TRYAGAIN"; + break; + default: + return "UNKNOWNERR"; + break; + } +#endif +} + +bool +NatPmp::isErrorFatal(int error) +{ + switch (error) { + case NATPMP_ERR_INVALIDARGS: + case NATPMP_ERR_SOCKETERROR: + case NATPMP_ERR_CANNOTGETGATEWAY: + case NATPMP_ERR_CLOSEERR: + case NATPMP_ERR_RECVFROM: + case NATPMP_ERR_NOGATEWAYSUPPORT: + case NATPMP_ERR_CONNECTERR: + case NATPMP_ERR_SENDERR: + case NATPMP_ERR_UNDEFINEDERROR: + case NATPMP_ERR_UNSUPPORTEDVERSION: + case NATPMP_ERR_UNSUPPORTEDOPCODE: + case NATPMP_ERR_NOTAUTHORIZED: + case NATPMP_ERR_NETWORKFAILURE: + case NATPMP_ERR_OUTOFRESOURCES: + return true; + default: + return false; + } +} + +bool +NatPmp::validIgdInstance(const std::shared_ptr<IGD>& igdIn) +{ + if (igd_.get() != igdIn.get()) { + JAMI_ERR("NAT-PMP: IGD (%s) does not match local instance (%s)", + igdIn->toString().c_str(), + igd_->toString().c_str()); + return false; + } + + return true; +} + +void +NatPmp::processIgdUpdate(UpnpIgdEvent event) +{ + if (igd_->isValid()) { + // Remove all current mappings if any. + removeAllMappings(); + } + + if (observer_ == nullptr) + return; + // Process the response on the context thread. + runOnUpnpContextQueue([obs = observer_, igd = igd_, event] { obs->onIgdUpdated(igd, event); }); +} + +void +NatPmp::processMappingAdded(const Mapping& map) +{ + if (observer_ == nullptr) + return; + + // Process the response on the context thread. + runOnUpnpContextQueue([obs = observer_, igd = igd_, map] { obs->onMappingAdded(igd, map); }); +} + +void +NatPmp::processMappingRequestFailed(const Mapping& map) +{ + if (observer_ == nullptr) + return; + + // Process the response on the context thread. + runOnUpnpContextQueue([obs = observer_, igd = igd_, map] { obs->onMappingRequestFailed(map); }); +} + +void +NatPmp::processMappingRenewed(const Mapping& map) +{ + if (observer_ == nullptr) + return; + + // Process the response on the context thread. + runOnUpnpContextQueue([obs = observer_, igd = igd_, map] { obs->onMappingRenewed(igd, map); }); +} + +void +NatPmp::processMappingRemoved(const Mapping& map) +{ + if (observer_ == nullptr) + return; + + // Process the response on the context thread. + runOnUpnpContextQueue([obs = observer_, igd = igd_, map] { obs->onMappingRemoved(igd, map); }); +} + +} // namespace upnp +} // namespace jami + +#endif //-- #if HAVE_LIBNATPMP diff --git a/src/upnp/protocol/natpmp/nat_pmp.h b/src/upnp/protocol/natpmp/nat_pmp.h new file mode 100644 index 0000000..68fd28b --- /dev/null +++ b/src/upnp/protocol/natpmp/nat_pmp.h @@ -0,0 +1,174 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "connectivity/upnp/protocol/upnp_protocol.h" +#include "connectivity/upnp/protocol/igd.h" +#include "pmp_igd.h" + +#include "logger.h" +#include "connectivity/ip_utils.h" +#include "noncopyable.h" +#include "compiler_intrinsics.h" + +// uncomment to enable native natpmp error messages +//#define ENABLE_STRNATPMPERR 1 +#include <natpmp.h> + +#include <atomic> +#include <thread> + +namespace jami { +class IpAddr; +} + +namespace jami { +namespace upnp { + +// Requested lifetime in seconds. The actual lifetime might be different. +constexpr static unsigned int MAPPING_ALLOCATION_LIFETIME {60 * 60}; +// Max number of IGD search attempts before failure. +constexpr static unsigned int MAX_RESTART_SEARCH_RETRIES {3}; +// Time-out between two successive read response. +constexpr static auto TIMEOUT_BEFORE_READ_RETRY {std::chrono::milliseconds(300)}; +// Max number of read attempts before failure. +constexpr static unsigned int MAX_READ_RETRIES {3}; +// Base unit for the timeout between two successive IGD search. +constexpr static auto NATPMP_SEARCH_RETRY_UNIT {std::chrono::seconds(10)}; + +class NatPmp : public UPnPProtocol +{ +public: + NatPmp(); + ~NatPmp(); + + // Set the observer. + void setObserver(UpnpMappingObserver* obs) override; + + // Returns the protocol type. + NatProtocolType getProtocol() const override { return NatProtocolType::NAT_PMP; } + + // Get protocol type as string. + char const* getProtocolName() const override { return "NAT-PMP"; } + + // Notifies a change in network. + void clearIgds() override; + + // Renew pmp_igd. + void searchForIgd() override; + + // Get the IGD list. + std::list<std::shared_ptr<IGD>> getIgdList() const override; + + // Return true if it has at least one valid IGD. + bool isReady() const override; + + // Request a new mapping. + void requestMappingAdd(const Mapping& mapping) override; + + // Renew an allocated mapping. + void requestMappingRenew(const Mapping& mapping) override; + + // Removes a mapping. + void requestMappingRemove(const Mapping& mapping) override; + + // Get the host (local) address. + const IpAddr getHostAddress() const override; + + // Terminate. Nothing to do here, the clean-up is done when + // the IGD is cleared. + void terminate() override; + +private: + NON_COPYABLE(NatPmp); + + std::weak_ptr<NatPmp> weak() { return std::static_pointer_cast<NatPmp>(shared_from_this()); } + + // Helpers to run tasks on NAT-PMP internal execution queue. + ScheduledExecutor* getNatpmpScheduler() { return &natpmpScheduler_; } + template<typename Callback> + void runOnNatPmpQueue(Callback&& cb) + { + natpmpScheduler_.run([cb = std::forward<Callback>(cb)]() mutable { cb(); }); + } + + // Helpers to run tasks on UPNP context execution queue. + ScheduledExecutor* getUpnContextScheduler() { return UpnpThreadUtil::getScheduler(); } + + void terminate(std::condition_variable& cv); + + void initNatPmp(); + void getIgdPublicAddress(); + void removeAllMappings(); + int readResponse(natpmp_t& handle, natpmpresp_t& response); + int sendMappingRequest(const Mapping& mapping, uint32_t& lifetime); + + // Adds a port mapping. + int addPortMapping(Mapping& mapping); + // Removes a port mapping. + void removePortMapping(Mapping& mapping); + + // True if the error is fatal. + bool isErrorFatal(int error); + // Gets NAT-PMP error code string. + const char* getNatPmpErrorStr(int errorCode) const; + // Get local getaway. + std::unique_ptr<IpAddr> getLocalGateway() const; + + // Helpers to process user's callbacks + void processIgdUpdate(UpnpIgdEvent event); + void processMappingAdded(const Mapping& map); + void processMappingRequestFailed(const Mapping& map); + void processMappingRenewed(const Mapping& map); + void processMappingRemoved(const Mapping& map); + + // Check if the IGD has a local match + bool validIgdInstance(const std::shared_ptr<IGD>& igdIn); + + // Increment errors counter. + void incrementErrorsCounter(const std::shared_ptr<IGD>& igd); + + std::atomic_bool initialized_ {false}; + + // Data members + std::shared_ptr<PMPIGD> igd_; + natpmp_t natpmpHdl_; + ScheduledExecutor natpmpScheduler_ {"natpmp"}; + std::shared_ptr<Task> searchForIgdTimer_ {}; + unsigned int igdSearchCounter_ {0}; + UpnpMappingObserver* observer_ {nullptr}; + IpAddr hostAddress_ {}; + + // Calls from other threads that does not need synchronous access are + // rescheduled on the NatPmp private queue. This will avoid the need to + // protect most of the data members of this class. + // For some internal members (such as the igd instance and the host + // address) that need to be synchronously accessed, are protected by + // this mutex. + mutable std::mutex natpmpMutex_; + + // Shutdown synchronization + bool shutdownComplete_ {false}; +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/natpmp/pmp_igd.cpp b/src/upnp/protocol/natpmp/pmp_igd.cpp new file mode 100644 index 0000000..ac8b698 --- /dev/null +++ b/src/upnp/protocol/natpmp/pmp_igd.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "pmp_igd.h" + +#include <algorithm> + +namespace jami { +namespace upnp { + +PMPIGD::PMPIGD() + : IGD(NatProtocolType::NAT_PMP) +{} + +PMPIGD::PMPIGD(const PMPIGD& other) + : PMPIGD() +{ + assert(protocol_ == NatProtocolType::NAT_PMP); + // protocol_ = other.protocol_; + localIp_ = other.localIp_; + publicIp_ = other.publicIp_; + uid_ = other.uid_; +} + +bool +PMPIGD::operator==(IGD& other) const +{ + return getPublicIp() == other.getPublicIp() and getLocalIp() == other.getLocalIp(); +} + +bool +PMPIGD::operator==(PMPIGD& other) const +{ + return getPublicIp() == other.getPublicIp() and getLocalIp() == other.getLocalIp(); +} + +const std::string +PMPIGD::toString() const +{ + return getLocalIp().toString(); +} + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/natpmp/pmp_igd.h b/src/upnp/protocol/natpmp/pmp_igd.h new file mode 100644 index 0000000..a70e7ee --- /dev/null +++ b/src/upnp/protocol/natpmp/pmp_igd.h @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include "../igd.h" +#include "noncopyable.h" +#include "connectivity/ip_utils.h" + +#include <map> +#include <atomic> +#include <string> +#include <chrono> +#include <functional> + +namespace jami { +namespace upnp { + +class PMPIGD : public IGD +{ +public: + PMPIGD(); + PMPIGD(const PMPIGD&); + ~PMPIGD() = default; + + PMPIGD& operator=(PMPIGD&& other) = delete; + PMPIGD& operator=(PMPIGD& other) = delete; + + bool operator==(IGD& other) const; + bool operator==(PMPIGD& other) const; + + const std::string toString() const override; +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/pupnp/pupnp.cpp b/src/upnp/protocol/pupnp/pupnp.cpp new file mode 100644 index 0000000..cc63347 --- /dev/null +++ b/src/upnp/protocol/pupnp/pupnp.cpp @@ -0,0 +1,1599 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "pupnp.h" + +#include <opendht/thread_pool.h> +#include <opendht/http.h> + +namespace jami { +namespace upnp { + +// Action identifiers. +constexpr static const char* ACTION_ADD_PORT_MAPPING {"AddPortMapping"}; +constexpr static const char* ACTION_DELETE_PORT_MAPPING {"DeletePortMapping"}; +constexpr static const char* ACTION_GET_GENERIC_PORT_MAPPING_ENTRY {"GetGenericPortMappingEntry"}; +constexpr static const char* ACTION_GET_STATUS_INFO {"GetStatusInfo"}; +constexpr static const char* ACTION_GET_EXTERNAL_IP_ADDRESS {"GetExternalIPAddress"}; + +// Error codes returned by router when trying to remove ports. +constexpr static int ARRAY_IDX_INVALID = 713; +constexpr static int CONFLICT_IN_MAPPING = 718; + +// Max number of IGD search attempts before failure. +constexpr static unsigned int PUPNP_MAX_RESTART_SEARCH_RETRIES {3}; +// IGD search timeout (in seconds). +constexpr static unsigned int SEARCH_TIMEOUT {60}; +// Base unit for the timeout between two successive IGD search. +constexpr static auto PUPNP_SEARCH_RETRY_UNIT {std::chrono::seconds(10)}; + +// Helper functions for xml parsing. +static std::string_view +getElementText(IXML_Node* node) +{ + if (node) { + if (IXML_Node* textNode = ixmlNode_getFirstChild(node)) + if (const char* value = ixmlNode_getNodeValue(textNode)) + return std::string_view(value); + } + return {}; +} + +static std::string_view +getFirstDocItem(IXML_Document* doc, const char* item) +{ + std::unique_ptr<IXML_NodeList, decltype(ixmlNodeList_free)&> + nodeList(ixmlDocument_getElementsByTagName(doc, item), ixmlNodeList_free); + if (nodeList) { + // If there are several nodes which match the tag, we only want the first one. + return getElementText(ixmlNodeList_item(nodeList.get(), 0)); + } + return {}; +} + +static std::string_view +getFirstElementItem(IXML_Element* element, const char* item) +{ + std::unique_ptr<IXML_NodeList, decltype(ixmlNodeList_free)&> + nodeList(ixmlElement_getElementsByTagName(element, item), ixmlNodeList_free); + if (nodeList) { + // If there are several nodes which match the tag, we only want the first one. + return getElementText(ixmlNodeList_item(nodeList.get(), 0)); + } + return {}; +} + +static bool +errorOnResponse(IXML_Document* doc) +{ + if (not doc) + return true; + + auto errorCode = getFirstDocItem(doc, "errorCode"); + if (not errorCode.empty()) { + auto errorDescription = getFirstDocItem(doc, "errorDescription"); + JAMI_WARNING("PUPnP: Response contains error: {:s}: {:s}", + errorCode, + errorDescription); + return true; + } + return false; +} + +// UPNP class implementation + +PUPnP::PUPnP() +{ + JAMI_DBG("PUPnP: Creating instance [%p] ...", this); + runOnPUPnPQueue([this] { + threadId_ = getCurrentThread(); + JAMI_DBG("PUPnP: Instance [%p] created", this); + }); +} + +PUPnP::~PUPnP() +{ + JAMI_DBG("PUPnP: Instance [%p] destroyed", this); +} + +void +PUPnP::initUpnpLib() +{ + assert(not initialized_); + + int upnp_err = UpnpInit2(nullptr, 0); + + if (upnp_err != UPNP_E_SUCCESS) { + JAMI_ERR("PUPnP: Can't initialize libupnp: %s", UpnpGetErrorMessage(upnp_err)); + UpnpFinish(); + initialized_ = false; + return; + } + + // Disable embedded WebServer if any. + if (UpnpIsWebserverEnabled() == 1) { + JAMI_WARN("PUPnP: Web-server is enabled. Disabling"); + UpnpEnableWebserver(0); + if (UpnpIsWebserverEnabled() == 1) { + JAMI_ERR("PUPnP: Could not disable Web-server!"); + } else { + JAMI_DBG("PUPnP: Web-server successfully disabled"); + } + } + + char* ip_address = UpnpGetServerIpAddress(); + char* ip_address6 = nullptr; + unsigned short port = UpnpGetServerPort(); + unsigned short port6 = 0; +#if UPNP_ENABLE_IPV6 + ip_address6 = UpnpGetServerIp6Address(); + port6 = UpnpGetServerPort6(); +#endif + if (ip_address6 and port6) + JAMI_DBG("PUPnP: Initialized on %s:%u | %s:%u", ip_address, port, ip_address6, port6); + else + JAMI_DBG("PUPnP: Initialized on %s:%u", ip_address, port); + + // Relax the parser to allow malformed XML text. + ixmlRelaxParser(1); + + initialized_ = true; +} + +bool +PUPnP::isRunning() const +{ + std::unique_lock<std::mutex> lk(pupnpMutex_); + return not shutdownComplete_; +} + +void +PUPnP::registerClient() +{ + assert(not clientRegistered_); + + CHECK_VALID_THREAD(); + + // Register Upnp control point. + int upnp_err = UpnpRegisterClient(ctrlPtCallback, this, &ctrlptHandle_); + if (upnp_err != UPNP_E_SUCCESS) { + JAMI_ERR("PUPnP: Can't register client: %s", UpnpGetErrorMessage(upnp_err)); + } else { + JAMI_DBG("PUPnP: Successfully registered client"); + clientRegistered_ = true; + } +} + +void +PUPnP::setObserver(UpnpMappingObserver* obs) +{ + if (not isValidThread()) { + runOnPUPnPQueue([w = weak(), obs] { + if (auto upnpThis = w.lock()) { + upnpThis->setObserver(obs); + } + }); + return; + } + + JAMI_DBG("PUPnP: Setting observer to %p", obs); + + observer_ = obs; +} + +const IpAddr +PUPnP::getHostAddress() const +{ + std::lock_guard<std::mutex> lock(pupnpMutex_); + return hostAddress_; +} + +void +PUPnP::terminate(std::condition_variable& cv) +{ + JAMI_DBG("PUPnP: Terminate instance %p", this); + + clientRegistered_ = false; + observer_ = nullptr; + + UpnpUnRegisterClient(ctrlptHandle_); + + if (initialized_) { + if (UpnpFinish() != UPNP_E_SUCCESS) { + JAMI_ERR("PUPnP: Failed to properly close lib-upnp"); + } + + initialized_ = false; + } + + // Clear all the lists. + discoveredIgdList_.clear(); + + { + std::lock_guard<std::mutex> lock(pupnpMutex_); + validIgdList_.clear(); + shutdownComplete_ = true; + cv.notify_one(); + } +} + +void +PUPnP::terminate() +{ + std::unique_lock<std::mutex> lk(pupnpMutex_); + std::condition_variable cv {}; + + runOnPUPnPQueue([w = weak(), &cv = cv] { + if (auto upnpThis = w.lock()) { + upnpThis->terminate(cv); + } + }); + + if (cv.wait_for(lk, std::chrono::seconds(10), [this] { return shutdownComplete_; })) { + JAMI_DBG("PUPnP: Shutdown completed"); + } else { + JAMI_ERR("PUPnP: Shutdown timed-out"); + // Force stop if the shutdown take too much time. + shutdownComplete_ = true; + } +} + +void +PUPnP::searchForDevices() +{ + CHECK_VALID_THREAD(); + + JAMI_DBG("PUPnP: Send IGD search request"); + + // Send out search for multiple types of devices, as some routers may possibly + // only reply to one. + + auto err = UpnpSearchAsync(ctrlptHandle_, SEARCH_TIMEOUT, UPNP_ROOT_DEVICE, this); + if (err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Send search for UPNP_ROOT_DEVICE failed. Error %d: %s", + err, + UpnpGetErrorMessage(err)); + } + + err = UpnpSearchAsync(ctrlptHandle_, SEARCH_TIMEOUT, UPNP_IGD_DEVICE, this); + if (err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Send search for UPNP_IGD_DEVICE failed. Error %d: %s", + err, + UpnpGetErrorMessage(err)); + } + + err = UpnpSearchAsync(ctrlptHandle_, SEARCH_TIMEOUT, UPNP_WANIP_SERVICE, this); + if (err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Send search for UPNP_WANIP_SERVICE failed. Error %d: %s", + err, + UpnpGetErrorMessage(err)); + } + + err = UpnpSearchAsync(ctrlptHandle_, SEARCH_TIMEOUT, UPNP_WANPPP_SERVICE, this); + if (err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Send search for UPNP_WANPPP_SERVICE failed. Error %d: %s", + err, + UpnpGetErrorMessage(err)); + } +} + +void +PUPnP::clearIgds() +{ + if (not isValidThread()) { + runOnPUPnPQueue([w = weak()] { + if (auto upnpThis = w.lock()) { + upnpThis->clearIgds(); + } + }); + return; + } + + JAMI_DBG("PUPnP: clearing IGDs and devices lists"); + + if (searchForIgdTimer_) + searchForIgdTimer_->cancel(); + + igdSearchCounter_ = 0; + + { + std::lock_guard<std::mutex> lock(pupnpMutex_); + for (auto const& igd : validIgdList_) { + igd->setValid(false); + } + validIgdList_.clear(); + hostAddress_ = {}; + } + + discoveredIgdList_.clear(); +} + +void +PUPnP::searchForIgd() +{ + if (not isValidThread()) { + runOnPUPnPQueue([w = weak()] { + if (auto upnpThis = w.lock()) { + upnpThis->searchForIgd(); + } + }); + return; + } + + // Update local address before searching. + updateHostAddress(); + + if (isReady()) { + JAMI_DBG("PUPnP: Already have a valid IGD. Skip the search request"); + return; + } + + if (igdSearchCounter_++ >= PUPNP_MAX_RESTART_SEARCH_RETRIES) { + JAMI_WARN("PUPnP: Setup failed after %u trials. PUPnP will be disabled!", + PUPNP_MAX_RESTART_SEARCH_RETRIES); + return; + } + + JAMI_DBG("PUPnP: Start search for IGD: attempt %u", igdSearchCounter_); + + // Do not init if the host is not valid. Otherwise, the init will fail + // anyway and may put libupnp in an unstable state (mainly deadlocks) + // even if the UpnpFinish() method is called. + if (not hasValidHostAddress()) { + JAMI_WARN("PUPnP: Host address is invalid. Skipping the IGD search"); + } else { + // Init and register if needed + if (not initialized_) { + initUpnpLib(); + } + if (initialized_ and not clientRegistered_) { + registerClient(); + } + // Start searching + if (clientRegistered_) { + assert(initialized_); + searchForDevices(); + } else { + JAMI_WARN("PUPnP: PUPNP not fully setup. Skipping the IGD search"); + } + } + + // Cancel the current timer (if any) and re-schedule. + // The connectivity change may be received while the the local + // interface is not fully setup. The rescheduling typically + // usefull to mitigate this race. + if (searchForIgdTimer_) + searchForIgdTimer_->cancel(); + + searchForIgdTimer_ = getUpnContextScheduler()->scheduleIn( + [w = weak()] { + if (auto upnpThis = w.lock()) + upnpThis->searchForIgd(); + }, + PUPNP_SEARCH_RETRY_UNIT * igdSearchCounter_); +} + +std::list<std::shared_ptr<IGD>> +PUPnP::getIgdList() const +{ + std::lock_guard<std::mutex> lock(pupnpMutex_); + std::list<std::shared_ptr<IGD>> igdList; + for (auto& it : validIgdList_) { + // Return only active IGDs. + if (it->isValid()) { + igdList.emplace_back(it); + } + } + return igdList; +} + +bool +PUPnP::isReady() const +{ + // Must at least have a valid local address. + if (not getHostAddress() or getHostAddress().isLoopback()) + return false; + + return hasValidIgd(); +} + +bool +PUPnP::hasValidIgd() const +{ + std::lock_guard<std::mutex> lock(pupnpMutex_); + for (auto& it : validIgdList_) { + if (it->isValid()) { + return true; + } + } + return false; +} + +void +PUPnP::updateHostAddress() +{ + std::lock_guard<std::mutex> lock(pupnpMutex_); + hostAddress_ = ip_utils::getLocalAddr(AF_INET); +} + +bool +PUPnP::hasValidHostAddress() +{ + std::lock_guard<std::mutex> lock(pupnpMutex_); + return hostAddress_ and not hostAddress_.isLoopback(); +} + +void +PUPnP::incrementErrorsCounter(const std::shared_ptr<IGD>& igd) +{ + if (not igd or not igd->isValid()) + return; + if (not igd->incrementErrorsCounter()) { + // Disable this IGD. + igd->setValid(false); + // Notify the listener. + if (observer_) + observer_->onIgdUpdated(igd, UpnpIgdEvent::INVALID_STATE); + } +} + +bool +PUPnP::validateIgd(const std::string& location, IXML_Document* doc_container_ptr) +{ + CHECK_VALID_THREAD(); + + assert(doc_container_ptr != nullptr); + + XMLDocument document(doc_container_ptr, ixmlDocument_free); + auto descDoc = document.get(); + // Check device type. + auto deviceType = getFirstDocItem(descDoc, "deviceType"); + if (deviceType != UPNP_IGD_DEVICE) { + // Device type not IGD. + return false; + } + + std::shared_ptr<UPnPIGD> igd_candidate = parseIgd(descDoc, location); + if (not igd_candidate) { + // No valid IGD candidate. + return false; + } + + JAMI_DBG("PUPnP: Validating the IGD candidate [UDN: %s]\n" + " Name : %s\n" + " Service Type : %s\n" + " Service ID : %s\n" + " Base URL : %s\n" + " Location URL : %s\n" + " control URL : %s\n" + " Event URL : %s", + igd_candidate->getUID().c_str(), + igd_candidate->getFriendlyName().c_str(), + igd_candidate->getServiceType().c_str(), + igd_candidate->getServiceId().c_str(), + igd_candidate->getBaseURL().c_str(), + igd_candidate->getLocationURL().c_str(), + igd_candidate->getControlURL().c_str(), + igd_candidate->getEventSubURL().c_str()); + + // Check if IGD is connected. + if (not actionIsIgdConnected(*igd_candidate)) { + JAMI_WARN("PUPnP: IGD candidate %s is not connected", igd_candidate->getUID().c_str()); + return false; + } + + // Validate external Ip. + igd_candidate->setPublicIp(actionGetExternalIP(*igd_candidate)); + if (igd_candidate->getPublicIp().toString().empty()) { + JAMI_WARN("PUPnP: IGD candidate %s has no valid external Ip", + igd_candidate->getUID().c_str()); + return false; + } + + // Validate internal Ip. + if (igd_candidate->getBaseURL().empty()) { + JAMI_WARN("PUPnP: IGD candidate %s has no valid internal Ip", + igd_candidate->getUID().c_str()); + return false; + } + + // Typically the IGD local address should be extracted from the XML + // document (e.g. parsing the base URL). For simplicity, we assume + // that it matches the gateway as seen by the local interface. + if (const auto& localGw = ip_utils::getLocalGateway()) { + igd_candidate->setLocalIp(localGw); + } else { + JAMI_WARN("PUPnP: Could not set internal address for IGD candidate %s", + igd_candidate->getUID().c_str()); + return false; + } + + // Store info for subscription. + std::string eventSub = igd_candidate->getEventSubURL(); + + { + // Add the IGD if not already present in the list. + std::lock_guard<std::mutex> lock(pupnpMutex_); + for (auto& igd : validIgdList_) { + // Must not be a null pointer + assert(igd.get() != nullptr); + if (*igd == *igd_candidate) { + JAMI_DBG("PUPnP: Device [%s] with int/ext addresses [%s:%s] is already in the list " + "of valid IGDs", + igd_candidate->getUID().c_str(), + igd_candidate->toString().c_str(), + igd_candidate->getPublicIp().toString().c_str()); + return true; + } + } + } + + // We have a valid IGD + igd_candidate->setValid(true); + + JAMI_DBG("PUPnP: Added a new IGD [%s] to the list of valid IGDs", + igd_candidate->getUID().c_str()); + + JAMI_DBG("PUPnP: New IGD addresses [int: %s - ext: %s]", + igd_candidate->toString().c_str(), + igd_candidate->getPublicIp().toString().c_str()); + + // Subscribe to IGD events. + int upnp_err = UpnpSubscribeAsync(ctrlptHandle_, + eventSub.c_str(), + UPNP_INFINITE, + subEventCallback, + this); + if (upnp_err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Failed to send subscribe request to %s: error %i - %s", + igd_candidate->getUID().c_str(), + upnp_err, + UpnpGetErrorMessage(upnp_err)); + // return false; + } else { + JAMI_DBG("PUPnP: Successfully subscribed to IGD %s", igd_candidate->getUID().c_str()); + } + + { + // This is a new (and hopefully valid) IGD. + std::lock_guard<std::mutex> lock(pupnpMutex_); + validIgdList_.emplace_back(igd_candidate); + } + + // Report to the listener. + runOnUpnpContextQueue([w = weak(), igd_candidate] { + if (auto upnpThis = w.lock()) { + if (upnpThis->observer_) + upnpThis->observer_->onIgdUpdated(igd_candidate, UpnpIgdEvent::ADDED); + } + }); + + return true; +} + +void +PUPnP::requestMappingAdd(const Mapping& mapping) +{ + runOnPUPnPQueue([w = weak(), mapping] { + if (auto upnpThis = w.lock()) { + if (not upnpThis->isRunning()) + return; + Mapping mapRes(mapping); + if (upnpThis->actionAddPortMapping(mapRes)) { + mapRes.setState(MappingState::OPEN); + mapRes.setInternalAddress(upnpThis->getHostAddress().toString()); + upnpThis->processAddMapAction(mapRes); + } else { + upnpThis->incrementErrorsCounter(mapRes.getIgd()); + mapRes.setState(MappingState::FAILED); + upnpThis->processRequestMappingFailure(mapRes); + } + } + }); +} + +void +PUPnP::requestMappingRemove(const Mapping& mapping) +{ + // Send remove request using the matching IGD + runOnPUPnPQueue([w = weak(), mapping] { + if (auto upnpThis = w.lock()) { + // Abort if we are shutting down. + if (not upnpThis->isRunning()) + return; + if (upnpThis->actionDeletePortMapping(mapping)) { + upnpThis->processRemoveMapAction(mapping); + } else { + assert(mapping.getIgd()); + // Dont need to report in case of failure. + upnpThis->incrementErrorsCounter(mapping.getIgd()); + } + } + }); +} + +std::shared_ptr<UPnPIGD> +PUPnP::findMatchingIgd(const std::string& ctrlURL) const +{ + std::lock_guard<std::mutex> lock(pupnpMutex_); + + auto iter = std::find_if(validIgdList_.begin(), + validIgdList_.end(), + [&ctrlURL](const std::shared_ptr<IGD>& igd) { + if (auto upnpIgd = std::dynamic_pointer_cast<UPnPIGD>(igd)) { + return upnpIgd->getControlURL() == ctrlURL; + } + return false; + }); + + if (iter == validIgdList_.end()) { + JAMI_WARN("PUPnP: Did not find the IGD matching ctrl URL [%s]", ctrlURL.c_str()); + return {}; + } + + return std::dynamic_pointer_cast<UPnPIGD>(*iter); +} + +void +PUPnP::processAddMapAction(const Mapping& map) +{ + CHECK_VALID_THREAD(); + + if (observer_ == nullptr) + return; + + runOnUpnpContextQueue([w = weak(), map] { + if (auto upnpThis = w.lock()) { + if (upnpThis->observer_) + upnpThis->observer_->onMappingAdded(map.getIgd(), std::move(map)); + } + }); +} + +void +PUPnP::processRequestMappingFailure(const Mapping& map) +{ + CHECK_VALID_THREAD(); + + if (observer_ == nullptr) + return; + + runOnUpnpContextQueue([w = weak(), map] { + if (auto upnpThis = w.lock()) { + JAMI_DBG("PUPnP: Failed to request mapping %s", map.toString().c_str()); + if (upnpThis->observer_) + upnpThis->observer_->onMappingRequestFailed(map); + } + }); +} + +void +PUPnP::processRemoveMapAction(const Mapping& map) +{ + CHECK_VALID_THREAD(); + + if (observer_ == nullptr) + return; + + runOnUpnpContextQueue([map, obs = observer_] { + JAMI_DBG("PUPnP: Closed mapping %s", map.toString().c_str()); + obs->onMappingRemoved(map.getIgd(), std::move(map)); + }); +} + +const char* +PUPnP::eventTypeToString(Upnp_EventType eventType) +{ + switch (eventType) { + case UPNP_CONTROL_ACTION_REQUEST: + return "UPNP_CONTROL_ACTION_REQUEST"; + case UPNP_CONTROL_ACTION_COMPLETE: + return "UPNP_CONTROL_ACTION_COMPLETE"; + case UPNP_CONTROL_GET_VAR_REQUEST: + return "UPNP_CONTROL_GET_VAR_REQUEST"; + case UPNP_CONTROL_GET_VAR_COMPLETE: + return "UPNP_CONTROL_GET_VAR_COMPLETE"; + case UPNP_DISCOVERY_ADVERTISEMENT_ALIVE: + return "UPNP_DISCOVERY_ADVERTISEMENT_ALIVE"; + case UPNP_DISCOVERY_ADVERTISEMENT_BYEBYE: + return "UPNP_DISCOVERY_ADVERTISEMENT_BYEBYE"; + case UPNP_DISCOVERY_SEARCH_RESULT: + return "UPNP_DISCOVERY_SEARCH_RESULT"; + case UPNP_DISCOVERY_SEARCH_TIMEOUT: + return "UPNP_DISCOVERY_SEARCH_TIMEOUT"; + case UPNP_EVENT_SUBSCRIPTION_REQUEST: + return "UPNP_EVENT_SUBSCRIPTION_REQUEST"; + case UPNP_EVENT_RECEIVED: + return "UPNP_EVENT_RECEIVED"; + case UPNP_EVENT_RENEWAL_COMPLETE: + return "UPNP_EVENT_RENEWAL_COMPLETE"; + case UPNP_EVENT_SUBSCRIBE_COMPLETE: + return "UPNP_EVENT_SUBSCRIBE_COMPLETE"; + case UPNP_EVENT_UNSUBSCRIBE_COMPLETE: + return "UPNP_EVENT_UNSUBSCRIBE_COMPLETE"; + case UPNP_EVENT_AUTORENEWAL_FAILED: + return "UPNP_EVENT_AUTORENEWAL_FAILED"; + case UPNP_EVENT_SUBSCRIPTION_EXPIRED: + return "UPNP_EVENT_SUBSCRIPTION_EXPIRED"; + default: + return "Unknown UPNP Event"; + } +} + +int +PUPnP::ctrlPtCallback(Upnp_EventType event_type, const void* event, void* user_data) +{ + auto pupnp = static_cast<PUPnP*>(user_data); + + if (pupnp == nullptr) { + JAMI_WARN("PUPnP: Control point callback without PUPnP"); + return UPNP_E_SUCCESS; + } + + auto upnpThis = pupnp->weak().lock(); + + if (not upnpThis) + return UPNP_E_SUCCESS; + + // Ignore if already unregistered. + if (not upnpThis->clientRegistered_) + return UPNP_E_SUCCESS; + + // Process the callback. + return upnpThis->handleCtrlPtUPnPEvents(event_type, event); +} + +PUPnP::CtrlAction +PUPnP::getAction(const char* xmlNode) +{ + if (strstr(xmlNode, ACTION_ADD_PORT_MAPPING)) { + return CtrlAction::ADD_PORT_MAPPING; + } else if (strstr(xmlNode, ACTION_DELETE_PORT_MAPPING)) { + return CtrlAction::DELETE_PORT_MAPPING; + } else if (strstr(xmlNode, ACTION_GET_GENERIC_PORT_MAPPING_ENTRY)) { + return CtrlAction::GET_GENERIC_PORT_MAPPING_ENTRY; + } else if (strstr(xmlNode, ACTION_GET_STATUS_INFO)) { + return CtrlAction::GET_STATUS_INFO; + } else if (strstr(xmlNode, ACTION_GET_EXTERNAL_IP_ADDRESS)) { + return CtrlAction::GET_EXTERNAL_IP_ADDRESS; + } else { + return CtrlAction::UNKNOWN; + } +} + +void +PUPnP::processDiscoverySearchResult(const std::string& cpDeviceId, + const std::string& igdLocationUrl, + const IpAddr& dstAddr) +{ + CHECK_VALID_THREAD(); + + // Update host address if needed. + if (not hasValidHostAddress()) + updateHostAddress(); + + // The host address must be valid to proceed. + if (not hasValidHostAddress()) { + JAMI_WARN("PUPnP: Local address is invalid. Ignore search result for now!"); + return; + } + + // Use the device ID and the URL as ID. This is necessary as some + // IGDs may have the same device ID but different URLs. + + auto igdId = cpDeviceId + " url: " + igdLocationUrl; + + if (not discoveredIgdList_.emplace(igdId).second) { + // JAMI_WARN("PUPnP: IGD [%s] already in the list", igdId.c_str()); + return; + } + + JAMI_DBG("PUPnP: Discovered a new IGD [%s]", igdId.c_str()); + + // NOTE: here, we check if the location given is related to the source address. + // If it's not the case, it's certainly a router plugged in the network, but not + // related to this network. So the given location will be unreachable and this + // will cause some timeout. + + // Only check the IP address (ignore the port number). + dht::http::Url url(igdLocationUrl); + if (IpAddr(url.host).toString(false) != dstAddr.toString(false)) { + JAMI_DBG("PUPnP: Returned location %s does not match the source address %s", + IpAddr(url.host).toString(true, true).c_str(), + dstAddr.toString(true, true).c_str()); + return; + } + + // Run a separate thread to prevent blocking this thread + // if the IGD HTTP server is not responsive. + dht::ThreadPool::io().run([w = weak(), igdLocationUrl] { + if (auto upnpThis = w.lock()) { + upnpThis->downLoadIgdDescription(igdLocationUrl); + } + }); +} + +void +PUPnP::downLoadIgdDescription(const std::string& locationUrl) +{ + IXML_Document* doc_container_ptr = nullptr; + int upnp_err = UpnpDownloadXmlDoc(locationUrl.c_str(), &doc_container_ptr); + + if (upnp_err != UPNP_E_SUCCESS or not doc_container_ptr) { + JAMI_WARN("PUPnP: Error downloading device XML document from %s -> %s", + locationUrl.c_str(), + UpnpGetErrorMessage(upnp_err)); + } else { + JAMI_DBG("PUPnP: Succeeded to download device XML document from %s", locationUrl.c_str()); + runOnPUPnPQueue([w = weak(), url = locationUrl, doc_container_ptr] { + if (auto upnpThis = w.lock()) { + upnpThis->validateIgd(url, doc_container_ptr); + } + }); + } +} + +void +PUPnP::processDiscoveryAdvertisementByebye(const std::string& cpDeviceId) +{ + CHECK_VALID_THREAD(); + + discoveredIgdList_.erase(cpDeviceId); + + std::shared_ptr<IGD> igd; + { + std::lock_guard<std::mutex> lk(pupnpMutex_); + for (auto it = validIgdList_.begin(); it != validIgdList_.end();) { + if ((*it)->getUID() == cpDeviceId) { + igd = *it; + JAMI_DBG("PUPnP: Received [%s] for IGD [%s] %s. Will be removed.", + PUPnP::eventTypeToString(UPNP_DISCOVERY_ADVERTISEMENT_BYEBYE), + igd->getUID().c_str(), + igd->toString().c_str()); + igd->setValid(false); + // Remove the IGD. + it = validIgdList_.erase(it); + break; + } else { + it++; + } + } + } + + // Notify the listener. + if (observer_ and igd) { + observer_->onIgdUpdated(igd, UpnpIgdEvent::REMOVED); + } +} + +void +PUPnP::processDiscoverySubscriptionExpired(Upnp_EventType event_type, const std::string& eventSubUrl) +{ + CHECK_VALID_THREAD(); + + std::lock_guard<std::mutex> lk(pupnpMutex_); + for (auto& it : validIgdList_) { + if (auto igd = std::dynamic_pointer_cast<UPnPIGD>(it)) { + if (igd->getEventSubURL() == eventSubUrl) { + JAMI_DBG("PUPnP: Received [%s] event for IGD [%s] %s. Request a new subscribe.", + PUPnP::eventTypeToString(event_type), + igd->getUID().c_str(), + igd->toString().c_str()); + UpnpSubscribeAsync(ctrlptHandle_, + eventSubUrl.c_str(), + UPNP_INFINITE, + subEventCallback, + this); + break; + } + } + } +} + +int +PUPnP::handleCtrlPtUPnPEvents(Upnp_EventType event_type, const void* event) +{ + switch (event_type) { + // "ALIVE" events are processed as "SEARCH RESULT". It might be usefull + // if "SEARCH RESULT" was missed. + case UPNP_DISCOVERY_ADVERTISEMENT_ALIVE: + case UPNP_DISCOVERY_SEARCH_RESULT: { + const UpnpDiscovery* d_event = (const UpnpDiscovery*) event; + + // First check the error code. + auto upnp_status = UpnpDiscovery_get_ErrCode(d_event); + if (upnp_status != UPNP_E_SUCCESS) { + JAMI_ERR("PUPnP: UPNP discovery is in erroneous state: %s", + UpnpGetErrorMessage(upnp_status)); + break; + } + + // Parse the event's data. + std::string deviceId {UpnpDiscovery_get_DeviceID_cstr(d_event)}; + std::string location {UpnpDiscovery_get_Location_cstr(d_event)}; + IpAddr dstAddr(*(const pj_sockaddr*) (UpnpDiscovery_get_DestAddr(d_event))); + runOnPUPnPQueue([w = weak(), + deviceId = std::move(deviceId), + location = std::move(location), + dstAddr = std::move(dstAddr)] { + if (auto upnpThis = w.lock()) { + upnpThis->processDiscoverySearchResult(deviceId, location, dstAddr); + } + }); + break; + } + case UPNP_DISCOVERY_ADVERTISEMENT_BYEBYE: { + const UpnpDiscovery* d_event = (const UpnpDiscovery*) event; + + std::string deviceId(UpnpDiscovery_get_DeviceID_cstr(d_event)); + + // Process the response on the main thread. + runOnPUPnPQueue([w = weak(), deviceId = std::move(deviceId)] { + if (auto upnpThis = w.lock()) { + upnpThis->processDiscoveryAdvertisementByebye(deviceId); + } + }); + break; + } + case UPNP_DISCOVERY_SEARCH_TIMEOUT: { + // Even if the discovery search is successful, it's normal to receive + // time-out events. This because we send search requests using various + // device types, which some of them may not return a response. + break; + } + case UPNP_EVENT_RECEIVED: { + // Nothing to do. + break; + } + // Treat failed autorenewal like an expired subscription. + case UPNP_EVENT_AUTORENEWAL_FAILED: + case UPNP_EVENT_SUBSCRIPTION_EXPIRED: // This event will occur only if autorenewal is disabled. + { + JAMI_WARN("PUPnP: Received Subscription Event %s", eventTypeToString(event_type)); + const UpnpEventSubscribe* es_event = (const UpnpEventSubscribe*) event; + if (es_event == nullptr) { + JAMI_WARN("PUPnP: Received Subscription Event with null pointer"); + break; + } + std::string publisherUrl(UpnpEventSubscribe_get_PublisherUrl_cstr(es_event)); + + // Process the response on the main thread. + runOnPUPnPQueue([w = weak(), event_type, publisherUrl = std::move(publisherUrl)] { + if (auto upnpThis = w.lock()) { + upnpThis->processDiscoverySubscriptionExpired(event_type, publisherUrl); + } + }); + break; + } + case UPNP_EVENT_SUBSCRIBE_COMPLETE: + case UPNP_EVENT_UNSUBSCRIBE_COMPLETE: { + UpnpEventSubscribe* es_event = (UpnpEventSubscribe*) event; + if (es_event == nullptr) { + JAMI_WARN("PUPnP: Received Subscription Event with null pointer"); + } else { + UpnpEventSubscribe_delete(es_event); + } + break; + } + case UPNP_CONTROL_ACTION_COMPLETE: { + const UpnpActionComplete* a_event = (const UpnpActionComplete*) event; + if (a_event == nullptr) { + JAMI_WARN("PUPnP: Received Action Complete Event with null pointer"); + break; + } + auto res = UpnpActionComplete_get_ErrCode(a_event); + if (res != UPNP_E_SUCCESS and res != UPNP_E_TIMEDOUT) { + auto err = UpnpActionComplete_get_ErrCode(a_event); + JAMI_WARN("PUPnP: Received Action Complete error %i %s", err, UpnpGetErrorMessage(err)); + } else { + auto actionRequest = UpnpActionComplete_get_ActionRequest(a_event); + // Abort if there is no action to process. + if (actionRequest == nullptr) { + JAMI_WARN("PUPnP: Can't get the Action Request data from the event"); + break; + } + + auto actionResult = UpnpActionComplete_get_ActionResult(a_event); + if (actionResult != nullptr) { + ixmlDocument_free(actionResult); + } else { + JAMI_WARN("PUPnP: Action Result document not found"); + } + } + break; + } + default: { + JAMI_WARN("PUPnP: Unhandled Control Point event"); + break; + } + } + + return UPNP_E_SUCCESS; +} + +int +PUPnP::subEventCallback(Upnp_EventType event_type, const void* event, void* user_data) +{ + if (auto pupnp = static_cast<PUPnP*>(user_data)) + return pupnp->handleSubscriptionUPnPEvent(event_type, event); + JAMI_WARN("PUPnP: Subscription callback without service Id string"); + return 0; +} + +int +PUPnP::handleSubscriptionUPnPEvent(Upnp_EventType, const void* event) +{ + UpnpEventSubscribe* es_event = static_cast<UpnpEventSubscribe*>(const_cast<void*>(event)); + + if (es_event == nullptr) { + JAMI_ERR("PUPnP: Unexpected null pointer!"); + return UPNP_E_INVALID_ARGUMENT; + } + std::string publisherUrl(UpnpEventSubscribe_get_PublisherUrl_cstr(es_event)); + int upnp_err = UpnpEventSubscribe_get_ErrCode(es_event); + if (upnp_err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Subscription error %s from %s", + UpnpGetErrorMessage(upnp_err), + publisherUrl.c_str()); + return upnp_err; + } + + return UPNP_E_SUCCESS; +} + +std::unique_ptr<UPnPIGD> +PUPnP::parseIgd(IXML_Document* doc, std::string locationUrl) +{ + if (not(doc and locationUrl.c_str())) + return nullptr; + + // Check the UDN to see if its already in our device list. + std::string UDN(getFirstDocItem(doc, "UDN")); + if (UDN.empty()) { + JAMI_WARN("PUPnP: could not find UDN in description document of device"); + return nullptr; + } else { + std::lock_guard<std::mutex> lk(pupnpMutex_); + for (auto& it : validIgdList_) { + if (it->getUID() == UDN) { + // We already have this device in our list. + return nullptr; + } + } + } + + JAMI_DBG("PUPnP: Found new device [%s]", UDN.c_str()); + + std::unique_ptr<UPnPIGD> new_igd; + int upnp_err; + + // Get friendly name. + std::string friendlyName(getFirstDocItem(doc, "friendlyName")); + + // Get base URL. + std::string baseURL(getFirstDocItem(doc, "URLBase")); + if (baseURL.empty()) + baseURL = locationUrl; + + // Get list of services defined by serviceType. + std::unique_ptr<IXML_NodeList, decltype(ixmlNodeList_free)&> serviceList(nullptr, + ixmlNodeList_free); + serviceList.reset(ixmlDocument_getElementsByTagName(doc, "serviceType")); + unsigned long list_length = ixmlNodeList_length(serviceList.get()); + + // Go through the "serviceType" nodes until we find the the correct service type. + for (unsigned long node_idx = 0; node_idx < list_length; node_idx++) { + IXML_Node* serviceType_node = ixmlNodeList_item(serviceList.get(), node_idx); + std::string serviceType(getElementText(serviceType_node)); + + // Only check serviceType of WANIPConnection or WANPPPConnection. + if (serviceType != UPNP_WANIP_SERVICE + && serviceType != UPNP_WANPPP_SERVICE) { + // IGD is not WANIP or WANPPP service. Going to next node. + continue; + } + + // Get parent node. + IXML_Node* service_node = ixmlNode_getParentNode(serviceType_node); + if (not service_node) { + // IGD serviceType has no parent node. Going to next node. + continue; + } + + // Perform sanity check. The parent node should be called "service". + if (strcmp(ixmlNode_getNodeName(service_node), "service") != 0) { + // IGD "serviceType" parent node is not called "service". Going to next node. + continue; + } + + // Get serviceId. + IXML_Element* service_element = (IXML_Element*) service_node; + std::string serviceId(getFirstElementItem(service_element, "serviceId")); + if (serviceId.empty()) { + // IGD "serviceId" is empty. Going to next node. + continue; + } + + // Get the relative controlURL and turn it into absolute address using the URLBase. + std::string controlURL(getFirstElementItem(service_element, "controlURL")); + if (controlURL.empty()) { + // IGD control URL is empty. Going to next node. + continue; + } + + char* absolute_control_url = nullptr; + upnp_err = UpnpResolveURL2(baseURL.c_str(), controlURL.c_str(), &absolute_control_url); + if (upnp_err == UPNP_E_SUCCESS) + controlURL = absolute_control_url; + else + JAMI_WARN("PUPnP: Error resolving absolute controlURL -> %s", + UpnpGetErrorMessage(upnp_err)); + + std::free(absolute_control_url); + + // Get the relative eventSubURL and turn it into absolute address using the URLBase. + std::string eventSubURL(getFirstElementItem(service_element, "eventSubURL")); + if (eventSubURL.empty()) { + JAMI_WARN("PUPnP: IGD event sub URL is empty. Going to next node"); + continue; + } + + char* absolute_event_sub_url = nullptr; + upnp_err = UpnpResolveURL2(baseURL.c_str(), eventSubURL.c_str(), &absolute_event_sub_url); + if (upnp_err == UPNP_E_SUCCESS) + eventSubURL = absolute_event_sub_url; + else + JAMI_WARN("PUPnP: Error resolving absolute eventSubURL -> %s", + UpnpGetErrorMessage(upnp_err)); + + std::free(absolute_event_sub_url); + + new_igd.reset(new UPnPIGD(std::move(UDN), + std::move(baseURL), + std::move(friendlyName), + std::move(serviceType), + std::move(serviceId), + std::move(locationUrl), + std::move(controlURL), + std::move(eventSubURL))); + + return new_igd; + } + + return nullptr; +} + +bool +PUPnP::actionIsIgdConnected(const UPnPIGD& igd) +{ + if (not clientRegistered_) + return false; + + // Set action name. + IXML_Document* action_container_ptr = UpnpMakeAction("GetStatusInfo", + igd.getServiceType().c_str(), + 0, + nullptr); + if (not action_container_ptr) { + JAMI_WARN("PUPnP: Failed to make GetStatusInfo action"); + return false; + } + XMLDocument action(action_container_ptr, ixmlDocument_free); // Action pointer. + + IXML_Document* response_container_ptr = nullptr; + int upnp_err = UpnpSendAction(ctrlptHandle_, + igd.getControlURL().c_str(), + igd.getServiceType().c_str(), + nullptr, + action.get(), + &response_container_ptr); + if (not response_container_ptr or upnp_err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Failed to send GetStatusInfo action -> %s", UpnpGetErrorMessage(upnp_err)); + return false; + } + XMLDocument response(response_container_ptr, ixmlDocument_free); + + if (errorOnResponse(response.get())) { + JAMI_WARN("PUPnP: Failed to get GetStatusInfo from %s -> %d: %s", + igd.getServiceType().c_str(), + upnp_err, + UpnpGetErrorMessage(upnp_err)); + return false; + } + + // Parse response. + auto status = getFirstDocItem(response.get(), "NewConnectionStatus"); + return status == "Connected"; +} + +IpAddr +PUPnP::actionGetExternalIP(const UPnPIGD& igd) +{ + if (not clientRegistered_) + return {}; + + // Action and response pointers. + std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&> + action(nullptr, ixmlDocument_free); // Action pointer. + std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&> + response(nullptr, ixmlDocument_free); // Response pointer. + + // Set action name. + static constexpr const char* action_name {"GetExternalIPAddress"}; + + IXML_Document* action_container_ptr = nullptr; + action_container_ptr = UpnpMakeAction(action_name, igd.getServiceType().c_str(), 0, nullptr); + action.reset(action_container_ptr); + + if (not action) { + JAMI_WARN("PUPnP: Failed to make GetExternalIPAddress action"); + return {}; + } + + IXML_Document* response_container_ptr = nullptr; + int upnp_err = UpnpSendAction(ctrlptHandle_, + igd.getControlURL().c_str(), + igd.getServiceType().c_str(), + nullptr, + action.get(), + &response_container_ptr); + response.reset(response_container_ptr); + + if (not response or upnp_err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Failed to send GetExternalIPAddress action -> %s", + UpnpGetErrorMessage(upnp_err)); + return {}; + } + + if (errorOnResponse(response.get())) { + JAMI_WARN("PUPnP: Failed to get GetExternalIPAddress from %s -> %d: %s", + igd.getServiceType().c_str(), + upnp_err, + UpnpGetErrorMessage(upnp_err)); + return {}; + } + + return {getFirstDocItem(response.get(), "NewExternalIPAddress")}; +} + +std::map<Mapping::key_t, Mapping> +PUPnP::getMappingsListByDescr(const std::shared_ptr<IGD>& igd, const std::string& description) const +{ + auto upnpIgd = std::dynamic_pointer_cast<UPnPIGD>(igd); + assert(upnpIgd); + + std::map<Mapping::key_t, Mapping> mapList; + + if (not clientRegistered_ or not upnpIgd->isValid() or not upnpIgd->getLocalIp()) + return mapList; + + // Set action name. + static constexpr const char* action_name {"GetGenericPortMappingEntry"}; + + for (int entry_idx = 0;; entry_idx++) { + std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&> + action(nullptr, ixmlDocument_free); // Action pointer. + IXML_Document* action_container_ptr = nullptr; + + std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&> + response(nullptr, ixmlDocument_free); // Response pointer. + IXML_Document* response_container_ptr = nullptr; + + UpnpAddToAction(&action_container_ptr, + action_name, + upnpIgd->getServiceType().c_str(), + "NewPortMappingIndex", + std::to_string(entry_idx).c_str()); + action.reset(action_container_ptr); + + if (not action) { + JAMI_WARN("PUPnP: Failed to add NewPortMappingIndex action"); + break; + } + + int upnp_err = UpnpSendAction(ctrlptHandle_, + upnpIgd->getControlURL().c_str(), + upnpIgd->getServiceType().c_str(), + nullptr, + action.get(), + &response_container_ptr); + response.reset(response_container_ptr); + + if (not response) { + // No existing mapping. Abort silently. + break; + } + + if (upnp_err != UPNP_E_SUCCESS) { + JAMI_ERR("PUPnP: GetGenericPortMappingEntry returned with error: %i", upnp_err); + break; + } + + // Check error code. + auto errorCode = getFirstDocItem(response.get(), "errorCode"); + if (not errorCode.empty()) { + auto error = to_int<int>(errorCode); + if (error == ARRAY_IDX_INVALID or error == CONFLICT_IN_MAPPING) { + // No more port mapping entries in the response. + JAMI_DBG("PUPnP: No more mappings (found a total of %i mappings", entry_idx); + break; + } else { + auto errorDescription = getFirstDocItem(response.get(), "errorDescription"); + JAMI_ERROR("PUPnP: GetGenericPortMappingEntry returned with error: {:s}: {:s}", + errorCode, + errorDescription); + break; + } + } + + // Parse the response. + auto desc_actual = getFirstDocItem(response.get(), "NewPortMappingDescription"); + auto client_ip = getFirstDocItem(response.get(), "NewInternalClient"); + + if (client_ip != getHostAddress().toString()) { + // Silently ignore un-matching addresses. + continue; + } + + if (desc_actual.find(description) == std::string::npos) + continue; + + auto port_internal = getFirstDocItem(response.get(), "NewInternalPort"); + auto port_external = getFirstDocItem(response.get(), "NewExternalPort"); + std::string transport(getFirstDocItem(response.get(), "NewProtocol")); + + if (port_internal.empty() || port_external.empty() || transport.empty()) { + JAMI_ERR("PUPnP: GetGenericPortMappingEntry returned an invalid entry at index %i", + entry_idx); + continue; + } + + std::transform(transport.begin(), transport.end(), transport.begin(), ::toupper); + PortType type = transport.find("TCP") != std::string::npos ? PortType::TCP : PortType::UDP; + auto ePort = to_int<uint16_t>(port_external); + auto iPort = to_int<uint16_t>(port_internal); + + Mapping map(type, ePort, iPort); + map.setIgd(igd); + + mapList.emplace(map.getMapKey(), std::move(map)); + } + + JAMI_DEBUG("PUPnP: Found {:d} allocated mappings on IGD {:s}", + mapList.size(), + upnpIgd->toString()); + + return mapList; +} + +void +PUPnP::deleteMappingsByDescription(const std::shared_ptr<IGD>& igd, const std::string& description) +{ + if (not(clientRegistered_ and igd->getLocalIp())) + return; + + JAMI_DBG("PUPnP: Remove all mappings (if any) on IGD %s matching descr prefix %s", + igd->toString().c_str(), + Mapping::UPNP_MAPPING_DESCRIPTION_PREFIX); + + auto mapList = getMappingsListByDescr(igd, description); + + for (auto const& [_, map] : mapList) { + requestMappingRemove(map); + } +} + +bool +PUPnP::actionAddPortMapping(const Mapping& mapping) +{ + CHECK_VALID_THREAD(); + + if (not clientRegistered_) + return false; + + auto igdIn = std::dynamic_pointer_cast<UPnPIGD>(mapping.getIgd()); + if (not igdIn) + return false; + + // The requested IGD must be present in the list of local valid IGDs. + auto igd = findMatchingIgd(igdIn->getControlURL()); + + if (not igd or not igd->isValid()) + return false; + + // Action and response pointers. + XMLDocument action(nullptr, ixmlDocument_free); + IXML_Document* action_container_ptr = nullptr; + XMLDocument response(nullptr, ixmlDocument_free); + IXML_Document* response_container_ptr = nullptr; + + // Set action sequence. + UpnpAddToAction(&action_container_ptr, + ACTION_ADD_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewRemoteHost", + ""); + UpnpAddToAction(&action_container_ptr, + ACTION_ADD_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewExternalPort", + mapping.getExternalPortStr().c_str()); + UpnpAddToAction(&action_container_ptr, + ACTION_ADD_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewProtocol", + mapping.getTypeStr()); + UpnpAddToAction(&action_container_ptr, + ACTION_ADD_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewInternalPort", + mapping.getInternalPortStr().c_str()); + UpnpAddToAction(&action_container_ptr, + ACTION_ADD_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewInternalClient", + getHostAddress().toString().c_str()); + UpnpAddToAction(&action_container_ptr, + ACTION_ADD_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewEnabled", + "1"); + UpnpAddToAction(&action_container_ptr, + ACTION_ADD_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewPortMappingDescription", + mapping.toString().c_str()); + UpnpAddToAction(&action_container_ptr, + ACTION_ADD_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewLeaseDuration", + "0"); + + action.reset(action_container_ptr); + + int upnp_err = UpnpSendAction(ctrlptHandle_, + igd->getControlURL().c_str(), + igd->getServiceType().c_str(), + nullptr, + action.get(), + &response_container_ptr); + response.reset(response_container_ptr); + + bool success = true; + + if (upnp_err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Failed to send action %s for mapping %s. %d: %s", + ACTION_ADD_PORT_MAPPING, + mapping.toString().c_str(), + upnp_err, + UpnpGetErrorMessage(upnp_err)); + JAMI_WARN("PUPnP: IGD ctrlUrl %s", igd->getControlURL().c_str()); + JAMI_WARN("PUPnP: IGD service type %s", igd->getServiceType().c_str()); + + success = false; + } + + // Check if an error has occurred. + auto errorCode = getFirstDocItem(response.get(), "errorCode"); + if (not errorCode.empty()) { + success = false; + // Try to get the error description. + std::string errorDescription; + if (response) { + errorDescription = getFirstDocItem(response.get(), "errorDescription"); + } + + JAMI_WARNING("PUPnP: {:s} returned with error: {:s} {:s}", + ACTION_ADD_PORT_MAPPING, + errorCode, + errorDescription); + } + return success; +} + +bool +PUPnP::actionDeletePortMapping(const Mapping& mapping) +{ + CHECK_VALID_THREAD(); + + if (not clientRegistered_) + return false; + + auto igdIn = std::dynamic_pointer_cast<UPnPIGD>(mapping.getIgd()); + if (not igdIn) + return false; + + // The requested IGD must be present in the list of local valid IGDs. + auto igd = findMatchingIgd(igdIn->getControlURL()); + + if (not igd or not igd->isValid()) + return false; + + // Action and response pointers. + XMLDocument action(nullptr, ixmlDocument_free); + IXML_Document* action_container_ptr = nullptr; + XMLDocument response(nullptr, ixmlDocument_free); + IXML_Document* response_container_ptr = nullptr; + + // Set action sequence. + UpnpAddToAction(&action_container_ptr, + ACTION_DELETE_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewRemoteHost", + ""); + UpnpAddToAction(&action_container_ptr, + ACTION_DELETE_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewExternalPort", + mapping.getExternalPortStr().c_str()); + UpnpAddToAction(&action_container_ptr, + ACTION_DELETE_PORT_MAPPING, + igd->getServiceType().c_str(), + "NewProtocol", + mapping.getTypeStr()); + + action.reset(action_container_ptr); + + int upnp_err = UpnpSendAction(ctrlptHandle_, + igd->getControlURL().c_str(), + igd->getServiceType().c_str(), + nullptr, + action.get(), + &response_container_ptr); + response.reset(response_container_ptr); + + bool success = true; + + if (upnp_err != UPNP_E_SUCCESS) { + JAMI_WARN("PUPnP: Failed to send action %s for mapping from %s. %d: %s", + ACTION_DELETE_PORT_MAPPING, + mapping.toString().c_str(), + upnp_err, + UpnpGetErrorMessage(upnp_err)); + JAMI_WARN("PUPnP: IGD ctrlUrl %s", igd->getControlURL().c_str()); + JAMI_WARN("PUPnP: IGD service type %s", igd->getServiceType().c_str()); + + success = false; + } + + if (not response) { + JAMI_WARN("PUPnP: Failed to get response for %s", ACTION_DELETE_PORT_MAPPING); + success = false; + } + + // Check if there is an error code. + auto errorCode = getFirstDocItem(response.get(), "errorCode"); + if (not errorCode.empty()) { + auto errorDescription = getFirstDocItem(response.get(), "errorDescription"); + JAMI_WARNING("PUPnP: {:s} returned with error: {:s}: {:s}", + ACTION_DELETE_PORT_MAPPING, + errorCode, + errorDescription); + success = false; + } + + return success; +} + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/pupnp/pupnp.h b/src/upnp/protocol/pupnp/pupnp.h new file mode 100644 index 0000000..a77f30f --- /dev/null +++ b/src/upnp/protocol/pupnp/pupnp.h @@ -0,0 +1,271 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#ifdef _WIN32 +#define UPNP_USE_MSVCPP +#define UPNP_STATIC_LIB +#endif + +#include "../upnp_protocol.h" +#include "../igd.h" +#include "upnp_igd.h" + +#include "logger.h" +#include "connectivity/ip_utils.h" +#include "noncopyable.h" +#include "compiler_intrinsics.h" + +#include <upnp/upnp.h> +#include <upnp/upnptools.h> + +#ifdef _WIN32 +#include <windows.h> +#include <wincrypt.h> +#endif + +#include <atomic> +#include <thread> +#include <list> +#include <map> +#include <set> +#include <string> +#include <memory> +#include <future> + +namespace jami { +class IpAddr; +} + +namespace jami { +namespace upnp { + +class PUPnP : public UPnPProtocol +{ +public: + using XMLDocument = std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&>; + + enum class CtrlAction { + UNKNOWN, + ADD_PORT_MAPPING, + DELETE_PORT_MAPPING, + GET_GENERIC_PORT_MAPPING_ENTRY, + GET_STATUS_INFO, + GET_EXTERNAL_IP_ADDRESS + }; + + PUPnP(); + ~PUPnP(); + + // Set the observer + void setObserver(UpnpMappingObserver* obs) override; + + // Returns the protocol type. + NatProtocolType getProtocol() const override { return NatProtocolType::PUPNP; } + + // Get protocol type as string. + char const* getProtocolName() const override { return "PUPNP"; } + + // Notifies a change in network. + void clearIgds() override; + + // Sends out async search for IGD. + void searchForIgd() override; + + // Get the IGD list. + std::list<std::shared_ptr<IGD>> getIgdList() const override; + + // Return true if the it's fully setup. + bool isReady() const override; + + // Get from the IGD the list of already allocated mappings if any. + std::map<Mapping::key_t, Mapping> getMappingsListByDescr( + const std::shared_ptr<IGD>& igd, const std::string& descr) const override; + + // Request a new mapping. + void requestMappingAdd(const Mapping& mapping) override; + + // Renew an allocated mapping. + // Not implemented. Currently, UPNP allocations do not have expiration time. + void requestMappingRenew([[maybe_unused]] const Mapping& mapping) override { assert(false); }; + + // Removes a mapping. + void requestMappingRemove(const Mapping& igdMapping) override; + + // Get the host (local) address. + const IpAddr getHostAddress() const override; + + // Terminate the instance. + void terminate() override; + +private: + NON_COPYABLE(PUPnP); + + // Helpers to run tasks on PUPNP private execution queue. + ScheduledExecutor* getPUPnPScheduler() { return &pupnpScheduler_; } + template<typename Callback> + void runOnPUPnPQueue(Callback&& cb) + { + pupnpScheduler_.run([cb = std::forward<Callback>(cb)]() mutable { cb(); }); + } + + // Helper to run tasks on UPNP context execution queue. + ScheduledExecutor* getUpnContextScheduler() { return UpnpThreadUtil::getScheduler(); } + + void terminate(std::condition_variable& cv); + + // Init lib-upnp + void initUpnpLib(); + + // Return true if running. + bool isRunning() const; + + // Register the client + void registerClient(); + + // Start search for UPNP devices + void searchForDevices(); + + // Return true if it has at least one valid IGD. + bool hasValidIgd() const; + + // Update the host (local) address. + void updateHostAddress(); + + // Check the host (local) address. + // Returns true if the address is valid. + bool hasValidHostAddress(); + + // Delete mappings matching the description + void deleteMappingsByDescription(const std::shared_ptr<IGD>& igd, + const std::string& description); + + // Search for the IGD in the local list of known IGDs. + std::shared_ptr<UPnPIGD> findMatchingIgd(const std::string& ctrlURL) const; + + // Process the reception of an add mapping action answer. + void processAddMapAction(const Mapping& map); + + // Process the a mapping request failure. + void processRequestMappingFailure(const Mapping& map); + + // Process the reception of a remove mapping action answer. + void processRemoveMapAction(const Mapping& map); + + // Increment IGD errors counter. + void incrementErrorsCounter(const std::shared_ptr<IGD>& igd); + + // Download XML document. + void downLoadIgdDescription(const std::string& url); + + // Validate IGD from the xml document received from the router. + bool validateIgd(const std::string& location, IXML_Document* doc_container_ptr); + + // Returns control point action callback based on xml node. + static CtrlAction getAction(const char* xmlNode); + + // Control point callback. + static int ctrlPtCallback(Upnp_EventType event_type, const void* event, void* user_data); +#if UPNP_VERSION < 10800 + static inline int ctrlPtCallback(Upnp_EventType event_type, void* event, void* user_data) + { + return ctrlPtCallback(event_type, (const void*) event, user_data); + }; +#endif + // Process IGD responses. + void processDiscoverySearchResult(const std::string& deviceId, + const std::string& igdUrl, + const IpAddr& dstAddr); + void processDiscoveryAdvertisementByebye(const std::string& deviceId); + void processDiscoverySubscriptionExpired(Upnp_EventType event_type, + const std::string& eventSubUrl); + + // Callback event handler function for the UPnP client (control point). + int handleCtrlPtUPnPEvents(Upnp_EventType event_type, const void* event); + + // Subscription event callback. + static int subEventCallback(Upnp_EventType event_type, const void* event, void* user_data); +#if UPNP_VERSION < 10800 + static inline int subEventCallback(Upnp_EventType event_type, void* event, void* user_data) + { + return subEventCallback(event_type, (const void*) event, user_data); + }; +#endif + + // Callback subscription event function for handling subscription request. + int handleSubscriptionUPnPEvent(Upnp_EventType event_type, const void* event); + + // Parses the IGD candidate. + std::unique_ptr<UPnPIGD> parseIgd(IXML_Document* doc, std::string locationUrl); + + // These functions directly create UPnP actions and make synchronous UPnP + // control point calls. Must be run on the PUPNP internal execution queue. + bool actionIsIgdConnected(const UPnPIGD& igd); + IpAddr actionGetExternalIP(const UPnPIGD& igd); + bool actionAddPortMapping(const Mapping& mapping); + bool actionDeletePortMapping(const Mapping& mapping); + + // Event type to string + static const char* eventTypeToString(Upnp_EventType eventType); + + std::weak_ptr<PUPnP> weak() { return std::static_pointer_cast<PUPnP>(shared_from_this()); } + + // Execution queue to run lib upnp actions + ScheduledExecutor pupnpScheduler_ {"pupnp"}; + + // Initialization status. + std::atomic_bool initialized_ {false}; + // Client registration status. + std::atomic_bool clientRegistered_ {false}; + + std::shared_ptr<Task> searchForIgdTimer_ {}; + unsigned int igdSearchCounter_ {0}; + + // List of discovered IGDs. + std::set<std::string> discoveredIgdList_; + + // Control point handle. + UpnpClient_Handle ctrlptHandle_ {-1}; + + // Observer to report the results. + UpnpMappingObserver* observer_ {nullptr}; + + // List of valid IGDs. + std::list<std::shared_ptr<IGD>> validIgdList_; + + // Current host address. + IpAddr hostAddress_ {}; + + // Calls from other threads that does not need synchronous access are + // rescheduled on the UPNP private queue. This will avoid the need to + // protect most of the data members of this class. + // For some internal members (namely the validIgdList and the hostAddress) + // that need to be synchronously accessed, are protected by this mutex. + mutable std::mutex pupnpMutex_; + + // Shutdown synchronization + bool shutdownComplete_ {false}; +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/pupnp/upnp_igd.cpp b/src/upnp/protocol/pupnp/upnp_igd.cpp new file mode 100644 index 0000000..2f8a332 --- /dev/null +++ b/src/upnp/protocol/pupnp/upnp_igd.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "upnp_igd.h" + +namespace jami { +namespace upnp { + +UPnPIGD::UPnPIGD(std::string&& UDN, + std::string&& baseURL, + std::string&& friendlyName, + std::string&& serviceType, + std::string&& serviceId, + std::string&& locationURL, + std::string&& controlURL, + std::string&& eventSubURL, + IpAddr&& localIp, + IpAddr&& publicIp) + : IGD(NatProtocolType::PUPNP) +{ + uid_ = std::move(UDN); + baseURL_ = std::move(baseURL); + friendlyName_ = std::move(friendlyName); + serviceType_ = std::move(serviceType); + serviceId_ = std::move(serviceId); + locationURL_ = std::move(locationURL); + controlURL_ = std::move(controlURL); + eventSubURL_ = std::move(eventSubURL); + localIp_ = std::move(localIp); + publicIp_ = std::move(publicIp); +} + +bool +UPnPIGD::operator==(IGD& other) const +{ + return localIp_ == other.getLocalIp() and publicIp_ == other.getPublicIp(); +} + +bool +UPnPIGD::operator==(UPnPIGD& other) const +{ + if (localIp_ and publicIp_) { + if (localIp_ != other.localIp_ or publicIp_ != other.publicIp_) { + return false; + } + } + + return uid_ == other.uid_ and baseURL_ == other.baseURL_ + and friendlyName_ == other.friendlyName_ and serviceType_ == other.serviceType_ + and serviceId_ == other.serviceId_ and locationURL_ == other.locationURL_ + and controlURL_ == other.controlURL_ and eventSubURL_ == other.eventSubURL_; +} + +} // namespace upnp +} // namespace jami \ No newline at end of file diff --git a/src/upnp/protocol/pupnp/upnp_igd.h b/src/upnp/protocol/pupnp/upnp_igd.h new file mode 100644 index 0000000..2ad213b --- /dev/null +++ b/src/upnp/protocol/pupnp/upnp_igd.h @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "connectivity/upnp/protocol/igd.h" + +#include "noncopyable.h" +#include "connectivity/ip_utils.h" + +#include <map> +#include <string> +#include <chrono> +#include <functional> + +namespace jami { +namespace upnp { + +class UPnPIGD : public IGD +{ +public: + UPnPIGD(std::string&& UDN, + std::string&& baseURL, + std::string&& friendlyName, + std::string&& serviceType, + std::string&& serviceId, + std::string&& locationURL, + std::string&& controlURL, + std::string&& eventSubURL, + IpAddr&& localIp = {}, + IpAddr&& publicIp = {}); + + ~UPnPIGD() {} + + bool operator==(IGD& other) const; + bool operator==(UPnPIGD& other) const; + + const std::string& getBaseURL() const + { + std::lock_guard<std::mutex> lock(mutex_); + return baseURL_; + } + const std::string& getFriendlyName() const + { + std::lock_guard<std::mutex> lock(mutex_); + return friendlyName_; + } + const std::string& getServiceType() const + { + std::lock_guard<std::mutex> lock(mutex_); + return serviceType_; + } + const std::string& getServiceId() const + { + std::lock_guard<std::mutex> lock(mutex_); + return serviceId_; + } + const std::string& getLocationURL() const + { + std::lock_guard<std::mutex> lock(mutex_); + return locationURL_; + } + const std::string& getControlURL() const + { + std::lock_guard<std::mutex> lock(mutex_); + return controlURL_; + } + const std::string& getEventSubURL() const + { + std::lock_guard<std::mutex> lock(mutex_); + return eventSubURL_; + } + + const std::string toString() const override { return controlURL_; } + +private: + std::string baseURL_ {}; + std::string friendlyName_ {}; + std::string serviceType_ {}; + std::string serviceId_ {}; + std::string locationURL_ {}; + std::string controlURL_ {}; + std::string eventSubURL_ {}; +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/protocol/upnp_protocol.h b/src/upnp/protocol/upnp_protocol.h new file mode 100644 index 0000000..b38a4dd --- /dev/null +++ b/src/upnp/protocol/upnp_protocol.h @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "igd.h" +#include "mapping.h" +#include "ip_utils.h" +//#include "upnp/upnp_thread_util.h" + +#include <map> +#include <string> +#include <chrono> +#include <functional> +#include <condition_variable> +#include <list> + +namespace jami { +namespace upnp { + +// UPnP device descriptions. +constexpr static const char* UPNP_ROOT_DEVICE = "upnp:rootdevice"; +constexpr static const char* UPNP_IGD_DEVICE + = "urn:schemas-upnp-org:device:InternetGatewayDevice:1"; +constexpr static const char* UPNP_WAN_DEVICE = "urn:schemas-upnp-org:device:WANDevice:1"; +constexpr static const char* UPNP_WANCON_DEVICE + = "urn:schemas-upnp-org:device:WANConnectionDevice:1"; +constexpr static const char* UPNP_WANIP_SERVICE = "urn:schemas-upnp-org:service:WANIPConnection:1"; +constexpr static const char* UPNP_WANPPP_SERVICE + = "urn:schemas-upnp-org:service:WANPPPConnection:1"; + +enum class UpnpIgdEvent { ADDED, REMOVED, INVALID_STATE }; + +// Interface used to report mapping event from the protocol implementations. +// This interface is meant to be implemented only by UPnPConext class. Sincce +// this class is a singleton, it's assumed that it out-lives the protocol +// implementations. In other words, the observer is always assumed to point to a +// valid instance. +class UpnpMappingObserver +{ +public: + UpnpMappingObserver() {}; + virtual ~UpnpMappingObserver() {}; + + virtual void onIgdUpdated(const std::shared_ptr<IGD>& igd, UpnpIgdEvent event) = 0; + virtual void onMappingAdded(const std::shared_ptr<IGD>& igd, const Mapping& map) = 0; + virtual void onMappingRequestFailed(const Mapping& map) = 0; +#if HAVE_LIBNATPMP + virtual void onMappingRenewed(const std::shared_ptr<IGD>& igd, const Mapping& map) = 0; +#endif + virtual void onMappingRemoved(const std::shared_ptr<IGD>& igd, const Mapping& map) = 0; +}; + +// Pure virtual interface class that UPnPContext uses to call protocol functions. +class UPnPProtocol : public std::enable_shared_from_this<UPnPProtocol>//, protected UpnpThreadUtil +{ +public: + enum class UpnpError : int { INVALID_ERR = -1, ERROR_OK, CONFLICT_IN_MAPPING }; + + UPnPProtocol() {}; + virtual ~UPnPProtocol() {}; + + // Get protocol type. + virtual NatProtocolType getProtocol() const = 0; + + // Get protocol type as string. + virtual char const* getProtocolName() const = 0; + + // Clear all known IGDs. + virtual void clearIgds() = 0; + + // Search for IGD. + virtual void searchForIgd() = 0; + + // Get the IGD instance. + virtual std::list<std::shared_ptr<IGD>> getIgdList() const = 0; + + // Return true if it has at least one valid IGD. + virtual bool isReady() const = 0; + + // Get the list of already allocated mappings if any. + virtual std::map<Mapping::key_t, Mapping> getMappingsListByDescr(const std::shared_ptr<IGD>&, + const std::string&) const + { + return {}; + } + + // Sends a request to add a mapping. + virtual void requestMappingAdd(const Mapping& map) = 0; + + // Renew an allocated mapping. + virtual void requestMappingRenew(const Mapping& mapping) = 0; + + // Sends a request to remove a mapping. + virtual void requestMappingRemove(const Mapping& igdMapping) = 0; + + // Set the user callbacks. + virtual void setObserver(UpnpMappingObserver* obs) = 0; + + // Get the current host (local) address + virtual const IpAddr getHostAddress() const = 0; + + // Terminate + virtual void terminate() = 0; +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/upnp_context.cpp b/src/upnp/upnp_context.cpp new file mode 100644 index 0000000..ef556f1 --- /dev/null +++ b/src/upnp/upnp_context.cpp @@ -0,0 +1,1339 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "upnp_context.h" + +namespace jami { +namespace upnp { + +constexpr static auto MAP_UPDATE_INTERVAL = std::chrono::seconds(30); +constexpr static int MAX_REQUEST_RETRIES = 20; +constexpr static int MAX_REQUEST_REMOVE_COUNT = 5; + +constexpr static uint16_t UPNP_TCP_PORT_MIN {10000}; +constexpr static uint16_t UPNP_TCP_PORT_MAX {UPNP_TCP_PORT_MIN + 5000}; +constexpr static uint16_t UPNP_UDP_PORT_MIN {20000}; +constexpr static uint16_t UPNP_UDP_PORT_MAX {UPNP_UDP_PORT_MIN + 5000}; + +UPnPContext::UPnPContext() +{ + JAMI_DBG("Creating UPnPContext instance [%p]", this); + + // Set port ranges + portRange_.emplace(PortType::TCP, std::make_pair(UPNP_TCP_PORT_MIN, UPNP_TCP_PORT_MAX)); + portRange_.emplace(PortType::UDP, std::make_pair(UPNP_UDP_PORT_MIN, UPNP_UDP_PORT_MAX)); + + if (not isValidThread()) { + runOnUpnpContextQueue([this] { init(); }); + return; + } +} + +std::shared_ptr<UPnPContext> +UPnPContext::getUPnPContext() +{ + // This is the unique shared instance (singleton) of UPnPContext class. + static auto context = std::make_shared<UPnPContext>(); + return context; +} + +void +UPnPContext::shutdown(std::condition_variable& cv) +{ + JAMI_DBG("Shutdown UPnPContext instance [%p]", this); + + stopUpnp(true); + + for (auto const& [_, proto] : protocolList_) { + proto->terminate(); + } + + { + std::lock_guard<std::mutex> lock(mappingMutex_); + mappingList_->clear(); + if (mappingListUpdateTimer_) + mappingListUpdateTimer_->cancel(); + controllerList_.clear(); + protocolList_.clear(); + shutdownComplete_ = true; + cv.notify_one(); + } +} + +void +UPnPContext::shutdown() +{ + std::unique_lock<std::mutex> lk(mappingMutex_); + std::condition_variable cv; + + runOnUpnpContextQueue([&, this] { shutdown(cv); }); + + JAMI_DBG("Waiting for shutdown ..."); + + if (cv.wait_for(lk, std::chrono::seconds(30), [this] { return shutdownComplete_; })) { + JAMI_DBG("Shutdown completed"); + } else { + JAMI_ERR("Shutdown timed-out"); + } +} + +UPnPContext::~UPnPContext() +{ + JAMI_DBG("UPnPContext instance [%p] destroyed", this); +} + +void +UPnPContext::init() +{ + threadId_ = getCurrentThread(); + CHECK_VALID_THREAD(); + +#if HAVE_LIBNATPMP + auto natPmp = std::make_shared<NatPmp>(); + natPmp->setObserver(this); + protocolList_.emplace(NatProtocolType::NAT_PMP, std::move(natPmp)); +#endif + +#if HAVE_LIBUPNP + auto pupnp = std::make_shared<PUPnP>(); + pupnp->setObserver(this); + protocolList_.emplace(NatProtocolType::PUPNP, std::move(pupnp)); +#endif +} + +void +UPnPContext::startUpnp() +{ + assert(not controllerList_.empty()); + + CHECK_VALID_THREAD(); + + JAMI_DBG("Starting UPNP context"); + + // Request a new IGD search. + for (auto const& [_, protocol] : protocolList_) { + protocol->searchForIgd(); + } + + started_ = true; +} + +void +UPnPContext::stopUpnp(bool forceRelease) +{ + if (not isValidThread()) { + runOnUpnpContextQueue([this, forceRelease] { stopUpnp(forceRelease); }); + return; + } + + JAMI_DBG("Stopping UPNP context"); + + // Clear all current mappings if any. + + // Use a temporary list to avoid processing the mapping + // list while holding the lock. + std::list<Mapping::sharedPtr_t> toRemoveList; + { + std::lock_guard<std::mutex> lock(mappingMutex_); + + PortType types[2] {PortType::TCP, PortType::UDP}; + for (auto& type : types) { + auto& mappingList = getMappingList(type); + for (auto const& [_, map] : mappingList) { + toRemoveList.emplace_back(map); + } + } + // Invalidate the current IGDs. + preferredIgd_.reset(); + validIgdList_.clear(); + } + for (auto const& map : toRemoveList) { + requestRemoveMapping(map); + + // Notify is not needed in updateMappingState when + // shutting down (hence set it to false). NotifyCallback + // would trigger a new SIP registration and create a + // false registered state upon program close. + // It's handled by upper layers. + + updateMappingState(map, MappingState::FAILED, false); + // We dont remove mappings with auto-update enabled, + // unless forceRelease is true. + if (not map->getAutoUpdate() or forceRelease) { + map->enableAutoUpdate(false); + unregisterMapping(map); + } + } + + // Clear all current IGDs. + for (auto const& [_, protocol] : protocolList_) { + protocol->clearIgds(); + } + + started_ = false; +} + +uint16_t +UPnPContext::generateRandomPort(PortType type, bool mustBeEven) +{ + auto minPort = type == PortType::TCP ? UPNP_TCP_PORT_MIN : UPNP_UDP_PORT_MIN; + auto maxPort = type == PortType::TCP ? UPNP_TCP_PORT_MAX : UPNP_UDP_PORT_MAX; + + if (minPort >= maxPort) { + JAMI_ERR("Max port number (%i) must be greater than min port number (%i)", maxPort, minPort); + // Must be called with valid range. + assert(false); + } + + int fact = mustBeEven ? 2 : 1; + if (mustBeEven) { + minPort /= fact; + maxPort /= fact; + } + + // Seed the generator. + static std::mt19937 gen(dht::crypto::getSeededRandomEngine()); + // Define the range. + std::uniform_int_distribution<uint16_t> dist(minPort, maxPort); + return dist(gen) * fact; +} + +void +UPnPContext::connectivityChanged() +{ + if (not isValidThread()) { + runOnUpnpContextQueue([this] { connectivityChanged(); }); + return; + } + + auto hostAddr = ip_utils::getLocalAddr(AF_INET); + + JAMI_DBG("Connectivity change check: host address %s", hostAddr.toString().c_str()); + + auto restartUpnp = false; + + // On reception of "connectivity change" notification, the UPNP search + // will be restarted if either there is no valid IGD, or the IGD address + // changed. + + if (not isReady()) { + restartUpnp = true; + } else { + // Check if the host address changed. + for (auto const& [_, protocol] : protocolList_) { + if (protocol->isReady() and hostAddr != protocol->getHostAddress()) { + JAMI_WARN("Host address changed from %s to %s", + protocol->getHostAddress().toString().c_str(), + hostAddr.toString().c_str()); + protocol->clearIgds(); + restartUpnp = true; + break; + } + } + } + + // We have at least one valid IGD and the host address did + // not change, so no need to restart. + if (not restartUpnp) { + return; + } + + // No registered controller. A new search will be performed when + // a controller is registered. + if (controllerList_.empty()) + return; + + JAMI_DBG("Connectivity changed. Clear the IGDs and restart"); + + stopUpnp(); + startUpnp(); + + // Mapping with auto update enabled must be processed first. + processMappingWithAutoUpdate(); +} + +void +UPnPContext::setPublicAddress(const IpAddr& addr) +{ + if (not addr) + return; + + std::lock_guard<std::mutex> lock(mappingMutex_); + if (knownPublicAddress_ != addr) { + knownPublicAddress_ = std::move(addr); + JAMI_DBG("Setting the known public address to %s", addr.toString().c_str()); + } +} + +bool +UPnPContext::isReady() const +{ + std::lock_guard<std::mutex> lock(mappingMutex_); + return not validIgdList_.empty(); +} + +IpAddr +UPnPContext::getExternalIP() const +{ + std::lock_guard<std::mutex> lock(mappingMutex_); + // Return the first IGD Ip available. + if (not validIgdList_.empty()) { + return (*validIgdList_.begin())->getPublicIp(); + } + return {}; +} + +Mapping::sharedPtr_t +UPnPContext::reserveMapping(Mapping& requestedMap) +{ + auto desiredPort = requestedMap.getExternalPort(); + + if (desiredPort == 0) { + JAMI_DBG("Desired port is not set, will provide the first available port for [%s]", + requestedMap.getTypeStr()); + } else { + JAMI_DBG("Try to find mapping for port %i [%s]", desiredPort, requestedMap.getTypeStr()); + } + + Mapping::sharedPtr_t mapRes; + + { + std::lock_guard<std::mutex> lock(mappingMutex_); + auto& mappingList = getMappingList(requestedMap.getType()); + + // We try to provide a mapping in "OPEN" state. If not found, + // we provide any available mapping. In this case, it's up to + // the caller to use it or not. + for (auto const& [_, map] : mappingList) { + // If the desired port is null, we pick the first available port. + if (map->isValid() and (desiredPort == 0 or map->getExternalPort() == desiredPort) + and map->isAvailable()) { + // Considere the first available mapping regardless of its + // state. A mapping with OPEN state will be used if found. + if (not mapRes) + mapRes = map; + + if (map->getState() == MappingState::OPEN) { + // Found an "OPEN" mapping. We are done. + mapRes = map; + break; + } + } + } + } + + // Create a mapping if none was available. + if (not mapRes) { + JAMI_WARN("Did not find any available mapping. Will request one now"); + mapRes = registerMapping(requestedMap); + } + + if (mapRes) { + // Make the mapping unavailable + mapRes->setAvailable(false); + // Copy attributes. + mapRes->setNotifyCallback(requestedMap.getNotifyCallback()); + mapRes->enableAutoUpdate(requestedMap.getAutoUpdate()); + // Notify the listener. + if (auto cb = mapRes->getNotifyCallback()) + cb(mapRes); + } + + updateMappingList(true); + + return mapRes; +} + +void +UPnPContext::releaseMapping(const Mapping& map) +{ + if (not isValidThread()) { + runOnUpnpContextQueue([this, map] { releaseMapping(map); }); + return; + } + + auto mapPtr = getMappingWithKey(map.getMapKey()); + + if (not mapPtr) { + // Might happen if the mapping failed or was never granted. + JAMI_DBG("Mapping %s does not exist or was already removed", map.toString().c_str()); + return; + } + + if (mapPtr->isAvailable()) { + JAMI_WARN("Trying to release an unused mapping %s", mapPtr->toString().c_str()); + return; + } + + // Remove it. + requestRemoveMapping(mapPtr); + unregisterMapping(mapPtr); +} + +void +UPnPContext::registerController(void* controller) +{ + { + std::lock_guard<std::mutex> lock(mappingMutex_); + if (shutdownComplete_) { + JAMI_WARN("UPnPContext already shut down"); + return; + } + } + + if (not isValidThread()) { + runOnUpnpContextQueue([this, controller] { registerController(controller); }); + return; + } + + auto ret = controllerList_.emplace(controller); + if (not ret.second) { + JAMI_WARN("Controller %p is already registered", controller); + return; + } + + JAMI_DBG("Successfully registered controller %p", controller); + if (not started_) + startUpnp(); +} + +void +UPnPContext::unregisterController(void* controller) +{ + if (not isValidThread()) { + runOnUpnpContextQueue([this, controller] { unregisterController(controller); }); + return; + } + + if (controllerList_.erase(controller) == 1) { + JAMI_DBG("Successfully unregistered controller %p", controller); + } else { + JAMI_DBG("Controller %p was already removed", controller); + } + + if (controllerList_.empty()) { + stopUpnp(); + } +} + +uint16_t +UPnPContext::getAvailablePortNumber(PortType type) +{ + // Only return an availalable random port. No actual + // reservation is made here. + + std::lock_guard<std::mutex> lock(mappingMutex_); + auto& mappingList = getMappingList(type); + int tryCount = 0; + while (tryCount++ < MAX_REQUEST_RETRIES) { + uint16_t port = generateRandomPort(type); + Mapping map(type, port, port); + if (mappingList.find(map.getMapKey()) == mappingList.end()) + return port; + } + + // Very unlikely to get here. + JAMI_ERR("Could not find an available port after %i trials", MAX_REQUEST_RETRIES); + return 0; +} + +void +UPnPContext::requestMapping(const Mapping::sharedPtr_t& map) +{ + assert(map); + + if (not isValidThread()) { + runOnUpnpContextQueue([this, map] { requestMapping(map); }); + return; + } + + auto const& igd = getPreferredIgd(); + // We must have at least a valid IGD pointer if we get here. + // Not this method is called only if there were a valid IGD, however, + // because the processing is asynchronous, it's possible that the IGD + // was invalidated when the this code executed. + if (not igd) { + JAMI_DBG("No valid IGDs available"); + return; + } + + map->setIgd(igd); + + JAMI_DBG("Request mapping %s using protocol [%s] IGD [%s]", + map->toString().c_str(), + igd->getProtocolName(), + igd->toString().c_str()); + + if (map->getState() != MappingState::IN_PROGRESS) + updateMappingState(map, MappingState::IN_PROGRESS); + + auto const& protocol = protocolList_.at(igd->getProtocol()); + protocol->requestMappingAdd(*map); +} + +bool +UPnPContext::provisionNewMappings(PortType type, int portCount) +{ + JAMI_DBG("Provision %i new mappings of type [%s]", portCount, Mapping::getTypeStr(type)); + + assert(portCount > 0); + + while (portCount > 0) { + auto port = getAvailablePortNumber(type); + if (port > 0) { + // Found an available port number + portCount--; + Mapping map(type, port, port, true); + registerMapping(map); + } else { + // Very unlikely to get here! + JAMI_ERR("Can not find any available port to provision!"); + return false; + } + } + + return true; +} + +bool +UPnPContext::deleteUnneededMappings(PortType type, int portCount) +{ + JAMI_DBG("Remove %i unneeded mapping of type [%s]", portCount, Mapping::getTypeStr(type)); + + assert(portCount > 0); + + CHECK_VALID_THREAD(); + + std::lock_guard<std::mutex> lock(mappingMutex_); + auto& mappingList = getMappingList(type); + + for (auto it = mappingList.begin(); it != mappingList.end();) { + auto map = it->second; + assert(map); + + if (not map->isAvailable()) { + it++; + continue; + } + + if (map->getState() == MappingState::OPEN and portCount > 0) { + // Close portCount mappings in "OPEN" state. + requestRemoveMapping(map); + it = unregisterMapping(it); + portCount--; + } else if (map->getState() != MappingState::OPEN) { + // If this methods is called, it means there are more open + // mappings than required. So, all mappings in a state other + // than "OPEN" state (typically in in-progress state) will + // be deleted as well. + it = unregisterMapping(it); + } else { + it++; + } + } + + return true; +} + +void +UPnPContext::updatePreferredIgd() +{ + CHECK_VALID_THREAD(); + + if (preferredIgd_ and preferredIgd_->isValid()) + return; + + // Reset and search for the best IGD. + preferredIgd_.reset(); + + for (auto const& [_, protocol] : protocolList_) { + if (protocol->isReady()) { + auto igdList = protocol->getIgdList(); + assert(not igdList.empty()); + auto const& igd = igdList.front(); + if (not igd->isValid()) + continue; + + // Prefer NAT-PMP over PUPNP. + if (preferredIgd_ and igd->getProtocol() != NatProtocolType::NAT_PMP) + continue; + + // Update. + preferredIgd_ = igd; + } + } + + if (preferredIgd_ and preferredIgd_->isValid()) { + JAMI_DBG("Preferred IGD updated to [%s] IGD [%s %s] ", + preferredIgd_->getProtocolName(), + preferredIgd_->getUID().c_str(), + preferredIgd_->toString().c_str()); + } +} + +std::shared_ptr<IGD> +UPnPContext::getPreferredIgd() const +{ + CHECK_VALID_THREAD(); + + return preferredIgd_; +} + +void +UPnPContext::updateMappingList(bool async) +{ + // Run async if requested. + if (async) { + runOnUpnpContextQueue([this] { updateMappingList(false); }); + return; + } + + CHECK_VALID_THREAD(); + + // Update the preferred IGD. + updatePreferredIgd(); + + if (mappingListUpdateTimer_) { + mappingListUpdateTimer_->cancel(); + mappingListUpdateTimer_ = {}; + } + + // Skip if no controller registered. + if (controllerList_.empty()) + return; + + // Cancel the current timer (if any) and re-schedule. + std::shared_ptr<IGD> prefIgd = getPreferredIgd(); + if (not prefIgd) { + JAMI_DBG("UPNP/NAT-PMP enabled, but no valid IGDs available"); + // No valid IGD. Nothing to do. + return; + } + + mappingListUpdateTimer_ = getScheduler()->scheduleIn([this] { updateMappingList(false); }, + MAP_UPDATE_INTERVAL); + + // Process pending requests if any. + processPendingRequests(prefIgd); + + // Make new requests for mappings that failed and have + // the auto-update option enabled. + processMappingWithAutoUpdate(); + + PortType typeArray[2] = {PortType::TCP, PortType::UDP}; + + for (auto idx : {0, 1}) { + auto type = typeArray[idx]; + + MappingStatus status; + getMappingStatus(type, status); + + JAMI_DBG("Mapping status [%s] - overall %i: %i open (%i ready + %i in use), %i pending, %i " + "in-progress, %i failed", + Mapping::getTypeStr(type), + status.sum(), + status.openCount_, + status.readyCount_, + status.openCount_ - status.readyCount_, + status.pendingCount_, + status.inProgressCount_, + status.failedCount_); + + if (status.failedCount_ > 0) { + std::lock_guard<std::mutex> lock(mappingMutex_); + auto const& mappingList = getMappingList(type); + for (auto const& [_, map] : mappingList) { + if (map->getState() == MappingState::FAILED) { + JAMI_DBG("Mapping status [%s] - Available [%s]", + map->toString(true).c_str(), + map->isAvailable() ? "YES" : "NO"); + } + } + } + + int toRequestCount = (int) minOpenPortLimit_[idx] + - (int) (status.readyCount_ + status.inProgressCount_ + + status.pendingCount_); + + // Provision/release mappings accordingly. + if (toRequestCount > 0) { + // Take into account the request in-progress when making + // requests for new mappings. + provisionNewMappings(type, toRequestCount); + } else if (status.readyCount_ > maxOpenPortLimit_[idx]) { + deleteUnneededMappings(type, status.readyCount_ - maxOpenPortLimit_[idx]); + } + } + + // Prune the mapping list if needed + if (protocolList_.at(NatProtocolType::PUPNP)->isReady()) { +#if HAVE_LIBNATPMP + // Dont perform if NAT-PMP is valid. + if (not protocolList_.at(NatProtocolType::NAT_PMP)->isReady()) +#endif + { + pruneMappingList(); + } + } + +#if HAVE_LIBNATPMP + // Renew nat-pmp allocations + if (protocolList_.at(NatProtocolType::NAT_PMP)->isReady()) + renewAllocations(); +#endif +} + +void +UPnPContext::pruneMappingList() +{ + CHECK_VALID_THREAD(); + + MappingStatus status; + getMappingStatus(status); + + // Do not prune the list if there are pending/in-progress requests. + if (status.inProgressCount_ != 0 or status.pendingCount_ != 0) { + return; + } + + auto const& igd = getPreferredIgd(); + if (not igd or igd->getProtocol() != NatProtocolType::PUPNP) { + return; + } + auto protocol = protocolList_.at(NatProtocolType::PUPNP); + + auto remoteMapList = protocol->getMappingsListByDescr(igd, + Mapping::UPNP_MAPPING_DESCRIPTION_PREFIX); + if (remoteMapList.empty()) { + std::lock_guard<std::mutex> lock(mappingMutex_); + if (not getMappingList(PortType::TCP).empty() or getMappingList(PortType::TCP).empty()) { + JAMI_WARN("We have provisionned mappings but the PUPNP IGD returned an empty list!"); + } + } + + pruneUnMatchedMappings(igd, remoteMapList); + pruneUnTrackedMappings(igd, remoteMapList); +} + +void +UPnPContext::pruneUnMatchedMappings(const std::shared_ptr<IGD>& igd, + const std::map<Mapping::key_t, Mapping>& remoteMapList) +{ + // Check/synchronize local mapping list with the list + // returned by the IGD. + + PortType types[2] {PortType::TCP, PortType::UDP}; + + for (auto& type : types) { + // Use a temporary list to avoid processing mappings while holding the lock. + std::list<Mapping::sharedPtr_t> toRemoveList; + { + std::lock_guard<std::mutex> lock(mappingMutex_); + auto& mappingList = getMappingList(type); + for (auto const& [_, map] : mappingList) { + // Only check mappings allocated by UPNP protocol. + if (map->getProtocol() != NatProtocolType::PUPNP) { + continue; + } + // Set mapping as failed if not found in the list + // returned by the IGD. + if (map->getState() == MappingState::OPEN + and remoteMapList.find(map->getMapKey()) == remoteMapList.end()) { + toRemoveList.emplace_back(map); + + JAMI_WARN("Mapping %s (IGD %s) marked as \"OPEN\" but not found in the " + "remote list. Mark as failed!", + map->toString().c_str(), + igd->toString().c_str()); + } + } + } + + for (auto const& map : toRemoveList) { + updateMappingState(map, MappingState::FAILED); + unregisterMapping(map); + } + } +} + +void +UPnPContext::pruneUnTrackedMappings(const std::shared_ptr<IGD>& igd, + const std::map<Mapping::key_t, Mapping>& remoteMapList) +{ + // Use a temporary list to avoid processing mappings while holding the lock. + std::list<Mapping> toRemoveList; + { + std::lock_guard<std::mutex> lock(mappingMutex_); + + for (auto const& [_, map] : remoteMapList) { + // Must has valid IGD pointer and use UPNP protocol. + assert(map.getIgd()); + assert(map.getIgd()->getProtocol() == NatProtocolType::PUPNP); + auto& mappingList = getMappingList(map.getType()); + auto it = mappingList.find(map.getMapKey()); + if (it == mappingList.end()) { + // Not present, request mapping remove. + toRemoveList.emplace_back(std::move(map)); + // Make only few remove requests at once. + if (toRemoveList.size() >= MAX_REQUEST_REMOVE_COUNT) + break; + } + } + } + + // Remove un-tracked mappings. + auto protocol = protocolList_.at(NatProtocolType::PUPNP); + for (auto const& map : toRemoveList) { + protocol->requestMappingRemove(map); + } +} + +void +UPnPContext::pruneMappingsWithInvalidIgds(const std::shared_ptr<IGD>& igd) +{ + CHECK_VALID_THREAD(); + + // Use temporary list to avoid holding the lock while + // processing the mapping list. + std::list<Mapping::sharedPtr_t> toRemoveList; + { + std::lock_guard<std::mutex> lock(mappingMutex_); + + PortType types[2] {PortType::TCP, PortType::UDP}; + for (auto& type : types) { + auto& mappingList = getMappingList(type); + for (auto const& [_, map] : mappingList) { + if (map->getIgd() == igd) + toRemoveList.emplace_back(map); + } + } + } + + for (auto const& map : toRemoveList) { + JAMI_DBG("Remove mapping %s (has an invalid IGD %s [%s])", + map->toString().c_str(), + igd->toString().c_str(), + igd->getProtocolName()); + updateMappingState(map, MappingState::FAILED); + unregisterMapping(map); + } +} + +void +UPnPContext::processPendingRequests(const std::shared_ptr<IGD>& igd) +{ + // This list holds the mappings to be requested. This is + // needed to avoid performing the requests while holding + // the lock. + std::list<Mapping::sharedPtr_t> requestsList; + + // Populate the list of requests to perform. + { + std::lock_guard<std::mutex> lock(mappingMutex_); + PortType typeArray[2] {PortType::TCP, PortType::UDP}; + + for (auto type : typeArray) { + auto& mappingList = getMappingList(type); + for (auto& [_, map] : mappingList) { + if (map->getState() == MappingState::PENDING) { + JAMI_DBG("Send pending request for mapping %s to IGD %s", + map->toString().c_str(), + igd->toString().c_str()); + requestsList.emplace_back(map); + } + } + } + } + + // Process the pending requests. + for (auto const& map : requestsList) { + requestMapping(map); + } +} + +void +UPnPContext::processMappingWithAutoUpdate() +{ + // This list holds the mappings to be requested. This is + // needed to avoid performing the requests while holding + // the lock. + std::list<Mapping::sharedPtr_t> requestsList; + + // Populate the list of requests for mappings with auto-update enabled. + { + std::lock_guard<std::mutex> lock(mappingMutex_); + PortType typeArray[2] {PortType::TCP, PortType::UDP}; + + for (auto type : typeArray) { + auto& mappingList = getMappingList(type); + for (auto const& [_, map] : mappingList) { + if (map->getState() == MappingState::FAILED and map->getAutoUpdate()) { + requestsList.emplace_back(map); + } + } + } + } + + for (auto const& oldMap : requestsList) { + // Request a new mapping if auto-update is enabled. + JAMI_DBG("Mapping %s has auto-update enabled, a new mapping will be requested", + oldMap->toString().c_str()); + + // Reserve a new mapping. + Mapping newMapping(oldMap->getType()); + newMapping.enableAutoUpdate(true); + newMapping.setNotifyCallback(oldMap->getNotifyCallback()); + + auto const& mapPtr = reserveMapping(newMapping); + assert(mapPtr); + + // Release the old one. + oldMap->setAvailable(true); + oldMap->enableAutoUpdate(false); + oldMap->setNotifyCallback(nullptr); + unregisterMapping(oldMap); + } +} + +void +UPnPContext::onIgdUpdated(const std::shared_ptr<IGD>& igd, UpnpIgdEvent event) +{ + assert(igd); + + if (not isValidThread()) { + runOnUpnpContextQueue([this, igd, event] { onIgdUpdated(igd, event); }); + return; + } + + // Reset to start search for a new best IGD. + preferredIgd_.reset(); + + char const* IgdState = event == UpnpIgdEvent::ADDED ? "ADDED" + : event == UpnpIgdEvent::REMOVED ? "REMOVED" + : "INVALID"; + + auto const& igdLocalAddr = igd->getLocalIp(); + auto protocolName = igd->getProtocolName(); + + JAMI_DBG("New event for IGD [%s %s] [%s]: [%s]", + igd->getUID().c_str(), + igd->toString().c_str(), + protocolName, + IgdState); + + // Check if the IGD has valid addresses. + if (not igdLocalAddr) { + JAMI_WARN("[%s] IGD has an invalid local address", protocolName); + return; + } + + if (not igd->getPublicIp()) { + JAMI_WARN("[%s] IGD has an invalid public address", protocolName); + return; + } + + if (knownPublicAddress_ and igd->getPublicIp() != knownPublicAddress_) { + JAMI_WARN("[%s] IGD external address [%s] does not match known public address [%s]." + " The mapped addresses might not be reachable", + protocolName, + igd->getPublicIp().toString().c_str(), + knownPublicAddress_.toString().c_str()); + } + + // The IGD was removed or is invalid. + if (event == UpnpIgdEvent::REMOVED or event == UpnpIgdEvent::INVALID_STATE) { + JAMI_WARN("State of IGD [%s %s] [%s] changed to [%s]. Pruning the mapping list", + igd->getUID().c_str(), + igd->toString().c_str(), + protocolName, + IgdState); + + pruneMappingsWithInvalidIgds(igd); + + std::lock_guard<std::mutex> lock(mappingMutex_); + validIgdList_.erase(igd); + return; + } + + // Update the IGD list. + { + std::lock_guard<std::mutex> lock(mappingMutex_); + auto ret = validIgdList_.emplace(igd); + if (ret.second) { + JAMI_DBG("IGD [%s] on address %s was added. Will process any pending requests", + protocolName, + igdLocalAddr.toString(true, true).c_str()); + } else { + // Already in the list. + JAMI_ERR("IGD [%s] on address %s already in the list", + protocolName, + igdLocalAddr.toString(true, true).c_str()); + return; + } + } + + // Update the provisionned mappings. + updateMappingList(false); +} + +void +UPnPContext::onMappingAdded(const std::shared_ptr<IGD>& igd, const Mapping& mapRes) +{ + CHECK_VALID_THREAD(); + + // Check if we have a pending request for this response. + auto map = getMappingWithKey(mapRes.getMapKey()); + if (not map) { + // We may receive a response for a canceled request. Just ignore it. + JAMI_DBG("Response for mapping %s [IGD %s] [%s] does not have a local match", + mapRes.toString().c_str(), + igd->toString().c_str(), + mapRes.getProtocolName()); + return; + } + + // The mapping request is new and successful. Update. + map->setIgd(igd); + map->setInternalAddress(mapRes.getInternalAddress()); + map->setExternalPort(mapRes.getExternalPort()); + + // Update the state and report to the owner. + updateMappingState(map, MappingState::OPEN); + + JAMI_DBG("Mapping %s (on IGD %s [%s]) successfully performed", + map->toString().c_str(), + igd->toString().c_str(), + map->getProtocolName()); + + // Call setValid() to reset the errors counter. We need + // to reset the counter on each successful response. + igd->setValid(true); +} + +#if HAVE_LIBNATPMP +void +UPnPContext::onMappingRenewed(const std::shared_ptr<IGD>& igd, const Mapping& map) +{ + auto mapPtr = getMappingWithKey(map.getMapKey()); + + if (not mapPtr) { + // We may receive a notification for a canceled request. Ignore it. + JAMI_WARN("Renewed mapping %s from IGD %s [%s] does not have a match in local list", + map.toString().c_str(), + igd->toString().c_str(), + map.getProtocolName()); + return; + } + if (mapPtr->getProtocol() != NatProtocolType::NAT_PMP or not mapPtr->isValid() + or mapPtr->getState() != MappingState::OPEN) { + JAMI_WARN("Renewed mapping %s from IGD %s [%s] is in unexpected state", + mapPtr->toString().c_str(), + igd->toString().c_str(), + mapPtr->getProtocolName()); + return; + } + + mapPtr->setRenewalTime(map.getRenewalTime()); +} +#endif + +void +UPnPContext::requestRemoveMapping(const Mapping::sharedPtr_t& map) +{ + CHECK_VALID_THREAD(); + + if (not map) { + JAMI_ERR("Mapping shared pointer is null!"); + return; + } + + if (not map->isValid()) { + // Silently ignore if the mapping is invalid + return; + } + + auto protocol = protocolList_.at(map->getIgd()->getProtocol()); + protocol->requestMappingRemove(*map); +} + +void +UPnPContext::deleteAllMappings(PortType type) +{ + if (not isValidThread()) { + runOnUpnpContextQueue([this, type] { deleteAllMappings(type); }); + return; + } + + std::lock_guard<std::mutex> lock(mappingMutex_); + auto& mappingList = getMappingList(type); + + for (auto const& [_, map] : mappingList) { + requestRemoveMapping(map); + } +} + +void +UPnPContext::onMappingRemoved(const std::shared_ptr<IGD>& igd, const Mapping& mapRes) +{ + if (not mapRes.isValid()) + return; + + if (not isValidThread()) { + runOnUpnpContextQueue([this, igd, mapRes] { onMappingRemoved(igd, mapRes); }); + return; + } + + auto map = getMappingWithKey(mapRes.getMapKey()); + // Notify the listener. + if (map and map->getNotifyCallback()) + map->getNotifyCallback()(map); +} + +Mapping::sharedPtr_t +UPnPContext::registerMapping(Mapping& map) +{ + if (map.getExternalPort() == 0) { + JAMI_DBG("Port number not set. Will set a random port number"); + auto port = getAvailablePortNumber(map.getType()); + map.setExternalPort(port); + map.setInternalPort(port); + } + + // Newly added mapping must be in pending state by default. + map.setState(MappingState::PENDING); + + Mapping::sharedPtr_t mapPtr; + + { + std::lock_guard<std::mutex> lock(mappingMutex_); + auto& mappingList = getMappingList(map.getType()); + + auto ret = mappingList.emplace(map.getMapKey(), std::make_shared<Mapping>(map)); + if (not ret.second) { + JAMI_WARN("Mapping request for %s already added!", map.toString().c_str()); + return {}; + } + mapPtr = ret.first->second; + assert(mapPtr); + } + + // No available IGD. The pending mapping requests will be processed + // when a IGD becomes available (in onIgdAdded() method). + if (not isReady()) { + JAMI_WARN("No IGD available. Mapping will be requested when an IGD becomes available"); + } else { + requestMapping(mapPtr); + } + + return mapPtr; +} + +std::map<Mapping::key_t, Mapping::sharedPtr_t>::iterator +UPnPContext::unregisterMapping(std::map<Mapping::key_t, Mapping::sharedPtr_t>::iterator it) +{ + assert(it->second); + + CHECK_VALID_THREAD(); + auto descr = it->second->toString(); + auto& mappingList = getMappingList(it->second->getType()); + auto ret = mappingList.erase(it); + + return ret; +} + +void +UPnPContext::unregisterMapping(const Mapping::sharedPtr_t& map) +{ + CHECK_VALID_THREAD(); + + if (not map) { + JAMI_ERR("Mapping pointer is null"); + return; + } + + if (map->getAutoUpdate()) { + // Dont unregister mappings with auto-update enabled. + return; + } + auto& mappingList = getMappingList(map->getType()); + + if (mappingList.erase(map->getMapKey()) == 1) { + JAMI_DBG("Unregistered mapping %s", map->toString().c_str()); + } else { + // The mapping may already be un-registered. Just ignore it. + JAMI_DBG("Mapping %s [%s] does not have a local match", + map->toString().c_str(), + map->getProtocolName()); + } +} + +std::map<Mapping::key_t, Mapping::sharedPtr_t>& +UPnPContext::getMappingList(PortType type) +{ + unsigned typeIdx = type == PortType::TCP ? 0 : 1; + return mappingList_[typeIdx]; +} + +Mapping::sharedPtr_t +UPnPContext::getMappingWithKey(Mapping::key_t key) +{ + std::lock_guard<std::mutex> lock(mappingMutex_); + auto const& mappingList = getMappingList(Mapping::getTypeFromMapKey(key)); + auto it = mappingList.find(key); + if (it == mappingList.end()) + return nullptr; + return it->second; +} + +void +UPnPContext::getMappingStatus(PortType type, MappingStatus& status) +{ + std::lock_guard<std::mutex> lock(mappingMutex_); + auto& mappingList = getMappingList(type); + + for (auto const& [_, map] : mappingList) { + switch (map->getState()) { + case MappingState::PENDING: { + status.pendingCount_++; + break; + } + case MappingState::IN_PROGRESS: { + status.inProgressCount_++; + break; + } + case MappingState::FAILED: { + status.failedCount_++; + break; + } + case MappingState::OPEN: { + status.openCount_++; + if (map->isAvailable()) + status.readyCount_++; + break; + } + + default: + // Must not get here. + assert(false); + break; + } + } +} + +void +UPnPContext::getMappingStatus(MappingStatus& status) +{ + getMappingStatus(PortType::TCP, status); + getMappingStatus(PortType::UDP, status); +} + +void +UPnPContext::onMappingRequestFailed(const Mapping& mapRes) +{ + CHECK_VALID_THREAD(); + + auto const& map = getMappingWithKey(mapRes.getMapKey()); + if (not map) { + // We may receive a response for a removed request. Just ignore it. + JAMI_DBG("Mapping %s [IGD %s] does not have a local match", + mapRes.toString().c_str(), + mapRes.getProtocolName()); + return; + } + + auto igd = map->getIgd(); + if (not igd) { + JAMI_ERR("IGD pointer is null"); + return; + } + + updateMappingState(map, MappingState::FAILED); + unregisterMapping(map); + + JAMI_WARN("Mapping request for %s failed on IGD %s [%s]", + map->toString().c_str(), + igd->toString().c_str(), + igd->getProtocolName()); +} + +void +UPnPContext::updateMappingState(const Mapping::sharedPtr_t& map, MappingState newState, bool notify) +{ + CHECK_VALID_THREAD(); + + assert(map); + + // Ignore if the state did not change. + if (newState == map->getState()) { + JAMI_DBG("Mapping %s already in state %s", map->toString().c_str(), map->getStateStr()); + return; + } + + // Update the state. + map->setState(newState); + + // Notify the listener if set. + if (notify and map->getNotifyCallback()) + map->getNotifyCallback()(map); +} + +#if HAVE_LIBNATPMP +void +UPnPContext::renewAllocations() +{ + CHECK_VALID_THREAD(); + + // Check if the we have valid PMP IGD. + auto pmpProto = protocolList_.at(NatProtocolType::NAT_PMP); + + auto now = sys_clock::now(); + std::vector<Mapping::sharedPtr_t> toRenew; + + for (auto type : {PortType::TCP, PortType::UDP}) { + std::lock_guard<std::mutex> lock(mappingMutex_); + auto mappingList = getMappingList(type); + for (auto const& [_, map] : mappingList) { + if (not map->isValid()) + continue; + if (map->getProtocol() != NatProtocolType::NAT_PMP) + continue; + if (map->getState() != MappingState::OPEN) + continue; + if (now < map->getRenewalTime()) + continue; + + toRenew.emplace_back(map); + } + } + + // Quit if there are no mapping to renew + if (toRenew.empty()) + return; + + for (auto const& map : toRenew) { + pmpProto->requestMappingRenew(*map); + } +} +#endif + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/upnp_context.h b/src/upnp/upnp_context.h new file mode 100644 index 0000000..30d50c0 --- /dev/null +++ b/src/upnp/upnp_context.h @@ -0,0 +1,294 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "protocol/upnp_protocol.h" +#if HAVE_LIBNATPMP +#include "protocol/natpmp/nat_pmp.h" +#endif +#if HAVE_LIBUPNP +#include "protocol/pupnp/pupnp.h" +#endif +#include "protocol/igd.h" + +#include "ip_utils.h" + +#include <opendht/rng.h> +#include <asio/steady_timer.hpp> + +#include <set> +#include <map> +#include <mutex> +#include <memory> +#include <string> +#include <chrono> +#include <random> +#include <atomic> +#include <cstdlib> + +//#include "upnp_thread_util.h" + +using random_device = dht::crypto::random_device; + +using IgdFoundCallback = std::function<void()>; + +namespace jami { +class IpAddr; +} + +namespace jami { +namespace upnp { + +class UPnPContext : public UpnpMappingObserver//, protected UpnpThreadUtil +{ +private: + struct MappingStatus + { + int openCount_ {0}; + int readyCount_ {0}; + int pendingCount_ {0}; + int inProgressCount_ {0}; + int failedCount_ {0}; + + void reset() + { + openCount_ = 0; + readyCount_ = 0; + pendingCount_ = 0; + inProgressCount_ = 0; + failedCount_ = 0; + }; + int sum() { return openCount_ + pendingCount_ + inProgressCount_ + failedCount_; } + }; + +public: + UPnPContext(); + ~UPnPContext(); + + // Retrieve the UPnPContext singleton. + static std::shared_ptr<UPnPContext> getUPnPContext(); + + // Terminate the instance. + void shutdown(); + + // Set the known public address + void setPublicAddress(const IpAddr& addr); + + // Check if there is a valid IGD in the IGD list. + bool isReady() const; + + // Get external Ip of a chosen IGD. + IpAddr getExternalIP() const; + + // Inform the UPnP context that the network status has changed. This clears the list of known + void connectivityChanged(); + + // Returns a shared pointer of the mapping. + Mapping::sharedPtr_t reserveMapping(Mapping& requestedMap); + + // Release an used mapping (make it available for future use). + void releaseMapping(const Mapping& map); + + // Register a controller + void registerController(void* controller); + // Unregister a controller + void unregisterController(void* controller); + + // Generate random port numbers + static uint16_t generateRandomPort(PortType type, bool mustBeEven = false); + +private: + // Initialization + void init(); + + /** + * @brief start the search for IGDs activate the mapping + * list update. + * + */ + void startUpnp(); + + /** + * @brief Clear all IGDs and release/delete current mappings + * + * @param forceRelease If true, also delete mappings with enabled + * auto-update feature. + * + */ + void stopUpnp(bool forceRelease = false); + + void shutdown(std::condition_variable& cv); + + // Create and register a new mapping. + Mapping::sharedPtr_t registerMapping(Mapping& map); + + // Removes the mapping from the list. + std::map<Mapping::key_t, Mapping::sharedPtr_t>::iterator unregisterMapping( + std::map<Mapping::key_t, Mapping::sharedPtr_t>::iterator it); + void unregisterMapping(const Mapping::sharedPtr_t& map); + + // Perform the request on the provided IGD. + void requestMapping(const Mapping::sharedPtr_t& map); + + // Request a mapping remove from the IGD. + void requestRemoveMapping(const Mapping::sharedPtr_t& map); + + // Remove all mappings of the given type. + void deleteAllMappings(PortType type); + + // Update the state and notify the listener + void updateMappingState(const Mapping::sharedPtr_t& map, + MappingState newState, + bool notify = true); + + // Provision ports. + uint16_t getAvailablePortNumber(PortType type); + + // Update preferred IGD + void updatePreferredIgd(); + + // Get preferred IGD + std::shared_ptr<IGD> getPreferredIgd() const; + + // Check and prune the mapping list. Called periodically. + void updateMappingList(bool async); + + // Provision (pre-allocate) the requested number of mappings. + bool provisionNewMappings(PortType type, int portCount); + + // Close unused mappings. + bool deleteUnneededMappings(PortType type, int portCount); + + /** + * Prune the mapping list.To avoid competing with allocation + * requests, the pruning is performed only if there are no + * requests in progress. + */ + void pruneMappingList(); + + /** + * Check if there are allocated mappings from previous instances, + * and try to close them. + * Only done for UPNP protocol. NAT-PMP allocations will expire + * anyway if not renewed. + */ + void pruneUnMatchedMappings(const std::shared_ptr<IGD>& igd, + const std::map<Mapping::key_t, Mapping>& remoteMapList); + + /** + * Check the local mapping list against the list returned by the + * IGD and remove all mappings which do not have a match. + * Only done for UPNP protocol. + */ + void pruneUnTrackedMappings(const std::shared_ptr<IGD>& igd, + const std::map<Mapping::key_t, Mapping>& remoteMapList); + + void pruneMappingsWithInvalidIgds(const std::shared_ptr<IGD>& igd); + + /** + * @brief Get the mapping list + * + * @param type transport type (TCP/UDP) + * @return a reference on the map + * @warning concurrency protection done by the caller + */ + std::map<Mapping::key_t, Mapping::sharedPtr_t>& getMappingList(PortType type); + + // Get the mapping from the key. + Mapping::sharedPtr_t getMappingWithKey(Mapping::key_t key); + + // Get the number of mappings per state. + void getMappingStatus(PortType type, MappingStatus& status); + void getMappingStatus(MappingStatus& status); + +#if HAVE_LIBNATPMP + void renewAllocations(); +#endif + + // Process requests with pending status. + void processPendingRequests(const std::shared_ptr<IGD>& igd); + + // Process mapping with auto-update flag enabled. + void processMappingWithAutoUpdate(); + + // Implementation of UpnpMappingObserver interface. + + // Callback used to report changes in IGD status. + void onIgdUpdated(const std::shared_ptr<IGD>& igd, UpnpIgdEvent event) override; + // Callback used to report add request status. + void onMappingAdded(const std::shared_ptr<IGD>& igd, const Mapping& map) override; + // Callback invoked when a request fails. Reported on failures for both + // new requests and renewal requests (if supported by the the protocol). + void onMappingRequestFailed(const Mapping& map) override; +#if HAVE_LIBNATPMP + // Callback used to report renew request status. + void onMappingRenewed(const std::shared_ptr<IGD>& igd, const Mapping& map) override; +#endif + // Callback used to report remove request status. + void onMappingRemoved(const std::shared_ptr<IGD>& igd, const Mapping& map) override; + +private: + UPnPContext(const UPnPContext&) = delete; + UPnPContext(UPnPContext&&) = delete; + UPnPContext& operator=(UPnPContext&&) = delete; + UPnPContext& operator=(const UPnPContext&) = delete; + + bool started_ {false}; + + // The known public address. The external addresses returned by + // the IGDs will be checked against this address. + IpAddr knownPublicAddress_ {}; + + // Set of registered controllers + std::set<void*> controllerList_; + + // Map of available protocols. + std::map<NatProtocolType, std::shared_ptr<UPnPProtocol>> protocolList_; + + // Port ranges for TCP and UDP (in that order). + std::map<PortType, std::pair<uint16_t, uint16_t>> portRange_ {}; + + // Min open ports limit + int minOpenPortLimit_[2] {4, 8}; + // Max open ports limit + int maxOpenPortLimit_[2] {8, 12}; + + //std::shared_ptr<Task> mappingListUpdateTimer_ {}; + asio::steady_timer mappingListUpdateTimer_;// {}; + + // Current preferred IGD. Can be null if there is no valid IGD. + std::shared_ptr<IGD> preferredIgd_; + + // This mutex must lock only these two members. All other + // members must be accessed only from the UPNP context thread. + std::mutex mutable mappingMutex_; + // List of mappings. + std::map<Mapping::key_t, Mapping::sharedPtr_t> mappingList_[2] {}; + std::set<std::shared_ptr<IGD>> validIgdList_ {}; + + // Shutdown synchronization + bool shutdownComplete_ {false}; +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/upnp_control.cpp b/src/upnp/upnp_control.cpp new file mode 100644 index 0000000..b255617 --- /dev/null +++ b/src/upnp/upnp_control.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "upnp_control.h" + +namespace jami { +namespace upnp { + +Controller::Controller() +{ + try { + upnpContext_ = UPnPContext::getUPnPContext(); + } catch (std::runtime_error& e) { + JAMI_ERR("UPnP context error: %s", e.what()); + } + + assert(upnpContext_); + upnpContext_->registerController(this); + + JAMI_DBG("Controller@%p: Created UPnP Controller session", this); +} + +Controller::~Controller() +{ + JAMI_DBG("Controller@%p: Destroying UPnP Controller session", this); + + releaseAllMappings(); + upnpContext_->unregisterController(this); +} + +void +Controller::setPublicAddress(const IpAddr& addr) +{ + assert(upnpContext_); + + if (addr and addr.getFamily() == AF_INET) { + upnpContext_->setPublicAddress(addr); + } +} + +bool +Controller::isReady() const +{ + assert(upnpContext_); + return upnpContext_->isReady(); +} + +IpAddr +Controller::getExternalIP() const +{ + assert(upnpContext_); + if (upnpContext_->isReady()) { + return upnpContext_->getExternalIP(); + } + return {}; +} + +Mapping::sharedPtr_t +Controller::reserveMapping(uint16_t port, PortType type) +{ + Mapping map(type, port, port); + return reserveMapping(map); +} + +Mapping::sharedPtr_t +Controller::reserveMapping(Mapping& requestedMap) +{ + assert(upnpContext_); + + // Try to get a provisioned port + auto mapRes = upnpContext_->reserveMapping(requestedMap); + if (mapRes) + addLocalMap(*mapRes); + return mapRes; +} + +void +Controller::releaseMapping(const Mapping& map) +{ + assert(upnpContext_); + + removeLocalMap(map); + return upnpContext_->releaseMapping(map); +} + +void +Controller::releaseAllMappings() +{ + assert(upnpContext_); + + std::lock_guard<std::mutex> lk(mapListMutex_); + for (auto const& [_, map] : mappingList_) { + upnpContext_->releaseMapping(map); + } + mappingList_.clear(); +} + +void +Controller::addLocalMap(const Mapping& map) +{ + if (map.getMapKey()) { + std::lock_guard<std::mutex> lock(mapListMutex_); + auto ret = mappingList_.emplace(map.getMapKey(), map); + if (not ret.second) { + JAMI_WARN("Mapping request for %s already in the list!", map.toString().c_str()); + } + } +} + +bool +Controller::removeLocalMap(const Mapping& map) +{ + assert(upnpContext_); + + std::lock_guard<std::mutex> lk(mapListMutex_); + if (mappingList_.erase(map.getMapKey()) != 1) { + JAMI_ERR("Failed to remove mapping %s from local list", map.getTypeStr()); + return false; + } + + return true; +} + +uint16_t +Controller::generateRandomPort(PortType type) +{ + return UPnPContext::generateRandomPort(type); +} + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/upnp_control.h b/src/upnp/upnp_control.h new file mode 100644 index 0000000..183b4fb --- /dev/null +++ b/src/upnp/upnp_control.h @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2004-2023 Savoir-faire Linux Inc. + * + * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com> + * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com> + * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include "upnp_context.h" +#include "ip_utils.h" + +#include <memory> +#include <chrono> + +namespace jami { +class IpAddr; +} + +namespace jami { +namespace upnp { + +class UPnPContext; + +class Controller +{ +public: + Controller(); + ~Controller(); + + // Set known public address + void setPublicAddress(const IpAddr& addr); + // Checks if a valid IGD is available. + bool isReady() const; + // Gets the external ip of the first valid IGD in the list. + IpAddr getExternalIP() const; + + // Request port mapping. + // Returns a shared pointer on the allocated mapping. The shared + // pointer may point to nothing on failure. + Mapping::sharedPtr_t reserveMapping(Mapping& map); + Mapping::sharedPtr_t reserveMapping(uint16_t port, PortType type); + + // Remove port mapping. + void releaseMapping(const Mapping& map); + static uint16_t generateRandomPort(PortType); + +private: + // Adds a mapping locally to the list. + void addLocalMap(const Mapping& map); + // Removes a mapping from the local list. + bool removeLocalMap(const Mapping& map); + // Removes all mappings of the given type. + void releaseAllMappings(); + + std::shared_ptr<UPnPContext> upnpContext_; + + mutable std::mutex mapListMutex_; + std::map<Mapping::key_t, Mapping> mappingList_; +}; + +} // namespace upnp +} // namespace jami diff --git a/src/upnp/upnp_thread_util.h b/src/upnp/upnp_thread_util.h new file mode 100644 index 0000000..10d454a --- /dev/null +++ b/src/upnp/upnp_thread_util.h @@ -0,0 +1,35 @@ +#pragma once + +#include <thread> + +// This macro is used to validate that a code is executed from the expected +// thread. It's useful to detect unexpected race on data members. +#define CHECK_VALID_THREAD() \ + if (not isValidThread()) \ + JAMI_ERR() << "The calling thread " << getCurrentThread() \ + << " is not the expected thread: " << threadId_; + +namespace jami { +namespace upnp { + +class UpnpThreadUtil +{ +protected: + std::thread::id getCurrentThread() const { return std::this_thread::get_id(); } + + bool isValidThread() const { return threadId_ == getCurrentThread(); } + + // Upnp context execution queue (same as manager's scheduler) + // Helpers to run tasks on upnp context queue. + static ScheduledExecutor* getScheduler() { return &Manager::instance().scheduler(); } + template<typename Callback> + static void runOnUpnpContextQueue(Callback&& cb) + { + getScheduler()->run([cb = std::forward<Callback>(cb)]() mutable { cb(); }); + } + + std::thread::id threadId_; +}; + +} // namespace upnp +} // namespace jami -- GitLab