1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <description_buffer.hpp>
8 #include "ie_built_in_impl.hpp"
15 #include <ie_format_parser.h>
17 namespace InferenceEngine {
18 namespace ShapeInfer {
21 *@brief Implementation of Shape inference for BinaryConvolution layer
23 class BinConvShapeProp : public BuiltInShapeInferImpl {
25 explicit BinConvShapeProp(const std::string& type) : BuiltInShapeInferImpl(type) {}
27 void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
28 const std::map<std::string, std::string>& params,
29 const std::map<std::string, Blob::Ptr>& blobs,
30 std::vector<SizeVector>& outShapes) override {
32 BinaryConvolutionLayer binConvLayer(lp);
33 binConvLayer.params = params;
34 binConvLayer.type = _type;
35 validate(&binConvLayer, inBlobs, params, blobs);
37 auto dims = inShapes[0];
38 auto computeSpatialShape = [&](size_t inDim, int axis) {
40 if (binConvLayer._dilation[axis])
41 kernel = (binConvLayer._kernel[axis] - 1) * binConvLayer._dilation[axis] + 1;
43 kernel = binConvLayer._kernel[axis];
44 size_t stride = binConvLayer._stride[axis];
45 size_t pad = binConvLayer._padding[axis];
48 std::string padType = binConvLayer._auto_pad;
49 if (padType == "valid") {
50 outDim = std::ceil((inDim - kernel + 1.f) / stride);
51 } else if (padType == "same_upper") {
52 outDim = std::ceil(1.f * inDim / stride);
53 } else if (padType == "same_lower") {
54 outDim = std::floor(1.f * inDim / stride);
56 int padEnd = binConvLayer._pads_end[axis];
57 outDim = std::floor(1.f * (inDim + pad + padEnd - kernel) / stride) + 1.f;
61 THROW_IE_EXCEPTION << "New shapes " << details::dumpVec(dims) << " make output shape negative";
63 return static_cast<size_t>(outDim);
66 size_t inputN = dims[0];
67 size_t OC = binConvLayer._out_depth;
69 shapes.push_back(inputN);
72 shapes.push_back(computeSpatialShape(dims[dims.size() - 3], Z_AXIS));
73 shapes.push_back(computeSpatialShape(dims[dims.size() - 2], Y_AXIS));
74 shapes.push_back(computeSpatialShape(dims[dims.size() - 1], X_AXIS));
75 outShapes.push_back(shapes);
79 } // namespace ShapeInfer
80 } // namespace InferenceEngine