Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / validation_util.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18
19 #include "ngraph/evaluator.hpp"
20 #include "ngraph/op/concat.hpp"
21 #include "ngraph/op/convert.hpp"
22 #include "ngraph/op/min.hpp"
23 #include "ngraph/op/minimum.hpp"
24 #include "ngraph/op/squeeze.hpp"
25 #include "ngraph/op/unsqueeze.hpp"
26 #include "ngraph/runtime/host_tensor.hpp"
27 #include "ngraph/shape.hpp"
28 #include "ngraph/type/element_type_traits.hpp"
29 #include "ngraph/util.hpp"
30 #include "ngraph/validation_util.hpp"
31
32 NGRAPH_SUPPRESS_DEPRECATED_START
33
34 using namespace std;
35 using namespace ngraph;
36
37 Strides ngraph::conv_default_strides(const Node* /* node */,
38                                      const PartialShape& data_batch_shape,
39                                      const PartialShape& filters_shape)
40 {
41     size_t rank;
42
43     if (data_batch_shape.rank().is_static() && data_batch_shape.rank().get_length() >= 2)
44     {
45         rank = data_batch_shape.rank().get_length() - 2;
46     }
47     else if (filters_shape.rank().is_static() && filters_shape.rank().get_length() >= 2)
48     {
49         rank = filters_shape.rank().get_length() - 2;
50     }
51     else
52     {
53         rank = 0;
54     }
55
56     return Strides(rank, 1);
57 }
58
59 CoordinateDiff ngraph::conv_default_padding(const Node* /* node */,
60                                             const PartialShape& data_batch_shape,
61                                             const PartialShape& filters_shape)
62 {
63     size_t rank;
64
65     if (data_batch_shape.rank().is_static() && data_batch_shape.rank().get_length() >= 2)
66     {
67         rank = data_batch_shape.rank().get_length() - 2;
68     }
69     else if (filters_shape.rank().is_static() && filters_shape.rank().get_length() >= 2)
70     {
71         rank = filters_shape.rank().get_length() - 2;
72     }
73     else
74     {
75         rank = 0;
76     }
77
78     return CoordinateDiff(rank, 0);
79 }
80
81 //
82 // Infers the output shape of a windowed reduction operation, where the data may be dilated and/or
83 // padded, and the reduction window may be strided and/or dilated.
84 //
85 // TODO(amprocte): The messages here would be a bit friendlier if we didn't say "after
86 // padding/after dilation" for cases where there is actually no padding/dilation.
87 //
88 PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
89                                                            const PartialShape& data_shape,
90                                                            const Strides& data_dilation,
91                                                            const CoordinateDiff& data_padding_below,
92                                                            const CoordinateDiff& data_padding_above,
93                                                            const PartialShape& window_shape,
94                                                            const Strides& window_strides,
95                                                            const Strides& window_dilation,
96                                                            bool is_window_all_in_padding_allowed,
97                                                            bool ceil_mode)
98 {
99     PartialShape data_shape_merged{PartialShape::dynamic()};
100
101     NODE_VALIDATION_CHECK(node,
102                           data_shape_merged.merge_rank(data_shape.rank()) &&
103                               data_shape_merged.merge_rank(data_dilation.size()) &&
104                               data_shape_merged.merge_rank(data_padding_below.size()) &&
105                               data_shape_merged.merge_rank(data_padding_above.size()) &&
106                               data_shape_merged.merge_rank(window_shape.rank()) &&
107                               data_shape_merged.merge_rank(window_strides.size()) &&
108                               data_shape_merged.merge_rank(window_dilation.size()),
109                           "Ranks for data shape (",
110                           data_shape,
111                           "), data dilation (",
112                           data_dilation,
113                           "), padding below (",
114                           data_padding_below,
115                           "), padding above (",
116                           data_padding_above,
117                           "), window shape (",
118                           window_shape,
119                           "), window strides (",
120                           window_strides,
121                           "), and window dilation (",
122                           window_dilation,
123                           ") do not match.");
124
125     PartialShape output_shape = PartialShape::dynamic(data_shape_merged.rank());
126
127     if (output_shape.rank().is_static())
128     {
129         for (size_t i = 0; i < output_shape.rank().get_length(); i++)
130         {
131             NODE_VALIDATION_CHECK(node,
132                                   data_dilation[i] > 0,
133                                   "Data dilation (",
134                                   data_dilation,
135                                   ") has zero dimension at axis ",
136                                   i,
137                                   ".");
138             NODE_VALIDATION_CHECK(node,
139                                   window_strides[i] > 0,
140                                   "Window strides (",
141                                   window_strides,
142                                   ") has zero dimension at axis ",
143                                   i,
144                                   ".");
145             NODE_VALIDATION_CHECK(node,
146                                   window_dilation[i] > 0,
147                                   "Window dilation (",
148                                   window_dilation,
149                                   ") has zero dimension at axis ",
150                                   i,
151                                   ".");
152
153             bool data_dim_static = data_shape.rank().is_static() && data_shape[i].is_static();
154             bool window_dim_static = window_shape.rank().is_static() && window_shape[i].is_static();
155
156             ptrdiff_t data_padded_dilated_dim = -1;
157             if (data_dim_static)
158             {
159                 data_padded_dilated_dim =
160                     (static_cast<int64_t>(data_dilation[i]) * (data_shape[i].get_length() - 1)) +
161                     1 + data_padding_below[i] + data_padding_above[i];
162                 NODE_VALIDATION_CHECK(
163                     node,
164                     data_padded_dilated_dim > 0,
165                     "Data shape after padding and dilation has dimension less than 1 (dim: ",
166                     data_padded_dilated_dim,
167                     ") at axis ",
168                     i,
169                     ".");
170             }
171
172             ptrdiff_t window_dilated_dim = -1;
173             if (window_dim_static)
174             {
175                 window_dilated_dim =
176                     static_cast<int64_t>(window_dilation[i]) * (window_shape[i].get_length() - 1) +
177                     1;
178
179                 NODE_VALIDATION_CHECK(node,
180                                       window_dilated_dim > 0,
181                                       "Window after dilation has dimension less than 1 (dim: ",
182                                       window_dilated_dim,
183                                       ") at axis ",
184                                       i,
185                                       ".");
186
187                 NODE_VALIDATION_CHECK(
188                     node,
189                     is_window_all_in_padding_allowed ||
190                         (window_dilated_dim > data_padding_below[i] &&
191                          window_dilated_dim > data_padding_above[i]),
192                     "Window after dilation is sometimes entirely in the padding area for axis ",
193                     i,
194                     " (dilated window dimension: ",
195                     window_dilated_dim,
196                     ", padding below dimension: ",
197                     data_padding_below[i],
198                     ", padding above dimension: ",
199                     data_padding_above[i],
200                     ") and this is not ",
201                     "allowed.");
202             }
203
204             if (data_dim_static && window_dim_static)
205             {
206                 NODE_VALIDATION_CHECK(node,
207                                       window_dilated_dim <= data_padded_dilated_dim,
208                                       "Window after dilation has dimension (dim: ",
209                                       window_dilated_dim,
210                                       ") larger than the data shape after padding (dim: ",
211                                       data_padded_dilated_dim,
212                                       ") at axis ",
213                                       i,
214                                       ".");
215
216                 if (ceil_mode)
217                 {
218                     output_shape[i] = ceil_div(static_cast<size_t>(data_padded_dilated_dim) -
219                                                    static_cast<size_t>(window_dilated_dim),
220                                                window_strides[i]) +
221                                       1;
222                 }
223                 else
224                 {
225                     output_shape[i] = ((static_cast<size_t>(data_padded_dilated_dim) -
226                                         static_cast<size_t>(window_dilated_dim)) /
227                                        window_strides[i]) +
228                                       1;
229                 }
230             }
231         }
232     }
233
234     return output_shape;
235 }
236
237 //
238 // Infers the output batch shape and element type for convolution fprop.
239 //
240 PartialShape ngraph::infer_convolution_forward(const Node* node,
241                                                const PartialShape& data_batch_shape,
242                                                const Strides& data_dilation,
243                                                const CoordinateDiff& data_padding_below,
244                                                const CoordinateDiff& data_padding_above,
245                                                const PartialShape& filters_shape,
246                                                const Strides& filter_strides,
247                                                const Strides& filter_dilation)
248 {
249     Rank data_batch_filters_rank{Rank::dynamic()};
250
251     NODE_VALIDATION_CHECK(
252         node,
253         Rank::merge(data_batch_filters_rank, data_batch_shape.rank(), filters_shape.rank()),
254         "Data batch and filters rank do not match (data batch shape: ",
255         data_batch_shape,
256         ", filters shape: ",
257         filters_shape,
258         ").");
259
260     NODE_VALIDATION_CHECK(node,
261                           data_batch_filters_rank.is_dynamic() ||
262                               data_batch_filters_rank.get_length() >= 3,
263                           "Data batch and filters must have rank of at least 3 (one batch axis, ",
264                           "one input-channel axis, and at least one spatial dimension) ",
265                           "(data batch shape: ",
266                           data_batch_shape,
267                           ", filters shape: ",
268                           filters_shape,
269                           ").");
270
271     Rank spatial_rank{Rank::dynamic()};
272     NODE_VALIDATION_CHECK(node,
273                           Rank::merge(spatial_rank, spatial_rank, data_batch_filters_rank - 2) &&
274                               Rank::merge(spatial_rank, spatial_rank, data_dilation.size()) &&
275                               Rank::merge(spatial_rank, spatial_rank, data_padding_below.size()) &&
276                               Rank::merge(spatial_rank, spatial_rank, data_padding_above.size()) &&
277                               Rank::merge(spatial_rank, spatial_rank, filter_strides.size()) &&
278                               Rank::merge(spatial_rank, spatial_rank, filter_dilation.size()),
279                           "Ranks for data item shape/filters shape (data batch has shape ",
280                           data_batch_shape,
281                           ", so data item rank is ",
282                           (data_batch_shape.rank() - 2),
283                           " and filters have shape ",
284                           filters_shape,
285                           ", so filters spatial rank is ",
286                           (filters_shape.rank() - 2),
287                           "), data dilation (",
288                           data_dilation,
289                           "), padding below (",
290                           data_padding_below,
291                           "), padding above (",
292                           data_padding_above,
293                           "), filter strides (",
294                           filter_strides,
295                           "), and filter dilation (",
296                           filter_dilation,
297                           ") do not match.");
298
299     Dimension batch_size =
300         (data_batch_shape.rank().is_static() ? data_batch_shape[0] : Dimension::dynamic());
301     Dimension data_channel_count =
302         (data_batch_shape.rank().is_static() ? data_batch_shape[1] : Dimension::dynamic());
303     PartialShape data_spatial_shape(PartialShape::dynamic(spatial_rank));
304
305     Dimension filter_output_channel_count =
306         (filters_shape.rank().is_static() ? filters_shape[0] : Dimension::dynamic());
307     Dimension filter_input_channel_count =
308         (filters_shape.rank().is_static() ? filters_shape[1] : Dimension::dynamic());
309     PartialShape filter_spatial_shape(PartialShape::dynamic(spatial_rank));
310
311     //
312     // Note: spatial_rank is definitely static at this point.
313     //
314
315     for (size_t i = 0; i < spatial_rank.get_length(); i++)
316     {
317         if (data_batch_shape.rank().is_static())
318         {
319             data_spatial_shape[i] = data_batch_shape[i + 2];
320         }
321
322         if (filters_shape.rank().is_static())
323         {
324             filter_spatial_shape[i] = filters_shape[i + 2];
325         }
326     }
327
328     NODE_VALIDATION_CHECK(
329         node, batch_size.is_dynamic() || batch_size.get_length() > 0, "Batch size is zero.");
330
331     Dimension merged_channel_count;
332
333     NODE_VALIDATION_CHECK(
334         node,
335         Dimension::merge(merged_channel_count, data_channel_count, filter_input_channel_count),
336         "Data batch channel count (",
337         data_channel_count,
338         ") does not match filter input ",
339         "channel count (",
340         filter_input_channel_count,
341         ").");
342
343     NODE_VALIDATION_CHECK(node,
344                           merged_channel_count.is_dynamic() ||
345                               merged_channel_count.get_length() > 0,
346                           "Data batch channel count and/or filter input channel count is zero.");
347
348     NODE_VALIDATION_CHECK(node,
349                           filter_output_channel_count.is_dynamic() ||
350                               filter_output_channel_count.get_length() > 0,
351                           "Filter output channel count is zero.");
352
353     PartialShape data_output_shape = infer_windowed_reduction_output_shape(node,
354                                                                            data_spatial_shape,
355                                                                            data_dilation,
356                                                                            data_padding_below,
357                                                                            data_padding_above,
358                                                                            filter_spatial_shape,
359                                                                            filter_strides,
360                                                                            filter_dilation,
361                                                                            true);
362
363     PartialShape batch_output_shape(PartialShape::dynamic(spatial_rank + 2));
364     batch_output_shape[0] = batch_size;
365     batch_output_shape[1] = filter_output_channel_count;
366
367     for (size_t i = 0; i < spatial_rank.get_length(); i++)
368     {
369         batch_output_shape[i + 2] = data_output_shape[i];
370     }
371
372     return batch_output_shape;
373 }
374
375 //
376 // Infers the output batch shape and element type for batched pooling fprop.
377 //
378 PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
379                                                    const PartialShape& data_batch_shape,
380                                                    const CoordinateDiff& data_padding_below,
381                                                    const CoordinateDiff& data_padding_above,
382                                                    const PartialShape& window_shape,
383                                                    const Strides& window_strides,
384                                                    bool is_window_all_in_padding_allowed,
385                                                    bool ceil_mode)
386 {
387     NODE_VALIDATION_CHECK(node,
388                           data_batch_shape.rank().is_dynamic() ||
389                               data_batch_shape.rank().get_length() >= 3,
390                           "Data batch must have rank of at least 3 (one batch axis, ",
391                           "one input-channel axis, and at least one spatial dimension) ",
392                           "(data batch shape: ",
393                           data_batch_shape,
394                           ").");
395
396     PartialShape data_spatial_shape{PartialShape::dynamic()};
397
398     NODE_VALIDATION_CHECK(node,
399                           data_spatial_shape.merge_rank(data_batch_shape.rank() - 2) &&
400                               data_spatial_shape.merge_rank(data_padding_below.size()) &&
401                               data_spatial_shape.merge_rank(data_padding_above.size()) &&
402                               data_spatial_shape.merge_rank(window_shape.rank()) &&
403                               data_spatial_shape.merge_rank(window_strides.size()),
404                           "Ranks for data item shape (data batch has shape ",
405                           data_batch_shape,
406                           ", so data item rank is ",
407                           (data_batch_shape.rank() - 2),
408                           "), padding below (",
409                           data_padding_below,
410                           "), padding above (",
411                           data_padding_above,
412                           "), window shape (",
413                           window_shape,
414                           "), and window strides (",
415                           window_strides,
416                           ") do not match.");
417
418     Dimension batch_size{Dimension::dynamic()};
419     Dimension channel_count{Dimension::dynamic()};
420     PartialShape data_output_spatial_shape{PartialShape::dynamic(data_spatial_shape.rank())};
421
422     if (data_batch_shape.rank().is_static())
423     {
424         batch_size = data_batch_shape[0];
425         channel_count = data_batch_shape[1];
426
427         for (size_t i = 0; i < data_spatial_shape.rank().get_length(); i++)
428         {
429             data_spatial_shape[i] = data_batch_shape[i + 2];
430         }
431
432         NODE_VALIDATION_CHECK(
433             node, batch_size.is_dynamic() || batch_size.get_length() > 0, "Batch size is zero.");
434
435         NODE_VALIDATION_CHECK(node,
436                               channel_count.is_dynamic() || channel_count.get_length() > 0,
437                               "Channel count is zero.");
438
439         // For pooling ops we don't need dilation, so we fill in the identity value (all 1).
440         Strides data_dilation(data_spatial_shape.rank().get_length(), 1);
441         Strides window_dilation(data_spatial_shape.rank().get_length(), 1);
442
443         data_output_spatial_shape =
444             infer_windowed_reduction_output_shape(node,
445                                                   data_spatial_shape,
446                                                   data_dilation,
447                                                   data_padding_below,
448                                                   data_padding_above,
449                                                   window_shape,
450                                                   window_strides,
451                                                   window_dilation,
452                                                   is_window_all_in_padding_allowed,
453                                                   ceil_mode);
454     }
455
456     PartialShape data_batch_output_shape{
457         PartialShape::dynamic(data_output_spatial_shape.rank() + 2)};
458     data_batch_output_shape[0] = batch_size;
459     data_batch_output_shape[1] = channel_count;
460
461     for (size_t i = 0; i < data_spatial_shape.rank().get_length(); i++)
462     {
463         data_batch_output_shape[i + 2] = data_output_spatial_shape[i];
464     }
465
466     return data_batch_output_shape;
467 }
468
469 struct ChannelShapedInputSpec
470 {
471     element::Type m_element_type;
472     PartialShape m_shape;
473     std::string m_input_name;
474 };
475
476 static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_forward_helper(
477     const Node* node,
478     element::Type input_element_type,
479     const PartialShape& input_shape,
480     const std::vector<ChannelShapedInputSpec>& channel_shaped_inputs)
481 {
482     // Built up a slash-separated string naming all the channel-shaped inputs, for use in error
483     // messages.
484     std::stringstream ss;
485     bool first = true;
486     for (auto& inp : channel_shaped_inputs)
487     {
488         if (!first)
489         {
490             ss << "/";
491         }
492         ss << inp.m_input_name;
493         first = false;
494     }
495     std::string channel_input_names = ss.str();
496
497     // Infer output element type.
498     element::Type et_result{input_element_type};
499
500     for (auto& inp : channel_shaped_inputs)
501     {
502         NODE_VALIDATION_CHECK(node,
503                               element::Type::merge(et_result, et_result, inp.m_element_type),
504                               "Input element types do not match.");
505     }
506
507     // Extract channel dimension from input shape.
508     Dimension channel_dim{Dimension::dynamic()};
509
510     NODE_VALIDATION_CHECK(node,
511                           input_shape.is_dynamic() || input_shape.rank().get_length() >= 2,
512                           "Input argument must have rank of at least 2 (input argument shape: ",
513                           input_shape,
514                           ").");
515
516     if (input_shape.rank().is_static())
517     {
518         channel_dim = input_shape[1];
519     }
520
521     // Infer gamma/beta/mu/sigma shape, which must be consistent with a vector of size
522     // "channel_dim".
523     PartialShape channel_shape{PartialShape::dynamic()};
524
525     for (auto& inp : channel_shaped_inputs)
526     {
527         NODE_VALIDATION_CHECK(node,
528                               PartialShape::merge_into(channel_shape, inp.m_shape),
529                               "Shapes for ",
530                               channel_input_names,
531                               " do not match.");
532     }
533
534     NODE_VALIDATION_CHECK(node,
535                           channel_shape.merge_rank(1),
536                           "Shape for ",
537                           channel_input_names,
538                           " (",
539                           channel_shape,
540                           ") does not have rank 1.");
541
542     NODE_VALIDATION_CHECK(node,
543                           Dimension::merge(channel_dim, channel_dim, channel_shape[0]),
544                           "Input channel dimension (",
545                           channel_dim,
546                           ") does not match shape for ",
547                           channel_input_names,
548                           " (",
549                           channel_shape,
550                           ").");
551
552     NODE_VALIDATION_CHECK(node,
553                           channel_dim.is_dynamic() || channel_dim.get_length() >= 1,
554                           "Channel count must be at least 1.");
555
556     // Batch result shape is same as the input shape, except we may possibly have inferred more
557     // information from the channel count via gamma/beta/etc.
558     PartialShape batch_result_shape{input_shape};
559
560     if (batch_result_shape.rank().is_static())
561     {
562         batch_result_shape[1] = channel_dim;
563     }
564
565     return std::make_tuple(et_result, batch_result_shape, PartialShape{channel_dim});
566 }
567
568 std::tuple<element::Type, PartialShape, PartialShape>
569     ngraph::infer_batch_norm_forward(const Node* node,
570                                      element::Type input_element_type,
571                                      element::Type gamma_element_type,
572                                      element::Type beta_element_type,
573                                      element::Type mean_element_type,
574                                      element::Type variance_element_type,
575                                      const PartialShape& input_shape,
576                                      const PartialShape& gamma_shape,
577                                      const PartialShape& beta_shape,
578                                      const PartialShape& mean_shape,
579                                      const PartialShape& variance_shape)
580 {
581     return infer_batch_norm_forward_helper(node,
582                                            input_element_type,
583                                            input_shape,
584                                            {{gamma_element_type, gamma_shape, "gamma"},
585                                             {beta_element_type, beta_shape, "beta"},
586                                             {mean_element_type, mean_shape, "mean"},
587                                             {variance_element_type, variance_shape, "variance"}});
588 }
589
590 std::tuple<element::Type, PartialShape, PartialShape>
591     ngraph::infer_batch_norm_forward(const Node* node,
592                                      element::Type input_element_type,
593                                      element::Type gamma_element_type,
594                                      element::Type beta_element_type,
595                                      const PartialShape& input_shape,
596                                      const PartialShape& gamma_shape,
597                                      const PartialShape& beta_shape)
598 {
599     return infer_batch_norm_forward_helper(
600         node,
601         input_element_type,
602         input_shape,
603         {{gamma_element_type, gamma_shape, "gamma"}, {beta_element_type, beta_shape, "beta"}});
604 }
605
606 void ngraph::infer_auto_padding(const Shape& image_shape,
607                                 const Shape& filter_shape,
608                                 const Strides& filter_strides,
609                                 const Strides& filter_dilations,
610                                 const op::PadType pad_type,
611                                 CoordinateDiff& padding_above,
612                                 CoordinateDiff& padding_below)
613 {
614     const auto image_dims = std::vector<Dimension>(std::begin(image_shape), std::end(image_shape));
615     // because image_shape is fully known result of try_apply_infer_auto_padding is ignored
616     try_apply_auto_padding(image_dims,
617                            filter_shape,
618                            filter_strides,
619                            filter_dilations,
620                            pad_type,
621                            padding_above,
622                            padding_below);
623 }
624
625 bool ngraph::try_apply_auto_padding(const PartialShape& image_shape,
626                                     const Shape& filter_shape,
627                                     const Strides& filter_strides,
628                                     const Strides& filter_dilations,
629                                     const op::PadType pad_type,
630                                     CoordinateDiff& padding_above,
631                                     CoordinateDiff& padding_below)
632 {
633     NGRAPH_CHECK(pad_type == op::PadType::SAME_UPPER || pad_type == op::PadType::SAME_LOWER);
634
635     if (image_shape.rank().is_dynamic())
636     {
637         return false;
638     }
639     const auto image_dims = static_cast<std::vector<Dimension>>(image_shape);
640     const bool are_spatial_dims_static =
641         std::all_of(std::begin(image_dims) + 2, std::end(image_dims), [](const Dimension& dim) {
642             return dim.is_static();
643         });
644     if (!are_spatial_dims_static)
645     {
646         return false;
647     }
648
649     for (size_t i = 0; i < static_cast<size_t>(filter_shape.size()); i++)
650     {
651         int64_t image_size = static_cast<int64_t>(image_dims[i + 2].get_length());
652         int64_t filter_size = (static_cast<int64_t>(filter_shape[i]) - 1) * filter_dilations[i] + 1;
653         int64_t filter_stride = static_cast<int64_t>(filter_strides[i]);
654         auto output_size = (image_size + filter_stride - 1) / filter_stride;
655
656         auto padding_needed =
657             std::max(int64_t(0), (output_size - 1) * filter_stride + filter_size - image_size);
658         auto padding_lhs = padding_needed / 2;
659         auto padding_rhs = padding_needed - padding_lhs;
660         padding_below.push_back(pad_type == op::PadType::SAME_UPPER ? padding_lhs : padding_rhs);
661         padding_above.push_back(pad_type == op::PadType::SAME_UPPER ? padding_rhs : padding_lhs);
662     }
663     return true;
664 }
665
666 PartialShape ngraph::infer_slice_shape(const Node* node,
667                                        const PartialShape& input_shape,
668                                        const std::vector<int64_t>& begin,
669                                        const std::vector<int64_t>& end,
670                                        const std::vector<int64_t>& strides,
671                                        const AxisSet& begin_mask,
672                                        const AxisSet& end_mask,
673                                        const AxisSet& new_axis_mask,
674                                        const AxisSet& shrink_axis_mask,
675                                        const AxisSet& ellipsis_mask)
676 {
677     if (begin.size() && end.size())
678     {
679         NODE_VALIDATION_CHECK(node,
680                               begin.size() == end.size(),
681                               "Lower bounds and Upper bounds needs to have same number of values");
682     }
683     if (begin.size() && strides.size())
684     {
685         NODE_VALIDATION_CHECK(node,
686                               begin.size() == strides.size(),
687                               "Lower bounds and strides needs to have same number of values");
688     }
689     if (end.size() && strides.size())
690     {
691         NODE_VALIDATION_CHECK(node,
692                               end.size() == strides.size(),
693                               "Upper bounds and strides needs to have same number of values");
694     }
695
696     NODE_VALIDATION_CHECK(node, ellipsis_mask.size() <= 1, "At most one ellipsis is allowed.");
697
698     if (input_shape.rank().is_dynamic())
699     {
700         return PartialShape::dynamic();
701     }
702
703     NODE_VALIDATION_CHECK(node,
704                           input_shape.rank().get_length() + new_axis_mask.size() >= begin.size(),
705                           "Input rank plus number of new axis has to be at least the size of Lower "
706                           "and Upper bounds vector.");
707
708     std::vector<Dimension> dim;
709
710     size_t input_shape_idx = 0;
711     for (size_t axis = 0; axis < begin.size(); ++axis)
712     {
713         // add all dimensions hidden under the ellipsis mask if ellipsis mask is set
714         if (ellipsis_mask.count(axis))
715         {
716             // only one bit in ellipsis mask is allowed
717             int num_new_axis_after_ellipses = 0;
718             int num_input_axis_before_ellipses = 0;
719             for (size_t i = 0; i < axis; ++i)
720             {
721                 if (!new_axis_mask.count(i))
722                 {
723                     num_input_axis_before_ellipses++;
724                 }
725             }
726             for (size_t i = axis + 1; i < begin.size(); ++i)
727             {
728                 if (new_axis_mask.count(i))
729                 {
730                     num_new_axis_after_ellipses++;
731                 }
732             }
733
734             int64_t num_input_axis_after_ellipses =
735                 (begin.size() - axis - num_new_axis_after_ellipses -
736                  1); // -1 because it's a position of ellipses
737             int64_t num_of_hidden_dims = input_shape.rank().get_length() -
738                                          num_input_axis_after_ellipses -
739                                          num_input_axis_before_ellipses;
740             for (int64_t i = 0; i < num_of_hidden_dims; ++i)
741             {
742                 dim.emplace_back(input_shape[input_shape_idx]);
743                 input_shape_idx++;
744             }
745         }
746         else
747         {
748             // add new single dimension if new_axis_mask is set
749             if (new_axis_mask.count(axis))
750             {
751                 dim.emplace_back(1);
752             }
753             // skip this dimension if shrink_axis_mask is set
754             else if (shrink_axis_mask.count(axis))
755             {
756                 input_shape_idx++;
757             }
758             // calculating dimension (begin, end, begin_mask, end_mask, stride)
759             else
760             {
761                 // check dynamic dimension
762                 if (input_shape[input_shape_idx].is_dynamic())
763                 {
764                     input_shape_idx++;
765                     dim.emplace_back(Dimension::dynamic());
766                     continue;
767                 }
768
769                 int64_t lb = begin[axis];
770                 int64_t ub = end[axis];
771
772                 // set default value for stride or use given value
773                 int64_t stride = 1;
774                 if (strides.size() > axis)
775                 {
776                     stride = strides[axis];
777                 }
778                 NODE_VALIDATION_CHECK(node, stride != 0, "Stride must be non-zero");
779
780                 // convert negative indexes to positive
781                 // take max for this case: if abs(lb) > input_shape[input_shape_idx],then after
782                 // conversion lb < 0
783                 // so according to tensorflow and numpy we just get 0
784                 if (lb < 0)
785                 {
786                     lb = std::max(input_shape[input_shape_idx].get_length() + lb, int64_t(0));
787                 }
788
789                 if (ub < 0)
790                 {
791                     ub = std::max(input_shape[input_shape_idx].get_length() + ub,
792                                   stride > 0 ? int64_t(0) : int64_t(-1));
793                 }
794
795                 // apply restrictions when begin or end values more than max possible values.
796                 lb = std::min(input_shape[input_shape_idx].get_length(), lb);
797                 ub = std::min(input_shape[input_shape_idx].get_length(), ub);
798
799                 int64_t dimension = 0;
800                 if (stride < 0)
801                 {
802                     // apply masks
803                     if (begin_mask.count(axis))
804                     {
805                         lb = input_shape[input_shape_idx].get_length() - 1;
806                     }
807                     if (end_mask.count(axis))
808                     {
809                         ub = -1;
810                     }
811
812                     lb = std::min(lb, input_shape[input_shape_idx].get_length() - 1);
813                     lb -= 1; // we always get 1st element, so we need decrease range
814                     if (ub <= lb)
815                     {
816                         dimension = (ub - lb) / stride + 1;
817                     }
818                 }
819                 else
820                 {
821                     // apply masks
822                     if (begin_mask.count(axis))
823                     {
824                         lb = 0;
825                     }
826                     if (end_mask.count(axis))
827                     {
828                         ub = input_shape[input_shape_idx].get_length();
829                     }
830
831                     lb += 1; // we always get 1st element, so we need decrease range
832                     if (ub >= lb)
833                     {
834                         dimension = (ub - lb) / stride + 1;
835                     }
836                 }
837
838                 dim.emplace_back(dimension);
839                 input_shape_idx++;
840             }
841         }
842     }
843     // get remaining values
844     for (; input_shape_idx < input_shape.rank().get_length(); ++input_shape_idx)
845     {
846         dim.emplace_back(input_shape[input_shape_idx]);
847     }
848
849     return dim;
850 }
851
852 std::vector<size_t> ngraph::normalize_axes(const std::string& node_description,
853                                            const std::vector<int64_t>& axes,
854                                            const Rank& tensor_rank)
855 {
856     std::vector<size_t> new_axes;
857
858     for (const auto& axis : axes)
859     {
860         new_axes.push_back(normalize_axis(node_description, axis, tensor_rank));
861     }
862
863     return new_axes;
864 }
865
866 int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, const Rank& tensor_rank)
867 {
868     return normalize_axis(node->description(), axis, tensor_rank);
869 }
870
871 int64_t ngraph::normalize_axis(const std::string& node_description,
872                                std::int64_t axis,
873                                const Rank& tensor_rank)
874 {
875     if (axis < 0)
876     {
877         // Handling negative axis requires static tensor rank
878         NGRAPH_CHECK(tensor_rank.is_static(),
879                      node_description,
880                      " Rank must be static in order to normalize negative axis=",
881                      axis);
882     }
883     if (tensor_rank.is_dynamic())
884     {
885         return axis;
886     }
887
888     const auto tensor_rank_value = tensor_rank.get_length();
889     return normalize_axis(node_description,
890                           axis,
891                           tensor_rank_value,
892                           -tensor_rank_value,
893                           tensor_rank_value ? (tensor_rank_value - 1) : 0);
894 }
895
896 int64_t ngraph::normalize_axis(const Node* node,
897                                std::int64_t axis,
898                                std::uint64_t tensor_rank,
899                                std::int64_t axis_range_min,
900                                std::int64_t axis_range_max)
901 {
902     return ngraph::normalize_axis(
903         node->description(), axis, tensor_rank, axis_range_min, axis_range_max);
904 }
905
906 int64_t ngraph::normalize_axis(const std::string& node_description,
907                                std::int64_t axis,
908                                std::uint64_t tensor_rank,
909                                std::int64_t axis_range_min,
910                                std::int64_t axis_range_max)
911 {
912     // Accepted range of value for axis is [axis_range_min, axis_range_max].
913     NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
914                  node_description,
915                  " Parameter axis ",
916                  axis,
917                  " out of the tensor rank range [",
918                  axis_range_min,
919                  ", ",
920                  axis_range_max,
921                  "].");
922
923     if (axis < 0)
924     {
925         axis = axis + tensor_rank;
926     }
927
928     return int64_t(axis);
929 }
930
931 void ngraph::opset1::infer_conv_backprop_auto_padding(const Shape& input_data_shape,
932                                                       const Shape& filters_shape,
933                                                       const Shape& output_shape,
934                                                       const Strides& strides,
935                                                       const Strides& dilations,
936                                                       const op::PadType auto_pad_type,
937                                                       const CoordinateDiff& output_padding,
938                                                       CoordinateDiff& pads_begin,
939                                                       CoordinateDiff& pads_end)
940 {
941     NGRAPH_CHECK(auto_pad_type == op::PadType::SAME_UPPER ||
942                  auto_pad_type == op::PadType::SAME_LOWER);
943
944     size_t num_spatial_dims = input_data_shape.size();
945     NGRAPH_CHECK(filters_shape.size() == num_spatial_dims && strides.size() == num_spatial_dims &&
946                  dilations.size() == num_spatial_dims && pads_begin.size() == num_spatial_dims &&
947                  pads_end.size() == num_spatial_dims && output_padding.size() == num_spatial_dims);
948
949     pads_begin = CoordinateDiff(num_spatial_dims);
950     pads_end = CoordinateDiff(num_spatial_dims);
951
952     for (uint64_t i = 0; i < num_spatial_dims; ++i)
953     {
954         int total_padding = strides[i] * (input_data_shape[i] - 1) +
955                             dilations[i] * (filters_shape[i] - 1) + 1 - output_shape[i] +
956                             output_padding[i];
957         if (auto_pad_type != op::PadType::SAME_UPPER)
958         {
959             pads_begin[i] = total_padding / 2;
960             pads_end[i] = total_padding - pads_begin[i];
961         }
962         else
963         {
964             pads_end[i] = total_padding / 2;
965             pads_begin[i] = total_padding - pads_end[i];
966         }
967     }
968 }
969
970 namespace
971 {
972     /// \brief Scalar variant describes value of an Output, for use in max shape determination
973     ///
974     /// For tensor values, we use the maximum value in the tensor
975     struct MaxValue
976     {
977         /// \brief No information known about the output
978         MaxValue() {}
979         /// \brief uint64_t assoiated with the output
980         MaxValue(uint64_t value)
981             : m_value(value)
982         {
983         }
984         MaxValue(const vector<uint64_t>& slices, int64_t slice_axis)
985             : m_slices(slices)
986             , m_slice_axis(slice_axis)
987         {
988             m_value = *max_element(m_slices.begin(), m_slices.end());
989         }
990         uint64_t m_value{numeric_limits<uint64_t>::max()};
991         vector<uint64_t> m_slices;
992         int64_t m_slice_axis{-1};
993     };
994
995     vector<MaxValue> exec_constant(Node* node, vector<MaxValue>& inputs)
996     {
997         auto result = MaxValue();
998         auto op = as_type<op::Constant>(node);
999         auto element_type = op->get_output_element_type(0);
1000         if (element_type.is_integral())
1001         {
1002             uint64_t max_val = 0;
1003             if (element_type.is_signed())
1004             {
1005                 for (auto elt : op->cast_vector<int64_t>())
1006                 {
1007                     if (max_val < elt)
1008                     {
1009                         max_val = elt;
1010                     }
1011                 }
1012             }
1013             else
1014             {
1015                 for (auto elt : op->cast_vector<uint64_t>())
1016                 {
1017                     if (max_val < elt)
1018                     {
1019                         max_val = elt;
1020                     }
1021                 }
1022             }
1023             result = MaxValue(max_val);
1024         }
1025         return {result};
1026     }
1027
1028     vector<MaxValue> exec_minimum(Node* node, vector<MaxValue>& inputs)
1029     {
1030         uint64_t min_value = numeric_limits<uint64_t>::max();
1031         switch (node->get_output_element_type(0))
1032         {
1033         case element::Type_t::i8: min_value = numeric_limits<int8_t>::max(); break;
1034         case element::Type_t::i16: min_value = numeric_limits<int16_t>::max(); break;
1035         case element::Type_t::i32: min_value = numeric_limits<int32_t>::max(); break;
1036         case element::Type_t::i64: min_value = numeric_limits<int64_t>::max(); break;
1037         case element::Type_t::u8: min_value = numeric_limits<uint8_t>::max(); break;
1038         case element::Type_t::u16: min_value = numeric_limits<uint16_t>::max(); break;
1039         case element::Type_t::u32: min_value = numeric_limits<uint32_t>::max(); break;
1040         case element::Type_t::u64: min_value = numeric_limits<uint64_t>::max(); break;
1041         default: break;
1042         }
1043         min_value = min(min_value, inputs.at(0).m_value);
1044         min_value = min(min_value, inputs.at(1).m_value);
1045         return {MaxValue(min_value)};
1046     }
1047
1048     vector<MaxValue> exec_concat(Node* node, vector<MaxValue>& inputs)
1049     {
1050         auto op = as_type<op::v0::Concat>(node);
1051         vector<uint64_t> slice_maxen;
1052         for (auto input : inputs)
1053         {
1054             slice_maxen.push_back(input.m_value);
1055         }
1056         auto axis = op->get_concatenation_axis();
1057         return {MaxValue(slice_maxen, axis)};
1058     }
1059
1060     vector<MaxValue> exec_reduce_min(Node* node, vector<MaxValue>& inputs)
1061     {
1062         auto data = inputs.at(0);
1063         if (data.m_slice_axis >= 0 && data.m_slices.size() > 1)
1064         {
1065             if (auto indices_const = as_type<op::v0::Constant>(node->get_input_node_ptr(1)))
1066             {
1067                 if (indices_const->get_output_element_type(0).is_integral())
1068                 {
1069                     auto indices_shape = indices_const->get_output_shape(0);
1070                     if (indices_shape == Shape{1})
1071                     {
1072                         auto indices = indices_const->cast_vector<int64_t>();
1073                         auto axis = indices.at(0);
1074                         if (axis == data.m_slice_axis)
1075                         {
1076                             return {
1077                                 MaxValue(*min_element(data.m_slices.begin(), data.m_slices.end()))};
1078                         }
1079                     }
1080                 }
1081             }
1082         }
1083         // Noting we can do
1084         return {MaxValue(data.m_value)};
1085     }
1086     vector<MaxValue> exec_nop(Node* node, vector<MaxValue>& inputs) { return {inputs.at(0)}; }
1087 }
1088
1089 pair<bool, uint64_t> ngraph::maximum_value(const Output<Node>& value)
1090 {
1091     static Evaluator<MaxValue>::op_handler_map handlers = {
1092         {op::v0::Concat::type_info, exec_concat},
1093         {op::v0::Constant::type_info, exec_constant},
1094         {op::v0::Convert::type_info, exec_nop},
1095         {op::v0::Minimum::type_info, exec_minimum},
1096         {op::v1::Minimum::type_info, exec_minimum},
1097         {op::v1::ReduceMin::type_info, exec_reduce_min},
1098         {op::v0::Squeeze::type_info, exec_nop},
1099         {op::v0::Unsqueeze::type_info, exec_nop}};
1100     Evaluator<MaxValue>::value_map value_map;
1101     Evaluator<MaxValue> evaluator(handlers, value_map);
1102     auto val = evaluator.evaluate(value);
1103     return pair<bool, uint64_t>(val.m_value < numeric_limits<uint64_t>::max(), val.m_value);
1104 }
1105
1106 void ngraph::evaluate_nodes(std::map<RawNodeOutput, HostTensorPtr>& value_map,
1107                             std::map<RawNodeOutput, HostTensorPtr>& output_tensor_map,
1108                             const OutputVector& outputs)
1109 {
1110     Evaluator<HostTensorPtr> evaluator({}, value_map);
1111     evaluator.set_univeral_handler(
1112         [&output_tensor_map](Node* node,
1113                              const HostTensorVector& input_tensors) -> HostTensorVector {
1114             HostTensorVector output_tensors;
1115             for (auto v : node->outputs())
1116             {
1117                 auto it = output_tensor_map.find(v);
1118                 if (it == output_tensor_map.end())
1119                 {
1120                     auto c = make_shared<HostTensor>(v);
1121                     output_tensors.push_back(c);
1122                 }
1123                 else
1124                 {
1125                     output_tensors.push_back(it->second);
1126                 }
1127             }
1128             if (node->evaluate(output_tensors, input_tensors))
1129             {
1130                 return output_tensors;
1131             }
1132             else
1133             {
1134                 NGRAPH_CHECK(false, "Evaluation failed on ", node);
1135             }
1136         });
1137     for (auto value : outputs)
1138     {
1139         evaluator.evaluate(value);
1140     }
1141 }