From e11541ef68adf3ea7f3cc5f93b9d2b2fa3a88bc4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 7 Aug 2019 17:18:58 +0900 Subject: [PATCH] [loco] TensorConcat Shape Inference (#6315) * [loco] TensorConcat Shape Inference CanonicalShapeInferenceRule is now able to infer the shape of TensorConcat nodes. Signed-off-by: Jonghyun Park * Remove invalid comments --- .../src/Service/CanonicalShapeInferenceRule.cpp | 30 +++++++++++- .../Service/CanonicalShapeInferenceRule.test.cpp | 25 ++++++++++ compiler/loco/src/Service/GraphTestcase.h | 53 ++++++++++++++++++++++ 3 files changed, 107 insertions(+), 1 deletion(-) diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 51b6315..b1e49af 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -283,7 +283,35 @@ public: loco::NodeShape visit(const loco::ReLU6 *node) final { return loco::shape_get(node->input()); } // TODO Support TensorBiasAdd - // TODO SUpport TensorConcat + + // CASE: TensorConcat + loco::NodeShape visit(const loco::TensorConcat *node) + { + auto const lhs_shape = loco::shape_get(node->lhs()).as(); + auto const rhs_shape = loco::shape_get(node->rhs()).as(); + + assert(lhs_shape.rank() == rhs_shape.rank()); + uint32_t const out_rank = lhs_shape.rank(); + + loco::TensorShape out_shape; + + out_shape.rank(out_rank); + + for (uint32_t axis = 0; axis < out_rank; ++axis) + { + if (axis == node->axis()) + { + out_shape.dim(axis) = lhs_shape.dim(axis).value() + rhs_shape.dim(axis).value(); + } + else + { + assert(lhs_shape.dim(axis) == rhs_shape.dim(axis)); + out_shape.dim(axis) = lhs_shape.dim(axis); + } + } + + return loco::NodeShape{out_shape}; + } }; } // namespace diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp index aff8937..05069d5 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -181,3 +181,28 @@ TEST(CanonicalShapeInferenceRuleTest, maxpool2d) ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as().height(), 4); ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as().width(), 2); } + +TEST(CanonicalShapeInferenceRuleTest, tensor_concat) +{ + using namespace loco; + + // Create a sample network + GraphTestcase testcase; + + testcase.lhs_node->shape({1, 2, 3}); + testcase.rhs_node->shape({1, 4, 3}); + testcase.concat_node->axis(1); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.concat_node)); + ASSERT_EQ(loco::shape_get(testcase.concat_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.concat_node).as().rank(), 3); + ASSERT_EQ(loco::shape_get(testcase.concat_node).as().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.concat_node).as().dim(1), 6); + ASSERT_EQ(loco::shape_get(testcase.concat_node).as().dim(2), 3); +} diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index 13c4e94..e3fce2b 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -14,6 +14,7 @@ enum class GraphCode FeatureCodec, AvgPool2D, MaxPool2D, + TensorConcat, }; template class GraphTestcase; @@ -314,6 +315,58 @@ private: std::unique_ptr _graph; }; +template <> class GraphTestcase final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Create a sample network + _graph = make_graph(); + + // Create Graph Input/Output + auto graph_lhs = _graph->inputs()->create(); + auto graph_rhs = _graph->inputs()->create(); + auto graph_out = _graph->outputs()->create(); + + graph_lhs->name("lhs"); + graph_rhs->name("rhs"); + graph_out->name("output"); + + // Create and connect nodes + lhs_node = _graph->nodes()->create(); + lhs_node->index(0); + + rhs_node = _graph->nodes()->create(); + rhs_node->index(1); + + concat_node = _graph->nodes()->create(); + concat_node->lhs(lhs_node); + concat_node->rhs(rhs_node); + + push_node = _graph->nodes()->create(); + push_node->index(0); + push_node->from(concat_node); + + // Create a link between input/output and corresponding nodes + loco::link(graph_lhs, lhs_node); + loco::link(graph_rhs, rhs_node); + loco::link(graph_out, push_node); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *lhs_node = nullptr; + loco::Pull *rhs_node = nullptr; + loco::TensorConcat *concat_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr _graph; +}; + namespace { -- 2.7.4