1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
8 #include "ie_built_in_impl.hpp"
14 namespace InferenceEngine {
15 namespace ShapeInfer {
18 *@brief Implementation of Shape inference for Split layer
20 class SplitShapeProp : public BuiltInShapeInferImpl {
22 explicit SplitShapeProp(const std::string& type) : BuiltInShapeInferImpl(type) {}
24 void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
25 const std::map<std::string, std::string>& params,
26 const std::map<std::string, Blob::Ptr>& blobs,
27 std::vector<SizeVector>& outShapes) override {
29 SplitLayer splitLayer(lp);
30 splitLayer.params = params;
31 splitLayer.type = _type;
32 validate(&splitLayer, inBlobs, params, blobs);
34 std::vector<int> out_sizes = splitLayer.GetParamAsInts("out_sizes", {});
35 if (out_sizes.empty())
36 THROW_IE_EXCEPTION << "Value of out_sizes attribute is empty";
39 for (const auto& size : out_sizes)
41 if (sum != inShapes[0][splitLayer._axis])
42 THROW_IE_EXCEPTION << "The sum of the dimensions on the axis(" << splitLayer._axis
43 << ") is not equal out_sizes: " << details::dumpVec(out_sizes);
45 for (const auto& size : out_sizes) {
46 outShapes.push_back(inShapes[0]);
47 outShapes[outShapes.size() - 1][splitLayer._axis] = static_cast<size_t>(size);
52 } // namespace ShapeInfer
53 } // namespace InferenceEngine