1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <description_buffer.hpp>
8 #include "ie_built_in_impl.hpp"
17 namespace InferenceEngine {
18 namespace ShapeInfer {
21 *@brief Implementation of Shape inference for Reshape layer
23 class FlattenShapeProp : public BuiltInShapeInferImpl {
25 explicit FlattenShapeProp(const std::string &type) : BuiltInShapeInferImpl(type) {}
27 void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
28 const std::map<std::string, std::string>& params,
29 const std::map<std::string, Blob::Ptr>& blobs,
30 std::vector<SizeVector>& outShapes) override {
32 ReshapeLayer reshapeLayer(lp);
33 reshapeLayer.params = params;
34 reshapeLayer.type = _type;
35 validate(&reshapeLayer, inBlobs, params, blobs);
37 auto inputShape = inShapes[0];
38 size_t inputShapeTotal = std::accumulate(inputShape.begin(), inputShape.end(), 1lu, std::multiplies<size_t>());
41 int numAxes = reshapeLayer.num_axes;
42 int axis = reshapeLayer.axis;
43 size_t notFlatten = 1;
44 if (numAxes == -1 && axis == 0) {
45 outShape = {inputShapeTotal};
48 for (int i = 0; i < axis; i++) {
49 notFlatten *= inputShape[i];
50 outShape.push_back(inputShape[i]);
53 outShape.push_back(1);
55 for (int i = numAxes + 1; i < inputShape.size(); i++) {
56 notFlatten *= inputShape[i];
57 outShape.push_back(inputShape[i]);
60 outShape[axis] = inputShapeTotal / notFlatten;
63 outShapes.emplace_back(outShape);
67 } // namespace ShapeInfer
68 } // namespace InferenceEngine