X-Git-Url: http://review.tizen.org/git/?a=blobdiff_plain;f=model-optimizer%2Fmo%2Ffront%2Fcommon%2Fpartial_infer%2Feltwise.py;h=7cfdb1570700416a8edaf2d22931bbba4cf091bb;hb=0923303e0201c5b59386ab146d0e30b2ef79272d;hp=12d4b800adfcdbc33c7a7d53ee7786caea23ee4d;hpb=ba6e22b1b5ee4cbefcc30e8d9493cddb0bb3dfdf;p=platform%2Fupstream%2Fdldt.git diff --git a/model-optimizer/mo/front/common/partial_infer/eltwise.py b/model-optimizer/mo/front/common/partial_infer/eltwise.py index 12d4b80..7cfdb15 100644 --- a/model-optimizer/mo/front/common/partial_infer/eltwise.py +++ b/model-optimizer/mo/front/common/partial_infer/eltwise.py @@ -14,9 +14,8 @@ limitations under the License. """ -import numpy as np -import logging as log import networkx as nx +import numpy as np from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Node @@ -90,3 +89,10 @@ def eltwise_infer(node, op=None, **kwargs): node.out_node().value = values[0] for i in range(len(values) - 1): node.out_node().value = op(node.out_node().value, values[i + 1]) + + +def bias_add_infer(node, op): + if node.in_port(0).data.get_value() is not None and node.in_port(1).data.get_value() is not None and op is not None: + node.out_port(0).data.set_value(op(node.in_port(0).data.get_value(), node.in_port(1).data.get_value())) + else: + node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())