Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorConditionChecker.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17
18 import numpy as np
19
20 from mo.middle.replacement import MiddleReplacementPattern
21
22
23 class ConditionChecks(MiddleReplacementPattern):
24     enabled = True
25     graph_condition = [lambda graph: graph.graph['is_cyclic']]
26
27     def run_after(self):
28         from extensions.middle.TensorIteratorBackEdge import BackEdgesMatching
29         return [BackEdgesMatching]
30
31     def run_before(self):
32         from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
33         return [TensorIteratorMerge]
34
35     @staticmethod
36     def pattern():
37         log.debug('+++++++++++++++ ConditionCheckerMatching ++++++++++++++++')
38         return dict(
39             nodes=[
40                 ('condition', dict(kind='op', op='TensorIteratorCondition')),
41                 ('Strided_slice', dict(kind='op', op='StridedSlice')),
42                 ('Strided_slice_data', dict(kind='data')),
43                 ('shape', dict(kind='op', op='Shape')),
44                 ('shape_data', dict(kind='data')),
45
46                 ('minimum', dict(kind='op', op='Minimum')),
47                 ('minimum_data', dict(kind='data')),
48                 ('Maximum', dict(kind='op', op='Maximum')),
49                 ('Maximum_data', dict(kind='data')),
50             ],
51             edges=[
52                 ('shape', 'shape_data'),
53                 ('shape_data', 'Strided_slice'),
54                 ('Strided_slice', 'Strided_slice_data'),
55                 ('Strided_slice_data', 'condition'),
56                 ('Strided_slice_data', 'minimum'),
57
58                 ('Maximum', 'Maximum_data'),
59                 ('Maximum_data', 'minimum'),
60                 ('minimum', 'minimum_data'),
61                 ('minimum_data', 'condition'),
62             ],
63         )
64
65     @staticmethod
66     def replace_pattern(graph, match: dict):
67         # Check for SS params
68         # Sanity check that we iterate over axis of some tensor
69         ss = match['Strided_slice']
70         params = ss.in_nodes()
71         assert np.all(params[1].in_node().value == 0)
72         assert np.all(params[2].in_node().value == 1)
73         assert np.all(params[3].in_node().value == 1)
74
75         # Check Maximum/Minimum params
76
77         # Check for comparing SS and seq_length source (it should be one tensor)
78         # SIMPLE CHECK
79         assert match['Strided_slice_data'].value is not None
80         if match['minimum_data'].value is None:
81             log.warning('TF loop doesn\'t have a constant upper bound produced by node {}, or ModelOptimizer '
82                         'cannot detect a constant in this case. Loops with a dynamic number of iterations are not '
83                         'supported, so in the resulting IR, generated TensorIterator will have '
84                         'a maximum number of iterations determined by input tensor size: {}'
85                         ''.format(match['minimum_data'].soft_get('name'), match['Strided_slice_data'].value)
86                         )
87         else:
88             assert match['Strided_slice_data'].value == match['minimum_data'].value, \
89                 'Values do not match: {} and {}'.format(match['Strided_slice_data'].value, match['minimum_data'].value)
90
91         # SMART CHECK
92         # TODO: add here some smart check for tensors equality
93
94         # Check that bound for Condition and Inputs/Outputs sizes match
95         condition_time = match['condition'].out_node(0)
96         inputs_and_outputs = condition_time.out_nodes()
97         type_list = ['TensorIteratorInput', 'TensorIteratorOutput']
98
99         for ta in inputs_and_outputs:
100             if ta.has_valid('kind') and ta['kind'] == 'op' and ta['op'] in type_list:
101                 assert ta.in_node(0).id == ss.id
102
103         log.debug('+++++++++++++++ Condition Check was successful ++++++++++++++++')