-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
THROW_IE_EXCEPTION << "Internal error: failed to find validator for layer with type: " << _type;
}
- void validate(CNNLayer* layer, const std::vector<SizeVector>& inShapes,
+ void validate(CNNLayer* layer, const std::vector<Blob::CPtr>& inBlobs,
const std::map<std::string, std::string>& params,
const std::map<std::string, Blob::Ptr>& blobs) {
_validator->parseParams(layer);
_validator->checkCorrespondence(layer, blobs, inShapes);
}
- virtual void inferShapesImpl(const std::vector<SizeVector>& inShapes,
+ virtual void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
const std::map<std::string, std::string>& params,
const std::map<std::string, Blob::Ptr>& blobs,
std::vector<SizeVector>& outShapes) = 0;
const std::map<std::string, Blob::Ptr>& blobs,
std::vector<SizeVector>& outShapes,
ResponseDesc* resp) noexcept override {
+ return DescriptionBuffer(GENERAL_ERROR, resp)
+ << "Unexpected call of deprecated Shape Infer function with input shapes";
+ }
+
+ StatusCode inferShapes(const std::vector<Blob::CPtr>& inBlobs,
+ const std::map<std::string, std::string>& params,
+ const std::map<std::string, Blob::Ptr>& blobs,
+ std::vector<SizeVector>& outShapes,
+ ResponseDesc* resp) noexcept override {
+ inShapes.clear();
+ for (const auto& blob : inBlobs) {
+ inShapes.push_back(blob->getTensorDesc().getDims());
+ }
outShapes.clear();
- std::string errorPrefix = "Failed to infer shapes for " + _type + " layer with error: ";
try {
- inferShapesImpl(inShapes, params, blobs, outShapes);
+ inferShapesImpl(inBlobs, params, blobs, outShapes);
return OK;
} catch (const std::exception& ex) {
- return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << errorPrefix + ex.what();
+ return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
} catch (...) {
- return InferenceEngine::DescriptionBuffer(UNEXPECTED) << errorPrefix + " unknown";
+ return InferenceEngine::DescriptionBuffer(UNEXPECTED) << "Unknown error";
}
}
protected:
std::string _type;
details::LayerValidator::Ptr _validator;
+ std::vector<SizeVector> inShapes;
};
} // namespace ShapeInfer