#ifndef __NEURUN_GRAPH_SHAPE_INFERENCE_H__
#define __NEURUN_GRAPH_SHAPE_INFERENCE_H__
-#include "model/OperationVisitor.h"
+#include "model/operation/AvgPool2DNode.h"
+#include "model/operation/ConcatNode.h"
+#include "model/operation/MaxPool2DNode.h"
+#include "model/operation/Conv2DNode.h"
+#include "model/operation/DepthwiseConv2DNode.h"
#include "model/Operands.h"
#include "model/Index.h"
#include "model/Layout.h"
const model::operation::AvgPool2DNode::Param ¶m,
model::Layout layout = model::Layout::NHWC);
+Shapes inferConcatShape(const Shapes &in_shapes, const model::operation::ConcatNode::Param ¶m);
+
Shapes inferMaxPoolShape(const model::Shape &in_shape,
const model::operation::MaxPool2DNode::Param ¶m,
model::Layout layout = model::Layout::NHWC);
return {model::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, ifm_shape.C}};
}
+Shapes inferConcatShape(const Shapes &in_shapes, const model::operation::ConcatNode::Param ¶m)
+{
+ const int32_t concat_axis = param.axis;
+ const auto &first_in_shape = in_shapes[0];
+
+ // Check that all shapes are equal except for concat axis dimension
+ for (const auto &in_shape : in_shapes)
+ {
+ assert(in_shape.rank() == first_in_shape.rank());
+ for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
+ assert(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx));
+ }
+
+ // Calculate output shape
+ model::Shape out_shape(first_in_shape);
+ out_shape.dim(concat_axis) = 0;
+ for (const auto &in_shape : in_shapes)
+ out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
+ return {out_shape};
+}
+
Shapes inferMaxPoolShape(const model::Shape &in_shape,
const model::operation::MaxPool2DNode::Param ¶m,
const model::Layout layout)
ASSERT_EQ(infered_out_shape.asFeature().C, 60);
}
+TEST(ShapeInference, ConcatNode)
+{
+ Shape in1{10, 20, 30, 3, 50};
+ Shape in2{10, 20, 30, 2, 50};
+ Shape in3{10, 20, 30, 2, 50};
+
+ operation::ConcatNode::Param param{3};
+ auto infered_shapes = neurun::shape_inference::inferConcatShape({in1, in2, in3}, param);
+ auto infered_out_shape = infered_shapes[0];
+
+ ASSERT_EQ(infered_out_shape.rank(), 5);
+ ASSERT_EQ(infered_out_shape.dim(0), 10);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 30);
+ ASSERT_EQ(infered_out_shape.dim(3), 7);
+ ASSERT_EQ(infered_out_shape.dim(4), 50);
+}
+
TEST(ShapeInference, FullyConnectedNode)
{
Shape in_shape{3, 4, 5, 6};