From ff6ac445291edfa8db8faaaa5d73615952fe6037 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Thu, 30 Aug 2018 17:33:57 +0900 Subject: [PATCH] [enco] Convolution Shape inference with stride (#1251) This commit revises ConvolutionSpec to infer stride values for each spatial axis, and use these values for OFM shape inference. Signed-off-by: Jonghyun Park --- .../enco/frontend/caffe/src/ConvolutionSpec.cpp | 30 +++++++++--- contrib/enco/frontend/caffe/src/ConvolutionSpec.h | 1 + .../frontend/caffe/src/ConvolutionSpec.test.cpp | 55 ++++++++++++++++++++++ 3 files changed, 80 insertions(+), 6 deletions(-) diff --git a/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp b/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp index bc691f0..345a499 100644 --- a/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp +++ b/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp @@ -11,12 +11,6 @@ ConvolutionSpec::ConvolutionSpec(const ::caffe::ConvolutionParameter ¶m) assert(param.pad_h() == 0); assert(param.pad_w() == 0); - // NOTE Stride is not supported, yet - // TODO Support padding - assert(param.stride().size() == 0); - assert(!param.has_stride_h()); - assert(!param.has_stride_w()); - // NOTE Dilation is not supported, yet // TODO Support dilation assert(param.dilation().size() == 0); @@ -39,6 +33,29 @@ uint32_t ConvolutionSpec::channel_axis(void) const return ifm_rank() + axis; } +uint32_t ConvolutionSpec::stride(uint32_t spatial_axis) const +{ + assert(spatial_axis < num_spatial_axes()); + + // TODO Support stride_h/stride_w parameters + assert(!_param.has_stride_h()); + assert(!_param.has_stride_w()); + + if (_param.stride().size() == 0) + { + // NOTE default stride is 1 + return 1; + } + + if (_param.stride().size() == 1) + { + return _param.stride(0); + } + + assert(_param.stride().size() == num_spatial_axes()); + return _param.stride(spatial_axis); +} + uint32_t ConvolutionSpec::ker_dim(uint32_t spatial_axis) const { assert(spatial_axis < num_spatial_axes()); @@ -96,6 +113,7 @@ nncc::core::ADT::tensor::Shape ConvolutionSpec::ofm_shape(void) const uint32_t dim = 0; dim += ifm_dim(full_axis) - ker_dim(spatial_axis); + dim /= stride(spatial_axis); dim += 1; res.dim(full_axis) = dim; diff --git a/contrib/enco/frontend/caffe/src/ConvolutionSpec.h b/contrib/enco/frontend/caffe/src/ConvolutionSpec.h index 64b6688..6c6e288 100644 --- a/contrib/enco/frontend/caffe/src/ConvolutionSpec.h +++ b/contrib/enco/frontend/caffe/src/ConvolutionSpec.h @@ -19,6 +19,7 @@ public: uint32_t num_batch_axes(void) const { return channel_axis(); } uint32_t num_spatial_axes(void) const { return ifm_rank() - channel_axis() - 1; } + uint32_t stride(uint32_t spatial_axis) const; uint32_t ker_dim(uint32_t spatial_axis) const; public: diff --git a/contrib/enco/frontend/caffe/src/ConvolutionSpec.test.cpp b/contrib/enco/frontend/caffe/src/ConvolutionSpec.test.cpp index 354b5ea..4870a52 100644 --- a/contrib/enco/frontend/caffe/src/ConvolutionSpec.test.cpp +++ b/contrib/enco/frontend/caffe/src/ConvolutionSpec.test.cpp @@ -171,3 +171,58 @@ TEST_F(ConvolutionSpecTest, conv_1) ASSERT_EQ(expected, obtained); } } + +namespace +{ +// NOTE This example is derived from conv1_3x3_s2 layer in reference inception v3 layer +// clang-format off +const char *conv_2 = STRING( +layer { + name: "data" + type: "Input" + top: "data" + input_param { + shape: { dim: 1 dim: 3 dim: 299 dim: 299 } + } +} +layer { + name: "conv" + type: "Convolution" + bottom: "data" + top: "conv" + convolution_param { + bias_term: false + num_output: 2 + stride: 2 + kernel_size: 3 + } +} +); +// clang-format on +} // namespace + +TEST_F(ConvolutionSpecTest, conv_2) +{ + ::caffe::NetParameter param; + + ASSERT_TRUE(load(conv_2, param)); + + ::caffe::Net net{param}; + + const tensor::Shape ifm_shape{1, 3, 299, 299}; + ConvolutionSpec spec{param.layer(1).convolution_param()}; + + spec.ifm_shape(ifm_shape); + + // Check 'stride' + ASSERT_EQ(spec.stride(0), 2); + ASSERT_EQ(spec.stride(1), 2); + + // Check 'ofm_shape' + { + auto expected = as_tensor_shape(net.blob_by_name("conv")->shape()); + auto obtained = spec.ofm_shape(); + + ASSERT_EQ(expected, obtained); + } +} -- 2.7.4