[tflchef] Prepare adding Ops with AveragePool2D (#2368)
author박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 22 Nov 2018 01:33:22 +0000 (10:33 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 22 Nov 2018 01:33:22 +0000 (10:33 +0900)
This will add AveragePool2D operator handler and needed changes to test.
Also place holders for other operators to prevent merge conflict.

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/tflchef/tests/CMakeLists.txt
contrib/tflchef/tflite/CMakeLists.txt
contrib/tflchef/tflite/src/Op/AveragePool2D.cpp [new file with mode: 0644]
contrib/tflchef/tflite/src/Op/AveragePool2D.h [new file with mode: 0644]
contrib/tflchef/tflite/src/TFliteOpChefs.h
contrib/tflchef/tflite/src/TFliteOpRegistry.h

index 673100c..81fb38d 100644 (file)
@@ -32,6 +32,48 @@ foreach(RECIPE IN ITEMS ${RECIPES})
   list(APPEND TESTS ${RECIPE_PREFIX})
 endforeach(RECIPE)
 
+# Test tflchef-reverse
+list(APPEND GEN_TFLITEFILES "averagepool2d/test.recipe")
+#list(APPEND GEN_TFLITEFILES "concatenation/test.recipe")
+#list(APPEND GEN_TFLITEFILES "conv2d/test.recipe")
+#list(APPEND GEN_TFLITEFILES "depthwiseconv2d/test.recipe")
+#list(APPEND GEN_TFLITEFILES "maxpool2d/test.recipe")
+#list(APPEND GEN_TFLITEFILES "relu/test.recipe")
+#list(APPEND GEN_TFLITEFILES "relu6/test.recipe")
+#list(APPEND GEN_TFLITEFILES "reshape/test.recipe")
+# TODO if all operates are added we can just use one line of "*/test.recipe"
+
+foreach(TFLITEFILE IN ITEMS ${GEN_TFLITEFILES})
+  get_filename_component(TFLITE_PREFIX ${TFLITEFILE} DIRECTORY)
+
+  # file from above tflchef-file block
+  # use tflite file as input of tflchef-reverse generated from tflchef-file
+  set(RECIPE_OUTPUT_FILE "${TFLITE_PREFIX}.tflite")
+  set(RECIPE_OUTPUT_TARGET tflchef_${TFLITE_PREFIX}_tflite)
+
+  set(RECIPE_GEN_OUTPUT_FILE "${TFLITE_PREFIX}.gen.recipe")
+  set(RECIPE_GEN_OUTPUT_TARGET tflchef_${TFLITE_PREFIX}_gen_recipe)
+
+  # Generate .gen.recipe from generated .tflite
+  add_custom_target(${RECIPE_GEN_OUTPUT_TARGET}
+                    ALL $<TARGET_FILE:tflchef-reverse> ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
+                    DEPENDS ${RECIPE_OUTPUT_TARGET}
+                    COMMENT "Generate ${TFLITE_PREFIX}.gen.recipe")
+
+  # now we are going to generate .gen.tflite from .gen.recipe
+  # to check generated .gen.recipe file is correct by using it.
+  # as weight values may be different, binary comparision is not acceptable.
+  set(RECIPE_GEN_OUTPUT_FILE2 "${TFLITE_PREFIX}.gen.tflite")
+  set(RECIPE_GEN_OUTPUT_TARGET2 tflchef_${TFLITE_PREFIX}_gen_tflite)
+
+  add_custom_target(${RECIPE_GEN_OUTPUT_TARGET2}
+                    ALL $<TARGET_FILE:tflchef-file> ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
+                    DEPENDS ${RECIPE_GEN_OUTPUT_TARGET}
+                    COMMENT "Generate ${TFLITE_PREFIX}.gen.tflite")
+
+  list(APPEND TESTS ${TFLITE_PREFIX}.gen)
+endforeach(TFLITEFILE)
+
 add_test(NAME tflchef_test
          COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/runall"
                  $<TARGET_FILE:nni>
index 0c33ee0..091f4af 100644 (file)
@@ -5,3 +5,4 @@ target_include_directories(tflchef_tflite PUBLIC include)
 target_include_directories(tflchef_tflite PRIVATE src)
 target_link_libraries(tflchef_tflite tflchef_proto)
 target_link_libraries(tflchef_tflite tflchef_flatbuffer)
+target_link_libraries(tflchef_tflite stdex)
diff --git a/contrib/tflchef/tflite/src/Op/AveragePool2D.cpp b/contrib/tflchef/tflite/src/Op/AveragePool2D.cpp
new file mode 100644 (file)
index 0000000..1f269e4
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "AveragePool2D.h"
+
+#include "Convert.h"
+
+namespace tflchef
+{
+
+void TFliteOpAveragePool2D::filler(const tflite::Operator *op, TFliteImport *import,
+                                   tflchef::ModelRecipe *model_recipe) const
+{
+  // Nothing to do with filler
+}
+
+tflchef::Operation *TFliteOpAveragePool2D::build(const tflite::Operator *op, TFliteImport *import,
+                                                 tflchef::ModelRecipe *model_recipe) const
+{
+  auto op_params = op->builtin_options_as_Pool2DOptions();
+  assert(op_params != nullptr);
+
+  auto operation = model_recipe->add_operation();
+
+  operation->set_type("AveragePool2D");
+
+  auto op_options = operation->mutable_averagepool2d_options();
+
+  op_options->set_padding(as_tflchef_padding(op_params->padding()));
+  op_options->set_stride_h(op_params->stride_h());
+  op_options->set_stride_w(op_params->stride_w());
+  op_options->set_filter_height(op_params->filter_height());
+  op_options->set_filter_width(op_params->filter_width());
+  op_options->set_activation(as_tflchef_activation(op_params->fused_activation_function()));
+
+  return operation;
+}
+
+} // namespace tflchef
diff --git a/contrib/tflchef/tflite/src/Op/AveragePool2D.h b/contrib/tflchef/tflite/src/Op/AveragePool2D.h
new file mode 100644 (file)
index 0000000..f9e9fb2
--- /dev/null
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TFLITE_OP_AVERAGEPOOL2D_H__
+#define __TFLITE_OP_AVERAGEPOOL2D_H__
+
+#include "TFliteOpChef.h"
+
+namespace tflchef
+{
+
+/**
+ * @brief tflchef operator builder for AveragePool2D
+ */
+class TFliteOpAveragePool2D : public TFliteOpChef
+{
+public:
+  void filler(const tflite::Operator *op, TFliteImport *import,
+              tflchef::ModelRecipe *model_recipe) const override;
+  tflchef::Operation *build(const tflite::Operator *op, TFliteImport *import,
+                            tflchef::ModelRecipe *model_recipe) const override;
+};
+
+} // namespace tflchef
+
+#endif // __TFLITE_OP_AVERAGEPOOL2D_H__
index ac770b1..73e1ed6 100644 (file)
 #ifndef __TFLITE_OP_CHEFS_H__
 #define __TFLITE_OP_CHEFS_H__
 
-// TODO add operator headers
 // In alphabet order
+#include "Op/AveragePool2D.h"
+//#include "Op/Concatenation.h"
+//#include "Op/Conv2D.h"
+//#include "Op/DepthwiseConv2D.h"
+//#include "Op/MaxPool2D.h"
+//#include "Op/ReLU.h"
+//#include "Op/ReLU6.h"
+//#include "Op/Reshape.h"
 
 #endif // __TFLITE_OP_CHEFS_H__
index ed67d26..f976221 100644 (file)
 #include "TFliteOpChef.h"
 #include "TFliteOpChefs.h"
 
+#include <stdex/Memory.h>
+
+using stdex::make_unique;
+
 namespace tflchef
 {
 
@@ -50,7 +54,15 @@ public:
 private:
   TFliteOpRegistry()
   {
-    // TODO add TFliteOpChef for each tflite operation.
+    _tfliteop_map[tflite::BuiltinOperator_AVERAGE_POOL_2D] = make_unique<TFliteOpAveragePool2D>();
+    //_tfliteop_map[tflite::BuiltinOperator_CONCATENATION] = make_unique<TFliteOpConcatenation>();
+    //_tfliteop_map[tflite::BuiltinOperator_CONV_2D] = make_unique<TFliteOpConv2D>();
+    //_tfliteop_map[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] =
+    //    make_unique<TFliteOpDepthwiseConv2D>();
+    //_tfliteop_map[tflite::BuiltinOperator_MAX_POOL_2D] = make_unique<TFliteOpMaxPool2D>();
+    //_tfliteop_map[tflite::BuiltinOperator_RELU] = make_unique<TFliteOpReLU>();
+    //_tfliteop_map[tflite::BuiltinOperator_RELU6] = make_unique<TFliteOpReLU6>();
+    //_tfliteop_map[tflite::BuiltinOperator_RESHAPE] = make_unique<TFliteOpReshape>();
   }
 
 private: