From f68dd14e5efc47f9bd0dc72cdbb589cfeaaf96b8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Tue, 9 Feb 2021 13:41:34 -0500
Subject: [PATCH] tls_session: protect session in send()

Change-Id: I7b87bad47c66bda0d66140b59315ea45a43a4839
---
 src/security/tls_session.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp
index 65402612e8..beaad1787d 100644
--- a/src/security/tls_session.cpp
+++ b/src/security/tls_session.cpp
@@ -283,7 +283,8 @@ public:
     std::unique_ptr<TlsAnonymousClientCredendials> cacred_; // ctor init.
     std::unique_ptr<TlsAnonymousServerCredendials> sacred_; // ctor init.
     std::unique_ptr<TlsCertificateCredendials> xcred_;      // ctor init.
-    std::mutex sessionMutex_;
+    std::mutex sessionReadMutex_;
+    std::mutex sessionWriteMutex_;
     gnutls_session_t session_ {nullptr};
     gnutls_datum_t cookie_key_ {nullptr, 0};
     gnutls_dtls_prestate_st prestate_ {};
@@ -832,6 +833,7 @@ TlsSession::TlsSessionImpl::sendOcspRequest(const std::string& uri,
 std::size_t
 TlsSession::TlsSessionImpl::send(const ValueType* tx_data, std::size_t tx_size, std::error_code& ec)
 {
+    std::lock_guard<std::mutex> lk(sessionWriteMutex_);
     if (state_ != TlsSessionState::ESTABLISHED) {
         ec = std::error_code(GNUTLS_E_INVALID_SESSION, std::system_category());
         return 0;
@@ -1039,7 +1041,8 @@ TlsSession::TlsSessionImpl::cleanup()
     stateCondition_.notify_all();
 
     {
-        std::lock_guard<std::mutex> lk(sessionMutex_);
+        std::lock_guard<std::mutex> lk1(sessionReadMutex_);
+        std::lock_guard<std::mutex> lk2(sessionWriteMutex_);
         if (session_) {
             if (transport_->isReliable())
                 gnutls_bye(session_, GNUTLS_SHUT_RDWR);
@@ -1624,11 +1627,6 @@ TlsSession::shutdown()
 std::size_t
 TlsSession::write(const ValueType* data, std::size_t size, std::error_code& ec)
 {
-    if (pimpl_->state_ != TlsSessionState::ESTABLISHED) {
-        ec = std::make_error_code(std::errc::broken_pipe);
-        return 0;
-    }
-
     return pimpl_->send(data, size, ec);
 }
 
@@ -1645,7 +1643,7 @@ TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec)
     while (true) {
         ssize_t ret;
         {
-            std::lock_guard<std::mutex> lk(pimpl_->sessionMutex_);
+            std::lock_guard<std::mutex> lk(pimpl_->sessionReadMutex_);
             if (!pimpl_->session_)
                 return 0;
             ret = gnutls_record_recv(pimpl_->session_, data, size);
-- 
GitLab