From cd23bfe576b01d4533fdde8c8eef315bcc0ab8db Mon Sep 17 00:00:00 2001 From: =?utf8?q?Ivan=20Vagin/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 17 Jul 2019 12:29:33 +0300 Subject: [PATCH] [shape_inference] Implemened shape inference for fully connected node (#5671) Implemened shape inference for fully connected node Signed-off-by: Ivan Vagin --- runtimes/neurun/core/include/util/ShapeInference.h | 2 ++ runtimes/neurun/core/src/util/ShapeInference.cc | 14 ++++++++++++++ runtimes/neurun/test/util/ShapeInference.cc | 12 ++++++++++++ 3 files changed, 28 insertions(+) diff --git a/runtimes/neurun/core/include/util/ShapeInference.h b/runtimes/neurun/core/include/util/ShapeInference.h index 5a900c1..e3722d2 100644 --- a/runtimes/neurun/core/include/util/ShapeInference.h +++ b/runtimes/neurun/core/include/util/ShapeInference.h @@ -35,6 +35,8 @@ Shapes inferAvgPoolShape(const model::Shape &in_shape, const model::operation::AvgPool2DNode::Param ¶m, model::Layout layout = model::Layout::NHWC); +Shapes inferFCShape(const model::Shape &in_shape, const model::Shape &ker_shape); + } // namespace shape_inference } // namespace neurun diff --git a/runtimes/neurun/core/src/util/ShapeInference.cc b/runtimes/neurun/core/src/util/ShapeInference.cc index 4a0b998..d859afe 100644 --- a/runtimes/neurun/core/src/util/ShapeInference.cc +++ b/runtimes/neurun/core/src/util/ShapeInference.cc @@ -116,5 +116,19 @@ Shapes inferAvgPoolShape(const model::Shape &in_shape, return {model::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, ifm_shape.C}}; } +Shapes inferFCShape(const model::Shape &in_shape, const model::Shape &ker_shape) +{ + assert(in_shape.rank() >= 2); + assert(ker_shape.rank() == 2); + + const auto input_size_with_batch = in_shape.num_elements(); + const auto num_units = ker_shape.dim(0); + const auto input_size = ker_shape.dim(1); + const auto batch_size = input_size_with_batch / input_size; + assert(input_size_with_batch % input_size == 0); + + return {{model::Shape({static_cast(batch_size), num_units})}}; +} + } // namespace shape_inference } // namespace neurun diff --git a/runtimes/neurun/test/util/ShapeInference.cc b/runtimes/neurun/test/util/ShapeInference.cc index 550d02b..4186af1 100644 --- a/runtimes/neurun/test/util/ShapeInference.cc +++ b/runtimes/neurun/test/util/ShapeInference.cc @@ -94,3 +94,15 @@ TEST(ShapeInference, AvgPool2DNodeExplicit) ASSERT_EQ(infered_out_shape.asFeature().W, 1); ASSERT_EQ(infered_out_shape.asFeature().C, 20); } + +TEST(ShapeInference, FullyConnectedNode) +{ + Shape in_shape{3, 4, 5, 6}; + Shape ker_shape{3, 10}; + auto infered_shapes = neurun::shape_inference::inferFCShape(in_shape, ker_shape); + auto infered_out_shape = infered_shapes[0]; + + ASSERT_EQ(infered_out_shape.rank(), 2); + ASSERT_EQ(infered_out_shape.dim(0), 36); + ASSERT_EQ(infered_out_shape.dim(1), 3); +} -- 2.7.4