Skip to content
Snippets Groups Projects
Select Git revision
  • ac1940dcd1c7078c1030a0e84f33c342d7a6d1ae
  • master default protected
2 results

TFInference.h

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    TFInference.h 3.84 KiB
    /**
     *  Copyright (C) 2020 Savoir-faire Linux Inc.
     *
     *  Author: Aline Gondim Santos <aline.gondimsantos@savoirfairelinux.com>
     *
     *  This program is free software; you can redistribute it and/or modify
     *  it under the terms of the GNU General Public License as published by
     *  the Free Software Foundation; either version 3 of the License, or
     *  (at your option) any later version.
     *
     *  This program is distributed in the hope that it will be useful,
     *  but WITHOUT ANY WARRANTY; without even the implied warranty of
     *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     *  GNU General Public License for more details.
     *
     *  You should have received a copy of the GNU General Public License
     *  along with this program; if not, write to the Free Software
     *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA.
     */
    
    #pragma once
    
    // Library headers
    #include "TFModels.h"
    
    // STL
    #include <memory>
    #include <string>
    #include <vector>
    
    #ifdef TFLITE
    #include <tensorflow/lite/interpreter.h>
    #include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h>
    
    namespace tflite {
    class FlatBufferModel;
    class Interpreter;
    class StatefulNnApiDelegate;
    } // namespace tflite
    
    #else
    #ifdef WIN32
    #define NOMINMAX
    #undef min
    #undef max
    #endif
    
    #include <tensorflow/core/lib/core/status.h>
    #include <tensorflow/core/public/session.h>
    #include <tensorflow/core/framework/tensor.h>
    #include <tensorflow/core/framework/types.pb.h>
    #include <tensorflow/core/platform/init_main.h>
    #include <tensorflow/core/protobuf/config.pb.h>
    
    namespace tensorflow {
    class Tensor;
    class Status;
    class GraphDef;
    class Session;
    struct SessionOptions;
    class TensorShape;
    class Env;
    enum DataType : int;
    } // namespace tensorflow
    
    #endif
    
    namespace jami {
    class TensorflowInference
    {
    public:
        /**
         * @brief TensorflowInference
         * Takes a supervised model where the model and labels files are defined
         * @param model
         */
        TensorflowInference(TFModel model);
        ~TensorflowInference();
    
    #ifdef TFLITE
        /**
         * @brief loadModel
         * Load the model from the file described in the Supervised Model
         */
        void loadModel();
        void buildInterpreter();
        void setInterpreterSettings();
    
        /**
         * @brief allocateTensors
         * Tries to allocate space for the tensors
         * In case of success isAllocated() should return true
         */
        void allocateTensors();
    
        // Debug methods
        void describeModelTensors() const;
        void describeTensor(std::string prefix, int index) const;
    
    #else
        void LoadGraph();
        tensorflow::Tensor imageTensor;
    
    #endif // TFLITE
    
        /**
         * @brief runGraph
         * runs the underlaying graph model.numberOfRuns times
         * Where numberOfRuns is defined in the model
         */
        void runGraph();
    
        /**
         * @brief init
         * Inits the model, interpreter, allocates tensors and load the labels
         */
        void init();
        // Getters
        bool isAllocated() const;
    
    protected:
    #ifdef TFLITE
        /**
         * @brief getTensorDimensions
         * Utility method to get Tensorflow Tensor dimensions
         * Given the index of the tensor, the function gives back a vector
         * Where each element is the dimension of the vector-space (finite dimension)
         * Thus, vector.size() is the number of vector-space used by the tensor
         * @param index
         * @return
         */
        std::vector<int> getTensorDimensions(int index) const;
    
        // Tensorflow model and interpreter
        std::unique_ptr<tflite::FlatBufferModel> flatbufferModel;
        std::unique_ptr<tflite::Interpreter> interpreter;
    #else
        std::unique_ptr<tensorflow::Session> session;
        std::vector<tensorflow::Tensor> outputs;
    #endif
        TFModel tfModel;
        std::vector<std::string> labels;
    
        /**
         * @brief nbLabels
         * The real number of labels may not match the labels.size() because of padding
         */
        size_t nbLabels;
    
        bool allocated_ = false;
    };
    } // namespace jami