@staticmethod
def infer(node: Node):
# TODO Add necessary checks and asserts
- node.out_node().shape = node.in_node(1).value
+ b_value = node.in_port(0).data.get_value()
+ b_shape = node.in_port(1).data.get_value()
+ assert b_shape is not None
+ node.out_port(0).data.set_shape(b_shape)
+
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
- if node.in_node(0).value is not None and node.in_node(1).value is not None:
- node.out_node().value = np.broadcast_to(node.in_node(0).value, node.in_node(1).value)
+ if b_value is not None:
+ new_value = np.broadcast_to(b_value, b_shape)
+ node.out_port(0).data.set_value(new_value)
+