From 7950c1d3fe62aea363bc6e28c81725f0c370d53a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Blin?=
 <sebastien.blin@savoirfairelinux.com>
Date: Tue, 25 Aug 2020 14:18:15 -0400
Subject: [PATCH] filetransfer: structures should be correctly deleted at the
 end of a transfer

add a onStateChange callback to delete channeled structures on finished states.

Change-Id: I7670cc655719029806a7a887d48197dc2f488651
---
 src/data_transfer.cpp               |  14 +++
 src/data_transfer.h                 |   2 +
 src/ftp_server.cpp                  |   2 +
 src/ftp_server.h                    |  11 +++
 src/jamidht/channeled_transfers.cpp |   9 +-
 src/jamidht/channeled_transfers.h   |   7 +-
 src/jamidht/p2p.cpp                 | 140 ++++++++++++++++++----------
 src/peer_connection.cpp             |  17 +++-
 src/peer_connection.h               |   5 +
 9 files changed, 151 insertions(+), 56 deletions(-)

diff --git a/src/data_transfer.cpp b/src/data_transfer.cpp
index 0ac7441efe..5312cb4d72 100644
--- a/src/data_transfer.cpp
+++ b/src/data_transfer.cpp
@@ -110,12 +110,15 @@ public:
 
     virtual void cancel() {}
 
+    void setOnStateChangedCb(const OnStateChangedCb& cb) override;
+
 protected:
     mutable std::mutex infoMutex_;
     mutable DRing::DataTransferInfo info_;
     mutable std::atomic_bool started_ {false};
     std::atomic_bool wasStarted_ {false};
     InternalCompletionCb internalCompletionCb_ {};
+    OnStateChangedCb stateChangedCb_ {};
 };
 
 void
@@ -125,6 +128,8 @@ DataTransfer::emit(DRing::DataTransferEventCode code) const
         std::lock_guard<std::mutex> lk {infoMutex_};
         info_.lastEvent = code;
     }
+    if (stateChangedCb_)
+        stateChangedCb_(id, code);
     if (internalCompletionCb_)
         return; // VCard transfer is just for the daemon
     runOnMainThread([id = id, code]() {
@@ -132,6 +137,12 @@ DataTransfer::emit(DRing::DataTransferEventCode code) const
     });
 }
 
+void
+DataTransfer::setOnStateChangedCb(const OnStateChangedCb& cb)
+{
+    stateChangedCb_ = std::move(cb);
+}
+
 //==============================================================================
 
 /**
@@ -453,6 +464,8 @@ SubOutgoingFileTransfer::emit(DRing::DataTransferEventCode code) const
         std::lock_guard<std::mutex> lk {infoMutex_};
         info_.lastEvent = code;
     }
+    if (stateChangedCb_)
+        stateChangedCb_(id, code);
     metaInfo_->updateInfo(info_);
     if (code == DRing::DataTransferEventCode::wait_peer_acceptance) {
         timeoutThread_ = std::unique_ptr<std::thread>(new std::thread([this]() {
@@ -489,6 +502,7 @@ public:
                                                                      peer_uri,
                                                                      internalCompletionCb_,
                                                                      metaInfo_);
+        newTransfer->setOnStateChangedCb(stateChangedCb_);
         subtransfer_.emplace_back(newTransfer);
         newTransfer->start();
         return newTransfer;
diff --git a/src/data_transfer.h b/src/data_transfer.h
index 9d02441baf..0e4f760bab 100644
--- a/src/data_transfer.h
+++ b/src/data_transfer.h
@@ -36,6 +36,8 @@ struct IncomingFileInfo
 };
 
 typedef std::function<void(const std::string&)> InternalCompletionCb;
+typedef std::function<void(const DRing::DataTransferId&, const DRing::DataTransferEventCode&)>
+    OnStateChangedCb;
 
 /// Front-end to data transfer service
 class DataTransferFacade
diff --git a/src/ftp_server.cpp b/src/ftp_server.cpp
index df4283fb7e..a44fcba24f 100644
--- a/src/ftp_server.cpp
+++ b/src/ftp_server.cpp
@@ -85,6 +85,8 @@ FtpServer::startNewFile()
         JAMI_DBG() << "[FTP] transfer aborted by client";
         closed_ = true; // send NOK msg at next read()
     } else {
+        if (tmpOnStateChangedCb_)
+            out_.stream->setOnStateChangedCb(std::move(tmpOnStateChangedCb_));
         go_ = true;
     }
 
diff --git a/src/ftp_server.h b/src/ftp_server.h
index 17e3734bf8..8269570d6b 100644
--- a/src/ftp_server.h
+++ b/src/ftp_server.h
@@ -46,6 +46,16 @@ public:
     void close() noexcept override;
 
     void setOnRecv(RecvCb&& cb) { onRecvCb_ = cb; }
+    void setOnStateChangedCb(const OnStateChangedCb& cb)
+    {
+        // If out_ is not attached, store the callback
+        // inside a temporary object. Will be linked when out_.stream
+        // will be attached
+        if (out_.stream)
+            out_.stream->setOnStateChangedCb(std::move(cb));
+        else
+            tmpOnStateChangedCb_ = std::move(cb);
+    }
 
 private:
     bool parseStream(const std::vector<uint8_t>&);
@@ -77,6 +87,7 @@ private:
 
     RecvCb onRecvCb_ {};
     InternalCompletionCb cb_ {};
+    OnStateChangedCb tmpOnStateChangedCb_ {};
 };
 
 } // namespace jami
diff --git a/src/jamidht/channeled_transfers.cpp b/src/jamidht/channeled_transfers.cpp
index 904d8304d9..f0da0a279b 100644
--- a/src/jamidht/channeled_transfers.cpp
+++ b/src/jamidht/channeled_transfers.cpp
@@ -29,8 +29,10 @@
 
 namespace jami {
 
-ChanneledOutgoingTransfer::ChanneledOutgoingTransfer(const std::shared_ptr<ChannelSocket>& channel)
+ChanneledOutgoingTransfer::ChanneledOutgoingTransfer(const std::shared_ptr<ChannelSocket>& channel,
+                                                     OnStateChangedCb&& cb)
     : channel_(channel)
+    , stateChangedCb_(cb)
 {}
 
 ChanneledOutgoingTransfer::~ChanneledOutgoingTransfer()
@@ -67,10 +69,12 @@ ChanneledOutgoingTransfer::linkTransfer(const std::shared_ptr<Stream>& file)
                 c->write(data.data(), data.size(), ec);
             }
         });
+    file_->setOnStateChangedCb(stateChangedCb_);
 }
 
 ChanneledIncomingTransfer::ChanneledIncomingTransfer(const std::shared_ptr<ChannelSocket>& channel,
-                                                     const std::shared_ptr<FtpServer>& ftp)
+                                                     const std::shared_ptr<FtpServer>& ftp,
+                                                     OnStateChangedCb&& cb)
     : ftp_(ftp)
     , channel_(channel)
 {
@@ -88,6 +92,7 @@ ChanneledIncomingTransfer::ChanneledIncomingTransfer(const std::shared_ptr<Chann
             c->write(data.data(), data.size(), ec);
         }
     });
+    ftp_->setOnStateChangedCb(cb);
 }
 
 ChanneledIncomingTransfer::~ChanneledIncomingTransfer()
diff --git a/src/jamidht/channeled_transfers.h b/src/jamidht/channeled_transfers.h
index cf38350f5f..1054dc20b6 100644
--- a/src/jamidht/channeled_transfers.h
+++ b/src/jamidht/channeled_transfers.h
@@ -24,6 +24,7 @@
 #include <memory>
 
 #include "dring/datatransfer_interface.h"
+#include "data_transfer.h"
 
 namespace jami {
 
@@ -34,12 +35,13 @@ class FtpServer;
 class ChanneledOutgoingTransfer
 {
 public:
-    ChanneledOutgoingTransfer(const std::shared_ptr<ChannelSocket>& channel);
+    ChanneledOutgoingTransfer(const std::shared_ptr<ChannelSocket>& channel, OnStateChangedCb&& cb);
     ~ChanneledOutgoingTransfer();
     void linkTransfer(const std::shared_ptr<Stream>& file);
     std::string peer() const;
 
 private:
+    OnStateChangedCb stateChangedCb_ {};
     std::shared_ptr<ChannelSocket> channel_ {};
     std::shared_ptr<Stream> file_;
 };
@@ -48,7 +50,8 @@ class ChanneledIncomingTransfer
 {
 public:
     ChanneledIncomingTransfer(const std::shared_ptr<ChannelSocket>& channel,
-                              const std::shared_ptr<FtpServer>& ftp);
+                              const std::shared_ptr<FtpServer>& ftp,
+                              OnStateChangedCb&& cb);
     ~ChanneledIncomingTransfer();
     DRing::DataTransferId id() const;
 
diff --git a/src/jamidht/p2p.cpp b/src/jamidht/p2p.cpp
index 5ad9b8e1f1..aca42ce804 100644
--- a/src/jamidht/p2p.cpp
+++ b/src/jamidht/p2p.cpp
@@ -211,6 +211,10 @@ public:
                      const std::function<void(PeerConnection*)>&);
     bool turnConnect();
     bool validatePeerCertificate(const dht::crypto::Certificate&, dht::InfoHash&);
+    void stateChanged(const std::string& peer_id,
+                      const DRing::DataTransferId& tid,
+                      const DRing::DataTransferEventCode& code);
+    void closeConnection(const std::string& peer_id, const DRing::DataTransferId& tid);
 
     std::future<void> loopFut_; // keep it last member
 
@@ -466,6 +470,10 @@ private:
                 connection_ = std::make_unique<PeerConnection>([this] { cancel(); },
                                                                peer_.toString(),
                                                                std::move(tls_ep_));
+                connection_->setOnStateChangedCb([this](const DRing::DataTransferId& id,
+                                                        const DRing::DataTransferEventCode& code) {
+                    parent_.stateChanged(peer_.toString(), id, code);
+                });
                 for (auto& cb : listeners_) {
                     cb(connection_.get());
                 }
@@ -727,6 +735,11 @@ DhtPeerConnector::Impl::answerToRequest(PeerConnectionMsg&& request,
                                                                    peer_h,
                                                                    std::move(
                                                                        waitForReadyEndpoints_[idx]));
+                connection->setOnStateChangedCb(
+                    [this, peer_h](const DRing::DataTransferId& id,
+                                   const DRing::DataTransferEventCode& code) {
+                        stateChanged(peer_h, id, code);
+                    });
                 connection->attachOutputStream(std::make_shared<FtpServer>(accountId, peer_h));
                 {
                     std::lock_guard<std::mutex> lk(serversMutex_);
@@ -841,6 +854,24 @@ DhtPeerConnector::Impl::cancelChanneled(const std::string& peerId, const DRing::
     });
 }
 
+void
+DhtPeerConnector::Impl::stateChanged(const std::string& peer_id,
+                                     const DRing::DataTransferId& tid,
+                                     const DRing::DataTransferEventCode& code)
+{
+    if (code == DRing::DataTransferEventCode::finished
+        or code == DRing::DataTransferEventCode::closed_by_peer
+        or code == DRing::DataTransferEventCode::timeout_expired)
+        closeConnection(peer_id, tid);
+}
+
+void
+DhtPeerConnector::Impl::closeConnection(const std::string& peer_id, const DRing::DataTransferId& tid)
+{
+    cancel(peer_id, tid);
+    cancelChanneled(peer_id, tid);
+}
+
 //==============================================================================
 
 DhtPeerConnector::DhtPeerConnector(JamiAccount& account)
@@ -896,57 +927,63 @@ DhtPeerConnector::requestConnection(
 
     const auto peer_h = dht::InfoHash(peer_id);
 
-    auto channelReadyCb = [this, tid, channeledConnectedCb, onChanneledCancelled, connect_cb](
-                              const std::shared_ptr<ChannelSocket>& channel) {
-        auto shared = pimpl_->account.lock();
-        if (!channel) {
-            onChanneledCancelled();
-            return;
-        }
-        if (!shared)
-            return;
-        JAMI_INFO("New file channel for outgoing transfer with id(%lu)", tid);
-
-        auto outgoingFile = std::make_shared<ChanneledOutgoingTransfer>(channel);
-        if (!outgoingFile)
-            return;
-        {
-            std::lock_guard<std::mutex> lk(pimpl_->channeledOutgoingMtx_);
-            pimpl_->channeledOutgoing_[tid].emplace_back(outgoingFile);
-        }
+    auto channelReadyCb =
+        [this, tid, peer_id, channeledConnectedCb, onChanneledCancelled, connect_cb](
+            const std::shared_ptr<ChannelSocket>& channel) {
+            auto shared = pimpl_->account.lock();
+            if (!channel) {
+                onChanneledCancelled();
+                return;
+            }
+            if (!shared)
+                return;
+            JAMI_INFO("New file channel for outgoing transfer with id(%lu)", tid);
+
+            auto outgoingFile = std::make_shared<ChanneledOutgoingTransfer>(
+                channel,
+                [this, peer_id](const DRing::DataTransferId& id,
+                                const DRing::DataTransferEventCode& code) {
+                    pimpl_->stateChanged(peer_id, id, code);
+                });
+            if (!outgoingFile)
+                return;
+            {
+                std::lock_guard<std::mutex> lk(pimpl_->channeledOutgoingMtx_);
+                pimpl_->channeledOutgoing_[tid].emplace_back(outgoingFile);
+            }
 
-        channel->onShutdown([this, tid, onChanneledCancelled, peer = outgoingFile->peer()]() {
-            JAMI_INFO("Channel down for outgoing transfer with id(%lu)", tid);
-            onChanneledCancelled();
-            dht::ThreadPool::io().run([w = pimpl_->weak(), tid, peer] {
-                auto shared = w.lock();
-                if (!shared)
-                    return;
-                // Cancel outgoing files
-                {
-                    std::lock_guard<std::mutex> lk(shared->channeledOutgoingMtx_);
-                    auto outgoingTransfers = shared->channeledOutgoing_.find(tid);
-                    if (outgoingTransfers != shared->channeledOutgoing_.end()) {
-                        auto& currentTransfers = outgoingTransfers->second;
-                        auto it = currentTransfers.begin();
-                        while (it != currentTransfers.end()) {
-                            auto& transfer = *it;
-                            if (transfer && transfer->peer() == peer)
-                                it = currentTransfers.erase(it);
-                            else
-                                ++it;
+            channel->onShutdown([this, tid, onChanneledCancelled, peer = outgoingFile->peer()]() {
+                JAMI_INFO("Channel down for outgoing transfer with id(%lu)", tid);
+                onChanneledCancelled();
+                dht::ThreadPool::io().run([w = pimpl_->weak(), tid, peer] {
+                    auto shared = w.lock();
+                    if (!shared)
+                        return;
+                    // Cancel outgoing files
+                    {
+                        std::lock_guard<std::mutex> lk(shared->channeledOutgoingMtx_);
+                        auto outgoingTransfers = shared->channeledOutgoing_.find(tid);
+                        if (outgoingTransfers != shared->channeledOutgoing_.end()) {
+                            auto& currentTransfers = outgoingTransfers->second;
+                            auto it = currentTransfers.begin();
+                            while (it != currentTransfers.end()) {
+                                auto& transfer = *it;
+                                if (transfer && transfer->peer() == peer)
+                                    it = currentTransfers.erase(it);
+                                else
+                                    ++it;
+                            }
+                            if (currentTransfers.empty())
+                                shared->channeledOutgoing_.erase(outgoingTransfers);
                         }
-                        if (currentTransfers.empty())
-                            shared->channeledOutgoing_.erase(outgoingTransfers);
                     }
-                }
-                Manager::instance().dataTransfers->close(tid);
+                    Manager::instance().dataTransfers->close(tid);
+                });
             });
-        });
-        // Cancel via DHT because we will use the channeled path
-        connect_cb(nullptr);
-        channeledConnectedCb(outgoingFile);
-    };
+            // Cancel via DHT because we will use the channeled path
+            connect_cb(nullptr);
+            channeledConnectedCb(outgoingFile);
+        };
 
     if (isVCard) {
         acc->connectionManager().connectDevice(peer_id,
@@ -999,8 +1036,7 @@ DhtPeerConnector::requestConnection(
 void
 DhtPeerConnector::closeConnection(const std::string& peer_id, const DRing::DataTransferId& tid)
 {
-    pimpl_->cancel(peer_id, tid);
-    pimpl_->cancelChanneled(peer_id, tid);
+    pimpl_->closeConnection(peer_id, tid);
 }
 
 bool
@@ -1028,7 +1064,11 @@ DhtPeerConnector::onIncomingConnection(const std::string& peer_id,
     if (!acc)
         return;
     auto incomingFile = std::make_unique<ChanneledIncomingTransfer>(
-        channel, std::make_shared<FtpServer>(acc->getAccountID(), peer_id, tid, std::move(cb)));
+        channel,
+        std::make_shared<FtpServer>(acc->getAccountID(), peer_id, tid, std::move(cb)),
+        [this, peer_id](const DRing::DataTransferId& id, const DRing::DataTransferEventCode& code) {
+            pimpl_->stateChanged(peer_id, id, code);
+        });
     {
         std::lock_guard<std::mutex> lk(pimpl_->channeledIncomingMtx_);
         pimpl_->channeledIncoming_.emplace(tid, std::move(incomingFile));
diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp
index 94cd52f4a3..ad00ce9e73 100644
--- a/src/peer_connection.cpp
+++ b/src/peer_connection.cpp
@@ -766,11 +766,12 @@ public:
 
     const std::string peer_uri;
     Channel<std::unique_ptr<CtrlMsg>> ctrlChannel;
+    OnStateChangedCb stateChangedCb_;
+    std::vector<std::shared_ptr<Stream>> inputs_;
+    std::vector<std::shared_ptr<Stream>> outputs_;
 
 private:
     std::unique_ptr<SocketType> endpoint_;
-    std::vector<std::shared_ptr<Stream>> inputs_;
-    std::vector<std::shared_ptr<Stream>> outputs_;
     std::future<void> eventLoopFut_;
     std::vector<uint8_t> bufferPool_; // will store non rattached buffers
 
@@ -832,11 +833,13 @@ PeerConnection::PeerConnectionImpl::eventLoop()
             switch (msg->type()) {
             case CtrlMsgType::ATTACH_INPUT: {
                 auto& input_msg = static_cast<AttachInputCtrlMsg&>(*msg);
+                input_msg.stream->setOnStateChangedCb(stateChangedCb_);
                 inputs_.emplace_back(std::move(input_msg.stream));
             } break;
 
             case CtrlMsgType::ATTACH_OUTPUT: {
                 auto& output_msg = static_cast<AttachOutputCtrlMsg&>(*msg);
+                output_msg.stream->setOnStateChangedCb(stateChangedCb_);
                 outputs_.emplace_back(std::move(output_msg.stream));
             } break;
 
@@ -954,4 +957,14 @@ PeerConnection::getPeerUri() const
     return pimpl_->peer_uri;
 }
 
+void
+PeerConnection::setOnStateChangedCb(const OnStateChangedCb& cb)
+{
+    pimpl_->stateChangedCb_ = cb;
+    for (auto& input : pimpl_->inputs_)
+        input->setOnStateChangedCb(pimpl_->stateChangedCb_);
+    for (auto& output : pimpl_->outputs_)
+        output->setOnStateChangedCb(pimpl_->stateChangedCb_);
+}
+
 } // namespace jami
diff --git a/src/peer_connection.h b/src/peer_connection.h
index 63a1174de5..744afd1129 100644
--- a/src/peer_connection.h
+++ b/src/peer_connection.h
@@ -22,6 +22,7 @@
 #pragma once
 
 #include "dring/datatransfer_interface.h"
+#include "data_transfer.h"
 #include "ip_utils.h"
 #include "generic_io.h"
 #include "security/diffie-hellman.h"
@@ -82,6 +83,8 @@ public:
     {
         // Not implemented
     }
+
+    virtual void setOnStateChangedCb(const OnStateChangedCb& cb) {}
 };
 
 //==============================================================================
@@ -269,6 +272,8 @@ public:
 
     std::string getPeerUri() const;
 
+    void setOnStateChangedCb(const OnStateChangedCb&);
+
 private:
     class PeerConnectionImpl;
     std::unique_ptr<PeerConnectionImpl> pimpl_;
-- 
GitLab