[loco] TensorTranspose: Shape and Type inference (#7796)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Fri, 27 Sep 2019 04:41:33 +0000 (13:41 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 27 Sep 2019 04:41:33 +0000 (13:41 +0900)
* [loco] TensorTranspose: Shape and Type inference

This enables shape and type inference for loco::TensorTranspose. A TC for shape inference was also added.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* Transpose -> TensorTranspose

compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
compiler/loco/src/Service/GraphTestcase.h
compiler/loco/src/Service/TypeInference.cpp

index a6a46f5..8d37c22 100644 (file)
@@ -596,6 +596,25 @@ public:
   // CASE: TensorSoftmax
   loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); }
 
+  // CASE: TensorTranspose
+  loco::NodeShape visit(const loco::TensorTranspose *node) final
+  {
+    loco::TensorShape output_shape;
+
+    auto input_shape = node_shape(node->input()).as<loco::TensorShape>();
+    assert(input_shape.rank() == node->perm()->size());
+
+    output_shape.rank(input_shape.rank());
+
+    for (uint32_t output_axis = 0; output_axis < output_shape.rank(); output_axis++)
+    {
+      auto new_dim = input_shape.dim(node->perm()->axis(output_axis));
+      output_shape.dim(output_axis) = new_dim;
+    }
+
+    return loco::NodeShape(output_shape);
+  }
+
   // CASE: TransposedConv2D
   loco::NodeShape visit(const loco::TransposedConv2D *node) final
   {
index 39cfb3f..5cc8c38 100644 (file)
@@ -316,6 +316,34 @@ TEST(CanonicalShapeInferenceRuleTest, tensor_broadcast)
   ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
 }
 
+TEST(CanonicalShapeInferenceRuleTest, tensor_transpose)
+{
+  // Create a sample network
+  GraphTestcase<GraphCode::TensorTranspose> tc;
+
+  tc.pull_node->shape({10, 20, 30, 40});
+
+  tc.transpose_node->perm()->size(4);
+  tc.transpose_node->perm()->axis(0) = 2;
+  tc.transpose_node->perm()->axis(1) = 3;
+  tc.transpose_node->perm()->axis(2) = 0;
+  tc.transpose_node->perm()->axis(3) = 1;
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(tc.graph());
+
+  // Verify!
+  ASSERT_TRUE(loco::shape_known(tc.push_node));
+  ASSERT_EQ(loco::shape_get(tc.push_node).domain(), loco::Domain::Tensor);
+  ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().rank(), 4);
+  ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(0), 30);
+  ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(1), 40);
+  ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(2), 10);
+  ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(3), 20);
+}
+
 namespace
 {
 
index 0c8fd45..6743b9a 100644 (file)
@@ -36,6 +36,7 @@ enum class GraphCode
   MaxPool2D,
   TensorBroadcast,
   TensorConcat,
+  TensorTranspose,
   FixedReshape,
 };
 
@@ -493,4 +494,48 @@ private:
   std::unique_ptr<loco::Graph> _graph;
 };
 
+template <> class GraphTestcase<GraphCode::TensorTranspose> final
+{
+public:
+  GraphTestcase()
+  {
+    using namespace loco;
+
+    // Create a sample network
+    _graph = make_graph();
+
+    // Create Graph Input/Output
+    auto graph_input = _graph->inputs()->create();
+    auto graph_output = _graph->outputs()->create();
+
+    graph_input->name("input");
+    graph_output->name("output");
+
+    // Create and connect nodes
+    pull_node = _graph->nodes()->create<Pull>();
+    pull_node->index(0);
+
+    transpose_node = _graph->nodes()->create<TensorTranspose>();
+    transpose_node->input(pull_node);
+
+    push_node = _graph->nodes()->create<loco::Push>();
+    push_node->index(0);
+    push_node->from(transpose_node);
+
+    // Create a link between input/output and corresponding nodes
+    loco::link(graph_input, pull_node);
+    loco::link(graph_output, push_node);
+  }
+
+public:
+  loco::Graph *graph() { return _graph.get(); }
+
+  loco::Pull *pull_node = nullptr;
+  loco::TensorTranspose *transpose_node = nullptr;
+  loco::Push *push_node = nullptr;
+
+private:
+  std::unique_ptr<loco::Graph> _graph;
+};
+
 #endif // __GRAPH_TESTCASE_H__
index 8827113..f3bb998 100644 (file)
@@ -151,6 +151,7 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<l
   loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); }
+  loco::DataType visit(const loco::TensorTranspose *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TransposedConv2D *node) { return loco::dtype_get(node->ifm()); }
 };