[loco] TensorConcat Shape Inference (#6315)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 7 Aug 2019 08:18:58 +0000 (17:18 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 7 Aug 2019 08:18:58 +0000 (17:18 +0900)
* [loco] TensorConcat Shape Inference

CanonicalShapeInferenceRule is now able to infer the shape of  TensorConcat
nodes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Remove invalid comments

compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
compiler/loco/src/Service/GraphTestcase.h

index 51b6315..b1e49af 100644 (file)
@@ -283,7 +283,35 @@ public:
   loco::NodeShape visit(const loco::ReLU6 *node) final { return loco::shape_get(node->input()); }
 
   // TODO Support TensorBiasAdd
-  // TODO SUpport TensorConcat
+
+  // CASE: TensorConcat
+  loco::NodeShape visit(const loco::TensorConcat *node)
+  {
+    auto const lhs_shape = loco::shape_get(node->lhs()).as<loco::TensorShape>();
+    auto const rhs_shape = loco::shape_get(node->rhs()).as<loco::TensorShape>();
+
+    assert(lhs_shape.rank() == rhs_shape.rank());
+    uint32_t const out_rank = lhs_shape.rank();
+
+    loco::TensorShape out_shape;
+
+    out_shape.rank(out_rank);
+
+    for (uint32_t axis = 0; axis < out_rank; ++axis)
+    {
+      if (axis == node->axis())
+      {
+        out_shape.dim(axis) = lhs_shape.dim(axis).value() + rhs_shape.dim(axis).value();
+      }
+      else
+      {
+        assert(lhs_shape.dim(axis) == rhs_shape.dim(axis));
+        out_shape.dim(axis) = lhs_shape.dim(axis);
+      }
+    }
+
+    return loco::NodeShape{out_shape};
+  }
 };
 
 } // namespace
index aff8937..05069d5 100644 (file)
@@ -181,3 +181,28 @@ TEST(CanonicalShapeInferenceRuleTest, maxpool2d)
   ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().height(), 4);
   ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().width(), 2);
 }
+
+TEST(CanonicalShapeInferenceRuleTest, tensor_concat)
+{
+  using namespace loco;
+
+  // Create a sample network
+  GraphTestcase<GraphCode::TensorConcat> testcase;
+
+  testcase.lhs_node->shape({1, 2, 3});
+  testcase.rhs_node->shape({1, 4, 3});
+  testcase.concat_node->axis(1);
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(testcase.graph());
+
+  // Verify!
+  ASSERT_TRUE(loco::shape_known(testcase.concat_node));
+  ASSERT_EQ(loco::shape_get(testcase.concat_node).domain(), loco::Domain::Tensor);
+  ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().rank(), 3);
+  ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(0), 1);
+  ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(1), 6);
+  ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(2), 3);
+}
index 13c4e94..e3fce2b 100644 (file)
@@ -14,6 +14,7 @@ enum class GraphCode
   FeatureCodec,
   AvgPool2D,
   MaxPool2D,
+  TensorConcat,
 };
 
 template <GraphCode Code> class GraphTestcase;
@@ -314,6 +315,58 @@ private:
   std::unique_ptr<loco::Graph> _graph;
 };
 
+template <> class GraphTestcase<GraphCode::TensorConcat> final
+{
+public:
+  GraphTestcase()
+  {
+    using namespace loco;
+
+    // Create a sample network
+    _graph = make_graph();
+
+    // Create Graph Input/Output
+    auto graph_lhs = _graph->inputs()->create();
+    auto graph_rhs = _graph->inputs()->create();
+    auto graph_out = _graph->outputs()->create();
+
+    graph_lhs->name("lhs");
+    graph_rhs->name("rhs");
+    graph_out->name("output");
+
+    // Create and connect nodes
+    lhs_node = _graph->nodes()->create<Pull>();
+    lhs_node->index(0);
+
+    rhs_node = _graph->nodes()->create<Pull>();
+    rhs_node->index(1);
+
+    concat_node = _graph->nodes()->create<TensorConcat>();
+    concat_node->lhs(lhs_node);
+    concat_node->rhs(rhs_node);
+
+    push_node = _graph->nodes()->create<loco::Push>();
+    push_node->index(0);
+    push_node->from(concat_node);
+
+    // Create a link between input/output and corresponding nodes
+    loco::link(graph_lhs, lhs_node);
+    loco::link(graph_rhs, rhs_node);
+    loco::link(graph_out, push_node);
+  }
+
+public:
+  loco::Graph *graph() { return _graph.get(); }
+
+  loco::Pull *lhs_node = nullptr;
+  loco::Pull *rhs_node = nullptr;
+  loco::TensorConcat *concat_node = nullptr;
+  loco::Push *push_node = nullptr;
+
+private:
+  std::unique_ptr<loco::Graph> _graph;
+};
+
 namespace
 {