2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from mo.middle.replacement import MiddleReplacementPattern
23 class ConditionChecks(MiddleReplacementPattern):
25 graph_condition = [lambda graph: graph.graph['is_cyclic']]
28 from extensions.middle.TensorIteratorBackEdge import BackEdgesMatching
29 return [BackEdgesMatching]
32 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
33 return [TensorIteratorMerge]
37 log.debug('+++++++++++++++ ConditionCheckerMatching ++++++++++++++++')
40 ('condition', dict(kind='op', op='TensorIteratorCondition')),
41 ('Strided_slice', dict(kind='op', op='StridedSlice')),
42 ('Strided_slice_data', dict(kind='data')),
43 ('shape', dict(kind='op', op='Shape')),
44 ('shape_data', dict(kind='data')),
46 ('minimum', dict(kind='op', op='Minimum')),
47 ('minimum_data', dict(kind='data')),
48 ('Maximum', dict(kind='op', op='Maximum')),
49 ('Maximum_data', dict(kind='data')),
52 ('shape', 'shape_data'),
53 ('shape_data', 'Strided_slice'),
54 ('Strided_slice', 'Strided_slice_data'),
55 ('Strided_slice_data', 'condition'),
56 ('Strided_slice_data', 'minimum'),
58 ('Maximum', 'Maximum_data'),
59 ('Maximum_data', 'minimum'),
60 ('minimum', 'minimum_data'),
61 ('minimum_data', 'condition'),
66 def replace_pattern(graph, match: dict):
68 # Sanity check that we iterate over axis of some tensor
69 ss = match['Strided_slice']
70 params = ss.in_nodes()
71 assert np.all(params[1].in_node().value == 0)
72 assert np.all(params[2].in_node().value == 1)
73 assert np.all(params[3].in_node().value == 1)
75 # Check Maximum/Minimum params
77 # Check for comparing SS and seq_length source (it should be one tensor)
79 assert match['Strided_slice_data'].value is not None
80 if match['minimum_data'].value is None:
81 log.warning('TF loop doesn\'t have a constant upper bound produced by node {}, or ModelOptimizer '
82 'cannot detect a constant in this case. Loops with a dynamic number of iterations are not '
83 'supported, so in the resulting IR, generated TensorIterator will have '
84 'a maximum number of iterations determined by input tensor size: {}'
85 ''.format(match['minimum_data'].soft_get('name'), match['Strided_slice_data'].value)
88 assert match['Strided_slice_data'].value == match['minimum_data'].value, \
89 'Values do not match: {} and {}'.format(match['Strided_slice_data'].value, match['minimum_data'].value)
92 # TODO: add here some smart check for tensors equality
94 # Check that bound for Condition and Inputs/Outputs sizes match
95 condition_time = match['condition'].out_node(0)
96 inputs_and_outputs = condition_time.out_nodes()
97 type_list = ['TensorIteratorInput', 'TensorIteratorOutput']
99 for ta in inputs_and_outputs:
100 if ta.has_valid('kind') and ta['kind'] == 'op' and ta['op'] in type_list:
101 assert ta.in_node(0).id == ss.id
103 log.debug('+++++++++++++++ Condition Check was successful ++++++++++++++++')