Add tf_classify reference application (#3642)
author윤지영/동작제어Lab(SR)/Engineer/삼성전자 <jy910.yun@samsung.com>
Tue, 4 Dec 2018 04:44:49 +0000 (13:44 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 4 Dec 2018 04:44:49 +0000 (13:44 +0900)
* Add tf_classify reference application

This patch creates an `example` directory
The tf_classify application is based on TendofFlow Android Camera Demo.

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
* Change application name and move the location

Use tflite_classify instead of tf_classify
Move application under contrib directory
Do not support to download files

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
* Add const keyword and remove unnecessary library link

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
* Use std::vector instead of c-style array

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
* Use OpenCV C++ apis

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
cmake/CfgOptionFlags.cmake
contrib/tflite_classify/.FORMATCHECKED [new file with mode: 0644]
contrib/tflite_classify/CMakeLists.txt [new file with mode: 0644]
contrib/tflite_classify/src/ImageClassifier.cc [new file with mode: 0644]
contrib/tflite_classify/src/ImageClassifier.h [new file with mode: 0644]
contrib/tflite_classify/src/InferenceInterface.cc [new file with mode: 0644]
contrib/tflite_classify/src/InferenceInterface.h [new file with mode: 0644]
contrib/tflite_classify/src/tflite_classify.cc [new file with mode: 0644]

index 9062e61..1668f78 100644 (file)
@@ -9,6 +9,7 @@ option(BUILD_ANDROID_NN_RUNTIME_TEST "Build Android NN Runtime Test" ON)
 option(BUILD_DETECTION_APP "Build detection example app" OFF)
 option(BUILD_NNAPI_QUICKCHECK "Build NN API Quickcheck tools" OFF)
 option(BUILD_TFLITE_BENCHMARK_MODEL "Build tflite benchmark model" OFF)
+option(BUILD_TFLITE_CLASSIFY_APP "Build tflite_classify app" OFF)
 
 if("${TARGET_ARCH}" STREQUAL "armv7l" AND NOT "${TARGET_OS}" STREQUAL "tizen")
   set(BUILD_PURE_ARM_COMPUTE ON)
diff --git a/contrib/tflite_classify/.FORMATCHECKED b/contrib/tflite_classify/.FORMATCHECKED
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/contrib/tflite_classify/CMakeLists.txt b/contrib/tflite_classify/CMakeLists.txt
new file mode 100644 (file)
index 0000000..b5605f5
--- /dev/null
@@ -0,0 +1,22 @@
+if(NOT BUILD_TFLITE_CLASSIFY_APP)
+  return()
+endif(NOT BUILD_TFLITE_CLASSIFY_APP)
+
+list(APPEND SOURCES "src/tflite_classify.cc")
+list(APPEND SOURCES "src/ImageClassifier.cc")
+list(APPEND SOURCES "src/InferenceInterface.cc")
+
+## Required package
+find_package(OpenCV REQUIRED)
+find_package(Boost REQUIRED COMPONENTS system filesystem)
+
+# Without this line, this appliation couldn't search the opencv library that were already installed in ${ROOTFS_ARM}/usr/lib/arm-linux-gnueabihf directory
+set(CMAKE_EXE_LINKER_FLAGS "-Wl,--as-needed -Wl,--rpath=${ROOTFS_ARM}/usr/lib/arm-linux-gnueabihf -Wl,--rpath=${ROOTFS_ARM}/lib/arm-linux-gnueabihf")
+
+add_executable(tflite_classify ${SOURCES})
+target_include_directories(tflite_classify PRIVATE src)
+target_link_libraries(tflite_classify tensorflow-lite ${LIB_PTHREAD} dl nnfw_support_tflite)
+target_link_libraries(tflite_classify ${Boost_LIBRARIES})
+target_link_libraries(tflite_classify ${OpenCV_LIBRARIES})
+
+install(TARGETS tflite_classify DESTINATION bin)
diff --git a/contrib/tflite_classify/src/ImageClassifier.cc b/contrib/tflite_classify/src/ImageClassifier.cc
new file mode 100644 (file)
index 0000000..fae4f06
--- /dev/null
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ImageClassifier.h"
+
+#include <fstream>
+#include <queue>
+#include <algorithm>
+
+ImageClassifier::ImageClassifier(const std::string &model_file, const std::string &label_file,
+                                 const int input_size, const int image_mean, const int image_std,
+                                 const std::string &input_name, const std::string &output_name,
+                                 const bool use_nnapi)
+    : _inference(new InferenceInterface(model_file, use_nnapi)), _input_size(input_size),
+      _image_mean(image_mean), _image_std(image_std), _input_name(input_name),
+      _output_name(output_name)
+{
+  // Load label
+  std::ifstream label_stream(label_file.c_str());
+  assert(label_stream);
+
+  std::string line;
+  while (std::getline(label_stream, line))
+  {
+    _labels.push_back(line);
+  }
+  _num_classes = _inference->getTensorSize(_output_name);
+  std::cout << "Output tensor size is " << _num_classes << ", label size is " << _labels.size()
+            << std::endl;
+
+  // Pre-allocate buffers
+  _fdata.reserve(_input_size * _input_size * 3);
+  _outputs.reserve(_num_classes);
+}
+
+std::vector<Recognition> ImageClassifier::recognizeImage(const cv::Mat &image)
+{
+  // Resize image
+  cv::Mat cropped;
+  cv::resize(image, cropped, cv::Size(_input_size, _input_size), 0, 0, cv::INTER_AREA);
+
+  // Preprocess the image data from 0~255 int to normalized float based
+  // on the provided parameters
+  _fdata.clear();
+  for (int y = 0; y < cropped.rows; ++y)
+  {
+    for (int x = 0; x < cropped.cols; ++x)
+    {
+      cv::Vec3b color = cropped.at<cv::Vec3b>(y, x);
+      color[0] = color[0] - (float)_image_mean / _image_std;
+      color[1] = color[1] - (float)_image_mean / _image_std;
+      color[2] = color[2] - (float)_image_mean / _image_std;
+
+      _fdata.push_back(color[0]);
+      _fdata.push_back(color[1]);
+      _fdata.push_back(color[2]);
+
+      cropped.at<cv::Vec3b>(y, x) = color;
+    }
+  }
+
+  // Copy the input data into model
+  _inference->feed(_input_name, _fdata, 1, _input_size, _input_size, 3);
+
+  // Run the inference call
+  _inference->run(_output_name);
+
+  // Copy the output tensor back into the output array
+  _inference->fetch(_output_name, _outputs);
+
+  // Find the best classifications
+  auto compare = [](const Recognition &lhs, const Recognition &rhs) {
+    return lhs.confidence < rhs.confidence;
+  };
+
+  std::priority_queue<Recognition, std::vector<Recognition>, decltype(compare)> pq(compare);
+  for (int i = 0; i < _num_classes; ++i)
+  {
+    if (_outputs[i] > _threshold)
+    {
+      pq.push(Recognition(_outputs[i], _labels[i]));
+    }
+  }
+
+  std::vector<Recognition> results;
+  int min = std::min(pq.size(), _max_results);
+  for (int i = 0; i < min; ++i)
+  {
+    results.push_back(pq.top());
+    pq.pop();
+  }
+
+  return results;
+}
diff --git a/contrib/tflite_classify/src/ImageClassifier.h b/contrib/tflite_classify/src/ImageClassifier.h
new file mode 100644 (file)
index 0000000..1ba19af
--- /dev/null
@@ -0,0 +1,99 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file     ImageClassifier.h
+ * @brief    This file contains ImageClassifier class and Recognition structure
+ * @ingroup  COM_AI_RUNTIME
+ */
+
+#ifndef __TFLITE_CLASSIFY_IMAGE_CLASSIFIER_H__
+#define __TFLITE_CLASSIFY_IMAGE_CLASSIFIER_H__
+
+#include "InferenceInterface.h"
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+#include <opencv2/opencv.hpp>
+
+/**
+ * @brief struct to define an immutable result returned by a Classifier
+ */
+struct Recognition
+{
+public:
+  /**
+   * @brief Construct a new Recognition object with confidence and title
+   * @param[in] _confidence A sortable score for how good the recognition is relative to others.
+   * Higher should be better.
+   * @param[in] _title      Display name for the recognition
+   */
+  Recognition(float _confidence, std::string _title) : confidence(_confidence), title(_title) {}
+
+  float confidence;  /** A sortable score for how good the recognition is relative to others. Higher
+                        should be better. */
+  std::string title; /** Display name for the recognition */
+};
+
+/**
+ * @brief Class to define a classifier specialized to label images
+ */
+class ImageClassifier
+{
+public:
+  /**
+   * @brief Construct a new ImageClassifier object with parameters
+   * @param[in] model_file  The filepath of the model FlatBuffer protocol buffer
+   * @param[in] label_file  The filepath of label file for classes
+   * @param[in] input_size  The input size. A square image of input_size x input_size is assumed
+   * @param[in] image_mean  The assumed mean of the image values
+   * @param[in] image_std   The assumed std of the image values
+   * @param[in] input_name  The label of the image input node
+   * @param[in] output_name The label of the output node
+   * @param[in] use_nnapi   The flag to distinguish between TfLite interpreter and NNFW runtime
+   */
+  ImageClassifier(const std::string &model_file, const std::string &label_file,
+                  const int input_size, const int image_mean, const int image_std,
+                  const std::string &input_name, const std::string &output_name,
+                  const bool use_nnapi);
+
+  /**
+   * @brief Recognize the given image data
+   * @param[in] image   The image data to recognize
+   * @return An immutable result vector array
+   */
+  std::vector<Recognition> recognizeImage(const cv::Mat &image);
+
+private:
+  const float _threshold = 0.1f;
+  const unsigned int _max_results = 3;
+
+  std::unique_ptr<InferenceInterface> _inference;
+  int _input_size;
+  int _image_mean;
+  int _image_std;
+  std::string _input_name;
+  std::string _output_name;
+
+  std::vector<std::string> _labels;
+  std::vector<float> _fdata;
+  std::vector<float> _outputs;
+  int _num_classes;
+};
+
+#endif // __TFLITE_CLASSIFY_IMAGE_CLASSIFIER_H__
diff --git a/contrib/tflite_classify/src/InferenceInterface.cc b/contrib/tflite_classify/src/InferenceInterface.cc
new file mode 100644 (file)
index 0000000..27d3382
--- /dev/null
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "InferenceInterface.h"
+
+using namespace tflite;
+using namespace tflite::ops::builtin;
+
+InferenceInterface::InferenceInterface(const std::string &model_file, const bool use_nnapi)
+    : _interpreter(nullptr), _model(nullptr), _sess(nullptr)
+{
+  // Load model
+  StderrReporter error_reporter;
+  _model = FlatBufferModel::BuildFromFile(model_file.c_str(), &error_reporter);
+  BuiltinOpResolver resolver;
+  InterpreterBuilder builder(*_model, resolver);
+  builder(&_interpreter);
+
+  if (use_nnapi)
+  {
+    _sess = std::make_shared<nnfw::support::tflite::NNAPISession>(_interpreter.get());
+  }
+  else
+  {
+    _sess = std::make_shared<nnfw::support::tflite::InterpreterSession>(_interpreter.get());
+  }
+
+  _sess->prepare();
+}
+
+InferenceInterface::~InferenceInterface() { _sess->teardown(); }
+
+void InferenceInterface::feed(const std::string &input_name, const std::vector<float> &data,
+                              const int batch, const int height, const int width, const int channel)
+{
+  // Set input tensor
+  for (const auto &id : _interpreter->inputs())
+  {
+    if (_interpreter->tensor(id)->name == input_name)
+    {
+      assert(_interpreter->tensor(id)->type == kTfLiteFloat32);
+      float *p = _interpreter->tensor(id)->data.f;
+
+      // TODO consider batch
+      for (int y = 0; y < height; ++y)
+      {
+        for (int x = 0; x < width; ++x)
+        {
+          for (int c = 0; c < channel; ++c)
+          {
+            *p++ = data[y * width * channel + x * channel + c];
+          }
+        }
+      }
+    }
+  }
+}
+
+void InferenceInterface::run(const std::string &output_name)
+{
+  // Run model
+  _sess->run();
+}
+
+void InferenceInterface::fetch(const std::string &output_name, std::vector<float> &outputs)
+{
+  // Get output tensor
+  for (const auto &id : _interpreter->outputs())
+  {
+    if (_interpreter->tensor(id)->name == output_name)
+    {
+      assert(_interpreter->tensor(id)->type == kTfLiteFloat32);
+      assert(getTensorSize(output_name) == outputs.capacity());
+      float *p = _interpreter->tensor(id)->data.f;
+
+      outputs.clear();
+      for (int i = 0; i < outputs.capacity(); ++i)
+      {
+        outputs.push_back(p[i]);
+      }
+    }
+  }
+}
+
+int InferenceInterface::getTensorSize(const std::string &name)
+{
+  for (const auto &id : _interpreter->outputs())
+  {
+    if (_interpreter->tensor(id)->name == name)
+    {
+      TfLiteTensor *t = _interpreter->tensor(id);
+      int v = 1;
+      for (int i = 0; i < t->dims->size; ++i)
+      {
+        v *= t->dims->data[i];
+      }
+      return v;
+    }
+  }
+  return -1;
+}
diff --git a/contrib/tflite_classify/src/InferenceInterface.h b/contrib/tflite_classify/src/InferenceInterface.h
new file mode 100644 (file)
index 0000000..2f01190
--- /dev/null
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file     InferenceInterface.h
+ * @brief    This file contains class for running the actual inference model
+ * @ingroup  COM_AI_RUNTIME
+ */
+
+#ifndef __TFLITE_CLASSIFY_INFERENCE_INTERFACE_H__
+#define __TFLITE_CLASSIFY_INFERENCE_INTERFACE_H__
+
+#include "support/tflite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+#include "support/tflite/InterpreterSession.h"
+#include "support/tflite/NNAPISession.h"
+
+#include <iostream>
+#include <string>
+
+/**
+ * @brief Class to define a inference interface for recognizing data
+ */
+class InferenceInterface
+{
+public:
+  /**
+   * @brief Construct a new InferenceInterface object with parameters
+   * @param[in] model_file  The filepath of the model FlatBuffer protocol buffer
+   * @param[in] use_nnapi   The flag to distinguish between TfLite interpreter and NNFW runtime
+   */
+  InferenceInterface(const std::string &model_file, const bool use_nnapi);
+
+  /**
+   * @brief Destructor an InferenceInterface object
+   */
+  ~InferenceInterface();
+
+  /**
+   * @brief Copy the input data into model
+   * @param[in] input_name  The label of the image input node
+   * @param[in] data        The actual data to be copied into input tensor
+   * @param[in] batch       The number of batch size
+   * @param[in] height      The number of height size
+   * @param[in] width       The number of width size
+   * @param[in] channel     The number of channel size
+   * @return N/A
+   */
+  void feed(const std::string &input_name, const std::vector<float> &data, const int batch,
+            const int height, const int width, const int channel);
+  /**
+   * @brief Run the inference call
+   * @param[in] output_name The label of the output node
+   * @return N/A
+   */
+  void run(const std::string &output_name);
+
+  /**
+   * @brief Copy the output tensor back into the output array
+   * @param[in] output_node The label of the output node
+   * @param[in] outputs     The output data array
+   * @return N/A
+   */
+  void fetch(const std::string &output_name, std::vector<float> &outputs);
+
+  /**
+   * @brief Get tensor size
+   * @param[in] name  The label of the node
+   * @result The size of tensor
+   */
+  int getTensorSize(const std::string &name);
+
+private:
+  std::unique_ptr<tflite::Interpreter> _interpreter;
+  std::unique_ptr<tflite::FlatBufferModel> _model;
+  std::shared_ptr<nnfw::support::tflite::Session> _sess;
+};
+
+#endif // __TFLITE_CLASSIFY_INFERENCE_INTERFACE_H__
diff --git a/contrib/tflite_classify/src/tflite_classify.cc b/contrib/tflite_classify/src/tflite_classify.cc
new file mode 100644 (file)
index 0000000..40c15f3
--- /dev/null
@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ImageClassifier.h"
+
+#include <iostream>
+
+#include <boost/filesystem.hpp>
+#include <opencv2/opencv.hpp>
+
+namespace fs = boost::filesystem;
+
+int main(const int argc, char **argv)
+{
+  const std::string MODEL_FILE = "tensorflow_inception_graph.tflite";
+  const std::string LABEL_FILE = "imagenet_comp_graph_label_strings.txt";
+
+  const std::string INPUT_NAME = "input";
+  const std::string OUTPUT_NAME = "output";
+  const int INPUT_SIZE = 224;
+  const int IMAGE_MEAN = 117;
+  const int IMAGE_STD = 1;
+  const int OUTPUT_SIZE = 1008;
+
+  const int FRAME_WIDTH = 640;
+  const int FRAME_HEIGHT = 480;
+
+  bool use_nnapi = false;
+  bool debug_mode = false;
+
+  if (std::getenv("USE_NNAPI") != nullptr)
+  {
+    use_nnapi = true;
+  }
+
+  if (std::getenv("DEBUG_MODE") != nullptr)
+  {
+    debug_mode = true;
+  }
+
+  std::cout << "USE_NNAPI : " << use_nnapi << std::endl;
+  std::cout << "DEBUG_MODE : " << debug_mode << std::endl;
+
+  std::cout << "Model : " << MODEL_FILE << std::endl;
+  std::cout << "Label : " << LABEL_FILE << std::endl;
+
+  if (!fs::exists(MODEL_FILE))
+  {
+    std::cerr << "model file not found: " << MODEL_FILE << std::endl;
+    exit(1);
+  }
+
+  if (!fs::exists(LABEL_FILE))
+  {
+    std::cerr << "label file not found: " << LABEL_FILE << std::endl;
+    exit(1);
+  }
+
+  // Create ImageClassifier
+  std::unique_ptr<ImageClassifier> classifier(
+      new ImageClassifier(MODEL_FILE, LABEL_FILE, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD, INPUT_NAME,
+                          OUTPUT_NAME, use_nnapi));
+
+  // Cam setting
+  cv::VideoCapture cap(0);
+  cv::Mat frame;
+
+  // Initialize camera
+  cap.set(CV_CAP_PROP_FRAME_WIDTH, FRAME_WIDTH);
+  cap.set(CV_CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT);
+  cap.set(CV_CAP_PROP_FPS, 5);
+
+  std::vector<Recognition> results;
+  clock_t begin, end;
+  while (cap.isOpened())
+  {
+    // Get image data
+    if (!cap.read(frame))
+    {
+      std::cout << "Frame is null..." << std::endl;
+      break;
+    }
+
+    if (debug_mode)
+    {
+      begin = clock();
+    }
+    // Recognize image
+    results = classifier->recognizeImage(frame);
+    if (debug_mode)
+    {
+      end = clock();
+    }
+
+    // Show result data
+    std::cout << std::endl;
+    if (results.size() > 0)
+    {
+      for (int i = 0; i < results.size(); ++i)
+      {
+        std::cout << results[i].title << ": " << results[i].confidence << std::endl;
+      }
+    }
+    else
+    {
+      std::cout << "." << std::endl;
+    }
+    if (debug_mode)
+    {
+      std::cout << "Frame: " << FRAME_WIDTH << "x" << FRAME_HEIGHT << std::endl;
+      std::cout << "Crop: " << INPUT_SIZE << "x" << INPUT_SIZE << std::endl;
+      std::cout << "Inference time(ms): " << ((end - begin) / (CLOCKS_PER_SEC / 1000)) << std::endl;
+    }
+  }
+
+  cap.release();
+
+  return 0;
+}