Publishing 2019 R3 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / broadcast.py
index b3894a5..d2d8070 100644 (file)
@@ -49,7 +49,13 @@ class Broadcast(Op):
     @staticmethod
     def infer(node: Node):
         # TODO Add necessary checks and asserts
-        node.out_node().shape = node.in_node(1).value
+        b_value = node.in_port(0).data.get_value()
+        b_shape = node.in_port(1).data.get_value()
+        assert b_shape is not None
+        node.out_port(0).data.set_shape(b_shape)
+
         PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
-        if node.in_node(0).value is not None and node.in_node(1).value is not None:
-            node.out_node().value = np.broadcast_to(node.in_node(0).value, node.in_node(1).value)
+        if b_value is not None:
+            new_value = np.broadcast_to(b_value, b_shape)
+            node.out_port(0).data.set_value(new_value)
+