mv_machine_learning: convert ImageClassification class into template class
authorVibhav Aggarwal <v.aggarwal@samsung.com>
Tue, 14 Nov 2023 11:01:23 +0000 (20:01 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 6 Dec 2023 01:36:46 +0000 (10:36 +0900)
[Issue type] code refactoring

Change-Id: I07a76684aece1773a3ef9dd2b3abbb430dad2394
Signed-off-by: Vibhav Aggarwal <v.aggarwal@samsung.com>
mv_machine_learning/image_classification/include/iimage_classification.h [new file with mode: 0644]
mv_machine_learning/image_classification/include/image_classification.h
mv_machine_learning/image_classification/include/image_classification_adapter.h
mv_machine_learning/image_classification/include/image_classification_default.h
mv_machine_learning/image_classification/src/image_classification.cpp
mv_machine_learning/image_classification/src/image_classification_adapter.cpp
mv_machine_learning/image_classification/src/image_classification_default.cpp

diff --git a/mv_machine_learning/image_classification/include/iimage_classification.h b/mv_machine_learning/image_classification/include/iimage_classification.h
new file mode 100644 (file)
index 0000000..fcca973
--- /dev/null
@@ -0,0 +1,49 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __IIMAGE_CLASSIFICATION_H__
+#define __IIMAGE_CLASSIFICATION_H__
+
+#include <mv_common.h>
+
+#include "image_classification_type.h"
+
+namespace mediavision
+{
+namespace machine_learning
+{
+class IImageClassification
+{
+public:
+       virtual ~IImageClassification() {};
+
+       virtual void preDestroy() = 0;
+       virtual void setEngineInfo(std::string engine_type, std::string device_type) = 0;
+       virtual void getNumberOfEngines(unsigned int *number_of_engines) = 0;
+       virtual void getEngineType(unsigned int engine_index, char **engine_type) = 0;
+       virtual void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) = 0;
+       virtual void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) = 0;
+       virtual void configure() = 0;
+       virtual void prepare() = 0;
+       virtual void perform(mv_source_h &mv_src) = 0;
+       virtual void performAsync(ImageClassificationInput &input) = 0;
+       virtual ImageClassificationResult &getOutput() = 0;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index f8ee160..3029af9 100644 (file)
 #include "ImageClassificationParser.h"
 #include "machine_learning_preprocess.h"
 #include "async_manager.h"
+#include "iimage_classification.h"
 
 namespace mediavision
 {
 namespace machine_learning
 {
-class ImageClassification
+template<typename T> class ImageClassification : public IImageClassification
 {
 private:
        std::unique_ptr<AsyncManager<ImageClassificationResult> > _async_manager;
@@ -44,11 +45,8 @@ private:
        void loadLabel();
        void getEngineList();
        void getDeviceList(const char *engine_type);
-       template<typename T>
        void preprocess(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo, std::vector<T> &inputVector);
        std::shared_ptr<MetaInfo> getInputMetaInfo();
-       template<typename T> void perform(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo);
-       template<typename T> void performAsync(ImageClassificationInput &input, std::shared_ptr<MetaInfo> metaInfo);
 
 protected:
        std::unique_ptr<mediavision::inference::Inference> _inference;
@@ -60,7 +58,7 @@ protected:
 
        void getOutputNames(std::vector<std::string> &names);
        void getOutpuTensor(std::string &target_name, std::vector<float> &tensor);
-       template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
+       void inference(std::vector<std::vector<T> > &inputVectors);
        virtual ImageClassificationResult &result() = 0;
 
 public:
index e625efe..d2f24d0 100644 (file)
@@ -23,6 +23,7 @@
 #include "itask.h"
 #include "machine_learning_config.h"
 #include "image_classification_default.h"
+#include "iimage_classification.h"
 
 namespace mediavision
 {
@@ -31,7 +32,7 @@ namespace machine_learning
 template<typename T, typename V> class ImageClassificationAdapter : public mediavision::common::ITask<T, V>
 {
 private:
-       std::unique_ptr<ImageClassification> _image_classification;
+       std::unique_ptr<IImageClassification> _image_classification;
        std::shared_ptr<MachineLearningConfig> _config;
        T _source;
        const std::string _config_file_name = "image_classification.json";
index 870b4eb..d409276 100644 (file)
@@ -29,8 +29,11 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class ImageClassificationDefault : public ImageClassification
+template<typename T> class ImageClassificationDefault : public ImageClassification<T>
 {
+       using ImageClassification<T>::_config;
+       using ImageClassification<T>::_labels;
+
 private:
        ImageClassificationResult _result;
 
index 369cd23..b3f3587 100644 (file)
@@ -35,12 +35,13 @@ namespace mediavision
 {
 namespace machine_learning
 {
-ImageClassification::ImageClassification(std::shared_ptr<MachineLearningConfig> config) : _config(config)
+template<typename T>
+ImageClassification<T>::ImageClassification(std::shared_ptr<MachineLearningConfig> config) : _config(config)
 {
        _inference = make_unique<Inference>();
 }
 
-void ImageClassification::preDestroy()
+template<typename T> void ImageClassification<T>::preDestroy()
 {
        if (!_async_manager)
                return;
@@ -48,9 +49,8 @@ void ImageClassification::preDestroy()
        _async_manager->stop();
 }
 
-void ImageClassification::configure()
+template<typename T> void ImageClassification<T>::configure()
 {
-       _config->loadMetaFile(make_unique<ImageClassificationParser>());
        loadLabel();
 
        int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
@@ -58,7 +58,7 @@ void ImageClassification::configure()
                throw InvalidOperation("Fail to bind a backend engine.");
 }
 
-void ImageClassification::loadLabel()
+template<typename T> void ImageClassification<T>::loadLabel()
 {
        if (_config->getLabelFilePath().empty())
                return;
@@ -79,7 +79,7 @@ void ImageClassification::loadLabel()
        readFile.close();
 }
 
-void ImageClassification::getEngineList()
+template<typename T> void ImageClassification<T>::getEngineList()
 {
        for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
                auto backend = _inference->getSupportedInferenceBackend(idx);
@@ -91,7 +91,7 @@ void ImageClassification::getEngineList()
        }
 }
 
-void ImageClassification::getDeviceList(const char *engine_type)
+template<typename T> void ImageClassification<T>::getDeviceList(const char *engine_type)
 {
        // TODO. add device types available for a given engine type later.
        //       In default, cpu and gpu only.
@@ -99,7 +99,8 @@ void ImageClassification::getDeviceList(const char *engine_type)
        _valid_devices.push_back("gpu");
 }
 
-void ImageClassification::setEngineInfo(std::string engine_type_name, std::string device_type_name)
+template<typename T>
+void ImageClassification<T>::setEngineInfo(std::string engine_type_name, std::string device_type_name)
 {
        if (engine_type_name.empty() || device_type_name.empty())
                throw InvalidParameter("Invalid engine info.");
@@ -120,7 +121,7 @@ void ImageClassification::setEngineInfo(std::string engine_type_name, std::strin
                 device_type_name.c_str(), device_type);
 }
 
-void ImageClassification::getNumberOfEngines(unsigned int *number_of_engines)
+template<typename T> void ImageClassification<T>::getNumberOfEngines(unsigned int *number_of_engines)
 {
        if (!_valid_backends.empty()) {
                *number_of_engines = _valid_backends.size();
@@ -131,7 +132,7 @@ void ImageClassification::getNumberOfEngines(unsigned int *number_of_engines)
        *number_of_engines = _valid_backends.size();
 }
 
-void ImageClassification::getEngineType(unsigned int engine_index, char **engine_type)
+template<typename T> void ImageClassification<T>::getEngineType(unsigned int engine_index, char **engine_type)
 {
        if (!_valid_backends.empty()) {
                if (_valid_backends.size() <= engine_index)
@@ -149,7 +150,8 @@ void ImageClassification::getEngineType(unsigned int engine_index, char **engine
        *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
 }
 
-void ImageClassification::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
+template<typename T>
+void ImageClassification<T>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
 {
        if (!_valid_devices.empty()) {
                *number_of_devices = _valid_devices.size();
@@ -160,7 +162,8 @@ void ImageClassification::getNumberOfDevices(const char *engine_type, unsigned i
        *number_of_devices = _valid_devices.size();
 }
 
-void ImageClassification::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
+template<typename T>
+void ImageClassification<T>::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
 {
        if (!_valid_devices.empty()) {
                if (_valid_devices.size() <= device_index)
@@ -178,7 +181,7 @@ void ImageClassification::getDeviceType(const char *engine_type, const unsigned
        *device_type = const_cast<char *>(_valid_devices[device_index].data());
 }
 
-void ImageClassification::prepare()
+template<typename T> void ImageClassification<T>::prepare()
 {
        int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
@@ -196,7 +199,7 @@ void ImageClassification::prepare()
                throw InvalidOperation("Fail to load model files.");
 }
 
-shared_ptr<MetaInfo> ImageClassification::getInputMetaInfo()
+template<typename T> shared_ptr<MetaInfo> ImageClassification<T>::getInputMetaInfo()
 {
        TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
        IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
@@ -212,7 +215,7 @@ shared_ptr<MetaInfo> ImageClassification::getInputMetaInfo()
 }
 
 template<typename T>
-void ImageClassification::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
+void ImageClassification<T>::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
 {
        LOGI("ENTER");
 
@@ -244,7 +247,7 @@ void ImageClassification::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> m
        LOGI("LEAVE");
 }
 
-template<typename T> void ImageClassification::inference(vector<vector<T> > &inputVectors)
+template<typename T> void ImageClassification<T>::inference(vector<vector<T> > &inputVectors)
 {
        LOGI("ENTER");
 
@@ -255,30 +258,18 @@ template<typename T> void ImageClassification::inference(vector<vector<T> > &inp
        LOGI("LEAVE");
 }
 
-template<typename T> void ImageClassification::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void ImageClassification<T>::perform(mv_source_h &mv_src)
 {
+       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
        vector<T> inputVector;
 
-       preprocess<T>(mv_src, metaInfo, inputVector);
+       preprocess(mv_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
-
-       inference<T>(inputVectors);
-}
-
-void ImageClassification::perform(mv_source_h &mv_src)
-{
-       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
-
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
-               perform<unsigned char>(mv_src, metaInfo);
-       else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
-               perform<float>(mv_src, metaInfo);
-       else
-               throw InvalidOperation("Invalid model data type.");
+       inference(inputVectors);
 }
 
-ImageClassificationResult &ImageClassification::getOutput()
+template<typename T> ImageClassificationResult &ImageClassification<T>::getOutput()
 {
        if (_async_manager) {
                if (!_async_manager->isWorking())
@@ -294,14 +285,13 @@ ImageClassificationResult &ImageClassification::getOutput()
        return _current_result;
 }
 
-template<typename T>
-void ImageClassification::performAsync(ImageClassificationInput &input, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void ImageClassification<T>::performAsync(ImageClassificationInput &input)
 {
        if (!_async_manager) {
                _async_manager = make_unique<AsyncManager<ImageClassificationResult> >([this]() {
                        AsyncInputQueue<T> inputQueue = _async_manager->popFromInput<T>();
 
-                       inference<T>(inputQueue.inputs);
+                       inference(inputQueue.inputs);
 
                        ImageClassificationResult &resultQueue = result();
 
@@ -310,30 +300,16 @@ void ImageClassification::performAsync(ImageClassificationInput &input, shared_p
                });
        }
 
+       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
        vector<T> inputVector;
 
-       preprocess<T>(input.inference_src, metaInfo, inputVector);
+       preprocess(input.inference_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
-
        _async_manager->push(inputVectors);
 }
 
-void ImageClassification::performAsync(ImageClassificationInput &input)
-{
-       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
-
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
-               performAsync<unsigned char>(input, metaInfo);
-       } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
-               performAsync<float>(input, metaInfo);
-               // TODO
-       } else {
-               throw InvalidOperation("Invalid model data type.");
-       }
-}
-
-void ImageClassification::getOutputNames(vector<string> &names)
+template<typename T> void ImageClassification<T>::getOutputNames(vector<string> &names)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
        IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
@@ -342,7 +318,7 @@ void ImageClassification::getOutputNames(vector<string> &names)
                names.push_back(it->first);
 }
 
-void ImageClassification::getOutpuTensor(string &target_name, vector<float> &tensor)
+template<typename T> void ImageClassification<T>::getOutpuTensor(string &target_name, vector<float> &tensor)
 {
        LOGI("ENTER");
 
@@ -359,18 +335,8 @@ void ImageClassification::getOutpuTensor(string &target_name, vector<float> &ten
        LOGI("LEAVE");
 }
 
-template void ImageClassification::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                        vector<float> &inputVector);
-template void ImageClassification::inference<float>(vector<vector<float> > &inputVectors);
-template void ImageClassification::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-template void ImageClassification::performAsync<float>(ImageClassificationInput &input, shared_ptr<MetaInfo> metaInfo);
-
-template void ImageClassification::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                                        vector<unsigned char> &inputVector);
-template void ImageClassification::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
-template void ImageClassification::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-template void ImageClassification::performAsync<unsigned char>(ImageClassificationInput &input,
-                                                                                                                          shared_ptr<MetaInfo> metaInfo);
+template class ImageClassification<unsigned char>;
+template class ImageClassification<float>;
 
 }
 }
index 17b5a1b..fdb2e30 100644 (file)
@@ -44,7 +44,19 @@ template<typename T, typename V> ImageClassificationAdapter<T, V>::~ImageClassif
 
 template<typename T, typename V> void ImageClassificationAdapter<T, V>::create()
 {
-       _image_classification = make_unique<ImageClassificationDefault>(_config);
+       _config->loadMetaFile(make_unique<ImageClassificationParser>());
+       mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
+
+       switch (dataType) {
+       case MV_INFERENCE_DATA_UINT8:
+               _image_classification = make_unique<ImageClassificationDefault<unsigned char> >(_config);
+               break;
+       case MV_INFERENCE_DATA_FLOAT32:
+               _image_classification = make_unique<ImageClassificationDefault<float> >(_config);
+               break;
+       default:
+               throw InvalidOperation("Invalid image classification data type.");
+       }
 }
 
 template<typename T, typename V>
index 37a1873..59e8f33 100644 (file)
@@ -30,23 +30,24 @@ namespace mediavision
 {
 namespace machine_learning
 {
-ImageClassificationDefault::ImageClassificationDefault(shared_ptr<MachineLearningConfig> config)
-               : ImageClassification(config), _result()
+template<typename T>
+ImageClassificationDefault<T>::ImageClassificationDefault(shared_ptr<MachineLearningConfig> config)
+               : ImageClassification<T>(config), _result()
 {}
 
-ImageClassificationDefault::~ImageClassificationDefault()
+template<typename T> ImageClassificationDefault<T>::~ImageClassificationDefault()
 {}
 
-ImageClassificationResult &ImageClassificationDefault::result()
+template<typename T> ImageClassificationResult &ImageClassificationDefault<T>::result()
 {
        vector<string> names;
 
-       ImageClassification::getOutputNames(names);
+       ImageClassification<T>::getOutputNames(names);
 
        vector<float> output_vec;
 
        // In case of image classification model, only one output tensor is used.
-       ImageClassification::getOutpuTensor(names[0], output_vec);
+       ImageClassification<T>::getOutpuTensor(names[0], output_vec);
 
        auto metaInfo = _config->getOutputMetaMap().at(names[0]);
        auto decodingScore = static_pointer_cast<DecodingScore>(metaInfo->decodingTypeMap.at(DecodingType::SCORE));
@@ -62,5 +63,8 @@ ImageClassificationResult &ImageClassificationDefault::result()
        return _result;
 }
 
+template class ImageClassificationDefault<unsigned char>;
+template class ImageClassificationDefault<float>;
+
 }
 }