1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
13 #include <ie_layers.h>
14 #include <ie_iextension.h>
15 #include "details/caseless.hpp"
16 #include <description_buffer.hpp>
17 #include <ie_layer_validators.hpp>
19 namespace InferenceEngine {
20 namespace ShapeInfer {
23 *@brief Base class for all built-in shape infer implementations. Contains common logic with validators and errors handling
25 class BuiltInShapeInferImpl : public IShapeInferImpl {
27 explicit BuiltInShapeInferImpl(const std::string& type) : _type(type) {
28 _validator = details::LayerValidators::getInstance()->getValidator(_type);
30 THROW_IE_EXCEPTION << "Internal error: failed to find validator for layer with type: " << _type;
33 void validate(CNNLayer* layer, const std::vector<Blob::CPtr>& inBlobs,
34 const std::map<std::string, std::string>& params,
35 const std::map<std::string, Blob::Ptr>& blobs) {
36 _validator->parseParams(layer);
37 _validator->checkParams(layer);
38 _validator->checkShapes(layer, inShapes);
39 _validator->checkCorrespondence(layer, blobs, inShapes);
42 virtual void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
43 const std::map<std::string, std::string>& params,
44 const std::map<std::string, Blob::Ptr>& blobs,
45 std::vector<SizeVector>& outShapes) = 0;
47 StatusCode inferShapes(const std::vector<SizeVector>& inShapes,
48 const std::map<std::string, std::string>& params,
49 const std::map<std::string, Blob::Ptr>& blobs,
50 std::vector<SizeVector>& outShapes,
51 ResponseDesc* resp) noexcept override {
52 return DescriptionBuffer(GENERAL_ERROR, resp)
53 << "Unexpected call of deprecated Shape Infer function with input shapes";
56 StatusCode inferShapes(const std::vector<Blob::CPtr>& inBlobs,
57 const std::map<std::string, std::string>& params,
58 const std::map<std::string, Blob::Ptr>& blobs,
59 std::vector<SizeVector>& outShapes,
60 ResponseDesc* resp) noexcept override {
62 for (const auto& blob : inBlobs) {
63 inShapes.push_back(blob->getTensorDesc().getDims());
67 inferShapesImpl(inBlobs, params, blobs, outShapes);
69 } catch (const std::exception& ex) {
70 return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
72 return InferenceEngine::DescriptionBuffer(UNEXPECTED) << "Unknown error";
78 details::LayerValidator::Ptr _validator;
79 std::vector<SizeVector> inShapes;
82 } // namespace ShapeInfer
83 } // namespace InferenceEngine