2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.front.common.partial_infer.utils import int64_array
20 from mo.front.common.replacement import FrontReplacementSubgraph
21 from mo.graph.graph import Node, Graph
22 from mo.utils.error import Error
25 class CTCGreedyDecoderReplacement(FrontReplacementSubgraph):
27 The TF implementation of the CTCGreedyDecoder produces a tuple with two tensors. The first element in the tuple is
28 the SparseTensor which is converted to a regular tensor with the SparseToDense operation. This replacer matches
29 CTCGreedyDecoder and SparseToDense operations and removes the SparseToDense and Cast operation which is also used
30 in the SparseToDense operation, because Inference Engine implementation of the CTCGreedyDecoder produces regular
33 The second input to the CTCGreedyDecoder in the TensorFlow is a 1D tensor with sequence lengths. In the Inference
34 Engine the second input to the CTCGreedyDecoder is a 2D tensor where the first element in each row is equal to 0
35 and all others are equal to 1. The length of the row is equal to the sequence length. The replacer modifies the
36 second input to be compatible with the Inference Engine CTCGreedyDecoder layer implementation.
41 def pattern(**kwargs):
44 ('decoder', dict(op='CTCGreedyDecoder')),
45 ('cast', dict(op='Cast')),
46 ('sparse_to_dense', dict(op='SparseToDense')),
49 ('decoder', 'sparse_to_dense', {'out': 0}),
50 ('decoder', 'cast', {'out': 1}),
51 ('cast', 'sparse_to_dense', {'out': 0}),
55 def nodes_to_remove(self, graph: Graph, match: dict):
56 return [match['cast'].id, match['sparse_to_dense']]
58 def replace_sub_graph(self, graph: Graph, match: dict):
59 decoder_node = match['decoder']
60 graph.remove_edge(decoder_node.id, match['sparse_to_dense'].id)
61 graph.remove_edge(decoder_node.id, match['cast'].id)
62 match['sparse_to_dense'].replace_node(decoder_node)
64 # update the TensorFlow infer function for the CTCGreedyDecoder to make necessary changes with the second input
65 decoder_node['old_infer'] = decoder_node.infer
66 decoder_node.infer = __class__.tf_greedy_decoder_infer
70 def tf_greedy_decoder_infer(node: Node):
71 sequence_length_node = node.in_node(1)
72 if sequence_length_node.value is None:
73 raise Error('The second input to the CTCGreedyDecoder node "{}" is not constant. This case is not '
74 'supported with the Inference Engine.'.format(node.soft_get('name')))
75 # the batch size is the dimension with index 1 for the layer CTCGreedyDecoder
76 new_value = np.ones([node.in_node(0).shape[1], sequence_length_node.value[0]])
78 new_value = np.transpose(new_value)
79 sequence_length_node.value = new_value
80 sequence_length_node.shape = int64_array(sequence_length_node.value.shape)