--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__
+#define __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__
+
+#include "loco/Service/ShapeInferenceRule.h"
+
+namespace loco
+{
+
+/**
+ * @brief Shape inference rule for canonical dialect
+ */
+struct CanonicalShapeInferenceRule final : public ShapeInferenceRule
+{
+ bool recognize(const Dialect *) const final;
+ bool infer(const Node *, NodeShape &) const final;
+};
+
+} // namespace loco
+
+#endif // __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/CanonicalShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
+#include <cassert>
+
+namespace
+{
+
+/**
+ * There are two possible maintenance policies.
+ * - Introduce a new canonical node first, and then extend this algorithm later
+ * - Introduce a new canonical node and extend this algorithm at the same time
+ *
+ * The current implementation assumes the former one (for historical reason).
+ *
+ * TODO Evaluate the impact of the latter one
+ *
+ * NOTE "Forward" means that this algorithm computes the ouput shape from inputs shapes
+ */
+class ForwardShapeInferenceAlgorithm final : public loco::CanonicalNodeVisitor<loco::NodeShape>
+{
+public:
+ // TODO Support AvgPool2D
+ // TODO Support BiasEncode
+ // TODO Support ConstGen
+ // TODO Support Conv2D
+ // TODO Support DepthwiseConv2D
+ // TODO Support DepthwiseFilterEncode
+ // TODO Support EltwiseAdd
+ // TODO Support EltwiseMul
+ // TODO Support Forward
+ // TODO Support FeatureBiasAdd
+ // TODO Support FeatureDecode
+ // TODO Support FeatureEncode
+ // TODO Support FilterEncode
+ // TODO Support FixedReshape
+ // TODO Support MaxPool2D
+
+ // CASE: Push
+ loco::NodeShape visit(const loco::Push *node) final
+ {
+ assert(loco::shape_known(node->from()));
+ return loco::shape_get(node->from());
+ }
+
+ // CASE: Pull
+ loco::NodeShape visit(const loco::Pull *node) final
+ {
+ // Build a tensor shape from "Pull" node
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(node->rank());
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ tensor_shape.dim(axis) = node->dim(axis);
+ }
+
+ return loco::NodeShape{tensor_shape};
+ }
+
+ // TODO Support ReLU
+ // TODO Support ReLU6
+ // TODO Support TensorBiasAdd
+ // TODO SUpport TensorConcat
+};
+
+} // namespace
+
+namespace loco
+{
+
+bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const
+{
+ return CanonicalDialect::get() == d;
+}
+
+bool CanonicalShapeInferenceRule::infer(const Node *node, NodeShape &shape) const
+{
+ assert(node->dialect() == loco::CanonicalDialect::get());
+ assert(dynamic_cast<const loco::CanonicalNode *>(node) != nullptr);
+
+ ForwardShapeInferenceAlgorithm alg;
+ shape = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg);
+
+ return true;
+}
+
+} // namespace loco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/CanonicalShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include "GraphTestcase.h"
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+TEST(CanonicalShapeInferenceRuleTest, minimal)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::Identity> testcase;
+
+ testcase.pull_node->shape({1, 2, 3, 4});
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.push_node));
+ ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
+}