[ MO GroupNorm ] Covered float Multiplication with Converts (#1602)
authorEvgenya Stepyreva <evgenya.stepyreva@intel.com>
Mon, 3 Aug 2020 11:45:39 +0000 (14:45 +0300)
committerGitHub <noreply@github.com>
Mon, 3 Aug 2020 11:45:39 +0000 (14:45 +0300)
model-optimizer/extensions/middle/GroupNorm.py

index 4e7ed9b..166fcd5 100644 (file)
@@ -18,10 +18,12 @@ from typing import Dict
 
 import numpy as np
 
+from extensions.ops.Cast import Cast
 from extensions.ops.elementwise import Mul, Add
 from extensions.ops.mvn import MVN
 from mo.front.common.partial_infer.utils import int64_array
 from mo.graph.graph import Graph, Node
+from mo.middle.passes.convert_data_type import data_type_str_to_np
 from mo.middle.replacement import MiddleReplacementPattern
 from mo.ops.const import Const
 from mo.ops.reshape import Reshape
@@ -58,9 +60,19 @@ class GroupNormToMVN(MiddleReplacementPattern):
         initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
         initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())
 
-        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
-        initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node)
-        initial_spatial_dims_node = node_to_get_spatial_dimensions_value(initial_shape_op_node)
+        initial_shape_op_node_float = Cast(
+            graph, {'name': initial_shape_op_node.name + '/to_float',
+                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
+        initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0))
+
+        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float)
+        initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float)
+        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node)
+        initial_spatial_dims_node = Cast(
+            graph, {'name': initial_spatial_dims_node_int.name + '/to_float',
+                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
+        initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0))
+
         group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
                                         'name': group_norm_node.name + '/GroupSize'}).create_node()
 
@@ -77,8 +89,11 @@ class GroupNormToMVN(MiddleReplacementPattern):
         batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))
 
         # create new node which concatenates several dims to one
-        new_shape_node = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
-                                                          initial_spatial_dims_node])
+        new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
+                                                                initial_spatial_dims_node])
+        new_shape_node = Cast(graph,
+                              {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node()
+        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))
 
         reshape_for_mvn_node = Reshape(graph, {}).create_node()