[exo-tflite] shape inference for TFLAveragePool2D (#7345)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 11 Sep 2019 05:04:37 +0000 (14:04 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 11 Sep 2019 05:04:37 +0000 (14:04 +0900)
Adding shape inference for TFLAveragePool2D and two test cases.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
compiler/exo-tflite/src/TestGraph.h [new file with mode: 0644]

index e3ba39c..0f82185 100644 (file)
 
 #include "ShapeInference.h"
 
+#include "Check.h"
+
 #include <cassert>
 
 namespace
 {
 
+// Call this for TFLAvgPool2D and TFLMaxPool2D only
+template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
+{
+  EXO_ASSERT(loco::shape_known(node->value()), "Shape must be known");
+
+  auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
+
+  uint32_t input_height = ifm_shape.dim(1).value();
+  uint32_t input_width = ifm_shape.dim(2).value();
+  uint32_t stride_height = node->stride()->h();
+  uint32_t stride_width = node->stride()->w();
+  uint32_t window_height = node->filter()->h();
+  uint32_t window_width = node->filter()->w();
+  uint32_t dilation_height = 1; // dilation for TFLAvgPool2D and TFLMaxPool2D is 1
+  uint32_t dilation_width = 1;
+  uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
+  uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
+
+  uint32_t output_height;
+  uint32_t output_width;
+
+  if (node->padding() == locoex::Padding::VALID)
+  {
+    output_height = (input_height + stride_height - effective_window_height) / stride_height;
+    output_width = (input_width + stride_width - effective_window_width) / stride_width;
+  }
+  else if (node->padding() == locoex::Padding::SAME)
+  {
+    output_height = (input_height + stride_height - 1) / stride_height;
+    output_width = (input_width + stride_width - 1) / stride_width;
+  }
+  else
+    EXO_ASSERT(false, "Wrong padding type");
+
+  loco::TensorShape ofm_shape;
+  ofm_shape.rank(4);
+  ofm_shape.dim(0) = ifm_shape.dim(0);
+  ofm_shape.dim(1) = output_height;
+  ofm_shape.dim(2) = output_width;
+  ofm_shape.dim(3) = ifm_shape.dim(3);
+
+  return loco::NodeShape{ofm_shape};
+}
+
 /**
  * @brief Class to infer the shape of TFLNode
  *
@@ -52,7 +98,10 @@ public:
 
   // TFLAdd
 
-  // TFLAveragePool2D
+  loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final
+  {
+    return infer_pool_2d_shape(node);
+  }
 
   // TODO TFLConcatenation
 
index c5c375d..eca72e6 100644 (file)
@@ -14,6 +14,8 @@
  * limitations under the License.
  */
 
+#include "TestGraph.h"
+
 #include "Dialect/IR/TFLNodes.h"
 #include "Dialect/IR/TFLDialect.h"
 #include "Dialect/Service/TFLShapeInferenceRule.h"
@@ -81,3 +83,91 @@ TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
     ASSERT_EQ(shape.dim(1), 4);
   }
 }
+
+// based on the case shown in
+// https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow
+TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
+{
+  exo::test::PullPushGraph<locoex::TFLAveragePool2D> test_graph;
+  auto pull = test_graph.pull;
+  {
+    pull->shape({1, 4, 3, 1});
+  }
+  auto tfl_node = test_graph.middle_node;
+  {
+    tfl_node->filter()->h(2);
+    tfl_node->filter()->w(2);
+    tfl_node->stride()->h(2);
+    tfl_node->stride()->w(2);
+    tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+    tfl_node->padding(locoex::Padding::VALID);
+  }
+  ASSERT_FALSE(loco::shape_known(tfl_node));
+
+  // shape inference
+  locoex::TFLShapeInferenceRule tfl_rule;
+  loco::CanonicalShapeInferenceRule canonical_rule;
+  loco::MultiDialectShapeInferenceRule rules;
+
+  rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+      .bind(locoex::TFLDialect::get(), &tfl_rule);
+
+  loco::apply(&rules).to(test_graph.g.get());
+
+  // Verify
+  {
+    ASSERT_TRUE(loco::shape_known(tfl_node));
+    ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+    auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+    ASSERT_EQ(shape.rank(), 4);
+    ASSERT_EQ(shape.dim(0).value(), 1);
+    ASSERT_EQ(shape.dim(1).value(), 2);
+    ASSERT_EQ(shape.dim(2).value(), 1);
+    ASSERT_EQ(shape.dim(3).value(), 1);
+  }
+}
+
+TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
+{
+  exo::test::PullPushGraph<locoex::TFLAveragePool2D> test_graph;
+  auto pull = test_graph.pull;
+  {
+    pull->shape({1, 4, 3, 1});
+  }
+
+  auto tfl_node = test_graph.middle_node;
+  {
+    tfl_node->filter()->h(2);
+    tfl_node->filter()->w(2);
+    tfl_node->stride()->h(2);
+    tfl_node->stride()->w(2);
+    tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+    tfl_node->padding(locoex::Padding::SAME);
+  }
+
+  ASSERT_FALSE(loco::shape_known(tfl_node));
+
+  // shape inference
+  locoex::TFLShapeInferenceRule tfl_rule;
+  loco::CanonicalShapeInferenceRule canonical_rule;
+  loco::MultiDialectShapeInferenceRule rules;
+
+  rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+      .bind(locoex::TFLDialect::get(), &tfl_rule);
+
+  loco::apply(&rules).to(test_graph.g.get());
+
+  // Verify
+  {
+    ASSERT_TRUE(loco::shape_known(tfl_node));
+    ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+    auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+    ASSERT_EQ(shape.rank(), 4);
+    ASSERT_EQ(shape.dim(0).value(), 1);
+    ASSERT_EQ(shape.dim(1).value(), 2);
+    ASSERT_EQ(shape.dim(2).value(), 2);
+    ASSERT_EQ(shape.dim(3).value(), 1);
+  }
+}
diff --git a/compiler/exo-tflite/src/TestGraph.h b/compiler/exo-tflite/src/TestGraph.h
new file mode 100644 (file)
index 0000000..11903d3
--- /dev/null
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TEST_GRAPH_H__
+#define __TEST_GRAPH_H__
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco.h>
+
+#include <stdex/Memory.h>
+
+namespace exo
+{
+namespace test
+{
+
+// graph to build [Pull - some node of type T - Push]
+template <typename T> struct PullPushGraph
+{
+public:
+  std::unique_ptr<loco::Graph> g;
+  loco::Pull *pull;
+  loco::Push *push;
+  T *middle_node;
+
+  PullPushGraph()
+  {
+    // g = Pull - T - Push
+    g = loco::make_graph();
+
+    pull = g->nodes()->create<loco::Pull>();
+
+    middle_node = g->nodes()->create<T>();
+    {
+      setInput();
+    }
+
+    push = g->nodes()->create<loco::Push>();
+    {
+      push->from(middle_node);
+    }
+
+    auto input = g->inputs()->create();
+    {
+      input->name("input");
+      loco::link(input, pull);
+    }
+    auto output = g->outputs()->create();
+    {
+      output->name("output");
+      loco::link(output, push);
+    }
+  }
+
+private:
+  void setInput(); // set the input of T
+};
+
+// setInput of TFL nodes
+template <> void PullPushGraph<locoex::TFLAveragePool2D>::setInput() { middle_node->value(pull); }
+
+} // namespace test
+} // namespace exo
+
+#endif // __TEST_GRAPH_H__