From bbfcd57451d5c672dbc125266a6968f3db39ded0 Mon Sep 17 00:00:00 2001
From: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
Date: Mon, 20 Nov 2017 12:44:52 -0500
Subject: [PATCH] TurnTransport: better IO api and fixes

* API additions:
  - peerAddresses
  - readlinefrom
  - writelineto
* API changes:
  - recvfrom: char*/length version
  - sendto: per-peer call, not longer a map
* Add more API documentation
* Max TURN buffer size changed to 4096 bytes
* Better IO buffer management with peers
* Fix auth data (was temporary buffer pushing garbage values to server)
* Turn tests modified for API changes

Change-Id: I0bffe114301e8cb1e2f2e37d7a0eb5ba67f38c61
Reviewed-by: Olivier Soldano <olivier.soldano@savoirfairelinux.com>
---
 src/turn_transport.cpp  | 179 ++++++++++++++++++++++++++++++++++------
 src/turn_transport.h    |  48 +++++++++--
 test/turn/test_TURN.cpp |  44 ++++++----
 3 files changed, 222 insertions(+), 49 deletions(-)

diff --git a/src/turn_transport.cpp b/src/turn_transport.cpp
index 42f0b7d4fb..fbe666c4d3 100644
--- a/src/turn_transport.cpp
+++ b/src/turn_transport.cpp
@@ -23,6 +23,7 @@
 #include "logger.h"
 #include "ip_utils.h"
 #include "sip/sip_utils.h"
+#include "map_utils.h"
 
 #include <pjnath.h>
 #include <pjlib-util.h>
@@ -35,25 +36,105 @@
 #include <vector>
 #include <iterator>
 #include <mutex>
+#include <sstream>
+#include <limits>
+#include <map>
+#include <condition_variable>
 
 namespace ring {
 
-enum class RelayState {
+using MutexGuard = std::lock_guard<std::mutex>;
+using MutexLock = std::unique_lock<std::mutex>;
+
+enum class RelayState
+{
     NONE,
     READY,
     DOWN,
 };
 
-class TurnTransportPimpl {
+class PeerChannel
+{
+public:
+    PeerChannel() {}
+    ~PeerChannel() {
+        MutexGuard lk {mutex_};
+        stop_ = true;
+        cv_.notify_all();
+    }
+
+    PeerChannel(PeerChannel&&o) {
+        MutexGuard lk {o.mutex_};
+        stream_ = std::move(o.stream_);
+    }
+    PeerChannel& operator =(PeerChannel&& o) {
+        std::lock(mutex_, o.mutex_);
+        MutexGuard lk1 {mutex_, std::adopt_lock};
+        MutexGuard lk2 {o.mutex_, std::adopt_lock};
+        stream_  = std::move(o.stream_);
+        return *this;
+    }
+
+    void operator <<(const std::string& data) {
+        MutexGuard lk {mutex_};
+        stream_.clear();
+        stream_ << data;
+        cv_.notify_one();
+    }
+
+    void read(std::vector<char>& output) {
+        MutexLock lk {mutex_};
+        cv_.wait(lk, [&, this]{
+                stream_.read(&output[0], output.size());
+                return stream_.gcount() > 0 or stop_;
+            });
+        output.resize(stop_ ? 0 : stream_.gcount());
+    }
+
+    std::vector<char> readline() {
+        MutexLock lk {mutex_};
+        std::vector<char> result(3000);
+        cv_.wait(lk, [&, this] {
+                if (stop_)
+                    return true;
+                stream_.getline(&result[0], 3000);
+                if (stream_) {
+                    result.resize(stream_.gcount());
+                    return result.size() > 0;
+                }
+                return false;
+            });
+        if (stop_)
+            return {};
+        return result;
+    }
+
+private:
+    PeerChannel(const PeerChannel&o) = delete;
+    PeerChannel& operator =(const PeerChannel& o) = delete;
+    std::mutex mutex_ {};
+    std::condition_variable cv_ {};
+    std::stringstream stream_ {};
+    bool stop_ {false};
+
+    friend void operator <<(std::vector<char>&, PeerChannel&);
+};
+
+class TurnTransportPimpl
+{
 public:
     TurnTransportPimpl() = default;
     ~TurnTransportPimpl();
 
     void onTurnState(pj_turn_state_t old_state, pj_turn_state_t new_state);
-    void onRxData(uint8_t* pkt, unsigned pkt_len, const pj_sockaddr_t* peer_addr, unsigned addr_len);
+    void onRxData(const uint8_t* pkt, unsigned pkt_len, const pj_sockaddr_t* peer_addr, unsigned addr_len);
     void onPeerConnection(pj_uint32_t conn_id, const pj_sockaddr_t* peer_addr, unsigned addr_len, pj_status_t status);
     void ioJob();
 
+    std::mutex apiMutex_;
+
+    std::map<IpAddr, PeerChannel> peerChannels_;
+
     TurnTransportParams settings;
     pj_caching_pool poolCache {};
     pj_pool_t* pool {nullptr};
@@ -63,9 +144,6 @@ public:
     IpAddr peerRelayAddr; // address where peers should connect to
     IpAddr mappedAddr;
 
-    std::map<IpAddr, std::vector<char>> streams;
-    std::mutex streamsMutex;
-
     std::atomic<RelayState> state {RelayState::NONE};
     std::atomic_bool ioJobQuit {false};
     std::thread ioWorker;
@@ -96,18 +174,26 @@ TurnTransportPimpl::onTurnState(pj_turn_state_t old_state, pj_turn_state_t new_s
     } else if (old_state <= PJ_TURN_STATE_READY and new_state > PJ_TURN_STATE_READY) {
         RING_WARN("TURN server disconnected (%s)", pj_turn_state_name(new_state));
         state = RelayState::DOWN;
+        MutexGuard lk {apiMutex_};
+        peerChannels_.clear();
     }
 }
 
 void
-TurnTransportPimpl::onRxData(uint8_t* pkt, unsigned pkt_len,
+TurnTransportPimpl::onRxData(const uint8_t* pkt, unsigned pkt_len,
                              const pj_sockaddr_t* addr, unsigned addr_len)
 {
-    IpAddr peer_addr ( *static_cast<const pj_sockaddr*>(addr), addr_len );
+    IpAddr peer_addr (*static_cast<const pj_sockaddr*>(addr), addr_len);
+
+    decltype(peerChannels_)::iterator channel_it;
+    {
+        MutexGuard lk {apiMutex_};
+        channel_it = peerChannels_.find(peer_addr);
+        if (channel_it == std::end(peerChannels_))
+            return;
+    }
 
-    std::lock_guard<std::mutex> lk {streamsMutex};
-    auto& vec = streams[peer_addr];
-    vec.insert(vec.cend(), pkt, pkt + pkt_len);
+    (channel_it->second) << std::string(reinterpret_cast<const char*>(pkt), pkt_len);
 }
 
 void
@@ -120,8 +206,11 @@ TurnTransportPimpl::onPeerConnection(pj_uint32_t conn_id,
         RING_DBG() << "Received connection attempt from " << peer_addr.toString(true, true)
                    << ", id=" << std::hex << conn_id;
         pj_turn_connect_peer(relay, conn_id, addr, addr_len);
-        std::lock_guard<std::mutex> lk {streamsMutex};
-        streams[peer_addr].clear();
+
+        {
+            MutexGuard lk {apiMutex_};
+            peerChannels_.emplace(peer_addr, PeerChannel {});
+        }
     }
 
     if (settings.onPeerConnection)
@@ -157,7 +246,8 @@ private:
 };
 
 template <class Callable, class... Args>
-inline void PjsipCall(Callable& func, Args... args)
+inline void
+PjsipCall(Callable& func, Args... args)
 {
     auto status = func(args...);
     if (status != PJ_SUCCESS)
@@ -165,7 +255,8 @@ inline void PjsipCall(Callable& func, Args... args)
 }
 
 template <class Callable, class... Args>
-inline auto PjsipCallReturn(const Callable& func, Args... args) -> decltype(func(args...))
+inline auto
+PjsipCallReturn(const Callable& func, Args... args) -> decltype(func(args...))
 {
     auto res = func(args...);
     if (!res)
@@ -178,6 +269,8 @@ inline auto PjsipCallReturn(const Callable& func, Args... args) -> decltype(func
 TurnTransport::TurnTransport(const TurnTransportParams& params)
     : pimpl_ {new TurnTransportPimpl}
 {
+    sip_utils::register_thread();
+
     auto server = params.server;
     if (!server.getPort())
         server.setPort(PJ_STUN_PORT);
@@ -244,10 +337,10 @@ TurnTransport::TurnTransport(const TurnTransportParams& params)
     pj_stun_auth_cred cred;
     pj_bzero(&cred, sizeof(cred));
     cred.type = PJ_STUN_AUTH_CRED_STATIC;
-    pj_cstr(&cred.data.static_cred.realm, params.realm.c_str());
-    pj_cstr(&cred.data.static_cred.username, params.username.c_str());
+    pj_cstr(&cred.data.static_cred.realm, pimpl_->settings.realm.c_str());
+    pj_cstr(&cred.data.static_cred.username, pimpl_->settings.username.c_str());
     cred.data.static_cred.data_type = PJ_STUN_PASSWD_PLAIN;
-    pj_cstr(&cred.data.static_cred.data, params.password.c_str());
+    pj_cstr(&cred.data.static_cred.data, pimpl_->settings.password.c_str());
 
     pimpl_->relayAddr = pj_strdup3(pimpl_->pool, server.toString().c_str());
 
@@ -267,6 +360,9 @@ TurnTransport::permitPeer(const IpAddr& addr)
     if (addr.isUnspecified())
         throw std::invalid_argument("invalid peer address");
 
+    if (addr.getFamily() != pimpl_->peerRelayAddr.getFamily())
+        throw std::invalid_argument("mismatching peer address family");
+
     PjsipCall(pj_turn_sock_set_perm, pimpl_->relay, 1, addr.pjPtr(), 1);
 }
 
@@ -297,10 +393,11 @@ TurnTransport::mappedAddr() const
 }
 
 bool
-TurnTransport::sendto(const IpAddr& peer, const std::vector<char>& buffer)
+TurnTransport::sendto(const IpAddr& peer, const char* const buffer, std::size_t length)
 {
+    sip_utils::register_thread();
     auto status = pj_turn_sock_sendto(pimpl_->relay,
-                                      reinterpret_cast<const pj_uint8_t*>(buffer.data()), buffer.size(),
+                                      reinterpret_cast<const pj_uint8_t*>(buffer), length,
                                       peer.pjPtr(), peer.getLength());
     if (status != PJ_SUCCESS && status != PJ_EPENDING)
         throw PjsipError(status);
@@ -308,12 +405,46 @@ TurnTransport::sendto(const IpAddr& peer, const std::vector<char>& buffer)
     return status == PJ_SUCCESS;
 }
 
+bool
+TurnTransport::sendto(const IpAddr& peer, const std::vector<char>& buffer)
+{
+    return sendto(peer, &buffer[0], buffer.size());
+}
+
+bool
+TurnTransport::writelineto(const IpAddr& peer, const char* const buffer, std::size_t length)
+{
+    if (sendto(peer, buffer, length))
+        return sendto(peer, "\n", 1);
+    return false;
+}
+
+void
+TurnTransport::recvfrom(const IpAddr& peer, std::vector<char>& result)
+{
+    if (result.empty())
+        throw std::runtime_error("TurnTransport::recvfrom() called with an empty output buffer");
+
+    MutexLock lk {pimpl_->apiMutex_};
+    auto& channel = pimpl_->peerChannels_.at(peer);
+    lk.unlock();
+    channel.read(result);
+}
+
 void
-TurnTransport::recvfrom(std::map<IpAddr, std::vector<char>>& streams)
+TurnTransport::readlinefrom(const IpAddr& peer, std::vector<char>& result)
+{
+    MutexLock lk {pimpl_->apiMutex_};
+    auto& channel = pimpl_->peerChannels_.at(peer);
+    lk.unlock();
+    result = channel.readline();
+}
+
+std::vector<IpAddr>
+TurnTransport::peerAddresses() const
 {
-    std::lock_guard<std::mutex> lk {pimpl_->streamsMutex};
-    streams = std::move(pimpl_->streams);
-    pimpl_->streams.clear();
+    MutexLock lk {pimpl_->apiMutex_};
+    return map_utils::extractKeys(pimpl_->peerChannels_);
 }
 
 } // namespace ring
diff --git a/src/turn_transport.h b/src/turn_transport.h
index 044dc7dece..4b59568eb3 100644
--- a/src/turn_transport.h
+++ b/src/turn_transport.h
@@ -43,12 +43,11 @@ struct TurnTransportParams {
     uint32_t connectionId {0};
     std::function<void(uint32_t conn_id, const IpAddr& peer_addr, bool success)> onPeerConnection;
 
-    std::size_t maxPacketSize {3000}; ///< size of one "logical" packet
+    std::size_t maxPacketSize {4096}; ///< size of one "logical" packet
 };
 
 class TurnTransport {
 public:
-    ///
     /// Constructs a TurnTransport connected by TCP to given server.
     ///
     /// Throw std::invalid_argument of peer address is invalid.
@@ -61,7 +60,6 @@ public:
 
     ~TurnTransport();
 
-    ///
     /// Wait for successful connection on the TURN server.
     ///
     /// TurnTransport constructor connects asynchronously on the TURN server.
@@ -69,12 +67,18 @@ public:
     ///
     void waitServerReady();
 
+    /// \return true if the TURN server is connected and ready to accept peers.
     bool isReady() const;
 
+    /// \return socket address (IP/port) where peers should connect to before doing IO with this client.
     const IpAddr& peerRelayAddr() const;
+
+    /// \return public address of this client as seen by the TURN server.
     const IpAddr& mappedAddr() const;
 
-    ///
+    /// \return a vector of connected peer addresses
+    std::vector<IpAddr> peerAddresses() const;
+
     /// Gives server access permission to given peer by its address.
     ///
     /// Throw std::invalid_argument of peer address is invalid.
@@ -88,18 +92,46 @@ public:
     ///
     void permitPeer(const IpAddr& addr);
 
+    /// Collect pending data from a given peer.
+    ///
+    /// Data are read from given /a peer incoming buffer until EOF or /a data size() is reached.
+    /// /a data is resized with exact number of characters read.
+    /// If /a peer is not connected this function raise an exception.
+    /// If /a peer exists but no data are available this method blocks until TURN deconnection
+    /// or at first incoming character.
     ///
-    /// Collect pending data.
+    /// \param [in] peer target peer address where data are read
+    /// \param [in,out] pre-dimensionned character vector to write incoming data
+    /// \exception std::out_of_range /a peer is not connected yet
     ///
-    void recvfrom(std::map<IpAddr, std::vector<char>>& streams);
+    void recvfrom(const IpAddr& peer, std::vector<char>& data);
 
+    /// Work as recvfrom but stop on first '\n' character found.
+    /// If such character isn't found, stop at /a data vector size.
     ///
-    /// Send data to a given peer through the TURN tunnel.
+    void readlinefrom(const IpAddr& peer, std::vector<char>& data);
+
+    /// Send data to given peer through the TURN tunnel.
+    ///
+    /// This method blocks until all given characters in /a data are sent to the given /a peer.
+    /// If /a peer is not connected this function raise an exception.
+    ///
+    /// \param [in] peer target peer address where data are read
+    /// \param [in,out] pre-dimensionned character vector to write incoming data
+    /// \exception std::out_of_range /a peer is not connected yet
     ///
     bool sendto(const IpAddr& peer, const std::vector<char>& data);
 
+    /// Works as sendto() vector version but accept a simple char array.
+    ///
+    bool sendto(const IpAddr& peer, const char* const buffer, std::size_t length);
+
+    /// Works as sendto() char array but happend a '\n' character at the end of sent data.
+    ///
+    bool writelineto(const IpAddr& peer, const char* const buffer, std::size_t length);
+
 public:
-    // Move semantic
+    // Move semantic only, not copiable
     TurnTransport(TurnTransport&&) = default;
     TurnTransport& operator=(TurnTransport&&) = default;
 
diff --git a/test/turn/test_TURN.cpp b/test/turn/test_TURN.cpp
index 8f7494f3e9..1d7c071cfe 100644
--- a/test/turn/test_TURN.cpp
+++ b/test/turn/test_TURN.cpp
@@ -71,6 +71,14 @@ public:
         return pkt;
     }
 
+    IpAddr address() const {
+        struct sockaddr addr;
+        socklen_t addrlen;
+        if (::getsockname(sock_, &addr, &addrlen) < 0)
+            throw std::system_error(errno, std::system_category());
+        return IpAddr {addr};
+    }
+
 private:
     int sock_ {-1};
 };
@@ -92,26 +100,25 @@ test_TURN::testSimpleConnection()
 
     // Permit myself
     turn.permitPeer(turn.mappedAddr());
-    sock.connect(turn.peerRelayAddr());
+    std::this_thread::sleep_for(std::chrono::seconds(2));
 
-    std::string test_data = "Hello, World!";
-    sock.send(test_data);
+    sock.connect(turn.peerRelayAddr());
 
     std::this_thread::sleep_for(std::chrono::seconds(1));
+    auto peers = turn.peerAddresses();
+    CPPUNIT_ASSERT(peers.size() == 1);
+    auto remotePeer = peers[0];
 
-    std::map<IpAddr, std::vector<char>> streams;
-    turn.recvfrom(streams);
-    CPPUNIT_ASSERT(streams.size() == 1);
-
-    auto peer_addr = std::begin(streams)->first;
-    const auto& vec = std::begin(streams)->second;
-    CPPUNIT_ASSERT(std::string(std::begin(vec), std::end(vec)) == test_data);
+    // Peer send data
+    std::string test_data = "Hello, World!";
+    sock.send(test_data);
 
-    turn.recvfrom(streams);
-    CPPUNIT_ASSERT(streams.size() == 0);
+    // Client read
+    std::vector<char> data(1000);
+    turn.recvfrom(remotePeer, data);
+    CPPUNIT_ASSERT(std::string(std::begin(data), std::end(data)) == test_data);
 
-    turn.sendto(peer_addr, std::vector<char>{1, 2, 3, 4});
-    std::this_thread::sleep_for(std::chrono::seconds(1));
+    turn.sendto(remotePeer, std::vector<char>{1, 2, 3, 4});
 
     auto res = sock.recv(1000);
     CPPUNIT_ASSERT(res.size() == 4);
@@ -121,16 +128,19 @@ test_TURN::testSimpleConnection()
     // This code higly load the network and can be long to execute.
     // Only kept for manual testing purpose.
     std::vector<char> big(100000);
+    std::size_t count = 1000;
     using clock = std::chrono::high_resolution_clock;
 
     auto t1 = clock::now();
-    sock.send(big);
+    auto i = count;
+    while (i--)
+        sock.send(big);
     auto t2 = clock::now();
 
     auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
     std::cout << "T= " << duration << "ns"
-              << ", V= " << (8000. * big.size() / duration) << "Mb/s"
-              << " / " << (1000. * big.size() / duration) << "MB/s"
+              << ", V= " << (8000. * count * big.size() / duration) << "Mb/s"
+              << " / " << (1000. * count * big.size() / duration) << "MB/s"
               << '\n';
     std::this_thread::sleep_for(std::chrono::seconds(5));
 #endif
-- 
GitLab