peer_connection.cpp 25.5 KB
Newer Older
1
/*
Sébastien Blin's avatar
Sébastien Blin committed
2
 *  Copyright (C) 2017-2019 Savoir-faire Linux Inc.
3 4
 *
 *  Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
5
 *  Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA.
 */

#include "peer_connection.h"

#include "data_transfer.h"
25 26
#include "manager.h"
#include "jamidht/jamiaccount.h"
27 28 29 30 31 32 33
#include "string_utils.h"
#include "channel.h"
#include "turn_transport.h"
#include "security/tls_session.h"

#include <algorithm>
#include <future>
34
#include <vector>
35 36 37 38 39 40 41 42 43 44 45 46 47 48
#include <atomic>
#include <stdexcept>
#include <istream>
#include <ostream>
#include <unistd.h>
#include <cstdio>

#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#else
#include <sys/select.h>
#endif

49
#ifndef _MSC_VER
50
#include <sys/time.h>
51
#endif
52

Adrien Béraud's avatar
Adrien Béraud committed
53
namespace jami {
54

55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
int
init_crt(gnutls_session_t session, dht::crypto::Certificate& crt)
{
    // Support only x509 format
    if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
        return GNUTLS_E_CERTIFICATE_ERROR;
    }

    // Store verification status
    unsigned int status = 0;
    auto ret = gnutls_certificate_verify_peers2(session, &status);
    if (ret < 0 or (status & GNUTLS_CERT_SIGNATURE_FAILURE) != 0) {
        return GNUTLS_E_CERTIFICATE_ERROR;
    }

    unsigned int cert_list_size = 0;
    auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size);
    if (cert_list == nullptr) {
        return GNUTLS_E_CERTIFICATE_ERROR;
    }

    // Check if received peer certificate is awaited
    std::vector<std::pair<uint8_t *, uint8_t *>> crt_data;
    crt_data.reserve(cert_list_size);
    for (unsigned i = 0; i < cert_list_size; i++)
        crt_data.emplace_back(cert_list[i].data,
                            cert_list[i].data + cert_list[i].size);
    crt = dht::crypto::Certificate{crt_data};

    return GNUTLS_E_SUCCESS;
}

87 88 89 90 91 92 93 94 95 96 97 98
using lock = std::lock_guard<std::mutex>;

static constexpr std::size_t IO_BUFFER_SIZE {3000}; ///< Size of char buffer used by IO operations

//==============================================================================

class TlsTurnEndpoint::Impl
{
public:
    static constexpr auto TLS_TIMEOUT = std::chrono::seconds(20);

    Impl(ConnectedTurnTransport& tr,
99 100
         std::function<bool(const dht::crypto::Certificate&)>&& cert_check)
        : turn {tr}, peerCertificateCheckFunc {std::move(cert_check)} {}
101 102 103 104 105 106 107 108 109 110 111

    ~Impl();

    // TLS callbacks
    int verifyCertificate(gnutls_session_t);
    void onTlsStateChange(tls::TlsSessionState);
    void onTlsRxData(std::vector<uint8_t>&&);
    void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int);

    std::unique_ptr<tls::TlsSession> tls;
    ConnectedTurnTransport& turn;
112
    std::function<bool(const dht::crypto::Certificate&)> peerCertificateCheckFunc;
113 114 115
    dht::crypto::Certificate peerCertificate;
};

116 117 118
// Declaration at namespace scope is necessary (until C++17)
constexpr std::chrono::seconds TlsTurnEndpoint::Impl::TLS_TIMEOUT;

119
TlsTurnEndpoint::Impl::~Impl()
120
{}
121 122 123 124

int
TlsTurnEndpoint::Impl::verifyCertificate(gnutls_session_t session)
{
125 126 127
    dht::crypto::Certificate crt;
    auto verified = init_crt(session, crt);
    if (verified != GNUTLS_E_SUCCESS) return verified;
128

129 130 131
    if (!peerCertificateCheckFunc(crt))
        return GNUTLS_E_CERTIFICATE_ERROR;

132 133 134 135 136 137 138
    peerCertificate = std::move(crt);

    return GNUTLS_E_SUCCESS;
}

void
TlsTurnEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state)
139
{}
140 141 142 143

void
TlsTurnEndpoint::Impl::onTlsRxData(UNUSED std::vector<uint8_t>&& buf)
{
Adrien Béraud's avatar
Adrien Béraud committed
144
    JAMI_ERR() << "[TLS-TURN] rx " << buf.size() << " (but not implemented)";
145 146 147 148 149 150 151 152 153 154 155
}

void
TlsTurnEndpoint::Impl::onTlsCertificatesUpdate(UNUSED const gnutls_datum_t* local_raw,
                                               UNUSED const gnutls_datum_t* remote_raw,
                                               UNUSED unsigned int remote_count)
{}

TlsTurnEndpoint::TlsTurnEndpoint(ConnectedTurnTransport& turn_ep,
                                 const Identity& local_identity,
                                 const std::shared_future<tls::DhParams>& dh_params,
156 157
                                 std::function<bool(const dht::crypto::Certificate&)>&& cert_check)
    : pimpl_ { std::make_unique<Impl>(turn_ep, std::move(cert_check)) }
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
{
    // Add TLS over TURN
    tls::TlsSession::TlsSessionCallbacks tls_cbs = {
        /*.onStateChange = */[this](tls::TlsSessionState state){ pimpl_->onTlsStateChange(state); },
        /*.onRxData = */[this](std::vector<uint8_t>&& buf){ pimpl_->onTlsRxData(std::move(buf)); },
        /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r,
                                           unsigned int n){ pimpl_->onTlsCertificatesUpdate(l, r, n); },
        /*.verifyCertificate = */[this](gnutls_session_t session){ return pimpl_->verifyCertificate(session); }
    };
    tls::TlsParams tls_param = {
        /*.ca_list = */     "",
        /*.peer_ca = */     nullptr,
        /*.cert = */        local_identity.second,
        /*.cert_key = */    local_identity.first,
        /*.dh_params = */   dh_params,
        /*.timeout = */     Impl::TLS_TIMEOUT,
        /*.cert_check = */  nullptr,
    };
    pimpl_->tls = std::make_unique<tls::TlsSession>(turn_ep, tls_param, tls_cbs);
}

TlsTurnEndpoint::~TlsTurnEndpoint() = default;

181 182 183 184 185 186
void
TlsTurnEndpoint::shutdown()
{
    pimpl_->tls->shutdown();
}

187 188 189 190 191 192 193
bool
TlsTurnEndpoint::isInitiator() const
{
    return pimpl_->tls->isInitiator();
}

void
194
TlsTurnEndpoint::waitForReady(const std::chrono::steady_clock::duration& timeout)
195
{
196
    pimpl_->tls->waitForReady(timeout);
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
}

int
TlsTurnEndpoint::maxPayload() const
{
    return pimpl_->tls->maxPayload();
}

std::size_t
TlsTurnEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
{
    return pimpl_->tls->read(buf, len, ec);
}

std::size_t
TlsTurnEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
{
    return pimpl_->tls->write(buf, len, ec);
}

217
const dht::crypto::Certificate&
218 219 220 221 222
TlsTurnEndpoint::peerCertificate() const
{
    return pimpl_->peerCertificate;
}

223 224 225 226 227 228
int
TlsTurnEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
    return pimpl_->tls->waitForData(ms_timeout, ec);
}

229 230 231 232
//==============================================================================

TcpSocketEndpoint::TcpSocketEndpoint(const IpAddr& addr)
    : addr_ {addr}
233
    , sock_{ static_cast<int>(::socket(addr.getFamily(), SOCK_STREAM, 0)) }
234 235 236 237 238 239 240 241 242 243
{
    if (sock_ < 0)
        std::system_error(errno, std::generic_category());
    auto bound = ip_utils::getAnyHostAddr(addr.getFamily());
    if (::bind(sock_, bound, bound.getLength()) < 0)
        std::system_error(errno, std::generic_category());
}

TcpSocketEndpoint::~TcpSocketEndpoint()
{
244
#ifndef _MSC_VER
245
    ::close(sock_);
246 247 248
#else
    ::closesocket(sock_);
#endif
249 250 251
}

void
252
TcpSocketEndpoint::connect(const std::chrono::steady_clock::duration& timeout)
253
{
254 255 256 257 258
    int ms =  std::chrono::duration_cast<std::chrono::milliseconds>(timeout).count();
    setsockopt(sock_, SOL_SOCKET, SO_RCVTIMEO, (const char *)&ms, sizeof(ms));
    setsockopt(sock_, SOL_SOCKET, SO_SNDTIMEO, (const char *)&ms, sizeof(ms));

    if ((::connect(sock_, addr_, addr_.getLength())) < 0)
259 260 261
        throw std::system_error(errno, std::generic_category());
}

262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
int
TcpSocketEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
    for (;;) {
        struct timeval tv;
        tv.tv_sec = ms_timeout / 1000;
        tv.tv_usec = (ms_timeout % 1000) * 1000;

        fd_set read_fds;
        FD_ZERO(&read_fds);
        FD_SET(sock_, &read_fds);

        auto res = ::select(sock_ + 1, &read_fds, nullptr, nullptr, &tv);
        if (res < 0)
            break;
        if (res == 0)
            return 0; // timeout
279
        if (FD_ISSET(sock_, &read_fds))
280
            return 1;
281 282
    }

283 284
    ec.assign(errno, std::generic_category());
    return -1;
285 286 287 288 289 290 291
}

std::size_t
TcpSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
{
    // NOTE: recv buf args is a void* on POSIX compliant system, but it's a char* on mingw
    auto res = ::recv(sock_, reinterpret_cast<char*>(buf), len, 0);
Guillaume Roguez's avatar
Guillaume Roguez committed
292 293 294 295
    if (res < 0)
        ec.assign(errno, std::generic_category());
    else
        ec.clear();
296 297 298 299 300 301 302 303
    return (res >= 0) ? res : 0;
}

std::size_t
TcpSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
{
    // NOTE: recv buf args is a void* on POSIX compliant system, but it's a char* on mingw
    auto res = ::send(sock_, reinterpret_cast<const char*>(buf), len, 0);
Guillaume Roguez's avatar
Guillaume Roguez committed
304 305 306 307
    if (res < 0)
        ec.assign(errno, std::generic_category());
    else
        ec.clear();
308 309 310 311 312
    return (res >= 0) ? res : 0;
}

//==============================================================================

313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
IceSocketEndpoint::IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender)
    : ice_(std::move(ice)), iceIsSender(isSender)
{}

IceSocketEndpoint::~IceSocketEndpoint()
{
    if (ice_) {
        ice_->stop();
        return;
    }
}

void
IceSocketEndpoint::shutdown() {
    if (ice_) {
        ice_->stop();
    }
}

int
IceSocketEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
    if (ice_) {
        if (!ice_->isRunning()) return -1;
337
        return iceIsSender ? ice_->isDataAvailable(1) : ice_->waitForData(1, ms_timeout, ec);
338 339 340 341 342 343 344 345 346 347
    }
    return -1;
}

std::size_t
IceSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
{
    if (ice_) {
        if (!ice_->isRunning()) return 0;
        try {
348
          auto res = ice_->recvfrom(1, reinterpret_cast<char *>(buf), len);
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
          if (res < 0)
            ec.assign(errno, std::generic_category());
          else
            ec.clear();
          return (res >= 0) ? res : 0;
        } catch (const std::exception &e) {
          JAMI_ERR("IceSocketEndpoint::read exception: %s", e.what());
        }
        return 0;
    }
    return -1;
}

std::size_t
IceSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
{
    if (ice_) {
        if (!ice_->isRunning()) return 0;
        auto res = 0;
368
        res = ice_->send(0, reinterpret_cast<const unsigned char *>(buf), len);
369 370 371 372 373 374 375 376 377 378 379 380
        if (res < 0) {
            ec.assign(errno, std::generic_category());
        } else {
            ec.clear();
        }
        return (res >= 0) ? res : 0;
    }
    return -1;
}

//==============================================================================

381 382 383 384 385
class TlsSocketEndpoint::Impl
{
public:
    static constexpr auto TLS_TIMEOUT = std::chrono::seconds(20);

386
    Impl(AbstractSocketEndpoint& ep, const dht::crypto::Certificate& peer_cert)
387 388
        : tr {ep}, peerCertificate {peer_cert} {}

389 390 391 392
    Impl(AbstractSocketEndpoint &ep,
         std::function<bool(const dht::crypto::Certificate &)> &&cert_check)
        : tr{ep}, peerCertificateCheckFunc{std::make_unique<std::function<bool(const dht::crypto::Certificate &)>>(std::move(cert_check))}, peerCertificate {null_cert} {}

393 394 395 396 397 398 399
    // TLS callbacks
    int verifyCertificate(gnutls_session_t);
    void onTlsStateChange(tls::TlsSessionState);
    void onTlsRxData(std::vector<uint8_t>&&);
    void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int);

    std::unique_ptr<tls::TlsSession> tls;
400
    AbstractSocketEndpoint& tr;
401
    const dht::crypto::Certificate& peerCertificate;
402 403
    dht::crypto::Certificate null_cert;
    std::unique_ptr<std::function<bool(const dht::crypto::Certificate &)>> peerCertificateCheckFunc;
404 405
};

406 407 408
// Declaration at namespace scope is necessary (until C++17)
constexpr std::chrono::seconds TlsSocketEndpoint::Impl::TLS_TIMEOUT;

409 410 411
int
TlsSocketEndpoint::Impl::verifyCertificate(gnutls_session_t session)
{
412 413 414
    dht::crypto::Certificate crt;
    auto verified = init_crt(session, crt);
    if (verified != GNUTLS_E_SUCCESS) return verified;
415 416 417 418 419 420 421 422 423 424 425 426
    if (peerCertificateCheckFunc) {
        if (!(*peerCertificateCheckFunc)(crt)) {
          JAMI_ERR() << "[TLS-SOCKET] Unexpected peer certificate";
          return GNUTLS_E_CERTIFICATE_ERROR;
        }

        null_cert = std::move(crt);
    } else {
        if (crt.getPacked() != peerCertificate.getPacked()) {
            JAMI_ERR() << "[TLS-SOCKET] Unexpected peer certificate";
            return GNUTLS_E_CERTIFICATE_ERROR;
        }
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
    }

    return GNUTLS_E_SUCCESS;
}

void
TlsSocketEndpoint::Impl::onTlsStateChange(UNUSED tls::TlsSessionState state)
{}

void
TlsSocketEndpoint::Impl::onTlsRxData(UNUSED std::vector<uint8_t>&& buf)
{}

void
TlsSocketEndpoint::Impl::onTlsCertificatesUpdate(UNUSED const gnutls_datum_t* local_raw,
                                                UNUSED const gnutls_datum_t* remote_raw,
                                                UNUSED unsigned int remote_count)
{}

446
TlsSocketEndpoint::TlsSocketEndpoint(AbstractSocketEndpoint& tr,
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
                                     const Identity& local_identity,
                                     const std::shared_future<tls::DhParams>& dh_params,
                                     const dht::crypto::Certificate& peer_cert)
    : pimpl_ { std::make_unique<Impl>(tr, peer_cert) }
{
    // Add TLS over TURN
    tls::TlsSession::TlsSessionCallbacks tls_cbs = {
        /*.onStateChange = */[this](tls::TlsSessionState state){ pimpl_->onTlsStateChange(state); },
        /*.onRxData = */[this](std::vector<uint8_t>&& buf){ pimpl_->onTlsRxData(std::move(buf)); },
        /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r,
                                           unsigned int n){ pimpl_->onTlsCertificatesUpdate(l, r, n); },
        /*.verifyCertificate = */[this](gnutls_session_t session){ return pimpl_->verifyCertificate(session); }
    };
    tls::TlsParams tls_param = {
        /*.ca_list = */     "",
        /*.peer_ca = */     nullptr,
        /*.cert = */        local_identity.second,
        /*.cert_key = */    local_identity.first,
        /*.dh_params = */   dh_params,
        /*.timeout = */     Impl::TLS_TIMEOUT,
        /*.cert_check = */  nullptr,
    };
    pimpl_->tls = std::make_unique<tls::TlsSession>(tr, tls_param, tls_cbs);
}

472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
TlsSocketEndpoint::TlsSocketEndpoint(AbstractSocketEndpoint& tr,
                                    const Identity& local_identity,
                                    const std::shared_future<tls::DhParams>& dh_params,
                                    std::function<bool(const dht::crypto::Certificate&)>&& cert_check)
    : pimpl_ { std::make_unique<Impl>(tr, std::move(cert_check)) }
{
    // Add TLS over TURN
    tls::TlsSession::TlsSessionCallbacks tls_cbs = {
        /*.onStateChange = */[this](tls::TlsSessionState state){ pimpl_->onTlsStateChange(state); },
        /*.onRxData = */[this](std::vector<uint8_t>&& buf){ pimpl_->onTlsRxData(std::move(buf)); },
        /*.onCertificatesUpdate = */[this](const gnutls_datum_t* l, const gnutls_datum_t* r,
                                           unsigned int n){ pimpl_->onTlsCertificatesUpdate(l, r, n); },
        /*.verifyCertificate = */[this](gnutls_session_t session){ return pimpl_->verifyCertificate(session); }
    };
    tls::TlsParams tls_param = {
        /*.ca_list = */     "",
        /*.peer_ca = */     nullptr,
        /*.cert = */        local_identity.second,
        /*.cert_key = */    local_identity.first,
        /*.dh_params = */   dh_params,
        /*.timeout = */     Impl::TLS_TIMEOUT,
        /*.cert_check = */  nullptr,
    };
    pimpl_->tls = std::make_unique<tls::TlsSession>(tr, tls_param, tls_cbs);
}


499 500
TlsSocketEndpoint::~TlsSocketEndpoint() = default;

501 502 503 504 505 506 507 508 509 510 511 512
bool
TlsSocketEndpoint::isInitiator() const
{
    return pimpl_->tls->isInitiator();
}

int
TlsSocketEndpoint::maxPayload() const
{
  return pimpl_->tls->maxPayload();
}

513 514 515 516 517 518 519 520 521 522 523 524 525
std::size_t
TlsSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
{
    return pimpl_->tls->read(buf, len, ec);
}

std::size_t
TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
{
    return pimpl_->tls->write(buf, len, ec);
}

void
526
TlsSocketEndpoint::waitForReady(const std::chrono::steady_clock::duration& timeout)
527
{
528
    pimpl_->tls->waitForReady(timeout);
529 530
}

531 532 533 534 535 536
int
TlsSocketEndpoint::waitForData(unsigned ms_timeout, std::error_code& ec) const
{
    return pimpl_->tls->waitForData(ms_timeout, ec);
}

537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
//==============================================================================

// following namespace prevents an ODR violation with definitions in p2p.cpp
namespace
{

enum class CtrlMsgType
{
    STOP,
    ATTACH_INPUT,
    ATTACH_OUTPUT,
};

struct CtrlMsg
{
    virtual CtrlMsgType type() const = 0;
    virtual ~CtrlMsg() = default;
};

struct StopCtrlMsg final : CtrlMsg
{
    explicit StopCtrlMsg() {}
    CtrlMsgType type() const override { return CtrlMsgType::STOP; }
};

struct AttachInputCtrlMsg final : CtrlMsg
{
    explicit AttachInputCtrlMsg(const std::shared_ptr<Stream>& stream)
        : stream {stream} {}
    CtrlMsgType type() const override { return CtrlMsgType::ATTACH_INPUT; }
    const std::shared_ptr<Stream> stream;
};

struct AttachOutputCtrlMsg final : CtrlMsg
{
    explicit AttachOutputCtrlMsg(const std::shared_ptr<Stream>& stream)
        : stream {stream} {}
    CtrlMsgType type() const override { return CtrlMsgType::ATTACH_OUTPUT; }
    const std::shared_ptr<Stream> stream;
};

} // namespace <anonymous>

//==============================================================================

class PeerConnection::PeerConnectionImpl
{
public:
585
    PeerConnectionImpl(std::function<void()>&& done,
586
                       const std::string& peer_uri,
587
                       std::unique_ptr<SocketType> endpoint)
588
        : peer_uri {peer_uri}
589
        , endpoint_ {std::move(endpoint)}
590 591 592 593
        , eventLoopFut_ {std::async(std::launch::async, [this, done=std::move(done)] {
                try {
                    eventLoop();
                } catch (const std::exception& e) {
Adrien Béraud's avatar
Adrien Béraud committed
594
                    JAMI_ERR() << "[CNX] peer connection event loop failure: " << e.what();
595 596 597
                    done();
                }
            })} {}
598 599 600

    ~PeerConnectionImpl() {
        ctrlChannel << std::make_unique<StopCtrlMsg>();
601
        endpoint_->shutdown();
602 603
    }

604 605 606 607 608
    bool hasStreamWithId(const DRing::DataTransferId& id) {
        auto isInInput = std::any_of(inputs_.begin(), inputs_.end(),
                                     [&id](const std::shared_ptr<Stream>& str) {
                                         return str && str->getId() == id; });
        if (isInInput) return true;
609 610 611 612 613
        auto isInOutput =
            std::any_of(outputs_.begin(), outputs_.end(),
                        [&id](const std::shared_ptr<Stream> &str) {
                          return str && str->getId() == id;
                        });
614 615 616
        return isInOutput;
    }

617 618 619 620 621
    const std::string peer_uri;
    Channel<std::unique_ptr<CtrlMsg>> ctrlChannel;

private:
    std::unique_ptr<SocketType> endpoint_;
622 623
    std::vector<std::shared_ptr<Stream>> inputs_;
    std::vector<std::shared_ptr<Stream>> outputs_;
624
    std::future<void> eventLoopFut_;
625
    std::vector<uint8_t> bufferPool_; // will store non rattached buffers
626 627 628 629 630 631 632 633

    void eventLoop();

    template <typename L, typename C>
    void handle_stream_list(L& stream_list, const C& callable) {
        if (stream_list.empty())
            return;
        const auto& item = std::begin(stream_list);
634
        auto& stream = *item;
635 636 637
        try {
            if (callable(stream))
                return;
Adrien Béraud's avatar
Adrien Béraud committed
638
            JAMI_DBG() << "EOF on stream #" << stream->getId();
639
        } catch (const std::system_error& e) {
Adrien Béraud's avatar
Adrien Béraud committed
640
            JAMI_WARN() << "Stream #" << stream->getId()
641 642
                        << " IO failed with code = " << e.code();
        } catch (const std::exception& e) {
Adrien Béraud's avatar
Adrien Béraud committed
643
            JAMI_ERR() << "Unexpected exception during IO with stream #"
644 645 646 647 648 649 650 651 652 653 654
                       << stream->getId()
                       << ": " << e.what();
        }
        stream->close();
        stream_list.erase(item);
    }
};

void
PeerConnection::PeerConnectionImpl::eventLoop()
{
Adrien Béraud's avatar
Adrien Béraud committed
655
    JAMI_DBG() << "[CNX] Peer connection to " << peer_uri << " ready";
656 657 658 659 660
    while (true) {
        // Process ctrl orders first
        while (true) {
            std::unique_ptr<CtrlMsg> msg;
            if (outputs_.empty() and inputs_.empty()) {
661 662 663 664 665 666
                if (!ctrlChannel.empty()) {
                    msg = ctrlChannel.receive();
                } else {
                    std::error_code ec;
                    if (endpoint_->waitForData(100, ec) > 0) {
                        std::vector<uint8_t> buf(IO_BUFFER_SIZE);
Adrien Béraud's avatar
Adrien Béraud committed
667
                        JAMI_DBG("A good buffer arrived before any input or output attachment");
668
                        auto size = endpoint_->read(buf, ec);
669 670
                        if (ec)
                            throw std::system_error(ec);
671 672 673 674
                        // If it's a good read, we should store the buffer somewhere
                        // and give it to the next input or output.
                        if (size < IO_BUFFER_SIZE)
                            bufferPool_.insert(bufferPool_.end(), buf.begin(), buf.begin() + size);
675 676 677
                    }
                    break;
                }
678 679 680 681 682 683 684 685 686
            } else if (!ctrlChannel.empty()) {
                msg = ctrlChannel.receive();
            } else
                break;

            switch (msg->type()) {
                case CtrlMsgType::ATTACH_INPUT:
                {
                    auto& input_msg = static_cast<AttachInputCtrlMsg&>(*msg);
687
                    inputs_.emplace_back(std::move(input_msg.stream));
688 689 690 691 692 693
                }
                break;

                case CtrlMsgType::ATTACH_OUTPUT:
                {
                    auto& output_msg = static_cast<AttachOutputCtrlMsg&>(*msg);
694
                    outputs_.emplace_back(std::move(output_msg.stream));
695 696 697 698
                }
                break;

                case CtrlMsgType::STOP:
699
                  return;
700

Adrien Béraud's avatar
Adrien Béraud committed
701
                default: JAMI_ERR("BUG: got unhandled control msg!");  break;
702 703 704 705
            }
        }

        // Then handles IO streams
706
        std::vector<uint8_t> buf;
707
        std::error_code ec;
708 709 710

        bool sleep = true;

Guillaume Roguez's avatar
Guillaume Roguez committed
711
        // sending loop
712
        handle_stream_list(inputs_, [&] (auto& stream) {
713
                if (!stream) return false;
714 715
                buf.resize(IO_BUFFER_SIZE);
                if (stream->read(buf)) {
Guillaume Roguez's avatar
Guillaume Roguez committed
716
                    if (not buf.empty()) {
717 718 719 720
                      endpoint_->write(buf, ec);
                      if (ec)
                        throw std::system_error(ec);
                      sleep = false;
Guillaume Roguez's avatar
Guillaume Roguez committed
721
                    }
722 723
                } else {
                    // EOF on outgoing stream => finished
724
                    return false;
725
                }
726
                if (!bufferPool_.empty()) {
727 728
                  stream->write(bufferPool_);
                  bufferPool_.clear();
729
                } else if (endpoint_->waitForData(0, ec) > 0) {
730 731 732 733 734
                  buf.resize(IO_BUFFER_SIZE);
                  endpoint_->read(buf, ec);
                  if (ec)
                    throw std::system_error(ec);
                  return stream->write(buf);
735 736 737
                } else if (ec)
                    throw std::system_error(ec);
                return true;
738
            });
739

Guillaume Roguez's avatar
Guillaume Roguez committed
740
        // receiving loop
741
        handle_stream_list(outputs_, [&] (auto& stream) {
742
                if (!stream) return false;
743
                buf.resize(IO_BUFFER_SIZE);
Guillaume Roguez's avatar
Guillaume Roguez committed
744 745 746
                auto eof = stream->read(buf);
                // if eof we let a chance to send a reply before leaving
                if (not buf.empty()) {
747 748 749 750
                    endpoint_->write(buf, ec);
                    if (ec)
                        throw std::system_error(ec);
                }
Guillaume Roguez's avatar
Guillaume Roguez committed
751 752
                if (not eof)
                    return false;
753

754 755 756 757
                if (!bufferPool_.empty()) {
                    stream->write(bufferPool_);
                    bufferPool_.clear();
                } else if (endpoint_->waitForData(0, ec) > 0) {
758 759 760
                  buf.resize(IO_BUFFER_SIZE);
                  endpoint_->read(buf, ec);
                  if (ec)
761
                    throw std::system_error(ec);
762 763 764 765
                  sleep = false;
                  return stream->write(buf);
                } else if (ec)
                  throw std::system_error(ec);
766
                return true;
767
            });
768 769 770

        if (sleep)
            std::this_thread::sleep_for(std::chrono::milliseconds(100));
771 772 773 774 775
    }
}

//==============================================================================

776
PeerConnection::PeerConnection(std::function<void()>&& done,
777
                               const std::string& peer_uri,
778
                               std::unique_ptr<GenericSocket<uint8_t>> endpoint)
779
    : pimpl_(std::make_unique<PeerConnectionImpl>(std::move(done), peer_uri, std::move(endpoint)))
780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796
{}

PeerConnection::~PeerConnection()
{}

void
PeerConnection::attachInputStream(const std::shared_ptr<Stream>& stream)
{
    pimpl_->ctrlChannel << std::make_unique<AttachInputCtrlMsg>(stream);
}

void
PeerConnection::attachOutputStream(const std::shared_ptr<Stream>& stream)
{
    pimpl_->ctrlChannel << std::make_unique<AttachOutputCtrlMsg>(stream);
}

797 798 799 800 801 802
bool
PeerConnection::hasStreamWithId(const DRing::DataTransferId& id)
{
    return pimpl_->hasStreamWithId(id);
}

803 804 805 806 807 808
std::string
PeerConnection::getPeerUri() const
{
    return pimpl_->peer_uri;
}

Adrien Béraud's avatar
Adrien Béraud committed
809
} // namespace jami