[ MO ] KSO=ON for Kaldi (#2028)
authorEvgenya Stepyreva <evgenya.stepyreva@intel.com>
Fri, 23 Oct 2020 10:14:00 +0000 (13:14 +0300)
committerGitHub <noreply@github.com>
Fri, 23 Oct 2020 10:14:00 +0000 (13:14 +0300)
* [ MO ] KSO=ON for Kaldi

* [ MO ] Kaldi KSO

* set static_shape for graph cycle making transformation

model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py
model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py
model-optimizer/extensions/middle/ReplaceMemoryOffsetWithSplice.py
model-optimizer/mo/main.py

index 15bf12b..8173328 100644 (file)
  See the License for the specific language governing permissions and
  limitations under the License.
 """
-from extensions.ops.elementwise import Mul, Pow
-from extensions.ops.split import VariadicSplit
-from mo.front.common.partial_infer.utils import int64_array
+import numpy as np
+
+from extensions.ops.Cast import Cast
+from extensions.ops.elementwise import Div
+from mo.front.common.partial_infer.utils import int64_array, float_array
 from mo.front.common.replacement import FrontReplacementPattern
 from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
 from mo.graph.graph import Graph
+from mo.middle.passes.convert_data_type import data_type_str_to_np
 from mo.ops.concat import Concat
-from mo.ops.const import Const
 from mo.ops.reshape import Reshape
 from mo.ops.shape import Shape
+from mo.utils.shape import node_to_get_shape_value_of_indices
 
 
 class ReplaceConvolutionReshape(FrontReplacementPattern):
@@ -57,30 +60,28 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
 
         # create Reshape before convolution
         # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
-        shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
-        shape.in_port(0).connect(node.in_port(0).get_source())
-
-        split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])},
-                                            {'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape)
+        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
+        shape = Cast(graph, {'name': node_name + '/to_float',
+                             'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
+        i_shape.in_port(0).connect(node.in_port(0).get_source())
+        shape.in_port(0).connect(i_shape.out_port(0))
 
-        pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/patch_stride/inverse'})
-        conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride]),
-                                          'name': node_name + '/patch_stride/'}).create_node()
-        pow_node.in_port(0).connect(conv_patch_stride.out_port(0))
+        N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
 
-        mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node()
-        mul.in_port(0).connect(split.out_port(1))
-        mul.in_port(1).connect(pow_node.out_port(0))
+        div = create_op_with_const_inputs(
+            graph, Div, {1: float_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
+        div.in_port(0).connect(H.out_port(0))
 
-        concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])},
+        concat = create_op_with_const_inputs(graph, Concat, {2: float_array([1]), 3: float_array([node.patch_stride])},
                                              {'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
+        concat.in_port(0).connect(N.out_port(0))
+        concat.in_port(1).connect(div.out_port(0))
 
-        concat.in_port(0).connect(split.out_port(0))
-        concat.in_port(1).connect(mul.out_port(0))
-        concat.in_port(3).connect(conv_patch_stride.out_port(0))
+        reshape_pattern = Cast(graph, {'name': node_name + '/to_int', 'dst_type': np.int64}).create_node()
+        concat.out_port(0).connect(reshape_pattern.in_port(0))
 
         reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
-        reshape_in.in_port(1).connect(concat.out_port(0))
+        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))
 
         # create Reshape after Convolution
         reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
index 1978790..68c5629 100644 (file)
  See the License for the specific language governing permissions and
  limitations under the License.
 """
-from extensions.ops.elementwise import Mul, Pow
-from extensions.ops.split import VariadicSplit
-from mo.front.common.partial_infer.utils import int64_array
+import numpy as np
+
+from extensions.ops.Cast import Cast
+from extensions.ops.elementwise import Div
+from mo.front.common.partial_infer.utils import int64_array, float_array
 from mo.front.common.replacement import FrontReplacementPattern
 from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
 from mo.graph.graph import Graph
+from mo.middle.passes.convert_data_type import data_type_str_to_np
 from mo.ops.concat import Concat
-from mo.ops.const import Const
 from mo.ops.reshape import Reshape
 from mo.ops.shape import Shape
+from mo.utils.shape import node_to_get_shape_value_of_indices
 
 
 class ReplacePoolingReshape(FrontReplacementPattern):
@@ -55,31 +58,29 @@ class ReplacePoolingReshape(FrontReplacementPattern):
             node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])
 
         # create Reshape before convolution
-        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
-        shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
-        shape.in_port(0).connect(node.in_port(0).get_source())
-
-        split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])},
-                                            {'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape)
+        # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
+        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
+        shape = Cast(graph, {'name': node_name + '/to_float',
+                             'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
+        i_shape.in_port(0).connect(node.in_port(0).get_source())
+        shape.in_port(0).connect(i_shape.out_port(0))
 
-        pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/pool_stride/inverse'})
-        node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride]),
-                                         'name': node_name + '/pool_stride/'}).create_node()
-        pow_node.in_port(0).connect(node_pool_stride.out_port(0))
+        N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
 
-        mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node()
-        mul.in_port(0).connect(split.out_port(1))
-        mul.in_port(1).connect(pow_node.out_port(0))
+        div = create_op_with_const_inputs(
+            graph, Div, {1: float_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'})
+        div.in_port(0).connect(H.out_port(0))
 
-        concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])},
+        concat = create_op_with_const_inputs(graph, Concat, {1: float_array([node.pool_stride]), 2: float_array([1])},
                                              {'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
+        concat.in_port(0).connect(N.out_port(0))
+        concat.in_port(3).connect(div.out_port(0))
 
-        concat.in_port(0).connect(split.out_port(0))
-        concat.in_port(3).connect(mul.out_port(0))
-        concat.in_port(1).connect(node_pool_stride.out_port(0))
+        reshape_pattern = Cast(graph, {'name': node_name + '/to_int', 'dst_type': np.int64}).create_node()
+        concat.out_port(0).connect(reshape_pattern.in_port(0))
 
         reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
-        reshape_in.in_port(1).connect(concat.out_port(0))
+        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))
 
         # create Reshape after Convolution
         reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
index 552f816..48d9ba0 100644 (file)
@@ -14,6 +14,7 @@
  limitations under the License.
 """
 import numpy as np
+import logging as log
 
 from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
 from extensions.ops.splice import Splice
@@ -174,6 +175,14 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
             memory_in.out_port(0).connect(out.in_port(0))
             out_port.get_connection().set_source(memory_out.out_port(0))
 
+        if not graph.graph['cmd_params'].static_shape:
+            log.error(
+                "Model can not be translated in a reshape-able way.\n"
+                "Model Optimizer key static_shape was turned on to prevent related errors.\n"
+                "There will be no success changing input shapes of the model with the help of "
+                "InferenceEngine reshape method", extra={'is_warning': True})
+            graph.graph['cmd_params'].static_shape = True
+
         graph.remove_node(op_output_id)
         graph.remove_node(node.id)
         graph.remove_node(pair_node.id)
index 79d9153..59c4fe1 100644 (file)
@@ -225,7 +225,6 @@ def prepare_ir(argv: argparse.Namespace):
         from mo.front.mxnet.register_custom_ops import get_front_classes
         import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
     elif is_kaldi:
-        argv.static_shape = True
         from mo.front.kaldi.register_custom_ops import get_front_classes
         import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
     elif is_onnx: