assert(param.pad_h() == 0);
assert(param.pad_w() == 0);
- // NOTE Stride is not supported, yet
- // TODO Support padding
- assert(param.stride().size() == 0);
- assert(!param.has_stride_h());
- assert(!param.has_stride_w());
-
// NOTE Dilation is not supported, yet
// TODO Support dilation
assert(param.dilation().size() == 0);
return ifm_rank() + axis;
}
+uint32_t ConvolutionSpec::stride(uint32_t spatial_axis) const
+{
+ assert(spatial_axis < num_spatial_axes());
+
+ // TODO Support stride_h/stride_w parameters
+ assert(!_param.has_stride_h());
+ assert(!_param.has_stride_w());
+
+ if (_param.stride().size() == 0)
+ {
+ // NOTE default stride is 1
+ return 1;
+ }
+
+ if (_param.stride().size() == 1)
+ {
+ return _param.stride(0);
+ }
+
+ assert(_param.stride().size() == num_spatial_axes());
+ return _param.stride(spatial_axis);
+}
+
uint32_t ConvolutionSpec::ker_dim(uint32_t spatial_axis) const
{
assert(spatial_axis < num_spatial_axes());
uint32_t dim = 0;
dim += ifm_dim(full_axis) - ker_dim(spatial_axis);
+ dim /= stride(spatial_axis);
dim += 1;
res.dim(full_axis) = dim;
ASSERT_EQ(expected, obtained);
}
}
+
+namespace
+{
+// NOTE This example is derived from conv1_3x3_s2 layer in reference inception v3 layer
+// clang-format off
+const char *conv_2 = STRING(
+layer {
+ name: "data"
+ type: "Input"
+ top: "data"
+ input_param {
+ shape: { dim: 1 dim: 3 dim: 299 dim: 299 }
+ }
+}
+layer {
+ name: "conv"
+ type: "Convolution"
+ bottom: "data"
+ top: "conv"
+ convolution_param {
+ bias_term: false
+ num_output: 2
+ stride: 2
+ kernel_size: 3
+ }
+}
+);
+// clang-format on
+} // namespace
+
+TEST_F(ConvolutionSpecTest, conv_2)
+{
+ ::caffe::NetParameter param;
+
+ ASSERT_TRUE(load(conv_2, param));
+
+ ::caffe::Net<float> net{param};
+
+ const tensor::Shape ifm_shape{1, 3, 299, 299};
+ ConvolutionSpec spec{param.layer(1).convolution_param()};
+
+ spec.ifm_shape(ifm_shape);
+
+ // Check 'stride'
+ ASSERT_EQ(spec.stride(0), 2);
+ ASSERT_EQ(spec.stride(1), 2);
+
+ // Check 'ofm_shape'
+ {
+ auto expected = as_tensor_shape(net.blob_by_name("conv")->shape());
+ auto obtained = spec.ofm_shape();
+
+ ASSERT_EQ(expected, obtained);
+ }
+}