See the License for the specific language governing permissions and
limitations under the License.
"""
+import numpy as np
+
+from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Sub
from extensions.ops.rank import Rank
from extensions.ops.split import Split
# add zeros/ones to related inputs to align it with data input
in0_rank = Rank(graph, {'name': node.name + '/rank_0'}).create_node()
- in1_rank = Shape(graph, {'name': node.name + '/rank_1'}).create_node()
+ in1_shape = Shape(graph, {'name': node.name + '/rank_1'}).create_node()
diff_size = Sub(graph, {'name': node.name + '/sub_0'}).create_node()
diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node()
-
const_begin = Const(graph, {'value': int64_array([1])}).create_node()
- const_pad_val = Const(graph, {'value': int64_array([1])}).create_node()
+ const_pad_val = Const(graph, {'value': int64_array(1)}).create_node()
block_shape = Pad(graph, {'name': node.name + '/aligned_block_shape', 'mode': 'constant'}).create_node()
in0_rank_1d = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
{'name': node.name + '/1d_rank_of_0'}, in0_rank)
- in1_rank_1d = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
- {'name': node.name + '/1d_rank_of_1'}, in1_rank)
node.in_port(0).get_source().connect(in0_rank.in_port(0))
- node.in_port(1).get_source().connect(in1_rank.in_port(0))
+ node.in_port(1).get_source().connect(in1_shape.in_port(0))
in0_rank_1d.out_port(0).connect(diff_size.in_port(0))
- in1_rank_1d.out_port(0).connect(diff_size.in_port(1))
+ in1_shape.out_port(0).connect(diff_size.in_port(1))
diff_size.out_port(0).connect(diff.in_port(0))
const_begin.out_port(0).connect(diff.in_port(1))
const_pad_val.out_port(0).connect(block_shape.in_port(3))
inputs_array = [block_shape, begin, end]
for idx, input_to_node in enumerate(inputs_array):
+ name_of_input_to_node = input_to_node.name
node.in_port(idx + 1).get_connection().set_destination(input_to_node.in_port(0))
const_begin.out_port(0).connect(input_to_node.in_port(1))
diff.out_port(0).connect(input_to_node.in_port(2))
input_to_node.out_port(0).connect(node.in_port(idx + 1))
+ convert = Cast(graph, {'name': name_of_input_to_node + '/i64', 'dst_type': np.int64}).create_node()
+ input_to_node.in_port(0).get_connection().insert_node(convert)
super().__init__(graph, {
'op': self.op,
'type': self.op,
+
'version': 'opset1',
- 'infer': __class__.infer,
- 'in_ports_count': 4,
- 'out_ports_count': 1,
+ 'infer': self.infer,
+
'mode': 'constant',
'fill_value': float(0),
+
'force_precision_in_ports': {
1: 'int64',
2: 'int64',
},
- }, attrs)
- def supported_attrs(self):
- return ['mode', 'fill_value', 'pads']
+ 'in_ports_count': 4,
+ 'out_ports_count': 1,
+ }, attrs)
def backend_attrs(self):
return [('pad_mode', 'mode'),
assert len(input_shape) == len(pad_end), \
'Length of end padding "{}" does not correspond to input tensor shape "{}" for node "{}".' \
''.format(pad_beg, input_shape, pad_node_name)
+ assert not node.is_in_port_connected(3) or node.in_port(3).data.get_shape().size == 0, \
+ 'Optional 3rd input of Pad operation should be scalar, but has shape {} for node {}' \
+ ''.format(node.in_port(3).data.get_shape(), pad_node_name)
node.out_port(0).data.set_shape(input_shape + pad_beg + pad_end)