From b74f4cbd03c76e95f87212165c84f5e315bd5cf1 Mon Sep 17 00:00:00 2001
From: agsantos <aline.gondimsantos@savoirfairelinux.com>
Date: Thu, 1 Oct 2020 14:30:43 -0400
Subject: [PATCH] GreenScreen: use opencv grabcut in post-processing

Change-Id: I58a2d3bb91917fb2b0ffc131f275fd58e4b5f4d0
---
 GreenScreen/pluginProcessor.cpp | 202 +++++++++++++++++---------------
 GreenScreen/pluginProcessor.h   |  12 +-
 2 files changed, 120 insertions(+), 94 deletions(-)

diff --git a/GreenScreen/pluginProcessor.cpp b/GreenScreen/pluginProcessor.cpp
index 4ec733e..9e67ffe 100644
--- a/GreenScreen/pluginProcessor.cpp
+++ b/GreenScreen/pluginProcessor.cpp
@@ -117,20 +117,7 @@ void
 PluginProcessor::setBackgroundRotation(int angle)
 {
     if (backgroundRotation != angle && (backgroundRotation - angle) != 0) {
-        switch (backgroundRotation - angle) {
-        case 90:
-            cv::rotate(backgroundImage, backgroundImage, cv::ROTATE_90_CLOCKWISE);
-            break;
-        case 180:
-            cv::rotate(backgroundImage, backgroundImage, cv::ROTATE_180);
-            break;
-        case -180:
-            cv::rotate(backgroundImage, backgroundImage, cv::ROTATE_180);
-            break;
-        case -90:
-            cv::rotate(backgroundImage, backgroundImage, cv::ROTATE_90_COUNTERCLOCKWISE);
-            break;
-        }
+        rotateFrame(backgroundRotation - angle, backgroundImage);
         backgroundRotation = angle;
     }
 }
@@ -138,12 +125,14 @@ PluginProcessor::setBackgroundRotation(int angle)
 void
 PluginProcessor::computePredictions()
 {
-    // Run the graph
-    pluginInference.runGraph();
-    auto predictions = pluginInference.masksPredictions();
+    if (count == 0) {
+        // Run the graph
+        pluginInference.runGraph();
+        auto predictions = pluginInference.masksPredictions();
 
-    // Save the predictions
-    computedMask = predictions;
+        // Save the predictions
+        computedMask = predictions;
+    }
 }
 
 void
@@ -182,73 +171,110 @@ PluginProcessor::drawMaskOnFrame(
     if (computedMask.empty()) {
         return;
     }
+
     if (previousMasks[0].empty()) {
         previousMasks[0] = cv::Mat(frameReduced.rows, frameReduced.cols, CV_32FC1, double(0.));
         previousMasks[1] = cv::Mat(frameReduced.rows, frameReduced.cols, CV_32FC1, double(0.));
     }
+
     int maskSize = static_cast<int>(std::sqrt(computedMask.size()));
     cv::Mat maskImg(maskSize, maskSize, CV_32FC1, computedMask.data());
+    cv::Mat* applyMask = &frameReduced;
+    cv::Mat output;
+    kSize = cv::Size(frameReduced.cols * 0.1, frameReduced.rows * 0.1);
+    if (kSize.height % 2 == 0)
+        kSize.height -= 1;
+    if (kSize.width % 2 == 0)
+        kSize.width -= 1;
 
-    rotateFrame(-angle, maskImg);
+    if (count == 0) {
+        rotateFrame(-angle, maskImg);
 #ifdef TFLITE
-    for (int i = 0; i < maskImg.cols; i++) {
-        for (int j = 0; j < maskImg.rows; j++) {
-            if (maskImg.at<float>(j, i) == 15)
-                maskImg.at<float>(j, i) = 255.;
-            else
-                maskImg.at<float>(j, i) = (float) ((int) ((0.6 * maskImg.at<float>(j, i)
-                                                           + 0.3 * previousMasks[0].at<float>(j, i)
-                                                           + 0.1 * previousMasks[1].at<float>(j, i)))
-                                                   % 256);
-        }
-    }
-#else // TFLITE
-    cv::resize(maskImg, maskImg, cv::Size(frameReduced.cols, frameReduced.rows));
-
-    double m, M;
-    cv::minMaxLoc(maskImg, &m, &M);
-
-    if (M < 2) { // avoid detection if there is any one in frame
-        maskImg = 0. * maskImg;
-    } else {
         for (int i = 0; i < maskImg.cols; i++) {
             for (int j = 0; j < maskImg.rows; j++) {
-                maskImg.at<float>(j, i) = (maskImg.at<float>(j, i) - m) / (M - m);
-
-                if (maskImg.at<float>(j, i) < 0.4)
-                    maskImg.at<float>(j, i) = 0.;
-                else if (maskImg.at<float>(j, i) < 0.7) {
-                    float value = maskImg.at<float>(j, i) * 0.6
-                                  + previousMasks[0].at<float>(j, i) * 0.3
-                                  + previousMasks[1].at<float>(j, i) * 0.1;
-                    maskImg.at<float>(j, i) = 0.;
-                    if (value > 0.7)
-                        maskImg.at<float>(j, i) = 1.;
-                } else
+                if (maskImg.at<float>(j, i) == 15)
                     maskImg.at<float>(j, i) = 1.;
+                else
+                    maskImg.at<float>(j, i) = smoothFactors[0] * previousMasks[0].at<float>(j, i)
+                                              + smoothFactors[1] * previousMasks[1].at<float>(j, i);
+            }
+        }
+        cv::morphologyEx(maskImg,
+                         maskImg,
+                         cv::MORPH_CLOSE,
+                         cv::getStructuringElement(cv::MORPH_ELLIPSE, kSize),
+                         cv::Point(-1, -1),
+                         4);
+#else
+        cv::resize(maskImg, maskImg, cv::Size(frameReduced.cols, frameReduced.rows));
+
+        double m, M;
+        cv::minMaxLoc(maskImg, &m, &M);
+
+        if (M < 2) { // avoid detection if there is any one in frame
+            maskImg = 0. * maskImg;
+        } else {
+            for (int i = 0; i < maskImg.cols; i++) {
+                for (int j = 0; j < maskImg.rows; j++) {
+                    maskImg.at<float>(j, i) = (maskImg.at<float>(j, i) - m) / (M - m);
+
+                    if (maskImg.at<float>(j, i) < 0.4)
+                        maskImg.at<float>(j, i) = 0.;
+                    else if (maskImg.at<float>(j, i) < 0.7) {
+                        float value = maskImg.at<float>(j, i) * smoothFactors[0]
+                                      + previousMasks[0].at<float>(j, i) * smoothFactors[1]
+                                      + previousMasks[1].at<float>(j, i) * smoothFactors[2];
+                        maskImg.at<float>(j, i) = 0.;
+                        if (value > 0.7)
+                            maskImg.at<float>(j, i) = 1.;
+                    } else
+                        maskImg.at<float>(j, i) = 1.;
+                }
             }
         }
-    }
 #endif
-
-    previousMasks[1] = previousMasks[0].clone();
-    previousMasks[0] = maskImg.clone();
-
-    kSize = cv::Size(maskImg.cols * 0.05, maskImg.rows * 0.05);
-    if (kSize.height % 2 == 0)
-        kSize.height -= 1;
-    if (kSize.width % 2 == 0)
-        kSize.width -= 1;
-
-#ifndef TFLITE
-    cv::dilate(maskImg, maskImg, cv::getStructuringElement(cv::MORPH_CROSS, kSize));
-    maskImg = maskImg * 255.;
+        if (cv::countNonZero(maskImg) != 0) {
+#ifdef TFLITE
+            cv::Mat tfMask;
+            tfMask = maskImg.clone();
+            tfMask *= 255.;
+            tfMask.convertTo(tfMask, CV_8UC1);
+            cv::threshold(tfMask, tfMask, 127, 255, cv::THRESH_BINARY);
+            if (cv::countNonZero(tfMask) != 0) {
+#endif
+                cv::Mat dilate;
+                cv::dilate(maskImg, dilate, cv::getStructuringElement(cv::MORPH_ELLIPSE, kSize));
+                cv::erode(maskImg,
+                          maskImg,
+                          cv::getStructuringElement(cv::MORPH_ELLIPSE, kSize),
+                          cv::Point(-1, -1),
+                          2);
+                for (int i = 0; i < maskImg.cols; i++) {
+                    for (int j = 0; j < maskImg.rows; j++) {
+                        if (dilate.at<float>(j, i) != maskImg.at<float>(j, i))
+                            maskImg.at<float>(j, i) = grabcutClass;
+                    }
+                }
+                maskImg.convertTo(maskImg, CV_8UC1);
+                applyMask->convertTo(*applyMask, CV_8UC1);
+                cv::Rect rect(1, 1, maskImg.rows, maskImg.cols);
+                cv::Mat bgdModel, fgdModel;
+                cv::grabCut(*applyMask, maskImg, rect, bgdModel, fgdModel, 2, cv::GC_INIT_WITH_MASK);
+                maskImg = maskImg & 1;
+#ifdef TFLITE
+                cv::bitwise_and(maskImg, tfMask, maskImg);
+            }
 #endif
-    GaussianBlur(maskImg, maskImg, kSize, 0); // mask from 0 to 255.
-    maskImg = maskImg / 255.;
+            maskImg.convertTo(maskImg, CV_32FC1);
+            maskImg *= 255.;
+            GaussianBlur(maskImg, maskImg, kSize, 0); // float mask from 0 to 255.
+            maskImg = maskImg / 255.;
+        }
+        previousMasks[1] = previousMasks[0].clone();
+        previousMasks[0] = maskImg.clone();
+    }
 
-    cv::Mat applyMask = frameReduced.clone();
-    cv::Mat roiMaskImg = maskImg.clone();
+    cv::Mat roiMaskImg = previousMasks[0].clone();
     cv::Mat roiMaskImgComplementary = 1. - roiMaskImg; // mask from 1. to 0
 
     std::vector<cv::Mat> channels;
@@ -267,35 +293,27 @@ PluginProcessor::drawMaskOnFrame(
     int origType = frameReduced.type();
     int roiMaskType = roiMaskImg.type();
 
-    applyMask.convertTo(applyMask, roiMaskType);
-    applyMask = applyMask.mul(roiMaskImg);
-    applyMask += backgroundImage.mul(roiMaskImgComplementary);
-    applyMask.convertTo(applyMask, origType);
+    frameReduced.convertTo(output, roiMaskType);
+    output = output.mul(roiMaskImg);
+    output += backgroundImage.mul(roiMaskImgComplementary);
+    output.convertTo(output, origType);
 
-    cv::resize(applyMask, applyMask, cv::Size(frame.cols, frame.rows));
+    cv::resize(output, output, cv::Size(frame.cols, frame.rows));
 
-    copyByLine(frame.data, applyMask.data, lineSize, cv::Size(frame.cols, frame.rows));
+    copyByLine(frame.data, output.data, lineSize, cv::Size(frame.cols, frame.rows));
+    count++;
+    count = count % frameCount;
 }
 
 void
 PluginProcessor::rotateFrame(int angle, cv::Mat& mat)
 {
-    if (angle != 0) {
-        switch (angle) {
-        case -90:
-            cv::rotate(mat, mat, cv::ROTATE_90_COUNTERCLOCKWISE);
-            break;
-        case 180:
-            cv::rotate(mat, mat, cv::ROTATE_180);
-            break;
-        case -180:
-            cv::rotate(mat, mat, cv::ROTATE_180);
-            break;
-        case 90:
-            cv::rotate(mat, mat, cv::ROTATE_90_CLOCKWISE);
-            break;
-        }
-    }
+    if (angle == -90)
+        cv::rotate(mat, mat, cv::ROTATE_90_COUNTERCLOCKWISE);
+    else if (std::abs(angle) == 180)
+        cv::rotate(mat, mat, cv::ROTATE_180);
+    else if (angle == 90)
+        cv::rotate(mat, mat, cv::ROTATE_90_CLOCKWISE);
 }
 
 bool
diff --git a/GreenScreen/pluginProcessor.h b/GreenScreen/pluginProcessor.h
index 528988f..e7f3ab0 100644
--- a/GreenScreen/pluginProcessor.h
+++ b/GreenScreen/pluginProcessor.h
@@ -81,16 +81,24 @@ public:
     cv::Mat backgroundImage;
 
     cv::Size kSize;
-    float scaleX = 0;
-    float scaleY = 0;
 
     PluginInference pluginInference;
     std::string backgroundPath;
+    int count = 0;
 
 private:
     // Frame
     cv::Mat frame;
     int backgroundRotation = 0;
     bool hasBackground_ = false;
+#ifdef TFLITE
+    int grabcutClass = 2;
+    int frameCount = 3;
+    float smoothFactors[2] = {0.3f, 0.05f};
+#else
+    int grabcutClass = 3;
+    int frameCount = 5;
+    float smoothFactors[3] = {0.6f, 0.3f, 0.1f};
+#endif
 };
 } // namespace jami
-- 
GitLab