[loco/Service] FixedReshape shape inference (#6594)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Wed, 14 Aug 2019 07:29:40 +0000 (16:29 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 14 Aug 2019 07:29:40 +0000 (16:29 +0900)
This commit introduces shape inference logic of FixedReshape and its
test

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp

index 243693a..440a35d 100644 (file)
@@ -309,7 +309,19 @@ public:
     return loco::NodeShape{node->encoder()->shape(input_tensor_shape)};
   }
 
-  // TODO Support FixedReshape
+  // CASE: FixedReshape
+  loco::NodeShape visit(const loco::FixedReshape *node) final
+  {
+    loco::TensorShape tensor_shape;
+
+    tensor_shape.rank(node->rank());
+    for (uint32_t axis = 0; axis < node->rank(); ++axis)
+    {
+      tensor_shape.dim(axis) = node->dim(axis);
+    }
+
+    return loco::NodeShape{tensor_shape};
+  }
 
   // CASE: MaxPool2D
   loco::NodeShape visit(const loco::MaxPool2D *node) final
index 05069d5..97114d3 100644 (file)
@@ -206,3 +206,24 @@ TEST(CanonicalShapeInferenceRuleTest, tensor_concat)
   ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(1), 6);
   ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(2), 3);
 }
+
+TEST(CanonicalShapeInferenceRuleTest, fixed_reshape)
+{
+  // Create a sample network
+  GraphTestcase<GraphCode::FixedReshape> testcase;
+
+  testcase.pull_node->shape({6, 6});
+  testcase.reshape_node->shape({4, 9});
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(testcase.graph());
+
+  // Verify!
+  ASSERT_TRUE(loco::shape_known(testcase.push_node));
+  ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 2);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 9);
+}