Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / ie_layer_validators.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include "ie_layers.h"
8 #include "details/caseless.hpp"
9 #include <memory>
10 #include <string>
11 #include <map>
12 #include <vector>
13
14 namespace InferenceEngine {
15 namespace details {
16
17 struct InOutDims {
18     std::vector<std::vector<size_t>> inDims;
19     std::vector<std::vector<size_t>> outDims;
20 };
21
22 /**
23  * @brief Contains methods to validate layer of specific type
24  */
25 class INFERENCE_ENGINE_API_CLASS(LayerValidator) {
26 public:
27     using Ptr = std::shared_ptr<LayerValidator>;
28
29     explicit LayerValidator(const std::string& _type) : _type(_type) {}
30
31     /**
32      * @brief It parses map of params <string,string> and applies to the layer's fields.
33      * This checks for presence of all required attributes, and that there's no extraneous parameters only.
34      * Throws exception in case of parsing error
35      */
36     virtual void parseParams(CNNLayer* layer) {}
37
38     /**
39      * @brief Validates layer parameters separately from blobs and shapes
40      * This is semantic check, like height and width more than kernel sizes, stride > 0, beta > 0, axis is correct and etc
41      * Throws exception if the check fails
42      */
43     virtual void checkParams(const CNNLayer* layer) {}
44
45     /**
46      * @brief Checks correspondence of input shapes and layer parameters.
47      * @note: This function doesn't touch ins and out Data of the layer.
48      * Throws exception if the check fails
49      */
50     virtual void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const {}
51
52     /**
53      * @brief Checks correspondence of all parameters in the aggregate, except output shapes.
54      * @note: This function doesn't touch ins and out Data of the layer.
55      * Throws exception if the check fails
56      */
57     virtual void checkCorrespondence(const CNNLayer* layer,
58                                      const std::map<std::string, Blob::Ptr>& blobs,
59                                      const std::vector<SizeVector>& inShapes) const {}
60
61 protected:
62     std::string _type;
63 };
64
65 /**
66  * @brief Contains all validators, registered for specific layer type
67  */
68 class INFERENCE_ENGINE_API_CLASS(LayerValidators) {
69 public:
70     static LayerValidators* getInstance();
71
72     LayerValidators(LayerValidators const&) = delete;
73
74     void operator=(LayerValidators const&)  = delete;
75
76     LayerValidator::Ptr getValidator(const std::string& type);
77
78     void addImpl(const std::string& type, const LayerValidator::Ptr& validator);
79
80 private:
81     LayerValidators() = default;
82
83 private:
84     static LayerValidators* _instance;
85     InferenceEngine::details::caseless_unordered_map<std::string, LayerValidator::Ptr> _validators;
86 };
87
88 static void getInOutShapes(const CNNLayer* layer, InOutDims& inOutShapes) {
89     inOutShapes.inDims.clear();
90     inOutShapes.outDims.clear();
91     if (layer) {
92         for (const auto& inData : layer->insData) {
93             auto locked = inData.lock();
94             if (locked) {
95                 inOutShapes.inDims.push_back(locked->getDims());
96             }
97         }
98         for (const auto& outData : layer->outData) {
99             if (outData) {
100                 inOutShapes.outDims.push_back(outData->getDims());
101             }
102         }
103     }
104 }
105
106 class GeneralValidator : public LayerValidator {
107 public:
108     explicit GeneralValidator(const std::string& _type);
109 };
110
111 class INFERENCE_ENGINE_API_CLASS(ConvolutionValidator) : public LayerValidator {
112 public:
113     void parseParams(CNNLayer* layer) override;
114
115     void checkParams(const CNNLayer* layer) override;
116
117     explicit ConvolutionValidator(const std::string& _type);
118
119     void checkCorrespondence(const CNNLayer* layer,
120                              const std::map<std::string, Blob::Ptr>& blobs,
121                              const std::vector<SizeVector>& inShapes) const override;
122
123     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
124 };
125
126 class INFERENCE_ENGINE_API_CLASS(DeconvolutionValidator) : public ConvolutionValidator {
127 public:
128     void parseParams(CNNLayer* layer) override;
129
130     void checkParams(const CNNLayer* layer) override;
131
132     explicit DeconvolutionValidator(const std::string& _type);
133
134     void checkCorrespondence(const CNNLayer* layer,
135                              const std::map<std::string, Blob::Ptr>& blobs,
136                              const std::vector<SizeVector>& inShapes) const override;
137
138     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
139 };
140
141
142 class INFERENCE_ENGINE_API_CLASS(PoolingValidator) : public LayerValidator {
143 public:
144     void parseParams(CNNLayer* layer) override;
145
146     void checkParams(const CNNLayer* layer) override;
147
148     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
149
150     explicit PoolingValidator(const std::string& _type);
151 };
152
153 class INFERENCE_ENGINE_API_CLASS(FullyConnectedValidator) : public LayerValidator {
154 public:
155     explicit FullyConnectedValidator(const std::string& _type);
156
157     void parseParams(CNNLayer* layer) override;
158
159     void checkParams(const CNNLayer* layer) override;
160
161     void checkCorrespondence(const CNNLayer* layer,
162                              const std::map<std::string, Blob::Ptr>& blobs,
163                              const std::vector<SizeVector>& inShapes) const override;
164
165     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
166 };
167
168 class INFERENCE_ENGINE_API_CLASS(CropValidator) : public LayerValidator {
169 public:
170     explicit CropValidator(const std::string& _type);
171
172     void parseParams(CNNLayer* layer) override;
173
174     void checkParams(const CNNLayer* layer) override;
175
176     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
177 };
178
179 class INFERENCE_ENGINE_API_CLASS(TileValidator) : public LayerValidator {
180 public:
181     explicit TileValidator(const std::string& _type);
182
183     void parseParams(CNNLayer* layer) override;
184
185     void checkParams(const CNNLayer* layer) override;
186
187     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
188 };
189
190 class INFERENCE_ENGINE_API_CLASS(BatchNormalizationValidator) : public LayerValidator {
191 public:
192     explicit BatchNormalizationValidator(const std::string& _type);
193
194     void parseParams(CNNLayer* layer) override;
195
196     void checkParams(const CNNLayer* layer) override;
197
198     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
199 };
200
201 class INFERENCE_ENGINE_API_CLASS(PowerValidator) : public LayerValidator {
202 public:
203     explicit PowerValidator(const std::string& _type);
204
205     void parseParams(CNNLayer* layer) override;
206
207     void checkParams(const CNNLayer* layer) override;
208
209     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
210 };
211
212 class INFERENCE_ENGINE_API_CLASS(PReLUValidator) : public LayerValidator {
213 public:
214     explicit PReLUValidator(const std::string& _type);
215
216     void parseParams(CNNLayer* layer) override;
217
218     void checkParams(const CNNLayer* layer) override;
219
220     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
221 };
222
223 class INFERENCE_ENGINE_API_CLASS(ScaleShiftValidator) : public LayerValidator {
224 public:
225     explicit ScaleShiftValidator(const std::string& _type);
226
227     void parseParams(CNNLayer* layer) override;
228
229     void checkParams(const CNNLayer* layer) override;
230
231     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
232 };
233
234 class INFERENCE_ENGINE_API_CLASS(ReshapeValidator) : public LayerValidator {
235 public:
236     explicit ReshapeValidator(const std::string& _type);
237
238     void parseParams(CNNLayer* layer) override;
239
240     void checkParams(const CNNLayer* layer) override;
241 };
242
243 class INFERENCE_ENGINE_API_CLASS(EltwiseValidator) : public LayerValidator {
244 public:
245     explicit EltwiseValidator(const std::string& _type);
246
247     void parseParams(CNNLayer* layer) override;
248
249     void checkParams(const CNNLayer* layer) override;
250
251     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
252 };
253
254 class INFERENCE_ENGINE_API_CLASS(ClampValidator) : public LayerValidator {
255 public:
256     explicit ClampValidator(const std::string& _type);
257
258     void parseParams(CNNLayer* layer) override;
259
260     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
261 };
262
263 class INFERENCE_ENGINE_API_CLASS(ReLUValidator) : public LayerValidator {
264 public:
265     explicit ReLUValidator(const std::string& _type);
266
267     void parseParams(CNNLayer* layer) override;
268
269     void checkParams(const CNNLayer* layer) override;
270
271     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
272 };
273
274 class INFERENCE_ENGINE_API_CLASS(MVNValidator) : public LayerValidator {
275 public:
276     explicit MVNValidator(const std::string& _type);
277
278     void parseParams(CNNLayer* layer) override;
279
280     void checkParams(const CNNLayer* layer) override;
281
282     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
283 };
284
285 class INFERENCE_ENGINE_API_CLASS(GRNValidator) : public LayerValidator {
286 public:
287     explicit GRNValidator(const std::string& _type);
288
289     void parseParams(CNNLayer* layer) override;
290
291     void checkParams(const CNNLayer* layer) override;
292
293     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
294 };
295
296 class INFERENCE_ENGINE_API_CLASS(SoftMaxValidator) : public LayerValidator {
297 public:
298     explicit SoftMaxValidator(const std::string& _type);
299
300     void parseParams(CNNLayer* layer) override;
301
302     void checkParams(const CNNLayer* layer) override;
303
304     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
305 };
306
307 class INFERENCE_ENGINE_API_CLASS(NormValidator) : public LayerValidator {
308 public:
309     explicit NormValidator(const std::string& _type);
310
311     void parseParams(CNNLayer* layer) override;
312
313     void checkParams(const CNNLayer* layer) override;
314
315     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
316 };
317
318 class INFERENCE_ENGINE_API_CLASS(SplitValidator) : public LayerValidator {
319 public:
320     explicit SplitValidator(const std::string& _type);
321
322     void parseParams(CNNLayer* layer) override;
323
324     void checkParams(const CNNLayer* layer) override;
325
326     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
327 };
328
329 class INFERENCE_ENGINE_API_CLASS(ConcatValidator) : public LayerValidator {
330 public:
331     explicit ConcatValidator(const std::string& _type);
332
333     void parseParams(CNNLayer* layer) override;
334
335     void checkParams(const CNNLayer* layer) override;
336
337     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
338 };
339
340 class INFERENCE_ENGINE_API_CLASS(GemmValidator) : public LayerValidator {
341 public:
342     explicit GemmValidator(const std::string& _type);
343
344     void parseParams(CNNLayer* layer) override;
345
346     void checkParams(const CNNLayer* layer) override;
347
348     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
349 };
350
351 class INFERENCE_ENGINE_API_CLASS(PadValidator) : public LayerValidator {
352 public:
353     explicit PadValidator(const std::string& _type);
354
355     void parseParams(CNNLayer* layer) override;
356
357     void checkParams(const CNNLayer* layer) override;
358
359     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
360 };
361
362 class INFERENCE_ENGINE_API_CLASS(GatherValidator) : public LayerValidator {
363 public:
364     explicit GatherValidator(const std::string& _type);
365
366     void parseParams(CNNLayer* layer) override;
367
368     void checkParams(const CNNLayer* layer) override;
369
370     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
371 };
372
373 class INFERENCE_ENGINE_API_CLASS(StridedSliceValidator) : public LayerValidator {
374 public:
375     explicit StridedSliceValidator(const std::string& _type);
376
377     void parseParams(CNNLayer* layer) override;
378
379     void checkParams(const CNNLayer* layer) override;
380
381     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
382 };
383
384 class INFERENCE_ENGINE_API_CLASS(ShuffleChannelsValidator) : public LayerValidator {
385 public:
386     explicit ShuffleChannelsValidator(const std::string& _type);
387
388     void parseParams(CNNLayer* layer) override;
389
390     void checkParams(const CNNLayer* layer) override;
391
392     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
393 };
394
395 class INFERENCE_ENGINE_API_CLASS(DepthToSpaceValidator) : public LayerValidator {
396 public:
397     explicit DepthToSpaceValidator(const std::string& _type);
398
399     void parseParams(CNNLayer* layer) override;
400
401     void checkParams(const CNNLayer* layer) override;
402
403     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
404 };
405
406 class INFERENCE_ENGINE_API_CLASS(SpaceToDepthValidator) : public LayerValidator {
407 public:
408     explicit SpaceToDepthValidator(const std::string& _type);
409
410     void parseParams(CNNLayer* layer) override;
411
412     void checkParams(const CNNLayer* layer) override;
413
414     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
415 };
416
417 class INFERENCE_ENGINE_API_CLASS(ReverseSequenceValidator) : public LayerValidator {
418 public:
419     explicit ReverseSequenceValidator(const std::string& _type);
420
421     void parseParams(CNNLayer* layer) override;
422
423     void checkParams(const CNNLayer* layer) override;
424
425     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
426 };
427
428 class INFERENCE_ENGINE_API_CLASS(SqueezeValidator) : public LayerValidator {
429 public:
430     explicit SqueezeValidator(const std::string& _type);
431
432     void parseParams(CNNLayer* layer) override;
433
434     void checkParams(const CNNLayer* layer) override;
435
436     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
437 };
438
439 class INFERENCE_ENGINE_API_CLASS(UnsqueezeValidator) : public LayerValidator {
440 public:
441     explicit UnsqueezeValidator(const std::string& _type);
442
443     void parseParams(CNNLayer* layer) override;
444
445     void checkParams(const CNNLayer* layer) override;
446
447     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
448 };
449
450 class INFERENCE_ENGINE_API_CLASS(RangeValidator) : public LayerValidator {
451 public:
452     explicit RangeValidator(const std::string& _type);
453
454     void parseParams(CNNLayer* layer) override;
455
456     void checkParams(const CNNLayer* layer) override;
457
458     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
459 };
460
461 class INFERENCE_ENGINE_API_CLASS(FillValidator) : public LayerValidator {
462 public:
463     explicit FillValidator(const std::string& _type);
464
465     void parseParams(CNNLayer* layer) override;
466
467     void checkParams(const CNNLayer* layer) override;
468
469     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
470 };
471
472 class INFERENCE_ENGINE_API_CLASS(ExpandValidator) : public LayerValidator {
473 public:
474     explicit ExpandValidator(const std::string& _type);
475
476     void parseParams(CNNLayer* layer) override;
477
478     void checkParams(const CNNLayer* layer) override;
479
480     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
481 };
482
483 template<RNNSequenceLayer::CellType CELL>
484 class INFERENCE_ENGINE_API_CLASS(RNNBaseValidator) : public LayerValidator {
485 public:
486     explicit RNNBaseValidator(const std::string& _type);
487
488     void parseParams(CNNLayer* layer) override;
489
490     void checkParams(const CNNLayer* layer) override;
491
492     void checkCorrespondence(const CNNLayer* layer,
493                              const std::map<std::string, Blob::Ptr>& blobs,
494                              const std::vector<SizeVector>& inShapes) const override;
495
496 protected:
497     static std::vector<std::string> def_acts;  // Default values for cell gate activations
498     static std::vector<float> def_alpha;  // Default activation alpha parameter
499     static std::vector<float> def_beta;   // Default activation beta parameter
500     static size_t G;   // gate number
501     static size_t NS;  // state number
502 };
503
504 template<RNNSequenceLayer::CellType CELL>
505 class INFERENCE_ENGINE_API_CLASS(RNNCellValidator) : public RNNBaseValidator<CELL> {
506 public:
507     explicit RNNCellValidator(const std::string& _type);
508
509     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
510 };
511
512 extern template class INFERENCE_ENGINE_API_CLASS(RNNCellValidator)<RNNSequenceLayer::LSTM>;
513 extern template class INFERENCE_ENGINE_API_CLASS(RNNCellValidator)<RNNSequenceLayer::GRU>;
514 extern template class INFERENCE_ENGINE_API_CLASS(RNNCellValidator)<RNNSequenceLayer::RNN>;
515
516 template<RNNSequenceLayer::CellType CELL>
517 class INFERENCE_ENGINE_API_CLASS(RNNSequenceValidator) : public RNNBaseValidator<CELL> {
518 public:
519     explicit RNNSequenceValidator(const std::string& _type);
520
521     void parseParams(CNNLayer* layer) override;
522
523     void checkParams(const CNNLayer* layer) override;
524
525     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
526 };
527
528 extern template class INFERENCE_ENGINE_API_CLASS(RNNSequenceValidator)<RNNSequenceLayer::LSTM>;
529 extern template class INFERENCE_ENGINE_API_CLASS(RNNSequenceValidator)<RNNSequenceLayer::GRU>;
530 extern template class INFERENCE_ENGINE_API_CLASS(RNNSequenceValidator)<RNNSequenceLayer::RNN>;
531
532 class INFERENCE_ENGINE_API_CLASS(ArgMaxValidator) : public LayerValidator {
533 public:
534     explicit ArgMaxValidator(const std::string& _type);
535
536     void checkParams(const CNNLayer* layer) override;
537
538     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
539 };
540
541 class INFERENCE_ENGINE_API_CLASS(CTCGreedyDecoderValidator) : public LayerValidator {
542 public:
543     explicit CTCGreedyDecoderValidator(const std::string& _type);
544
545     void checkParams(const CNNLayer* layer) override;
546
547     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
548 };
549
550 class INFERENCE_ENGINE_API_CLASS(DetectionOutputValidator) : public LayerValidator {
551 public:
552     explicit DetectionOutputValidator(const std::string& _type);
553
554     void parseParams(CNNLayer* layer) override;
555
556     void checkParams(const CNNLayer* layer) override;
557
558     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
559 };
560
561 class INFERENCE_ENGINE_API_CLASS(InterpValidator) : public LayerValidator {
562 public:
563     explicit InterpValidator(const std::string& _type);
564
565     void parseParams(CNNLayer* layer) override;
566
567     void checkParams(const CNNLayer* layer) override;
568
569     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
570 };
571
572 class INFERENCE_ENGINE_API_CLASS(PermuteValidator) : public LayerValidator {
573 public:
574     explicit PermuteValidator(const std::string& _type);
575
576     void checkParams(const CNNLayer* layer) override;
577
578     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
579 };
580
581 class INFERENCE_ENGINE_API_CLASS(PriorBoxValidator) : public LayerValidator {
582 public:
583     explicit PriorBoxValidator(const std::string& _type);
584
585     void checkParams(const CNNLayer* layer) override;
586
587     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
588 };
589
590 class INFERENCE_ENGINE_API_CLASS(PriorBoxClusteredValidator) : public LayerValidator {
591 public:
592     explicit PriorBoxClusteredValidator(const std::string& _type);
593
594     void checkParams(const CNNLayer* layer) override;
595
596     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
597 };
598
599 class INFERENCE_ENGINE_API_CLASS(ProposalValidator) : public LayerValidator {
600 public:
601     explicit ProposalValidator(const std::string& _type);
602
603     void checkParams(const CNNLayer* layer) override;
604
605     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
606 };
607
608 class INFERENCE_ENGINE_API_CLASS(PSROIPoolingValidator) : public LayerValidator {
609 public:
610     explicit PSROIPoolingValidator(const std::string& _type);
611
612     void checkParams(const CNNLayer* layer) override;
613
614     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
615 };
616
617 class INFERENCE_ENGINE_API_CLASS(RegionYoloValidator) : public LayerValidator {
618 public:
619     explicit RegionYoloValidator(const std::string& _type);
620
621     void checkParams(const CNNLayer* layer) override;
622
623     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
624 };
625
626 class INFERENCE_ENGINE_API_CLASS(ReorgYoloValidator) : public LayerValidator {
627 public:
628     explicit ReorgYoloValidator(const std::string& _type);
629
630     void checkParams(const CNNLayer* layer) override;
631
632     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
633 };
634
635 class INFERENCE_ENGINE_API_CLASS(ResampleValidator) : public LayerValidator {
636 public:
637     explicit ResampleValidator(const std::string& _type);
638
639     void checkParams(const CNNLayer* layer) override;
640
641     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
642 };
643
644 class INFERENCE_ENGINE_API_CLASS(ROIPoolingValidator) : public LayerValidator {
645 public:
646     explicit ROIPoolingValidator(const std::string& _type);
647
648     void checkParams(const CNNLayer* layer) override;
649
650     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
651 };
652
653 class INFERENCE_ENGINE_API_CLASS(SimplerNMSValidator) : public LayerValidator {
654 public:
655     explicit SimplerNMSValidator(const std::string& _type);
656
657     void checkParams(const CNNLayer* layer) override;
658
659     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
660 };
661
662 class INFERENCE_ENGINE_API_CLASS(SpatialTransformerValidator) : public LayerValidator {
663 public:
664     explicit SpatialTransformerValidator(const std::string& _type);
665
666     void checkParams(const CNNLayer* layer) override;
667
668     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
669 };
670
671 class INFERENCE_ENGINE_API_CLASS(UpsamplingValidator) : public LayerValidator {
672 public:
673     explicit UpsamplingValidator(const std::string& _type);
674
675     void checkParams(const CNNLayer* layer) override;
676
677     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
678 };
679
680 class INFERENCE_ENGINE_API_CLASS(ActivationValidator) : public LayerValidator {
681 public:
682     explicit ActivationValidator(const std::string& _type);
683
684     void checkParams(const CNNLayer* layer) override;
685
686     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
687 };
688
689 class INFERENCE_ENGINE_API_CLASS(ConstValidator) : public LayerValidator {
690 public:
691     explicit ConstValidator(const std::string& _type);
692
693     void checkParams(const CNNLayer* layer) override;
694
695     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
696 };
697
698 class INFERENCE_ENGINE_API_CLASS(ELUValidator) : public LayerValidator {
699 public:
700     explicit ELUValidator(const std::string& _type);
701
702     void checkParams(const CNNLayer* layer) override;
703
704     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
705 };
706
707 class INFERENCE_ENGINE_API_CLASS(InputValidator) : public LayerValidator {
708 public:
709     explicit InputValidator(const std::string& _type);
710
711     void checkParams(const CNNLayer* layer) override;
712
713     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
714 };
715
716 class INFERENCE_ENGINE_API_CLASS(MemoryValidator) : public LayerValidator {
717 public:
718     explicit MemoryValidator(const std::string& _type);
719
720     void checkParams(const CNNLayer* layer) override;
721
722     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
723 };
724
725 class INFERENCE_ENGINE_API_CLASS(NormalizeValidator) : public LayerValidator {
726 public:
727     explicit NormalizeValidator(const std::string& _type);
728
729     void checkParams(const CNNLayer* layer) override;
730
731     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
732 };
733
734 class INFERENCE_ENGINE_API_CLASS(CopyValidator) : public LayerValidator {
735 public:
736     explicit CopyValidator(const std::string& _type);
737
738     void checkParams(const CNNLayer* layer) override;
739
740     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
741 };
742
743 class INFERENCE_ENGINE_API_CLASS(PowerFileValidator) : public LayerValidator {
744 public:
745     explicit PowerFileValidator(const std::string& _type);
746
747     void checkParams(const CNNLayer* layer) override;
748
749     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
750 };
751
752 class INFERENCE_ENGINE_API_CLASS(ReLU6Validator) : public LayerValidator {
753 public:
754     explicit ReLU6Validator(const std::string& _type);
755
756     void checkParams(const CNNLayer* layer) override;
757
758     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
759 };
760
761 class INFERENCE_ENGINE_API_CLASS(SigmoidValidator) : public LayerValidator {
762 public:
763     explicit SigmoidValidator(const std::string& _type);
764
765     void checkParams(const CNNLayer* layer) override;
766
767     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
768 };
769
770 class INFERENCE_ENGINE_API_CLASS(TanHValidator) : public LayerValidator {
771 public:
772     explicit TanHValidator(const std::string& _type);
773
774     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
775 };
776
777 class INFERENCE_ENGINE_API_CLASS(UnpoolingValidator) : public LayerValidator {
778 public:
779     explicit UnpoolingValidator(const std::string& _type);
780
781     void checkParams(const CNNLayer* layer) override;
782
783     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
784 };
785
786 class INFERENCE_ENGINE_API_CLASS(QuantizeValidator) : public LayerValidator {
787 public:
788     explicit QuantizeValidator(const std::string& _type);
789
790     void parseParams(CNNLayer* layer) override;
791
792     void checkParams(const CNNLayer* layer) override;
793
794     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
795 };
796
797 class INFERENCE_ENGINE_API_CLASS(BinaryConvolutionValidator) : public LayerValidator {
798 public:
799     void parseParams(CNNLayer* layer) override;
800
801     void checkParams(const CNNLayer* layer) override;
802
803     explicit BinaryConvolutionValidator(const std::string& _type);
804
805     void checkCorrespondence(const CNNLayer* layer,
806                              const std::map<std::string, Blob::Ptr>& blobs,
807                              const std::vector<SizeVector>& inShapes) const override;
808
809     void checkShapes(const CNNLayer* layer, const std::vector<SizeVector>& inShapes) const override;
810 };
811
812 template<typename Validator>
813 class ValidatorRegisterBase {
814 public:
815     explicit ValidatorRegisterBase(const std::string& type) {
816         LayerValidators::getInstance()->addImpl(type, std::make_shared<Validator>(type));
817     }
818 };
819
820 #define REG_LAYER_VALIDATOR_FOR_TYPE(__validator, __type) \
821 static ValidatorRegisterBase<__validator> __reg__##__type(#__type)
822
823 REG_LAYER_VALIDATOR_FOR_TYPE(ActivationValidator, Activation);
824 REG_LAYER_VALIDATOR_FOR_TYPE(ArgMaxValidator, ArgMax);
825 REG_LAYER_VALIDATOR_FOR_TYPE(BatchNormalizationValidator, BatchNormalization);
826 REG_LAYER_VALIDATOR_FOR_TYPE(CTCGreedyDecoderValidator, CTCGreedyDecoder);
827 REG_LAYER_VALIDATOR_FOR_TYPE(ClampValidator, Clamp);
828 REG_LAYER_VALIDATOR_FOR_TYPE(ConcatValidator, Concat);
829 REG_LAYER_VALIDATOR_FOR_TYPE(ConstValidator, Const);
830 REG_LAYER_VALIDATOR_FOR_TYPE(ConvolutionValidator, Convolution);
831 REG_LAYER_VALIDATOR_FOR_TYPE(CopyValidator, Copy);
832 REG_LAYER_VALIDATOR_FOR_TYPE(CropValidator, Crop);
833 REG_LAYER_VALIDATOR_FOR_TYPE(DeconvolutionValidator, Deconvolution);
834 REG_LAYER_VALIDATOR_FOR_TYPE(DetectionOutputValidator, DetectionOutput);
835 REG_LAYER_VALIDATOR_FOR_TYPE(ELUValidator, ELU);
836 REG_LAYER_VALIDATOR_FOR_TYPE(EltwiseValidator, Eltwise);
837 REG_LAYER_VALIDATOR_FOR_TYPE(FullyConnectedValidator, InnerProduct);
838 REG_LAYER_VALIDATOR_FOR_TYPE(FullyConnectedValidator, FullyConnected);
839 REG_LAYER_VALIDATOR_FOR_TYPE(GRNValidator, GRN);
840 REG_LAYER_VALIDATOR_FOR_TYPE(InputValidator, Input);
841 REG_LAYER_VALIDATOR_FOR_TYPE(InterpValidator, Interp);
842 REG_LAYER_VALIDATOR_FOR_TYPE(MVNValidator, MVN);
843 REG_LAYER_VALIDATOR_FOR_TYPE(MemoryValidator, Memory);
844 REG_LAYER_VALIDATOR_FOR_TYPE(NormValidator, Norm);
845 REG_LAYER_VALIDATOR_FOR_TYPE(NormValidator, LRN);
846 REG_LAYER_VALIDATOR_FOR_TYPE(NormalizeValidator, Normalize);
847 REG_LAYER_VALIDATOR_FOR_TYPE(PReLUValidator, PReLU);
848 REG_LAYER_VALIDATOR_FOR_TYPE(PSROIPoolingValidator, PSROIPooling);
849 REG_LAYER_VALIDATOR_FOR_TYPE(PermuteValidator, Permute);
850 REG_LAYER_VALIDATOR_FOR_TYPE(PoolingValidator, Pooling);
851 REG_LAYER_VALIDATOR_FOR_TYPE(PowerValidator, Power);
852 REG_LAYER_VALIDATOR_FOR_TYPE(PowerFileValidator, PowerFile);
853 REG_LAYER_VALIDATOR_FOR_TYPE(PriorBoxClusteredValidator, PriorBoxClustered);
854 REG_LAYER_VALIDATOR_FOR_TYPE(PriorBoxValidator, PriorBox);
855 REG_LAYER_VALIDATOR_FOR_TYPE(ProposalValidator, Proposal);
856 REG_LAYER_VALIDATOR_FOR_TYPE(ROIPoolingValidator, ROIPooling);
857 REG_LAYER_VALIDATOR_FOR_TYPE(ReLUValidator, ReLU);
858 REG_LAYER_VALIDATOR_FOR_TYPE(ReLU6Validator, ReLU6);
859 REG_LAYER_VALIDATOR_FOR_TYPE(RegionYoloValidator, RegionYolo);
860 REG_LAYER_VALIDATOR_FOR_TYPE(ReorgYoloValidator, ReorgYolo);
861 REG_LAYER_VALIDATOR_FOR_TYPE(ResampleValidator, Resample);
862 REG_LAYER_VALIDATOR_FOR_TYPE(ReshapeValidator, Reshape);
863 REG_LAYER_VALIDATOR_FOR_TYPE(ReshapeValidator, Flatten);
864 REG_LAYER_VALIDATOR_FOR_TYPE(ScaleShiftValidator, ScaleShift);
865 REG_LAYER_VALIDATOR_FOR_TYPE(SigmoidValidator, Sigmoid);
866 REG_LAYER_VALIDATOR_FOR_TYPE(SigmoidValidator, Logistic);
867 REG_LAYER_VALIDATOR_FOR_TYPE(SimplerNMSValidator, SimplerNMS);
868 REG_LAYER_VALIDATOR_FOR_TYPE(SoftMaxValidator, SoftMax);
869 REG_LAYER_VALIDATOR_FOR_TYPE(SpatialTransformerValidator, SpatialTransformer);
870 REG_LAYER_VALIDATOR_FOR_TYPE(SplitValidator, Split);
871 REG_LAYER_VALIDATOR_FOR_TYPE(SplitValidator, Slice);
872 REG_LAYER_VALIDATOR_FOR_TYPE(GemmValidator, Gemm);
873 REG_LAYER_VALIDATOR_FOR_TYPE(PadValidator, Pad);
874 REG_LAYER_VALIDATOR_FOR_TYPE(GatherValidator, Gather);
875 REG_LAYER_VALIDATOR_FOR_TYPE(StridedSliceValidator, StridedSlice);
876 REG_LAYER_VALIDATOR_FOR_TYPE(ShuffleChannelsValidator, ShuffleChannels);
877 REG_LAYER_VALIDATOR_FOR_TYPE(DepthToSpaceValidator, DepthToSpace);
878 REG_LAYER_VALIDATOR_FOR_TYPE(SpaceToDepthValidator, SpaceToDepth);
879 REG_LAYER_VALIDATOR_FOR_TYPE(ReverseSequenceValidator, ReverseSequence);
880 REG_LAYER_VALIDATOR_FOR_TYPE(RNNCellValidator<RNNSequenceLayer::RNN>, RNNCell);
881 REG_LAYER_VALIDATOR_FOR_TYPE(RNNCellValidator<RNNSequenceLayer::GRU>, GRUCell);
882 REG_LAYER_VALIDATOR_FOR_TYPE(RNNCellValidator<RNNSequenceLayer::LSTM>, LSTMCell);
883 REG_LAYER_VALIDATOR_FOR_TYPE(RNNSequenceValidator<RNNSequenceLayer::RNN>, RNNSequence);
884 REG_LAYER_VALIDATOR_FOR_TYPE(RNNSequenceValidator<RNNSequenceLayer::GRU>, GRUSequence);
885 REG_LAYER_VALIDATOR_FOR_TYPE(RNNSequenceValidator<RNNSequenceLayer::LSTM>, LSTMSequence);
886 REG_LAYER_VALIDATOR_FOR_TYPE(SqueezeValidator, Squeeze);
887 REG_LAYER_VALIDATOR_FOR_TYPE(UnsqueezeValidator, Unsqueeze);
888 REG_LAYER_VALIDATOR_FOR_TYPE(RangeValidator, Range);
889 REG_LAYER_VALIDATOR_FOR_TYPE(FillValidator, Fill);
890 REG_LAYER_VALIDATOR_FOR_TYPE(ExpandValidator, Expand);
891 REG_LAYER_VALIDATOR_FOR_TYPE(TanHValidator, TanH);
892 REG_LAYER_VALIDATOR_FOR_TYPE(TileValidator, Tile);
893 REG_LAYER_VALIDATOR_FOR_TYPE(UnpoolingValidator, Unpooling);
894 REG_LAYER_VALIDATOR_FOR_TYPE(UpsamplingValidator, Upsampling);
895 REG_LAYER_VALIDATOR_FOR_TYPE(QuantizeValidator, Quantize);
896 REG_LAYER_VALIDATOR_FOR_TYPE(BinaryConvolutionValidator, BinaryConvolution);
897 }  // namespace details
898 }  // namespace InferenceEngine