Implement reshapeable CTCGreedyDecoderPlusSparseToDense transformation and test ...
authorRoman Kazantsev <roman.kazantsev@intel.com>
Fri, 28 Aug 2020 11:28:32 +0000 (14:28 +0300)
committerGitHub <noreply@github.com>
Fri, 28 Aug 2020 11:28:32 +0000 (14:28 +0300)
* Implement reshapeable CTCGreedyDecoderPlusSparseToDense transformation and test

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix consts (after code-review #1)

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Add CTCGreedyDecoderTransformation with more generic pattern

Also it adds new middle-replacer for transforming sequence length to a mask
along with tests.

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Do fixes after review #2

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix after review #3

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix after review #4

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/front/tf/CTCGreedyDecoderReplacement.py
model-optimizer/extensions/front/tf/CTCGreedyDecoderReplacement_test.py [new file with mode: 0644]
model-optimizer/extensions/middle/SequenceLengthToMask.py [new file with mode: 0644]
model-optimizer/extensions/middle/SequenceLenthToMask_test.py [new file with mode: 0644]
model-optimizer/extensions/ops/ctc_greedy_decoder.py

index 4167d68..02e61f4 100644 (file)
@@ -561,6 +561,7 @@ extensions/middle/ReverseTransposeNormalization.py
 extensions/middle/ReverseV2ToReverseSequence.py
 extensions/middle/RNNSequenceNormalizeToIE.py
 extensions/middle/ScaleInput.py
+extensions/middle/SequenceLengthToMask.py
 extensions/middle/SharedWeightsDuplication.py
 extensions/middle/SliceConverter.py
 extensions/middle/SliceLikeToStridedSlice.py
index 2086c1c..c3df89e 100644 (file)
  limitations under the License.
 """
 
+import logging as log
+
 import numpy as np
 
+from extensions.ops.Cast import Cast
+from extensions.front.Pack import Pack
+from extensions.front.FillToBroadcast import FillToBroadcast
+from mo.front.common.partial_infer.utils import int64_array
 from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import Graph
-from mo.ops.const import Const
-from mo.utils.error import Error
+from mo.front.tf.graph_utils import create_op_with_const_inputs
+from mo.graph.graph import Graph, rename_nodes
+from mo.ops.broadcast import Broadcast
+from mo.ops.concat import Concat
+from mo.ops.squeeze import Squeeze
+from mo.ops.unsqueeze import Unsqueeze
 
 
 class CTCGreedyDecoderReplacement(FrontReplacementSubgraph):
     """
+    TensorFlow CTCGreedyDecoder produces output in a sparse tensor that is not supported by Inference Engine and
+    Inference Engine's CTCGreedyDecoder has different output that is in a dense format. So this transformation
+    intents to replace TF CTCGreedyDecoder+SparseToDense with IE one.
+    Also Inference Engine's CTCGreedyDecoder has a specific format for the second input tensor, a sequence length,
+    different from TF's one so this transformation cares about transformation of its format.
+    The second input to the CTCGreedyDecoder in the TensorFlow is a 1D tensor with sequence lengths. In the Inference
+    Engine the second input to the CTCGreedyDecoder is a 2D tensor, a sequence mask, where the first element
+    in each row is equal to 1 and all others in the tail are equal to 0. The number of ones represents
+    a sequence length.
+    """
+    enabled = True
+
+    def run_after(self):
+        # CTCGreedyDecoderReplacement is not reshape-able transformation
+        # so reshape-able CTCGreedyDecoderReplacement2 transformation is applied first
+        return [CTCGreedyDecoderReplacement2]
+
+    @staticmethod
+    def pattern(**kwargs):
+        return dict(
+            nodes=[('decoder', dict(op='CTCGreedyDecoder')),
+                   ('cast', dict(op='Cast')),
+                   ('sparse_to_dense', dict(op='SparseToDense'))
+                   ],
+            edges=[('decoder', 'sparse_to_dense', {'out': 0}),
+                   ('decoder', 'cast', {'out': 1}),
+                   ('cast', 'sparse_to_dense', {'out': 0})
+                   ]
+        )
+
+    def replace_sub_graph(self, graph: Graph, match: dict):
+        # TODO: Once Inference Engine's CTCGreedyDecoder starts to support sequence length format like in TensorFlow,
+        # CTCGreedyDecoderReplacement2 needs to be removed and CTCGreedyDecoderReplacement, a more generic
+        # transformation, needs to be adopted for all cases
+        ctc_greedy_decoder = match['decoder']
+        cast = match['cast']
+        sparse_to_dense = match['sparse_to_dense']
+        sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
+
+        # disconnect SparseToDense and Cast nodes
+        sparse_to_dense.in_port(0).disconnect()
+        cast.in_port(0).disconnect()
+
+        # transform CTCGreedyDecoder output to TensorFlow's one:
+        # 1. squeeze the output to [N, T] shape
+        # 2. cast it to integer
+        squeeze_dec_seq = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])},
+                                                      {'name': sparse_to_dense_name})
+        squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0))
+        cast_to_int = Cast(graph, {'name': sparse_to_dense_name + '/CastToInt',
+                                   'dst_type': np.int32}).create_node()
+        cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0))
+
+        # preserve output name from original graph
+        rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'),
+                      (cast_to_int, sparse_to_dense_name)])
+
+        # set output of the new sub-graph as a source for SparseToDense consumer
+        sparse_to_dense.out_port(0).get_connection().set_source(cast_to_int.out_port(0))
+
+        # remove no longer needed nodes
+        graph.remove_nodes_from([sparse_to_dense.id, cast.id])
+
+        # mark CTCGreedyDecoder node as a node that requires transformation of sequence length to a mask format
+        # in the middle phase
+        ctc_greedy_decoder['use_mask_format'] = True
+
+        # unless the second input of CTCGreedyDecoder is a parameter, it enforces MO to use --static-shape
+        # to try getting the second input with a value
+        sequence_length_node = ctc_greedy_decoder.in_node(1)
+        if sequence_length_node.soft_get('op') != 'Parameter' and not graph.graph['cmd_params'].static_shape:
+            log.error(
+                "Model can not be translated in a reshape-able way.\n"
+                "Model Optimizer key static_shape was turned on to prevent related errors.\n"
+                "There will be no success changing input shapes of the model with the help of "
+                "InferenceEngine reshape method", extra={'is_warning': True})
+            graph.graph['cmd_params'].static_shape = True
+
+
+class CTCGreedyDecoderReplacement2(FrontReplacementSubgraph):
+    """
     The TF implementation of the CTCGreedyDecoder produces a tuple with two tensors. The first element in the tuple is
     the SparseTensor which is converted to a regular tensor with the SparseToDense operation. This replacer matches
     CTCGreedyDecoder and SparseToDense operations and removes the SparseToDense and Cast operation which is also used
     in the SparseToDense operation, because Inference Engine implementation of the CTCGreedyDecoder produces regular
     tensor as output.
-
-    The second input to the CTCGreedyDecoder in the TensorFlow is a 1D tensor with sequence lengths. In the Inference
-    Engine the second input to the CTCGreedyDecoder is a 2D tensor where the first element in each row is equal to 0
-    and all others are equal to 1. The length of the row is equal to the sequence length. The replacer modifies the
-    second input to be compatible with the Inference Engine CTCGreedyDecoder layer implementation.
+    Also, Inference Engine CTCGreedyDecoder requires a mask format for sequence lengths that is a different from
+    original one. Hence, this transformation changes a format of sequence length to a mask by replacing Fill and Pack
+    nodes with a special graph that produces a tensor of ones with shape [T, N] accepted by opset CTCGreedyDecoder.
     """
     enabled = True
 
+    def run_before(self):
+        return [Pack, FillToBroadcast]
+
     @staticmethod
     def pattern(**kwargs):
         return dict(
             nodes=[
+                ('transpose', dict(op='Transpose')),
+                ('shape', dict(op='ShapeOf')),
+                ('shape_1', dict(op='ShapeOf')),
+                ('strided_slice', dict(op='StridedSlice')),
+                ('stack', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
+                ('stack1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [2]))),
+                ('stack2', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
+                ('strided_slice_1', dict(op='StridedSlice')),
+                ('stack_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [0]))),
+                ('stack1_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
+                ('stack2_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
+                ('dims', dict(op='Pack')),
+                ('fill', dict(op='Fill')),
                 ('decoder', dict(op='CTCGreedyDecoder')),
                 ('cast', dict(op='Cast')),
                 ('sparse_to_dense', dict(op='SparseToDense')),
             ],
             edges=[
-                ('decoder', 'sparse_to_dense', {'out': 0}),
-                ('decoder', 'cast', {'out': 1}),
+                ('transpose', 'shape', {'out': 0}),
+                ('transpose', 'shape_1', {'out': 0}),
+                ('transpose', 'decoder', {'out': 0, 'in': 0}),
+                ('shape', 'strided_slice', {'out': 0, 'in': 0}),
+                ('stack', 'strided_slice', {'out': 0, 'in': 1}),
+                ('stack1', 'strided_slice', {'out': 0, 'in': 2}),
+                ('stack2', 'strided_slice', {'out': 0, 'in': 3}),
+                ('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
+                ('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
+                ('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
+                ('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
+                ('strided_slice', 'dims', {'out': 0, 'in': 0}),
+                ('dims', 'fill', {'out': 0, 'in': 0}),
+                ('strided_slice_1', 'fill', {'out': 0, 'in': 1}),
+                ('fill', 'decoder', {'out': 0, 'in': 1}),
+                ('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
+                ('decoder', 'cast', {'out': 1, 'in': 0}),
                 ('cast', 'sparse_to_dense', {'out': 0}),
             ]
         )
 
-    def nodes_to_remove(self, graph: Graph, match: dict):
-        return [match['cast'].id, match['sparse_to_dense']]
-
     def replace_sub_graph(self, graph: Graph, match: dict):
-        # TODO: it requires further refactoring and improvement to provide reshape-ability
-        decoder_node = match['decoder']
-        decoder_node_name = decoder_node.soft_get('name', decoder_node.id)
-        graph.remove_edge(decoder_node.id, match['sparse_to_dense'].id)
-        graph.remove_edge(decoder_node.id, match['cast'].id)
-        match['sparse_to_dense'].replace_node(decoder_node)
-
-        sequence_length_node = decoder_node.in_node(1)
-        if sequence_length_node.value is None:
-            raise Error('The second input to the CTCGreedyDecoder node "{}" is not constant. This case is not '
-                        'supported with the Inference Engine.'.format(decoder_node_name))
-
-        # the batch size is the dimension with index 1 for the layer CTCGreedyDecoder
-        mask_value = np.ones([decoder_node.in_node(0).shape[1], sequence_length_node.value[0]])
-        mask_value[:, 0] = 0
-        mask_value = np.transpose(mask_value)
-        mask_node = Const(graph, {'name': decoder_node_name + '/Mask',
-                                  'value': mask_value}).create_node()
-        decoder_node.in_port(1).disconnect()
-        decoder_node.in_port(1).connect(mask_node.out_port(0))
-
-        return {}
+        # obtain references to necessary nodes and their names
+        fill = match['fill']
+        dims = match['dims']
+        strided_slice = match['strided_slice']
+        strided_slice_1 = match['strided_slice_1']
+        ctc_greedy_decoder = match['decoder']
+        cast = match['cast']
+        sparse_to_dense = match['sparse_to_dense']
+        strided_slice_name = strided_slice.soft_get('name', strided_slice.id)
+        strided_slice_1_name = strided_slice_1.soft_get('name', strided_slice_1.id)
+        ctc_greedy_decoder_name = ctc_greedy_decoder.soft_get('name', ctc_greedy_decoder.id)
+        sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
+
+        # unsqueeze scalar values with batch size and time dimension
+        unsqueeze_batch_size = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
+                                                           {'name': strided_slice_name + '/Unsqueeze'})
+        dims.in_port(0).get_connection().set_destination(unsqueeze_batch_size.in_port(0))
+        unsqueeze_time_size = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
+                                                           {'name': strided_slice_1_name + '/Unsqueeze'})
+        fill.in_port(1).get_connection().set_destination(unsqueeze_time_size.in_port(0))
+
+        # compute a sequence mask shape [T, N] required for CTCGreedyDecoder
+        seq_mask_shape = Concat(graph, {'axis': 0, 'in_ports_count': 2,
+                                        'name': ctc_greedy_decoder_name + '/SequenceMaskShape'}).create_node()
+        seq_mask_shape.in_port(0).connect(unsqueeze_time_size.out_port(0))
+        seq_mask_shape.in_port(1).connect(unsqueeze_batch_size.out_port(0))
+
+        # compute a sequence mask
+        sequence_mask = create_op_with_const_inputs(graph, Broadcast, {0: np.array([1.0], dtype=np.float)},
+                                                    {'mode': 'numpy',
+                                                     'name': ctc_greedy_decoder_name + '/SequenceMask'})
+        sequence_mask.in_port(1).connect(seq_mask_shape.out_port(0))
+
+        # create CTCGreedyDecoder with the sequence mask instead of sequence length
+        ctc_greedy_decoder.in_port(1).disconnect()
+        ctc_greedy_decoder.in_port(1).connect(sequence_mask.out_port(0))
+
+        # remove fill and pack nodes since they are now in unconnected component
+        graph.remove_nodes_from([fill.id, dims.id])
+
+        # transform opset CTCGreedyDecoder output to TensorFlow's one that has a shape [N, T]
+        # opset CTCGreedyDecoder has an output with a shape [N, T, 1, 1]
+        squeeze_dec_seq = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])},
+                                                      {'name': sparse_to_dense_name})
+        squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0))
+        cast_to_int = Cast(graph, {'name': sparse_to_dense_name + '/CastToInt',
+                                   'dst_type': np.int32}).create_node()
+        cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0))
+
+        # preserve output name from original graph
+        rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'),
+                      (cast_to_int, sparse_to_dense_name)])
+
+        # set output of the new sub-graph as a source for SparseToDense consumer
+        sparse_to_dense.out_port(0).get_connection().set_source(cast_to_int.out_port(0))
+
+        # cleanup a graph
+        graph.remove_nodes_from([cast.id, sparse_to_dense.id])
diff --git a/model-optimizer/extensions/front/tf/CTCGreedyDecoderReplacement_test.py b/model-optimizer/extensions/front/tf/CTCGreedyDecoderReplacement_test.py
new file mode 100644 (file)
index 0000000..8fe482d
--- /dev/null
@@ -0,0 +1,165 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+import unittest
+
+from extensions.front.tf.CTCGreedyDecoderReplacement import CTCGreedyDecoderReplacement, CTCGreedyDecoderReplacement2
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, const
+
+
+class CTCGreedyDecoderReplacementTests(unittest.TestCase):
+    def test1(self):
+        nodes_attributes = {
+            # nodes from original graph
+            'logits': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+            'seq_len': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+            'decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
+            'cast': {'kind': 'op', 'op': 'Cast'},
+            'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
+            'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
+
+            # new nodes
+            'new_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder', 'use_mask_format': True},
+            **const('squeeze_axes', int64_array([2, 3])),
+            'squeeze_dec_seq': {'kind': 'op', 'op': 'Squeeze'},
+            'cast_to_int': {'kind': 'op', 'op': 'Cast'},
+        }
+
+        graph = build_graph(nodes_attributes,
+                            [('logits', 'decoder', {'out': 0, 'in': 0}),
+                             ('seq_len', 'decoder', {'out': 0, 'in': 1}),
+                             ('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
+                             ('decoder', 'cast', {'out': 1, 'in': 0}),
+                             ('cast', 'sparse_to_dense', {'out': 0}),
+                             ('sparse_to_dense', 'last', {'out': 0, 'in': 0}),
+                             ], nodes_with_edges_only=True)
+        graph.stage = 'front'
+        CTCGreedyDecoderReplacement().find_and_replace_pattern(graph)
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('logits', 'decoder', {'out': 0, 'in': 0}),
+                                 ('seq_len', 'decoder', {'out': 0, 'in': 1}),
+                                 ('decoder', 'squeeze_dec_seq', {'out': 0, 'in': 0}),
+                                 ('squeeze_axes', 'squeeze_dec_seq', {'out': 0, 'in': 1}),
+                                 ('squeeze_dec_seq', 'cast_to_int', {'out': 0, 'in': 0}),
+                                 ('cast_to_int', 'last', {'out': 0, 'in': 0}),
+                                 ],
+                                nodes_with_edges_only=True)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
+        self.assertEqual(len(graph.get_op_nodes(op='Cast')) == 1 and
+                         graph.get_op_nodes(op='Cast')[0]['name'] == 'sparse_to_dense', True,
+                         'Name is not inherited from original node for CTCGreedyDecoderReplacement')
+        self.assertTrue(flag, resp)
+
+    def test2(self):
+        nodes_attributes = {
+            # nodes from original graph
+            'logits': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+            'transpose': {'kind': 'op', 'op': 'Transpose'},
+            'shape': {'kind': 'op', 'op': 'ShapeOf'},
+            'shape_1': {'kind': 'op', 'op': 'ShapeOf'},
+            'strided_slice': {'kind': 'op', 'op': 'StridedSlice'},
+            **const('stack', int64_array([1])),
+            **const('stack1', int64_array([2])),
+            **const('stack2', int64_array([1])),
+            'strided_slice_1': {'kind': 'op', 'op': 'StridedSlice'},
+            **const('stack_1', int64_array([0])),
+            **const('stack1_1', int64_array([1])),
+            **const('stack2_1', int64_array([1])),
+            'dims': {'kind': 'op', 'op': 'Pack'},
+            'fill': {'kind': 'op', 'op': 'Fill'},
+            'decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
+            'cast': {'kind': 'op', 'op': 'Cast'},
+            'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
+
+            # new nodes
+            **const('unsqueeze_batch_size_axis', int64_array(0)),
+            'unsqueeze_batch_size': {'kind': 'op', 'op': 'Unsqueeze'},
+            **const('unsqueeze_time_size_axis', int64_array(0)),
+            'unsqueeze_time_size': {'kind': 'op', 'op': 'Unsqueeze'},
+            'seq_mask_shape': {'kind': 'op', 'op': 'Concat'},
+            'sequence_mask': {'kind': 'op', 'op': 'Broadcast'},
+            **const('one', np.array([1.0], dtype=np.float)),
+            **const('squeeze_axes', int64_array([2, 3])),
+            'squeeze_dec_seq': {'kind': 'op', 'op': 'Squeeze'},
+            'cast_to_int': {'kind': 'op', 'op': 'Cast'},
+
+            'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
+        }
+
+        graph = build_graph(nodes_attributes,
+                            [('logits', 'transpose', {'out': 0}),
+                             ('transpose', 'shape', {'out': 0}),
+                             ('transpose', 'shape_1', {'out': 0}),
+                             ('transpose', 'decoder', {'out': 0, 'in': 0}),
+                             ('shape', 'strided_slice', {'out': 0, 'in': 0}),
+                             ('stack', 'strided_slice', {'out': 0, 'in': 1}),
+                             ('stack1', 'strided_slice', {'out': 0, 'in': 2}),
+                             ('stack2', 'strided_slice', {'out': 0, 'in': 3}),
+                             ('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
+                             ('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
+                             ('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
+                             ('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
+                             ('strided_slice', 'dims', {'out': 0, 'in': 0}),
+                             ('dims', 'fill', {'out': 0, 'in': 0}),
+                             ('strided_slice_1', 'fill', {'out': 0, 'in': 1}),
+                             ('fill', 'decoder', {'out': 0, 'in': 1}),
+                             ('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
+                             ('decoder', 'cast', {'out': 1, 'in': 0}),
+                             ('cast', 'sparse_to_dense', {'out': 0}),
+                             ('sparse_to_dense', 'last', {'out': 0, 'in': 0}),
+                             ], nodes_with_edges_only=True)
+        graph.stage = 'front'
+        CTCGreedyDecoderReplacement2().find_and_replace_pattern(graph)
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('logits', 'transpose', {'out': 0}),
+                                 ('transpose', 'shape', {'out': 0}),
+                                 ('transpose', 'shape_1', {'out': 0}),
+                                 ('transpose', 'decoder', {'out': 0, 'in': 0}),
+                                 ('shape', 'strided_slice', {'out': 0, 'in': 0}),
+                                 ('stack', 'strided_slice', {'out': 0, 'in': 1}),
+                                 ('stack1', 'strided_slice', {'out': 0, 'in': 2}),
+                                 ('stack2', 'strided_slice', {'out': 0, 'in': 3}),
+                                 ('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
+                                 ('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
+                                 ('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
+                                 ('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
+                                 ('strided_slice', 'unsqueeze_batch_size', {'out': 0, 'in': 0}),
+                                 ('unsqueeze_batch_size_axis', 'unsqueeze_batch_size', {'out': 0, 'in': 1}),
+                                 ('strided_slice_1', 'unsqueeze_time_size', {'out': 0, 'in': 0}),
+                                 ('unsqueeze_time_size_axis', 'unsqueeze_time_size', {'out': 0, 'in': 1}),
+                                 ('unsqueeze_batch_size', 'seq_mask_shape', {'out': 0, 'in': 1}),
+                                 ('unsqueeze_time_size', 'seq_mask_shape', {'out': 0, 'in': 0}),
+                                 ('one', 'sequence_mask', {'out': 0, 'in': 0}),
+                                 ('seq_mask_shape', 'sequence_mask', {'out': 0, 'in': 1}),
+                                 ('sequence_mask', 'decoder', {'out': 0, 'in': 1}),
+                                 ('decoder', 'squeeze_dec_seq', {'out': 0, 'in': 0}),
+                                 ('squeeze_axes', 'squeeze_dec_seq', {'out': 0, 'in': 1}),
+                                 ('squeeze_dec_seq', 'cast_to_int', {'out': 0, 'in': 0}),
+                                 ('cast_to_int', 'last', {'out': 0, 'in': 0}),
+                                 ],
+                                nodes_with_edges_only=True)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
+        self.assertEqual(len(graph.get_op_nodes(op='Cast')) == 1 and
+                         graph.get_op_nodes(op='Cast')[0]['name'] == 'sparse_to_dense', True,
+                         'Name is not inherited from original node for CTCGreedyDecoderReplacement2')
+        self.assertTrue(flag, resp)
diff --git a/model-optimizer/extensions/middle/SequenceLengthToMask.py b/model-optimizer/extensions/middle/SequenceLengthToMask.py
new file mode 100644 (file)
index 0000000..65b156d
--- /dev/null
@@ -0,0 +1,63 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+
+from mo.graph.graph import Graph
+from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.const import Const
+from mo.utils.error import Error
+
+
+class SequenceLengthToMask(MiddleReplacementPattern):
+    """
+    Convert a sequence length to a sequence mask for CTCGreedyDecoder if its value is available.
+    """
+    enabled = True
+
+    def run_before(self):
+        from extensions.middle.pass_separator import MiddleFinish
+        return [MiddleFinish]
+
+    def find_and_replace_pattern(self, graph: Graph):
+        for ctc_greedy_decoder in graph.get_op_nodes(op='CTCGreedyDecoder', use_mask_format=True):
+            ctc_greedy_decoder_name = ctc_greedy_decoder.soft_get('name', ctc_greedy_decoder.id)
+
+            sequence_length_value = ctc_greedy_decoder.in_port(1).data.get_value()
+            if sequence_length_value is None:
+                raise Error('The second input to the CTCGreedyDecoder node "{}" is not constant. This case is not '
+                            'supported with the Inference Engine.'.format(ctc_greedy_decoder_name))
+
+            # transform a sequence length to a sequence mask
+            logits_shape = ctc_greedy_decoder.in_port(0).data.get_shape()
+            assert logits_shape is not None and len(logits_shape) == 3, \
+                "Incorrect shape for logits input of {} node".format(ctc_greedy_decoder_name)
+            batch_size = logits_shape[1]
+            time_size = logits_shape[0]
+            mask_value = np.zeros([batch_size, time_size], dtype=np.float)
+            for sample_ind, sample_seq_length in enumerate(sequence_length_value):
+                mask_value[sample_ind, 0:sample_seq_length] = 1
+            mask_value = np.transpose(mask_value)
+
+            # create Const node with computed mask value
+            mask_node = Const(graph, {'name': ctc_greedy_decoder_name + '/Mask',
+                                      'value': mask_value}).create_node()
+
+            # connect computed mask to CTCGreedyDecoder node
+            ctc_greedy_decoder.in_port(1).get_connection().set_source(mask_node.out_port(0))
+
+            # remove attribute-marker
+            del ctc_greedy_decoder['use_mask_format']
diff --git a/model-optimizer/extensions/middle/SequenceLenthToMask_test.py b/model-optimizer/extensions/middle/SequenceLenthToMask_test.py
new file mode 100644 (file)
index 0000000..5ff0d7d
--- /dev/null
@@ -0,0 +1,67 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+
+from extensions.middle.SequenceLengthToMask import SequenceLengthToMask
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph
+
+
+nodes_attributes = {'logits': {'shape': int64_array([5, 3, 30]), 'type': 'Parameter', 'kind': 'op',
+                               'op': 'Parameter'},
+                    'logits_data': {'value': None, 'shape': int64_array([5, 3, 30]), 'kind': 'data'},
+                    'seq_length': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': int64_array([5, 2, 3])},
+                    'seq_length_data': {'value': int64_array([5, 2, 3]), 'kind': 'data'},
+                    'ctc_greedy_decoder': {'type': None, 'kind': 'op', 'op': 'CTCGreedyDecoder',
+                                           'use_mask_format': True},
+                    'ctc_greedy_decoder_data': {'value': None, 'shape': None, 'kind': 'data'},
+                    'last': {'kind': 'op', 'op': 'Result'},
+
+                    # new nodes
+                    'seq_mask': {'type': 'Const', 'kind': 'op', 'op': 'Const',
+                                 'value': np.transpose(np.array([[1, 1, 1, 1, 1],
+                                                                 [1, 1, 0, 0, 0],
+                                                                 [1, 1, 1, 0, 0]], dtype=np.float))},
+                    'seq_mask_data': {'value': None, 'kind': 'data'},
+                    'new_ctc_greedy_decoder': {'type': None, 'kind': 'op', 'op': 'CTCGreedyDecoder'},
+                    }
+
+class ScaleInputTests(unittest.TestCase):
+    def test1(self):
+        graph = build_graph(nodes_attributes,
+                            [('logits', 'logits_data'),
+                             ('logits_data', 'ctc_greedy_decoder'),
+                             ('seq_length', 'seq_length_data'),
+                             ('seq_length_data', 'ctc_greedy_decoder'),
+                             ('ctc_greedy_decoder', 'ctc_greedy_decoder_data'),
+                             ('ctc_greedy_decoder_data', 'last')],
+                            nodes_with_edges_only=True)
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('logits', 'logits_data'),
+                                 ('logits_data', 'new_ctc_greedy_decoder'),
+                                 ('seq_mask', 'seq_mask_data'),
+                                 ('seq_mask_data', 'new_ctc_greedy_decoder'),
+                                 ('new_ctc_greedy_decoder', 'ctc_greedy_decoder_data'),
+                                 ('ctc_greedy_decoder_data', 'last')],
+                                nodes_with_edges_only=True)
+        SequenceLengthToMask().find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, graph_ref, 'last')
+        self.assertTrue(flag, resp)
index fbbb252..74f523c 100644 (file)
@@ -52,12 +52,22 @@ class CTCGreedyDecoderOp(Op):
         sequence_mask_shape = node.in_port(1).data.get_shape()
 
         # check shapes of input tensors
-        assert len(logits_shape) == 3 and len(sequence_mask_shape) == 2, \
-            'Incorrect rank of some input tensor for {} node'.format(node_name)
-        assert logits_shape[1] == sequence_mask_shape[1], \
-            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
-        assert logits_shape[0] == sequence_mask_shape[0], \
-            'Time dimensions of input tensors must be the same for {} node'.format(node_name)
+        assert len(logits_shape) == 3, \
+            'Incorrect rank of logits for {} node'.format(node_name)
+        if node.has_valid('use_mask_format') and node.use_mask_format is True:
+            # it is a case when CTCGreedyDecoder still uses an original format for sequence_length
+            assert len(sequence_mask_shape) == 1, \
+                'Incorrect rank of sequence length tensor for {} node'.format(node_name)
+            assert logits_shape[1] == sequence_mask_shape[0], \
+                'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
+        else:
+            # it is a case when CTCGreedyDecoder uses a sequence mask
+            assert len(sequence_mask_shape) == 2, \
+                'Incorrect rank of sequence length tensor for {} node'.format(node_name)
+            assert logits_shape[1] == sequence_mask_shape[1], \
+                'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
+            assert logits_shape[0] == sequence_mask_shape[0], \
+                'Time dimensions of input tensors must be the same for {} node'.format(node_name)
 
         batch_size = logits_shape[1]
         time_size = logits_shape[0]