1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
9 #include "caseless.hpp"
15 namespace InferenceEngine {
19 std::vector<std::vector<size_t>> inDims;
20 std::vector<std::vector<size_t>> outDims;
24 * @brief Contains methods to validate layer of specific type
26 class INFERENCE_ENGINE_API_CLASS(LayerValidator) {
28 using Ptr = std::shared_ptr<LayerValidator>;
30 explicit LayerValidator(const std::string& _type) : _type(_type) {}
33 * @brief It parses map of params <string,string> and applies to the layer's fields.
34 * This checks for presence of all required attributes, and that there's no extraneous parameters only.
35 * Throws exception in case of parsing error
37 virtual void parseParams(CNNLayer* layer) {}
40 * @brief Validates layer parameters separately from blobs and shapes
41 * This is semantic check, like height and width more than kernel sizes, stride > 0, beta > 0, axis is correct and etc
42 * Throws exception if the check fails
44 virtual void checkParams(const CNNLayer* layer) {}
47 * @brief Checks correspondence of input shapes and layer parameters.
48 * @note: This function doesn't touch ins and out Data of the layer.
49 * Throws exception if the check fails
51 virtual void checkShapes(const CNNLayer* layer,
52 const std::vector<SizeVector>& inShapes) const {}
55 * @brief Checks correspondence of all parameters in the aggregate, except output shapes.
56 * @note: This function doesn't touch ins and out Data of the layer.
57 * Throws exception if the check fails
59 virtual void checkCorrespondence(const CNNLayer* layer,
60 const std::map<std::string, Blob::Ptr>& blobs,
61 const std::vector<SizeVector>& inShapes) const {}
68 * @brief Contains all validators, registered for specific layer type
70 class INFERENCE_ENGINE_API_CLASS(LayerValidators) {
72 static LayerValidators* getInstance();
74 LayerValidators(LayerValidators const&) = delete;
76 void operator=(LayerValidators const&) = delete;
78 LayerValidator::Ptr getValidator(const std::string& type);
80 void addImpl(const std::string& type, const LayerValidator::Ptr& validator);
83 LayerValidators() = default;
86 static LayerValidators* _instance;
87 caseless_unordered_map<std::string, LayerValidator::Ptr> _validators;
90 static void checkWeakData(const DataWeakPtr& data) {
93 static void checkData(const DataPtr& data) {
98 * @brief Checks that input Data is not empty and pointers are not null, number of inputs correspond number of input shapes, dimensions in Data are not empty
100 static void checkInputs(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) {
101 // TODO: not finished implementation
102 if (layer->insData.size() != inShapes.size())
103 return THROW_IE_EXCEPTION << "Number of layer's inputs don't correspond number of new input shapes";
105 auto inData = layer->insData[0].lock();
106 bool isCorrect = false;
107 SizeVector inDims, inShape;
109 inDims = inData->getDims();
110 inShape = inShapes[0];
111 isCorrect = inShape.size() == inDims.size() && !inShape.empty() && !inDims.empty();
115 return THROW_IE_EXCEPTION << " Failed with invalid shapes: shapes are empty"
116 << "new input shape size=" << inShape.size() << ", input shape size in IR="
121 * @brief Checks that output Data is not empty and pointers are not null, number of outputs correspond number of output shapes, dimensions in Data are not empty
123 static void checkOutputs(const CNNLayer* layer, const std::vector<SizeVector>& outShapes) {}
125 static void getInOutShapes(const CNNLayer* layer, InOutDims& inOutShapes) {
126 inOutShapes.inDims.clear();
127 inOutShapes.outDims.clear();
129 for (const auto& inData : layer->insData) {
130 auto locked = inData.lock();
132 inOutShapes.inDims.push_back(locked->getDims());
135 for (const auto& outData : layer->outData) {
137 inOutShapes.outDims.push_back(outData->getDims());
143 class GeneralValidator : public LayerValidator {
145 explicit GeneralValidator(const std::string& _type);
148 class INFERENCE_ENGINE_API_CLASS(ConvolutionValidator) : public LayerValidator {
150 void parseParams(CNNLayer* layer) override;
152 void checkParams(const CNNLayer* layer) override;
154 explicit ConvolutionValidator(const std::string& _type);
156 void checkCorrespondence(const CNNLayer* layer,
157 const std::map<std::string, Blob::Ptr>& blobs,
158 const std::vector<SizeVector>& inShapes) const override;
161 class INFERENCE_ENGINE_API_CLASS(DeconvolutionValidator) : public LayerValidator {
163 void parseParams(CNNLayer* layer) override;
165 void checkParams(const CNNLayer* layer) override;
167 explicit DeconvolutionValidator(const std::string& _type);
169 void checkCorrespondence(const CNNLayer* layer,
170 const std::map<std::string, Blob::Ptr>& blobs,
171 const std::vector<SizeVector>& inShapes) const override;
175 class INFERENCE_ENGINE_API_CLASS(PoolingValidator) : public LayerValidator {
177 void parseParams(CNNLayer* layer) override;
179 void checkParams(const CNNLayer* layer) override;
181 explicit PoolingValidator(const std::string& _type);
184 class INFERENCE_ENGINE_API_CLASS(FullyConnectedValidator) : public LayerValidator {
186 explicit FullyConnectedValidator(const std::string& _type);
188 void parseParams(CNNLayer* layer) override;
190 void checkParams(const CNNLayer* layer) override;
192 void checkCorrespondence(const CNNLayer* layer,
193 const std::map<std::string, Blob::Ptr>& blobs,
194 const std::vector<SizeVector>& inShapes) const override;
197 class INFERENCE_ENGINE_API_CLASS(CropValidator) : public LayerValidator {
199 explicit CropValidator(const std::string& _type);
201 void parseParams(CNNLayer* layer) override;
203 void checkParams(const CNNLayer* layer) override;
205 void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
208 class INFERENCE_ENGINE_API_CLASS(TileValidator) : public LayerValidator {
210 explicit TileValidator(const std::string& _type);
212 void parseParams(CNNLayer* layer) override;
214 void checkParams(const CNNLayer* layer) override;
217 class INFERENCE_ENGINE_API_CLASS(BatchNormalizationValidator) : public LayerValidator {
219 explicit BatchNormalizationValidator(const std::string& _type);
221 void parseParams(CNNLayer* layer) override;
223 void checkParams(const CNNLayer* layer) override;
226 class INFERENCE_ENGINE_API_CLASS(PowerValidator) : public LayerValidator {
228 explicit PowerValidator(const std::string& _type);
230 void parseParams(CNNLayer* layer) override;
232 void checkParams(const CNNLayer* layer) override;
235 class INFERENCE_ENGINE_API_CLASS(PReLUValidator) : public LayerValidator {
237 explicit PReLUValidator(const std::string& _type);
239 void parseParams(CNNLayer* layer) override;
241 void checkParams(const CNNLayer* layer) override;
244 class INFERENCE_ENGINE_API_CLASS(ScaleShiftValidator) : public LayerValidator {
246 explicit ScaleShiftValidator(const std::string& _type);
248 void parseParams(CNNLayer* layer) override;
250 void checkParams(const CNNLayer* layer) override;
253 class INFERENCE_ENGINE_API_CLASS(ReshapeValidator) : public LayerValidator {
255 explicit ReshapeValidator(const std::string& _type);
257 void parseParams(CNNLayer* layer) override;
259 void checkParams(const CNNLayer* layer) override;
262 void calculateIn2Out(ReshapeLayer* layer);
265 class INFERENCE_ENGINE_API_CLASS(EltwiseValidator) : public LayerValidator {
267 explicit EltwiseValidator(const std::string& _type);
269 void parseParams(CNNLayer* layer) override;
271 void checkParams(const CNNLayer* layer) override;
274 class INFERENCE_ENGINE_API_CLASS(ClampValidator) : public LayerValidator {
276 explicit ClampValidator(const std::string& _type);
278 void parseParams(CNNLayer* layer) override;
280 void checkParams(const CNNLayer* layer) override;
283 class INFERENCE_ENGINE_API_CLASS(ReLUValidator) : public LayerValidator {
285 explicit ReLUValidator(const std::string& _type);
287 void parseParams(CNNLayer* layer) override;
289 void checkParams(const CNNLayer* layer) override;
292 class INFERENCE_ENGINE_API_CLASS(MVNValidator) : public LayerValidator {
294 explicit MVNValidator(const std::string& _type);
296 void parseParams(CNNLayer* layer) override;
298 void checkParams(const CNNLayer* layer) override;
301 class INFERENCE_ENGINE_API_CLASS(GRNValidator) : public LayerValidator {
303 explicit GRNValidator(const std::string& _type);
305 void parseParams(CNNLayer* layer) override;
307 void checkParams(const CNNLayer* layer) override;
310 class INFERENCE_ENGINE_API_CLASS(SoftMaxValidator) : public LayerValidator {
312 explicit SoftMaxValidator(const std::string& _type);
314 void parseParams(CNNLayer* layer) override;
316 void checkParams(const CNNLayer* layer) override;
319 class INFERENCE_ENGINE_API_CLASS(NormValidator) : public LayerValidator {
321 explicit NormValidator(const std::string& _type);
323 void parseParams(CNNLayer* layer) override;
325 void checkParams(const CNNLayer* layer) override;
328 class INFERENCE_ENGINE_API_CLASS(SplitValidator) : public LayerValidator {
330 explicit SplitValidator(const std::string& _type);
332 void parseParams(CNNLayer* layer) override;
334 void checkParams(const CNNLayer* layer) override;
337 class INFERENCE_ENGINE_API_CLASS(ConcatValidator) : public LayerValidator {
339 explicit ConcatValidator(const std::string& _type);
341 void parseParams(CNNLayer* layer) override;
343 void checkParams(const CNNLayer* layer) override;
346 template<typename Validator>
347 class ValidatorRegisterBase {
349 explicit ValidatorRegisterBase(const std::string& type) {
350 LayerValidators::getInstance()->addImpl(type, std::make_shared<Validator>(type));
354 #define REG_LAYER_VALIDATOR_FOR_TYPE(__validator, __type) \
355 static ValidatorRegisterBase<__validator> __reg__##__type(#__type)
357 REG_LAYER_VALIDATOR_FOR_TYPE(ConvolutionValidator, Convolution);
358 REG_LAYER_VALIDATOR_FOR_TYPE(DeconvolutionValidator, Deconvolution);
359 REG_LAYER_VALIDATOR_FOR_TYPE(PoolingValidator, Pooling);
360 REG_LAYER_VALIDATOR_FOR_TYPE(FullyConnectedValidator, InnerProduct);
361 REG_LAYER_VALIDATOR_FOR_TYPE(FullyConnectedValidator, FullyConnected);
362 REG_LAYER_VALIDATOR_FOR_TYPE(CropValidator, Crop);
363 REG_LAYER_VALIDATOR_FOR_TYPE(BatchNormalizationValidator, BatchNormalization);
364 REG_LAYER_VALIDATOR_FOR_TYPE(PowerValidator, Power);
365 REG_LAYER_VALIDATOR_FOR_TYPE(PReLUValidator, PReLU);
366 REG_LAYER_VALIDATOR_FOR_TYPE(ScaleShiftValidator, ScaleShift);
367 REG_LAYER_VALIDATOR_FOR_TYPE(TileValidator, Tile);
368 REG_LAYER_VALIDATOR_FOR_TYPE(ReshapeValidator, Reshape);
369 REG_LAYER_VALIDATOR_FOR_TYPE(ReshapeValidator, Flatten);
370 REG_LAYER_VALIDATOR_FOR_TYPE(EltwiseValidator, Eltwise);
371 REG_LAYER_VALIDATOR_FOR_TYPE(ClampValidator, Clamp);
372 REG_LAYER_VALIDATOR_FOR_TYPE(ReLUValidator, ReLU);
373 REG_LAYER_VALIDATOR_FOR_TYPE(MVNValidator, MVN);
374 REG_LAYER_VALIDATOR_FOR_TYPE(GRNValidator, GRN);
375 REG_LAYER_VALIDATOR_FOR_TYPE(SoftMaxValidator, SoftMax);
376 REG_LAYER_VALIDATOR_FOR_TYPE(NormValidator, Norm);
377 REG_LAYER_VALIDATOR_FOR_TYPE(NormValidator, LRN);
378 REG_LAYER_VALIDATOR_FOR_TYPE(SplitValidator, Split);
379 REG_LAYER_VALIDATOR_FOR_TYPE(SplitValidator, Slice);
380 REG_LAYER_VALIDATOR_FOR_TYPE(ConcatValidator, Concat);
382 } // namespace details
383 } // namespace InferenceEngine