[shape_inference] Implemented shape inference for concat node (#5669)
authorIvan Vagin/AI Tools Lab /SRR/Engineer/삼성전자 <ivan.vagin@samsung.com>
Mon, 22 Jul 2019 04:11:49 +0000 (07:11 +0300)
committer이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Mon, 22 Jul 2019 04:11:49 +0000 (13:11 +0900)
* [shape_inference] Implemented shape inference for concat node

Implemented shape inference for concat node

Signed-off-by: Ivan Vagin <ivan.vagin@samsung.com>
* Fixed formatting

* Use ConcatNode::Param to get concat axis

Signed-off-by: Ivan Vagin <ivan.vagin@samsung.com>
runtimes/neurun/core/include/util/ShapeInference.h
runtimes/neurun/core/src/util/ShapeInference.cc
runtimes/neurun/test/util/ShapeInference.cc

index 1ead20f..5407619 100644 (file)
 #ifndef __NEURUN_GRAPH_SHAPE_INFERENCE_H__
 #define __NEURUN_GRAPH_SHAPE_INFERENCE_H__
 
-#include "model/OperationVisitor.h"
+#include "model/operation/AvgPool2DNode.h"
+#include "model/operation/ConcatNode.h"
+#include "model/operation/MaxPool2DNode.h"
+#include "model/operation/Conv2DNode.h"
+#include "model/operation/DepthwiseConv2DNode.h"
 #include "model/Operands.h"
 #include "model/Index.h"
 #include "model/Layout.h"
@@ -35,6 +39,8 @@ Shapes inferAvgPoolShape(const model::Shape &in_shape,
                          const model::operation::AvgPool2DNode::Param &param,
                          model::Layout layout = model::Layout::NHWC);
 
+Shapes inferConcatShape(const Shapes &in_shapes, const model::operation::ConcatNode::Param &param);
+
 Shapes inferMaxPoolShape(const model::Shape &in_shape,
                          const model::operation::MaxPool2DNode::Param &param,
                          model::Layout layout = model::Layout::NHWC);
index 353b982..f0138a7 100644 (file)
@@ -117,6 +117,27 @@ Shapes inferAvgPoolShape(const model::Shape &in_shape,
   return {model::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, ifm_shape.C}};
 }
 
+Shapes inferConcatShape(const Shapes &in_shapes, const model::operation::ConcatNode::Param &param)
+{
+  const int32_t concat_axis = param.axis;
+  const auto &first_in_shape = in_shapes[0];
+
+  // Check that all shapes are equal except for concat axis dimension
+  for (const auto &in_shape : in_shapes)
+  {
+    assert(in_shape.rank() == first_in_shape.rank());
+    for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
+        assert(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx));
+  }
+
+  // Calculate output shape
+  model::Shape out_shape(first_in_shape);
+  out_shape.dim(concat_axis) = 0;
+  for (const auto &in_shape : in_shapes)
+    out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
+  return {out_shape};
+}
+
 Shapes inferMaxPoolShape(const model::Shape &in_shape,
                          const model::operation::MaxPool2DNode::Param &param,
                          const model::Layout layout)
index ee15535..3f43cfe 100644 (file)
@@ -202,6 +202,24 @@ TEST(ShapeInference, DepthwiseConv2DNode)
   ASSERT_EQ(infered_out_shape.asFeature().C, 60);
 }
 
+TEST(ShapeInference, ConcatNode)
+{
+  Shape in1{10, 20, 30, 3, 50};
+  Shape in2{10, 20, 30, 2, 50};
+  Shape in3{10, 20, 30, 2, 50};
+
+  operation::ConcatNode::Param param{3};
+  auto infered_shapes = neurun::shape_inference::inferConcatShape({in1, in2, in3}, param);
+  auto infered_out_shape = infered_shapes[0];
+
+  ASSERT_EQ(infered_out_shape.rank(), 5);
+  ASSERT_EQ(infered_out_shape.dim(0), 10);
+  ASSERT_EQ(infered_out_shape.dim(1), 20);
+  ASSERT_EQ(infered_out_shape.dim(2), 30);
+  ASSERT_EQ(infered_out_shape.dim(3), 7);
+  ASSERT_EQ(infered_out_shape.dim(4), 50);
+}
+
 TEST(ShapeInference, FullyConnectedNode)
 {
   Shape in_shape{3, 4, 5, 6};