[loco] Add CanonicalShapeInferenceRule (#6126)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 2 Aug 2019 04:28:35 +0000 (13:28 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 2 Aug 2019 04:28:35 +0000 (13:28 +0900)
This commit adds CanonicalShapeInferenceRule class with minimal
implementation.

The current implementation supports shape inference only for Push and
Pull nodes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h [new file with mode: 0644]
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp [new file with mode: 0644]
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp [new file with mode: 0644]

diff --git a/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h b/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h
new file mode 100644 (file)
index 0000000..3ef6fee
--- /dev/null
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__
+#define __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__
+
+#include "loco/Service/ShapeInferenceRule.h"
+
+namespace loco
+{
+
+/**
+ * @brief Shape inference rule for canonical dialect
+ */
+struct CanonicalShapeInferenceRule final : public ShapeInferenceRule
+{
+  bool recognize(const Dialect *) const final;
+  bool infer(const Node *, NodeShape &) const final;
+};
+
+} // namespace loco
+
+#endif // __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
new file mode 100644 (file)
index 0000000..77d887e
--- /dev/null
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/CanonicalShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
+#include <cassert>
+
+namespace
+{
+
+/**
+ * There are two possible maintenance policies.
+ * - Introduce a new canonical node first, and then extend this algorithm later
+ * - Introduce a new canonical node and extend this algorithm at the same time
+ *
+ * The current implementation assumes the former one (for historical reason).
+ *
+ * TODO Evaluate the impact of the latter one
+ *
+ * NOTE "Forward" means that this algorithm computes the ouput shape from inputs shapes
+ */
+class ForwardShapeInferenceAlgorithm final : public loco::CanonicalNodeVisitor<loco::NodeShape>
+{
+public:
+  // TODO Support AvgPool2D
+  // TODO Support BiasEncode
+  // TODO Support ConstGen
+  // TODO Support Conv2D
+  // TODO Support DepthwiseConv2D
+  // TODO Support DepthwiseFilterEncode
+  // TODO Support EltwiseAdd
+  // TODO Support EltwiseMul
+  // TODO Support Forward
+  // TODO Support FeatureBiasAdd
+  // TODO Support FeatureDecode
+  // TODO Support FeatureEncode
+  // TODO Support FilterEncode
+  // TODO Support FixedReshape
+  // TODO Support MaxPool2D
+
+  // CASE: Push
+  loco::NodeShape visit(const loco::Push *node) final
+  {
+    assert(loco::shape_known(node->from()));
+    return loco::shape_get(node->from());
+  }
+
+  // CASE: Pull
+  loco::NodeShape visit(const loco::Pull *node) final
+  {
+    // Build a tensor shape from "Pull" node
+    loco::TensorShape tensor_shape;
+
+    tensor_shape.rank(node->rank());
+    for (uint32_t axis = 0; axis < node->rank(); ++axis)
+    {
+      tensor_shape.dim(axis) = node->dim(axis);
+    }
+
+    return loco::NodeShape{tensor_shape};
+  }
+
+  // TODO Support ReLU
+  // TODO Support ReLU6
+  // TODO Support TensorBiasAdd
+  // TODO SUpport TensorConcat
+};
+
+} // namespace
+
+namespace loco
+{
+
+bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const
+{
+  return CanonicalDialect::get() == d;
+}
+
+bool CanonicalShapeInferenceRule::infer(const Node *node, NodeShape &shape) const
+{
+  assert(node->dialect() == loco::CanonicalDialect::get());
+  assert(dynamic_cast<const loco::CanonicalNode *>(node) != nullptr);
+
+  ForwardShapeInferenceAlgorithm alg;
+  shape = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg);
+
+  return true;
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
new file mode 100644 (file)
index 0000000..c0e3b3f
--- /dev/null
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/CanonicalShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include "GraphTestcase.h"
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+TEST(CanonicalShapeInferenceRuleTest, minimal)
+{
+  // Create a sample network
+  GraphTestcase<GraphCode::Identity> testcase;
+
+  testcase.pull_node->shape({1, 2, 3, 4});
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(testcase.graph());
+
+  // Verify!
+  ASSERT_TRUE(loco::shape_known(testcase.push_node));
+  ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 4);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
+}