From 4d10f31314af33b30225e9767d18a2771ea2f381 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 27 Sep 2019 13:41:33 +0900 Subject: [PATCH] [loco] TensorTranspose: Shape and Type inference (#7796) * [loco] TensorTranspose: Shape and Type inference This enables shape and type inference for loco::TensorTranspose. A TC for shape inference was also added. Signed-off-by: Hyun Sik Yoon * Transpose -> TensorTranspose --- .../src/Service/CanonicalShapeInferenceRule.cpp | 19 +++++++++ .../Service/CanonicalShapeInferenceRule.test.cpp | 28 ++++++++++++++ compiler/loco/src/Service/GraphTestcase.h | 45 ++++++++++++++++++++++ compiler/loco/src/Service/TypeInference.cpp | 1 + 4 files changed, 93 insertions(+) diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index a6a46f5..8d37c22 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -596,6 +596,25 @@ public: // CASE: TensorSoftmax loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); } + // CASE: TensorTranspose + loco::NodeShape visit(const loco::TensorTranspose *node) final + { + loco::TensorShape output_shape; + + auto input_shape = node_shape(node->input()).as(); + assert(input_shape.rank() == node->perm()->size()); + + output_shape.rank(input_shape.rank()); + + for (uint32_t output_axis = 0; output_axis < output_shape.rank(); output_axis++) + { + auto new_dim = input_shape.dim(node->perm()->axis(output_axis)); + output_shape.dim(output_axis) = new_dim; + } + + return loco::NodeShape(output_shape); + } + // CASE: TransposedConv2D loco::NodeShape visit(const loco::TransposedConv2D *node) final { diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp index 39cfb3f..5cc8c38 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -316,6 +316,34 @@ TEST(CanonicalShapeInferenceRuleTest, tensor_broadcast) ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(1), 2); } +TEST(CanonicalShapeInferenceRuleTest, tensor_transpose) +{ + // Create a sample network + GraphTestcase tc; + + tc.pull_node->shape({10, 20, 30, 40}); + + tc.transpose_node->perm()->size(4); + tc.transpose_node->perm()->axis(0) = 2; + tc.transpose_node->perm()->axis(1) = 3; + tc.transpose_node->perm()->axis(2) = 0; + tc.transpose_node->perm()->axis(3) = 1; + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(tc.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(tc.push_node)); + ASSERT_EQ(loco::shape_get(tc.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(tc.push_node).as().rank(), 4); + ASSERT_EQ(loco::shape_get(tc.push_node).as().dim(0), 30); + ASSERT_EQ(loco::shape_get(tc.push_node).as().dim(1), 40); + ASSERT_EQ(loco::shape_get(tc.push_node).as().dim(2), 10); + ASSERT_EQ(loco::shape_get(tc.push_node).as().dim(3), 20); +} + namespace { diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index 0c8fd45..6743b9a 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -36,6 +36,7 @@ enum class GraphCode MaxPool2D, TensorBroadcast, TensorConcat, + TensorTranspose, FixedReshape, }; @@ -493,4 +494,48 @@ private: std::unique_ptr _graph; }; +template <> class GraphTestcase final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Create a sample network + _graph = make_graph(); + + // Create Graph Input/Output + auto graph_input = _graph->inputs()->create(); + auto graph_output = _graph->outputs()->create(); + + graph_input->name("input"); + graph_output->name("output"); + + // Create and connect nodes + pull_node = _graph->nodes()->create(); + pull_node->index(0); + + transpose_node = _graph->nodes()->create(); + transpose_node->input(pull_node); + + push_node = _graph->nodes()->create(); + push_node->index(0); + push_node->from(transpose_node); + + // Create a link between input/output and corresponding nodes + loco::link(graph_input, pull_node); + loco::link(graph_output, push_node); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::TensorTranspose *transpose_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr _graph; +}; + #endif // __GRAPH_TESTCASE_H__ diff --git a/compiler/loco/src/Service/TypeInference.cpp b/compiler/loco/src/Service/TypeInference.cpp index 8827113..f3bb998 100644 --- a/compiler/loco/src/Service/TypeInference.cpp +++ b/compiler/loco/src/Service/TypeInference.cpp @@ -151,6 +151,7 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitorinput()); } loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); } loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::TensorTranspose *node) { return loco::dtype_get(node->input()); } loco::DataType visit(const loco::TransposedConv2D *node) { return loco::dtype_get(node->ifm()); } }; -- 2.7.4