Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / CTCGreedyDecoder.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from mo.front.common.partial_infer.utils import int64_array
20 from mo.front.common.replacement import FrontReplacementSubgraph
21 from mo.graph.graph import Node, Graph
22 from mo.utils.error import Error
23
24
25 class CTCGreedyDecoderReplacement(FrontReplacementSubgraph):
26     """
27     The TF implementation of the CTCGreedyDecoder produces a tuple with two tensors. The first element in the tuple is
28     the SparseTensor which is converted to a regular tensor with the SparseToDense operation. This replacer matches
29     CTCGreedyDecoder and SparseToDense operations and removes the SparseToDense and Cast operation which is also used
30     in the SparseToDense operation, because Inference Engine implementation of the CTCGreedyDecoder produces regular
31     tensor as output.
32
33     The second input to the CTCGreedyDecoder in the TensorFlow is a 1D tensor with sequence lengths. In the Inference
34     Engine the second input to the CTCGreedyDecoder is a 2D tensor where the first element in each row is equal to 0
35     and all others are equal to 1. The length of the row is equal to the sequence length. The replacer modifies the
36     second input to be compatible with the Inference Engine CTCGreedyDecoder layer implementation.
37     """
38     enabled = True
39
40     @staticmethod
41     def pattern(**kwargs):
42         return dict(
43             nodes=[
44                 ('decoder', dict(op='CTCGreedyDecoder')),
45                 ('cast', dict(op='Cast')),
46                 ('sparse_to_dense', dict(op='SparseToDense')),
47             ],
48             edges=[
49                 ('decoder', 'sparse_to_dense', {'out': 0}),
50                 ('decoder', 'cast', {'out': 1}),
51                 ('cast', 'sparse_to_dense', {'out': 0}),
52             ]
53         )
54
55     def nodes_to_remove(self, graph: Graph, match: dict):
56         return [match['cast'].id, match['sparse_to_dense']]
57
58     def replace_sub_graph(self, graph: Graph, match: dict):
59         decoder_node = match['decoder']
60         graph.remove_edge(decoder_node.id, match['sparse_to_dense'].id)
61         graph.remove_edge(decoder_node.id, match['cast'].id)
62         match['sparse_to_dense'].replace_node(decoder_node)
63
64         # update the TensorFlow infer function for the CTCGreedyDecoder to make necessary changes with the second input
65         decoder_node['old_infer'] = decoder_node.infer
66         decoder_node.infer = __class__.tf_greedy_decoder_infer
67         return {}
68
69     @staticmethod
70     def tf_greedy_decoder_infer(node: Node):
71         sequence_length_node = node.in_node(1)
72         if sequence_length_node.value is None:
73             raise Error('The second input to the CTCGreedyDecoder node "{}" is not constant. This case is not '
74                         'supported with the Inference Engine.'.format(node.soft_get('name')))
75         # the batch size is the dimension with index 1 for the layer CTCGreedyDecoder
76         new_value = np.ones([node.in_node(0).shape[1], sequence_length_node.value[0]])
77         new_value[:, 0] = 0
78         new_value = np.transpose(new_value)
79         sequence_length_node.value = new_value
80         sequence_length_node.shape = int64_array(sequence_length_node.value.shape)
81
82         node.old_infer(node)