[exo-tflite] Introducing TFLShapeInference (#6991)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Thu, 29 Aug 2019 08:19:11 +0000 (17:19 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 29 Aug 2019 08:19:11 +0000 (17:19 +0900)
* [exo-tflite] Introducing TFLShapeInference

This adds TFLShapeInference that infer the shape of TFL nodes and TFLShapeAnnot that passes shape from canonical node.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* adding missing graph input

* fix typo

* remove Forward word from internal class naming

compiler/exo-tflite/src/Dialect/Service/TFLShapeAnnot.h [new file with mode: 0644]
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp [new file with mode: 0644]
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.h [new file with mode: 0644]
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp [new file with mode: 0644]

diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeAnnot.h b/compiler/exo-tflite/src/Dialect/Service/TFLShapeAnnot.h
new file mode 100644 (file)
index 0000000..32d7e82
--- /dev/null
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_TFLSHAPE_ANNOT__
+#define __LOCOEX_SERVICE_TFLSHAPE_ANNOT__
+
+#include <loco/IR/Node.h>
+#include <loco/IR/TensorShape.h>
+
+namespace locoex
+{
+
+/**
+ * @brief Class to annotate shape to a TFL node
+ */
+struct TFLShapeAnnot : public loco::NodeAnnotation
+{
+public:
+  TFLShapeAnnot(const loco::TensorShape &shape) : _shape(shape) {}
+
+public:
+  const loco::TensorShape shape(void) const { return _shape; }
+
+private:
+  const loco::TensorShape _shape;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_TFLSHAPE_ANNOT__
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
new file mode 100644 (file)
index 0000000..92615a8
--- /dev/null
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLShapeInferenceRule.h"
+#include "TFLShapeAnnot.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include "ShapeInference.h"
+
+#include <cassert>
+
+namespace
+{
+
+/**
+ * @brief Class to infer the shape of TFLNode
+ *
+ * @note All TFLNode's inputs and outouts are always loco::Domain::Tensor
+ */
+class ShapeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::NodeShape>
+{
+public:
+  loco::NodeShape visit(const locoex::TFLNode *node) final
+  {
+    if (loco::shape_known(node)) // if shape was already inferred by inference rule
+    {
+      assert(loco::shape_get(node).domain() == loco::Domain::Tensor);
+
+      return loco::shape_get(node);
+    }
+    else
+    { // getting Shape data that was annotated while converting canonical node to TFLNode
+      auto shape_hint = node->annot<locoex::TFLShapeAnnot>();
+      assert(shape_hint != nullptr);
+
+      return shape_hint->shape();
+    }
+  }
+};
+
+} // namespace
+
+namespace locoex
+{
+
+bool TFLShapeInferenceRule::recognize(const loco::Dialect *d) const
+{
+  return TFLDialect::get() == d;
+}
+
+bool TFLShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
+{
+  assert(node->dialect() == TFLDialect::get());
+  assert(dynamic_cast<const TFLNode *>(node) != nullptr);
+
+  ShapeInferenceAlgorithm alg;
+  shape = dynamic_cast<const TFLNode *>(node)->accept(&alg);
+
+  return true;
+}
+
+} // namespace locoex
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.h b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.h
new file mode 100644 (file)
index 0000000..434a145
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
+#define __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
+
+#include <loco/Service/ShapeInference.h>
+
+namespace locoex
+{
+
+struct TFLShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+  bool recognize(const loco::Dialect *) const final;
+  bool infer(const loco::Node *, loco::NodeShape &) const final;
+};
+
+} // namespace locoex
+
+#endif // __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
new file mode 100644 (file)
index 0000000..de2e966
--- /dev/null
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/Service/TFLShapeAnnot.h"
+#include "Dialect/Service/TFLShapeInferenceRule.h"
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
+{
+  // Create a simple network
+  auto g = loco::make_graph();
+
+  auto pull_node = g->nodes()->create<loco::Pull>();
+
+  auto tfl_node = g->nodes()->create<locoex::TFLRelu>();
+  tfl_node->input(pull_node);
+
+  auto push_node = g->nodes()->create<loco::Push>();
+  push_node->from(tfl_node);
+
+  auto input = g->inputs()->create();
+  {
+    input->name("input");
+    loco::link(input, pull_node);
+  }
+  auto output = g->outputs()->create();
+  {
+    output->name("output");
+    loco::link(output, push_node);
+  }
+
+  // pre-check
+  ASSERT_FALSE(loco::shape_known(tfl_node));
+
+  // scenario.
+  // step 1. add annotation and run shape inference.
+  //         TFLShapeInference will get shape info from annotated data
+  // step 2. then, run shape inference again
+
+  // step 1.
+  loco::TensorShape ts;
+  {
+    ts.rank(2);
+    ts.dim(0) = 1;
+    ts.dim(1) = 3;
+  }
+  auto shape_annot = stdex::make_unique<locoex::TFLShapeAnnot>(ts);
+  tfl_node->annot<locoex::TFLShapeAnnot>(std::move(shape_annot));
+
+  locoex::TFLShapeInferenceRule tfl_rule;
+  loco::apply(&tfl_rule).to(g.get());
+
+  // Verify
+  auto check_shape = [](locoex::TFLRelu *tfl_node) {
+    ASSERT_TRUE(loco::shape_known(tfl_node));
+    ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+    auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+    ASSERT_EQ(shape.rank(), 2);
+    ASSERT_EQ(shape.dim(0), 1);
+    ASSERT_EQ(shape.dim(1), 3);
+  };
+
+  check_shape(tfl_node);
+
+  // step 2.
+  loco::apply(&tfl_rule).to(g.get());
+
+  check_shape(tfl_node);
+}