Extend MO for operation CTCLoss and partly refactor CTCGreedyDecoder (#588)
authorRoman Kazantsev <roman.kazantsev@intel.com>
Mon, 17 Aug 2020 16:19:59 +0000 (19:19 +0300)
committerGitHub <noreply@github.com>
Mon, 17 Aug 2020 16:19:59 +0000 (19:19 +0300)
* Extend MO for operation CTCLoss

* Change sequence length format to a mask format

* Add fixes after first-round review

* Add fixes after the second-round review

* Fixing CTCLossPlusCTCGreedyDecoder transformation

12 files changed:
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/front/caffe/ctcgreedydecoder_ext_test.py
model-optimizer/extensions/front/tf/CTCGreedyDecoder.py [deleted file]
model-optimizer/extensions/front/tf/CTCGreedyDecoderReplacement.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/CTCLossReplacement.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/CTCLossReplacement_test.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/CTCLoss_ext.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/sparse_to_dense_replacer.py
model-optimizer/extensions/ops/ctc_greedy_decoder.py
model-optimizer/extensions/ops/ctc_greedy_decoder_test.py
model-optimizer/extensions/ops/ctc_loss.py [new file with mode: 0644]
model-optimizer/extensions/ops/ctc_loss_test.py [new file with mode: 0644]

index 8b884cb90028b01382d5b4ffb3c6a8209f538689..b756cb5ed911f181d13325c62fe8cdc12ee0ef24 100644 (file)
@@ -351,8 +351,10 @@ extensions/front/tf/const_ext.py
 extensions/front/tf/conv_ext.py
 extensions/front/tf/crop_and_resize_ext.py
 extensions/front/tf/CropAndResizeReplacement.py
-extensions/front/tf/CTCGreedyDecoder.py
 extensions/front/tf/CTCGreedyDecoder_ext.py
+extensions/front/tf/CTCGreedyDecoderReplacement.py
+extensions/front/tf/CTCLoss_ext.py
+extensions/front/tf/CTCLossReplacement.py
 extensions/front/tf/cumsum_ext.py
 extensions/front/tf/deconv_ext.py
 extensions/front/tf/depth_to_space.py
@@ -593,6 +595,7 @@ extensions/ops/constant_fill.py
 extensions/ops/copyop.py
 extensions/ops/correlation.py
 extensions/ops/ctc_greedy_decoder.py
+extensions/ops/ctc_loss.py
 extensions/ops/cumsum.py
 extensions/ops/data_augmentation.py
 extensions/ops/depth_to_space.py
index 25bf849eef3df868948beabd8ad13949688aa6ff..0193e974fa51257fecea55db69044e8a28ea5e15 100644 (file)
@@ -54,7 +54,7 @@ class TestCTCGreedyDecoderExt(unittest.TestCase):
         exp_res = {
             'type': "CTCGreedyDecoder",
             'ctc_merge_repeated': 1,
-            'infer': CTCGreedyDecoderOp.ctc_greedy_decoder_infer
+            'infer': CTCGreedyDecoderOp.infer
         }
 
         for key in exp_res.keys():
diff --git a/model-optimizer/extensions/front/tf/CTCGreedyDecoder.py b/model-optimizer/extensions/front/tf/CTCGreedyDecoder.py
deleted file mode 100644 (file)
index 8b75218..0000000
+++ /dev/null
@@ -1,82 +0,0 @@
-"""
- Copyright (C) 2018-2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-      http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-import numpy as np
-
-from mo.front.common.partial_infer.utils import int64_array
-from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import Node, Graph
-from mo.utils.error import Error
-
-
-class CTCGreedyDecoderReplacement(FrontReplacementSubgraph):
-    """
-    The TF implementation of the CTCGreedyDecoder produces a tuple with two tensors. The first element in the tuple is
-    the SparseTensor which is converted to a regular tensor with the SparseToDense operation. This replacer matches
-    CTCGreedyDecoder and SparseToDense operations and removes the SparseToDense and Cast operation which is also used
-    in the SparseToDense operation, because Inference Engine implementation of the CTCGreedyDecoder produces regular
-    tensor as output.
-
-    The second input to the CTCGreedyDecoder in the TensorFlow is a 1D tensor with sequence lengths. In the Inference
-    Engine the second input to the CTCGreedyDecoder is a 2D tensor where the first element in each row is equal to 0
-    and all others are equal to 1. The length of the row is equal to the sequence length. The replacer modifies the
-    second input to be compatible with the Inference Engine CTCGreedyDecoder layer implementation.
-    """
-    enabled = True
-
-    @staticmethod
-    def pattern(**kwargs):
-        return dict(
-            nodes=[
-                ('decoder', dict(op='CTCGreedyDecoder')),
-                ('cast', dict(op='Cast')),
-                ('sparse_to_dense', dict(op='SparseToDense')),
-            ],
-            edges=[
-                ('decoder', 'sparse_to_dense', {'out': 0}),
-                ('decoder', 'cast', {'out': 1}),
-                ('cast', 'sparse_to_dense', {'out': 0}),
-            ]
-        )
-
-    def nodes_to_remove(self, graph: Graph, match: dict):
-        return [match['cast'].id, match['sparse_to_dense']]
-
-    def replace_sub_graph(self, graph: Graph, match: dict):
-        decoder_node = match['decoder']
-        graph.remove_edge(decoder_node.id, match['sparse_to_dense'].id)
-        graph.remove_edge(decoder_node.id, match['cast'].id)
-        match['sparse_to_dense'].replace_node(decoder_node)
-
-        # update the TensorFlow infer function for the CTCGreedyDecoder to make necessary changes with the second input
-        decoder_node['old_infer'] = decoder_node.infer
-        decoder_node.infer = __class__.tf_greedy_decoder_infer
-        return {}
-
-    @staticmethod
-    def tf_greedy_decoder_infer(node: Node):
-        sequence_length_node = node.in_node(1)
-        if sequence_length_node.value is None:
-            raise Error('The second input to the CTCGreedyDecoder node "{}" is not constant. This case is not '
-                        'supported with the Inference Engine.'.format(node.soft_get('name')))
-        # the batch size is the dimension with index 1 for the layer CTCGreedyDecoder
-        new_value = np.ones([node.in_node(0).shape[1], sequence_length_node.value[0]])
-        new_value[:, 0] = 0
-        new_value = np.transpose(new_value)
-        sequence_length_node.value = new_value
-        sequence_length_node.shape = int64_array(sequence_length_node.value.shape)
-
-        node.old_infer(node)
diff --git a/model-optimizer/extensions/front/tf/CTCGreedyDecoderReplacement.py b/model-optimizer/extensions/front/tf/CTCGreedyDecoderReplacement.py
new file mode 100644 (file)
index 0000000..2086c1c
--- /dev/null
@@ -0,0 +1,80 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+
+from mo.front.common.replacement import FrontReplacementSubgraph
+from mo.graph.graph import Graph
+from mo.ops.const import Const
+from mo.utils.error import Error
+
+
+class CTCGreedyDecoderReplacement(FrontReplacementSubgraph):
+    """
+    The TF implementation of the CTCGreedyDecoder produces a tuple with two tensors. The first element in the tuple is
+    the SparseTensor which is converted to a regular tensor with the SparseToDense operation. This replacer matches
+    CTCGreedyDecoder and SparseToDense operations and removes the SparseToDense and Cast operation which is also used
+    in the SparseToDense operation, because Inference Engine implementation of the CTCGreedyDecoder produces regular
+    tensor as output.
+
+    The second input to the CTCGreedyDecoder in the TensorFlow is a 1D tensor with sequence lengths. In the Inference
+    Engine the second input to the CTCGreedyDecoder is a 2D tensor where the first element in each row is equal to 0
+    and all others are equal to 1. The length of the row is equal to the sequence length. The replacer modifies the
+    second input to be compatible with the Inference Engine CTCGreedyDecoder layer implementation.
+    """
+    enabled = True
+
+    @staticmethod
+    def pattern(**kwargs):
+        return dict(
+            nodes=[
+                ('decoder', dict(op='CTCGreedyDecoder')),
+                ('cast', dict(op='Cast')),
+                ('sparse_to_dense', dict(op='SparseToDense')),
+            ],
+            edges=[
+                ('decoder', 'sparse_to_dense', {'out': 0}),
+                ('decoder', 'cast', {'out': 1}),
+                ('cast', 'sparse_to_dense', {'out': 0}),
+            ]
+        )
+
+    def nodes_to_remove(self, graph: Graph, match: dict):
+        return [match['cast'].id, match['sparse_to_dense']]
+
+    def replace_sub_graph(self, graph: Graph, match: dict):
+        # TODO: it requires further refactoring and improvement to provide reshape-ability
+        decoder_node = match['decoder']
+        decoder_node_name = decoder_node.soft_get('name', decoder_node.id)
+        graph.remove_edge(decoder_node.id, match['sparse_to_dense'].id)
+        graph.remove_edge(decoder_node.id, match['cast'].id)
+        match['sparse_to_dense'].replace_node(decoder_node)
+
+        sequence_length_node = decoder_node.in_node(1)
+        if sequence_length_node.value is None:
+            raise Error('The second input to the CTCGreedyDecoder node "{}" is not constant. This case is not '
+                        'supported with the Inference Engine.'.format(decoder_node_name))
+
+        # the batch size is the dimension with index 1 for the layer CTCGreedyDecoder
+        mask_value = np.ones([decoder_node.in_node(0).shape[1], sequence_length_node.value[0]])
+        mask_value[:, 0] = 0
+        mask_value = np.transpose(mask_value)
+        mask_node = Const(graph, {'name': decoder_node_name + '/Mask',
+                                  'value': mask_value}).create_node()
+        decoder_node.in_port(1).disconnect()
+        decoder_node.in_port(1).connect(mask_node.out_port(0))
+
+        return {}
diff --git a/model-optimizer/extensions/front/tf/CTCLossReplacement.py b/model-optimizer/extensions/front/tf/CTCLossReplacement.py
new file mode 100644 (file)
index 0000000..18670e3
--- /dev/null
@@ -0,0 +1,186 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+import logging as log
+
+from extensions.ops.Cast import Cast
+from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
+from extensions.ops.ctc_loss import CTCLoss
+from extensions.ops.elementwise import Equal
+from extensions.ops.parameter import Parameter
+from extensions.ops.ReduceOps import ReduceSum
+from extensions.ops.select import Select
+from extensions.ops.transpose import Transpose
+from mo.front.common.partial_infer.utils import int64_array
+from mo.front.common.replacement import FrontReplacementSubgraph
+from mo.front.tf.graph_utils import create_op_with_const_inputs
+from mo.graph.graph import Graph, rename_nodes
+from mo.middle.passes.convert_data_type import data_type_str_to_np
+from mo.ops.broadcast import Broadcast
+from mo.ops.shape import Shape
+from mo.ops.squeeze import Squeeze
+from mo.utils.error import Error
+
+
+class CTCLossReplacement(FrontReplacementSubgraph):
+    """
+    The CTCLoss appears along with CTCGreedyDecoder operation in particular. Since the TensorFlow* CTCGreedyDecoder
+    outputs sparse tensor format, the OpenVINO CTCGreedyDecoder has a different format and the CTCLoss is also affected
+    in terms of different format for its inputs. So the corresponding sub-graph with CTCGreedyDecoding and CTCLoss
+    must be transformed properly.
+    Also, the transformation changes the input sequence length format into a mask format. For example, 1D tensor of
+    sequence lengths equal to [4 2] is coded as 2D tensor [[1 1 1 1 0], [1 1 0 0 0]] with a time dimension is
+    equal to 5.
+    """
+    enabled = True
+
+    def run_before(self):
+        from extensions.front.tf.CTCGreedyDecoderReplacement import CTCGreedyDecoderReplacement
+        return [CTCGreedyDecoderReplacement]
+
+    def pattern(self):
+        return dict(
+            nodes=[
+                ('seq_len', dict(op='Parameter')),
+                ('transpose', dict(op='Transpose')),
+                ('ctc_greedy_decoder', dict(op='CTCGreedyDecoder')),
+                ('cast', dict(op='Cast')),
+                ('sparse_to_dense', dict(op='SparseToDense')),
+                ('const', dict(op='Const')),
+                ('ctc_loss', dict(op='CTCLoss')),
+            ],
+            edges=[
+                ('seq_len', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
+                ('seq_len', 'ctc_loss', {'out': 0, 'in': 3}),
+                ('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
+                ('transpose', 'ctc_loss', {'out': 0, 'in': 0}),
+                ('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
+                ('ctc_greedy_decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
+                ('ctc_greedy_decoder', 'sparse_to_dense', {'out': 1, 'in': 2}),
+                ('const', 'sparse_to_dense', {'out': 0, 'in': 3}),
+                ('ctc_greedy_decoder', 'cast', {'out': 1, 'in': 0}),
+                ('ctc_greedy_decoder', 'ctc_loss', {'out': 0, 'in': 1}),
+                ('cast', 'ctc_loss', {'out': 0, 'in': 2})
+            ])
+
+    def replace_sub_graph(self, graph: Graph, match: dict):
+        seq_len_tf = match['seq_len']
+        transpose_tf = match['transpose']
+        ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
+        cast_tf = match['cast']
+        ctc_loss_tf = match['ctc_loss']
+        sparse_to_dense_tf = match['sparse_to_dense']
+
+        output_sparse_to_dense_name = sparse_to_dense_tf.soft_get('name', sparse_to_dense_tf.id)
+        output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
+        ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)
+
+        log.debug('Found CTCLossFrontReplacer pattern after {} with name {}'.format(ctc_greedy_decoder_tf.op,
+                                                                                    ctc_greedy_decoder_tf.name))
+
+        # create sequence mask node, sub-graph for transforming into sequence length and connect with consumers
+        seq_len_tf_shape = seq_len_tf.soft_get('shape', None)
+        if seq_len_tf_shape is None or len(seq_len_tf_shape) != 2:
+            raise Error('The sequence length that is the second input to the CTCGreedyDecoder node "{}"'
+                        ' must be specified in a mask format.'.format(ctc_greedy_decoder_tf_name))
+        log.error('The format of input sequence length has been changed to a mask format', extra={'is_warning': True})
+        seq_len_tf_type = seq_len_tf.soft_get('data_type', None)
+        seq_len_tf_name = seq_len_tf.soft_get('name', seq_len_tf.id)
+        seq_mask_placeholder = Parameter(graph, {'name': seq_len_tf_name, 'shape': seq_len_tf_shape,
+                                                 'data_type': seq_len_tf_type}).create_node()
+        reduce_to_seq_len_node = create_op_with_const_inputs(graph, ReduceSum, {1: np.array(1, dtype=np.int32)},
+                                                             {'name': seq_len_tf_name + '/ReduceToSeqLen',
+                                                              'keep_dims': False})
+        reduce_to_seq_len_node.in_port(0).connect(seq_mask_placeholder.out_port(0))
+        seq_len_tf.out_port(0).get_connection().set_source(reduce_to_seq_len_node.out_port(0))
+
+        cast_fp_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
+        casted_seq_mask_node = Cast(graph, {'name': seq_len_tf_name + '/CastToFP32', 'dst_type': cast_fp_type}).create_node()
+        casted_seq_mask_node.in_port(0).connect(seq_mask_placeholder.out_port(0))
+        permuted_casted_seq_mask = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])},
+                                                               {'name': seq_len_tf_name + '/Permute'})
+        permuted_casted_seq_mask.in_port(0).connect(casted_seq_mask_node.out_port(0))
+        rename_nodes([(seq_len_tf, seq_len_tf_name + '/AbandonedName'), (seq_mask_placeholder, seq_len_tf_name)])
+
+        # create CTCGreedyDecoder node and set mask node
+        ctc_merge_repeated_i = ctc_greedy_decoder_tf.soft_get('ctc_merge_repeated', ctc_greedy_decoder_tf.id)
+        ctc_greedy_decoder = CTCGreedyDecoderOp(graph, {'name': output_sparse_to_dense_name,
+                                                        'ctc_merge_repeated': ctc_merge_repeated_i}).create_node()
+        ctc_greedy_decoder.in_port(1).connect(permuted_casted_seq_mask.out_port(0))
+        rename_nodes([(sparse_to_dense_tf, output_sparse_to_dense_name + '/AbandonedName'),
+                      (ctc_greedy_decoder, output_sparse_to_dense_name)])
+
+        # create CTCLoss node and set attributes
+        assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
+            'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
+        assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
+            'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name)
+        assert ctc_loss_tf.has_valid('unique'), \
+            'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name)
+        preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated
+        ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated
+        unique = ctc_loss_tf.unique
+        ctc_loss = CTCLoss(graph, {'name': output_ctc_loss_name,
+                                   'preprocess_collapse_repeated': preprocess_collapse_repeated,
+                                   'ctc_merge_repeated': ctc_merge_repeated,
+                                   'unique': unique}).create_node()
+        rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'), (ctc_loss, output_ctc_loss_name)])
+
+        # connect logits
+        ctc_greedy_decoder_tf.in_port(0).get_connection().set_destination(ctc_greedy_decoder.in_port(0))
+        ctc_loss.in_port(0).disconnect()
+        transpose_tf.in_port(0).get_connection().add_destination(ctc_loss.in_port(0))
+
+        # connect logit lengths
+        ctc_greedy_decoder_tf.in_port(1).disconnect()
+        ctc_loss.in_port(1).connect(reduce_to_seq_len_node.out_port(0))
+
+        # connect labels to ctc_loss
+        squeeze_op = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])})
+        cast_labels_op = Cast(graph, {'name': output_sparse_to_dense_name + '/CastLabels', 'dst_type': np.int32}).create_node()
+        squeeze_op.in_port(0).connect(ctc_greedy_decoder.out_port(0))
+        cast_labels_op.in_port(0).connect(squeeze_op.out_port(0))
+        ctc_loss.in_port(2).connect(cast_labels_op.out_port(0))
+
+        # connect label lengths
+        equal_op = create_op_with_const_inputs(graph, Equal, {1: np.array([-1], dtype=np.int32)},
+                                               {'name': output_sparse_to_dense_name + '/Equal'})
+        equal_op.in_port(0).connect(cast_labels_op.out_port(0))
+        labels_shape_op = Shape(graph, {'name': output_sparse_to_dense_name + '/ShapeOf'}).create_node()
+        labels_shape_op.in_port(0).connect(equal_op.out_port(0))
+        broadcast_one = create_op_with_const_inputs(graph, Broadcast, {0: np.array([1], dtype=np.int32)},
+                                                    {'mode': 'numpy',
+                                                     'name': output_sparse_to_dense_name + '/One'})
+        broadcast_one.in_port(1).connect(labels_shape_op.out_port(0))
+        broadcast_zero = create_op_with_const_inputs(graph, Broadcast, {0: np.array([0], dtype=np.int32)},
+                                                     {'mode': 'numpy',
+                                                      'name': output_sparse_to_dense_name + '/Zero'})
+        broadcast_zero.in_port(1).connect(labels_shape_op.out_port(0))
+
+        select_node = Select(graph, {'name': output_sparse_to_dense_name + '/Select'}).create_node()
+        select_node.in_port(0).connect(equal_op.out_port(0))
+        select_node.in_port(1).connect(broadcast_zero.out_port(0))
+        select_node.in_port(2).connect(broadcast_one.out_port(0))
+        label_length_node = create_op_with_const_inputs(graph, ReduceSum, {1: int64_array([1])},
+                                                      op_attrs={'name': output_sparse_to_dense_name + '/LabelLength',
+                                                                'keep_dims': False})
+        label_length_node.in_port(0).connect(select_node.out_port(0))
+        ctc_loss.in_port(3).connect(label_length_node.out_port(0))
+
+        # set source for output of new sub-graph and remove old nodes
+        ctc_loss_tf.out_port(0).get_connection().set_source(ctc_loss.out_port(0))
+        graph.remove_nodes_from([ctc_greedy_decoder_tf.id, ctc_loss_tf.id, cast_tf.id, sparse_to_dense_tf.id])
diff --git a/model-optimizer/extensions/front/tf/CTCLossReplacement_test.py b/model-optimizer/extensions/front/tf/CTCLossReplacement_test.py
new file mode 100644 (file)
index 0000000..1b56635
--- /dev/null
@@ -0,0 +1,134 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+import unittest
+from argparse import Namespace
+
+from extensions.front.tf.CTCLossReplacement import CTCLossReplacement
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, const
+
+
+class CTCLossFrontReplacementTest(unittest.TestCase):
+    def test1(self):
+        nodes_attributes = {
+            'logits': {'shape': int64_array([2, 6, 100]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+            'seq_mask': {'shape': int64_array([2, 100]), 'data_type': np.int32, 'kind': 'op', 'op': 'Parameter'},
+
+            'reduce_seq_mask': {'kind': 'op', 'op': 'ReduceSum'},
+            's_cast_seq_mask': {'kind': 'op', 'op': 'Cast'},
+            'transpose_cast_seq_mask': {'kind': 'op', 'op': 'Transpose'},
+
+            'transpose': {'kind': 'op', 'op': 'Transpose'},
+            'ctc_greedy_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
+            'cast': {'kind': 'op', 'op': 'Cast'},
+            'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
+            'const': {'kind': 'op', 'op': 'Const'},
+            'ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
+                         'ctc_merge_repeated': True, 'unique': False},
+
+            'equal_op': {'kind': 'op', 'op': 'Equal'},
+
+            'ctc_greedy_decoder_op': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
+            'ctc_loss_op': {'kind': 'op', 'op': 'CTCLoss'},
+            'squeeze_op': {'kind': 'op', 'op': 'Squeeze'},
+
+            'cast_labels_op': {'kind': 'op', 'op': 'Cast', 'type': 'Convert'},
+            'labels_shape_op': {'kind': 'op', 'op': 'ShapeOf'},
+            'broadcast_one_op': {'kind': 'op', 'op': 'Broadcast'},
+            'broadcast_zero_op': {'kind': 'op', 'op': 'Broadcast'},
+            'select_op': {'kind': 'op', 'op': 'Select'},
+            'label_length_op': {'kind': 'op', 'op': 'ReduceSum'},
+
+            **const('reduce_indices', int64_array(1)),
+            **const('permute_order', int64_array([1, 0])),
+            **const('default_value', int64_array(-1)),
+            **const('squeeze_axis', int64_array([2, 3])),
+            **const('minus_one', np.array([-1], dtype=np.int32)),
+            **const('one', np.array([1], dtype=np.int32)),
+            **const('zero', np.array([0], dtype=np.int32)),
+            **const('reduce_sum_axis', int64_array([1])),
+
+            'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
+        }
+
+        graph = build_graph(nodes_attributes,
+                            [('logits', 'transpose', {'out': 0, 'in': 0}),
+                             ('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
+                             ('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
+
+                             ('transpose', 'ctc_loss', {'out': 0, 'in': 0}),
+                             ('seq_mask', 'ctc_loss', {'out': 0, 'in': 3}),
+
+                             ('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
+                             ('ctc_greedy_decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
+                             ('ctc_greedy_decoder', 'sparse_to_dense', {'out': 1, 'in': 2}),
+                             ('default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
+                             ('ctc_greedy_decoder', 'cast', {'out': 1, 'in': 0}),
+                             ('ctc_greedy_decoder', 'ctc_loss', {'out': 0, 'in': 1}),
+                             ('cast', 'ctc_loss', {'out': 0, 'in': 2}),
+
+                             ('ctc_loss', 'last', {'out': 0, 'in': 0}),
+                             ], nodes_with_edges_only=True)
+        graph.graph['cmd_params'] = Namespace(data_type='FP32')
+        graph.stage = 'front'
+        CTCLossReplacement().find_and_replace_pattern(graph)
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('seq_mask', 'reduce_seq_mask', {'out': 0, 'in': 0}),
+                                 ('reduce_indices', 'reduce_seq_mask', {'out': 0, 'in': 1}),
+
+                                 ('seq_mask', 's_cast_seq_mask', {'out': 0, 'in': 0}),
+                                 ('s_cast_seq_mask', 'transpose_cast_seq_mask', {'out': 0, 'in': 0}),
+                                 ('permute_order', 'transpose_cast_seq_mask', {'out': 0, 'in': 1}),
+
+                                 ('logits', 'transpose', {'out': 0, 'in': 0}),
+                                 ('transpose', 'ctc_greedy_decoder_op', {'out': 0, 'in': 0}),
+                                 ('transpose_cast_seq_mask', 'ctc_greedy_decoder_op', {'out': 0, 'in': 1}),
+
+                                 ('ctc_greedy_decoder_op', 'squeeze_op', {'out': 0, 'in': 0}),
+                                 ('squeeze_axis', 'squeeze_op', {'out': 0, 'in': 1}),
+                                 ('squeeze_op', 'cast_labels_op', {'in': 0}),
+
+                                 ('minus_one', 'equal_op', {'out': 0, 'in': 1}),
+
+                                 ('equal_op', 'labels_shape_op', {'out': 0, 'in': 0}),
+                                 ('one', 'broadcast_one_op', {'out': 0, 'in': 0}),
+                                 ('labels_shape_op', 'broadcast_one_op', {'out': 0, 'in': 1}),
+                                 ('zero', 'broadcast_zero_op', {'out': 0, 'in': 0}),
+                                 ('labels_shape_op', 'broadcast_zero_op', {'out': 0, 'in': 1}),
+
+                                 ('equal_op', 'select_op', {'out': 0, 'in': 0}),
+                                 ('broadcast_zero_op', 'select_op', {'out': 0, 'in': 1}),
+                                 ('broadcast_one_op', 'select_op', {'out': 0, 'in': 2}),
+
+                                 ('select_op', 'label_length_op', {'out': 0, 'in': 0}),
+                                 ('reduce_sum_axis', 'label_length_op', {'out': 0, 'in': 1}),
+
+                                 ('logits', 'ctc_loss_op', {'out': 0, 'in': 0}),
+                                 ('reduce_seq_mask', 'ctc_loss_op', {'out': 0, 'in': 1}),
+                                 ('cast_labels_op', 'ctc_loss_op', {'out': 0, 'in': 2}),
+                                 ('label_length_op', 'ctc_loss_op', {'out': 0, 'in': 3}),
+
+                                 ('cast_labels_op', 'equal_op', {'out': 0, 'in': 0}),
+
+                                 ('ctc_loss_op', 'last', {'out': 0, 'in': 0})],
+                                nodes_with_edges_only=True)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
+        self.assertTrue(flag, resp)
diff --git a/model-optimizer/extensions/front/tf/CTCLoss_ext.py b/model-optimizer/extensions/front/tf/CTCLoss_ext.py
new file mode 100644 (file)
index 0000000..97800c5
--- /dev/null
@@ -0,0 +1,33 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from extensions.ops.ctc_loss import CTCLoss
+from mo.front.extractor import FrontExtractorOp
+
+
+class CTCLossFrontExtractor(FrontExtractorOp):
+    op = 'CTCLoss'
+    enabled = True
+
+    @classmethod
+    def extract(cls, node):
+        attrs = {
+            'ctc_merge_repeated': node.pb.attr['ctc_merge_repeated'].b,
+            'preprocess_collapse_repeated': node.pb.attr['preprocess_collapse_repeated'].b,
+            # unique is always false for CTCLoss V1
+            'unique': False
+        }
+        CTCLoss.update_node_stat(node, attrs)
+        return cls.enabled
index 67d2c0e295ffd6fc804efbf004721dff3746c3e4..de51c706585c64e6336f31540216fb7d9dc5fa66 100644 (file)
@@ -34,8 +34,9 @@ class SparseToDenseReplacer(FrontReplacementOp):
     enabled = True
 
     def run_after(self):
-        from extensions.front.tf.CTCGreedyDecoder import CTCGreedyDecoderReplacement
-        return [CTCGreedyDecoderReplacement]
+        from extensions.front.tf.CTCGreedyDecoderReplacement import CTCGreedyDecoderReplacement
+        from extensions.front.tf.CTCLossReplacement import CTCLossReplacement
+        return [CTCGreedyDecoderReplacement, CTCLossReplacement]
 
     def replace_op(self, graph: Graph, node: Node):
         node_name = node.soft_get('name', node.id)
index 9262eac5bfaa67419bf986883afd03090000f80b..fbbb252ca524c0cbf14af784dc9c78fa10979b0e 100644 (file)
@@ -14,8 +14,7 @@
  limitations under the License.
 """
 
-import numpy as np
-
+from mo.front.common.partial_infer.utils import int64_array
 from mo.graph.graph import Node, Graph
 from mo.ops.op import Op
 
@@ -25,12 +24,15 @@ class CTCGreedyDecoderOp(Op):
 
     def __init__(self, graph: Graph, attrs: dict):
         mandatory_props = {
-            'type': __class__.op,
-            'op': __class__.op,
+            'type': self.op,
+            'op': self.op,
             'version': 'opset1',
+
+            'infer': self.infer,
+            'reinterp_shape': True,
+
             'in_ports_count': 2,
-            'out_ports_count': 1,
-            'infer': CTCGreedyDecoderOp.ctc_greedy_decoder_infer
+            'out_ports_count': 1
         }
         super().__init__(graph, mandatory_props, attrs)
 
@@ -40,11 +42,23 @@ class CTCGreedyDecoderOp(Op):
         ]
 
     @staticmethod
-    def ctc_greedy_decoder_infer(node: Node):
-        outn = node.out_node(0)
-        inn = node.in_node(0)
-        inn2 = node.in_node(1)
-        outn.shape = np.ones(4, dtype=np.int)
-        assert inn.shape[1] == inn2.shape[1], 'Batch for CTCGreedyDecoder should be the same in both inputs'
-        outn.shape[0] = inn.shape[1]
-        outn.shape[1] = inn.shape[0]
+    def infer(node: Node):
+        node_name = node.soft_get('name', node.id)
+        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
+        assert len(connected_in_ports) == 2, \
+            "Incorrect number of inputs for {} node".format(node_name)
+
+        logits_shape = node.in_port(0).data.get_shape()
+        sequence_mask_shape = node.in_port(1).data.get_shape()
+
+        # check shapes of input tensors
+        assert len(logits_shape) == 3 and len(sequence_mask_shape) == 2, \
+            'Incorrect rank of some input tensor for {} node'.format(node_name)
+        assert logits_shape[1] == sequence_mask_shape[1], \
+            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
+        assert logits_shape[0] == sequence_mask_shape[0], \
+            'Time dimensions of input tensors must be the same for {} node'.format(node_name)
+
+        batch_size = logits_shape[1]
+        time_size = logits_shape[0]
+        node.out_port(0).data.set_shape(int64_array([batch_size, time_size, 1, 1]))
index cf65e677f83e98cdc40926209bfc2ef092fbac63..b8bf3038e4ed21e0f48dc103386e6b99ec62484e 100644 (file)
@@ -19,36 +19,77 @@ import unittest
 import numpy as np
 
 from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
+from mo.front.common.partial_infer.utils import int64_array
 from mo.graph.graph import Node
 from mo.utils.unittest.graph import build_graph
 
-nodes_attributes = {'node_1': {'type': 'Identity', 'kind': 'op'},
-                    'node_2': {'type': 'Identity', 'kind': 'op'},
-                    'ctc': {'type': 'CTCGreedyDecoder', 'kind': 'op'},
-                    'node_3': {'type': 'Identity', 'kind': 'op'},
-                    'op_output': { 'kind': 'op', 'op': 'Result'},
-                    }
-
-
-class TestConcatPartialInfer(unittest.TestCase):
-    def test_tf_concat_infer(self):
-        graph = build_graph(nodes_attributes,
-                            [
-                                ('node_1', 'ctc'),
-                                ('node_2', 'ctc'),
-                                ('ctc', 'node_3'),
-                                ('node_3', 'op_output')
-                            ],
-                            {
-                                'node_3': {'shape': None},
-                                'node_1': {'shape': np.array([88, 2, 71])},
-                                'node_2': {'shape': np.array([88, 2])},
-                                'ctc': {'ctc_merge_repeated': 1}
-                            })
-
-        ctc_node = Node(graph, 'ctc')
-        CTCGreedyDecoderOp.ctc_greedy_decoder_infer(ctc_node)
-        exp_shape = np.array([2, 88, 1, 1])
-        res_shape = graph.node['node_3']['shape']
-        for i in range(0, len(exp_shape)):
-            self.assertEqual(exp_shape[i], res_shape[i])
+
+nodes_attributes = {'logits': {'kind': 'op'},
+                    'logits_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'seq_mask': {'kind': 'op'},
+                    'seq_mask_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'ctcgreedydecoder_node': {'op': 'CTCGreedyDecoder', 'kind': 'op',
+                                              'ctc_merge_repeated': True},
+                    'output': {'shape': None, 'value': None, 'kind': 'data'}}
+
+# graph 1
+edges1 = [('logits', 'logits_data'),
+          ('seq_mask', 'seq_mask_data'),
+          ('logits_data', 'ctcgreedydecoder_node', {'in': 0}),
+          ('seq_mask_data', 'ctcgreedydecoder_node', {'in': 1}),
+          ('ctcgreedydecoder_node', 'output', {'out': 0})]
+
+# valid test case
+inputs1 = {'logits_data': {'shape': int64_array([100, 4, 5])},
+           'seq_mask_data': {'shape': int64_array([100, 4])}}
+
+# invalid test case with incorrect rank for the first input tensor
+inputs1_inv = {'logits_data': {'shape': int64_array([100, 4, 5, 6])},
+               'seq_mask_data': {'shape': int64_array([100, 4])}}
+
+# invalid test case with incorrect rank for the second input tensor
+inputs2_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
+               'seq_mask_data': {'shape': int64_array([100])}}
+
+# invalid test case with incorrect time dimension
+inputs3_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
+               'seq_mask_data': {'shape': int64_array([101, 4])}}
+
+# invalid test case with incorrect batch dimension
+inputs4_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
+               'seq_mask_data': {'shape': int64_array([100, 14])}}
+
+class TestCTCGreedyDecoder(unittest.TestCase):
+    def test_infer1(self):
+        graph = build_graph(nodes_attributes, edges1, inputs1)
+        ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
+        CTCGreedyDecoderOp.infer(ctcgreedydecoder_node)
+
+        # prepare reference results
+        ref_output_shape = int64_array([4, 100, 1, 1])
+
+        # get the result
+        res_output_shape = graph.node['output']['shape']
+
+        self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
+                        'shapes do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
+
+    def test_infer_invalid1(self):
+        graph = build_graph(nodes_attributes, edges1, inputs1_inv)
+        ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
+        self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
+
+    def test_infer_invalid2(self):
+        graph = build_graph(nodes_attributes, edges1, inputs2_inv)
+        ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
+        self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
+
+    def test_infer_invalid3(self):
+        graph = build_graph(nodes_attributes, edges1, inputs3_inv)
+        ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
+        self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
+
+    def test_infer_invalid4(self):
+        graph = build_graph(nodes_attributes, edges1, inputs4_inv)
+        ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
+        self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
diff --git a/model-optimizer/extensions/ops/ctc_loss.py b/model-optimizer/extensions/ops/ctc_loss.py
new file mode 100644 (file)
index 0000000..6004c63
--- /dev/null
@@ -0,0 +1,89 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+
+from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Node, Graph
+from mo.ops.op import Op
+
+
+class CTCLoss(Op):
+    op = 'CTCLoss'
+
+    def __init__(self, graph: Graph, attrs: dict):
+        mandatory_props = {
+            'type': self.op,
+            'op': self.op,
+            'version': 'opset4',
+
+            'type_infer': self.type_infer,
+            'infer': self.infer,
+
+            'in_ports_count': 5,
+            'out_ports_count': 1,
+        }
+        super().__init__(graph, mandatory_props, attrs)
+
+    def backend_attrs(self):
+        return ['preprocess_collapse_repeated', 'ctc_merge_repeated', 'unique']
+
+    @staticmethod
+    def type_infer(node):
+        logits_type = node.in_port(0).get_data_type()
+        logit_length_type = node.in_port(1).get_data_type()
+        labels_type = node.in_port(2).get_data_type()
+        label_length_type = node.in_port(3).get_data_type()
+        blank_index_type = labels_type
+        if not node.in_port(4).disconnected():
+            blank_index_type = node.in_port(4).get_data_type()
+
+        assert logit_length_type == label_length_type and logit_length_type in [np.int64, np.int32], \
+            'Inputs with logits and labels lengths for node {} must be the same and int32 or int64, {} and {} found'.format(
+                node.soft_get('name'), logit_length_type, label_length_type)
+        assert labels_type == blank_index_type and labels_type in [np.int64, np.int32], \
+            'Inputs with labels and blank index for node {} must be the same and int32 or int64, {} and {} found'.format(
+                node.soft_get('name'), labels_type, blank_index_type)
+
+        node.out_port(0).set_data_type(logits_type)
+
+    @staticmethod
+    def infer(node: Node):
+        node_name = node.soft_get('name', node.id)
+        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
+        assert len(connected_in_ports) in [4, 5], \
+            "Incorrect number of inputs for {} node".format(node_name)
+
+        logits_shape = node.in_port(0).data.get_shape()
+        logit_length_shape = node.in_port(1).data.get_shape()
+        labels_shape = node.in_port(2).data.get_shape()
+        label_length_shape = node.in_port(3).data.get_shape()
+        blank_index_shape = int64_array([])
+        if len(node.in_nodes()) == 5:
+            blank_index_shape = node.in_port(4).data.get_shape()
+
+        # check shapes of input tensors
+        assert len(logits_shape) == 3 and len(logit_length_shape) == 1 and len(labels_shape) == 2\
+            and len(label_length_shape) == 1 and len(blank_index_shape) == 0, \
+            'Incorrect rank of some input tensor for {} node'.format(node_name)
+        assert logits_shape[0] == logit_length_shape[0] and logits_shape[0] == labels_shape[0]\
+            and logits_shape[0] == label_length_shape[0], \
+            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
+        assert logits_shape[1] == labels_shape[1], \
+            'Time dimensions of input tensors must be the same for {} node'.format(node_name)
+
+        batch_size = logits_shape[0]
+        node.out_port(0).data.set_shape(int64_array([batch_size]))
diff --git a/model-optimizer/extensions/ops/ctc_loss_test.py b/model-optimizer/extensions/ops/ctc_loss_test.py
new file mode 100644 (file)
index 0000000..7d53804
--- /dev/null
@@ -0,0 +1,97 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+
+from extensions.ops.ctc_loss import CTCLoss
+from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Node
+from mo.utils.unittest.graph import build_graph
+
+nodes_attributes = {'logits': {'kind': 'op'},
+                    'logits_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'logit_length': {'kind': 'op'},
+                    'logit_length_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'labels': {'kind': 'op'},
+                    'labels_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'label_length': {'kind': 'op'},
+                    'label_length_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'blank_index': {'kind': 'op'},
+                    'blank_index_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'ctcloss_node': {'op': 'CTCLoss', 'kind': 'op', 'preprocess_collapse_repeated': False,
+                                     'ctc_merge_repeated': True, 'unique': False},
+                    'output': {'shape': None, 'value': None, 'kind': 'data'}}
+
+# graph 1
+edges1 = [('logits', 'logits_data'),
+          ('logit_length', 'logit_length_data'),
+          ('labels', 'labels_data'),
+          ('label_length', 'label_length_data'),
+          ('blank_index', 'blank_index_data'),
+          ('logits_data', 'ctcloss_node', {'in': 0}),
+          ('logit_length_data', 'ctcloss_node', {'in': 1}),
+          ('labels_data', 'ctcloss_node', {'in': 2}),
+          ('label_length_data', 'ctcloss_node', {'in': 3}),
+          ('blank_index_data', 'ctcloss_node', {'in': 4}),
+          ('ctcloss_node', 'output', {'out': 0})]
+
+# valid test case
+inputs1 = {'logits_data': {'shape': int64_array([4, 100, 5])},
+           'logit_length_data': {'shape': int64_array([4])},
+           'labels_data': {'shape': int64_array([4, 100])},
+           'label_length_data': {'shape': int64_array([4])},
+           'blank_index_data': {'shape': int64_array([])}}
+
+# invalid test case with incorrect rank for the second input tensor
+inputs2 = {'logits_data': {'shape': int64_array([4, 100, 5])},
+           'logit_length_data': {'shape': int64_array([4, 3])},
+           'labels_data': {'shape': int64_array([4, 100])},
+           'label_length_data': {'shape': int64_array([4])},
+           'blank_index_data': {'shape': int64_array([])}}
+
+# invalid test case with incorrect time dimension
+inputs3 = {'logits_data': {'shape': int64_array([4, 100, 5])},
+           'logit_length_data': {'shape': int64_array([4])},
+           'labels_data': {'shape': int64_array([4, 300])},
+           'label_length_data': {'shape': int64_array([4])},
+           'blank_index_data': {'shape': int64_array([])}}
+
+class TestCTCLoss(unittest.TestCase):
+    def test_infer1(self):
+        graph = build_graph(nodes_attributes, edges1, inputs1)
+        ctc_loss_node = Node(graph, 'ctcloss_node')
+        CTCLoss.infer(ctc_loss_node)
+
+        # prepare reference results
+        ref_output_shape = int64_array([4])
+
+        # get the result
+        res_output_shape = graph.node['output']['shape']
+
+        self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
+                        'shapes do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
+
+    def test_infer_invalid1(self):
+        graph = build_graph(nodes_attributes, edges1, inputs2)
+        ctc_loss_node = Node(graph, 'ctcloss_node')
+        self.assertRaises(AssertionError, CTCLoss.infer, ctc_loss_node)
+
+    def test_infer_invalid2(self):
+        graph = build_graph(nodes_attributes, edges1, inputs3)
+        ctc_loss_node = Node(graph, 'ctcloss_node')
+        self.assertRaises(AssertionError, CTCLoss.infer, ctc_loss_node)