import numpy as np
+from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Mul, Add
from extensions.ops.mvn import MVN
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
+from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.reshape import Reshape
initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())
- initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
- initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node)
- initial_spatial_dims_node = node_to_get_spatial_dimensions_value(initial_shape_op_node)
+ initial_shape_op_node_float = Cast(
+ graph, {'name': initial_shape_op_node.name + '/to_float',
+ 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
+ initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0))
+
+ initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float)
+ initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float)
+ initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node)
+ initial_spatial_dims_node = Cast(
+ graph, {'name': initial_spatial_dims_node_int.name + '/to_float',
+ 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
+ initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0))
+
group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
'name': group_norm_node.name + '/GroupSize'}).create_node()
batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))
# create new node which concatenates several dims to one
- new_shape_node = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
- initial_spatial_dims_node])
+ new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
+ initial_spatial_dims_node])
+ new_shape_node = Cast(graph,
+ {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node()
+ new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))
reshape_for_mvn_node = Reshape(graph, {}).create_node()