From: Balint Cristian Date: Sun, 14 Jun 2020 03:39:20 +0000 (+0300) Subject: [ONNX] Skip multiply with 1.0f constant for GEMM import (#5800) X-Git-Tag: upstream/0.7.0~558 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=12558b6f725e4341f3e5c905ee158a50d919df1b;p=platform%2Fupstream%2Ftvm.git [ONNX] Skip multiply with 1.0f constant for GEMM import (#5800) * [ONNX] Skip ADD inside Gemm op when vector is zero * [ONNX] Skip multiply with 1.0f constant for GEMM import --- diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index dabe55f..05a067d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -477,8 +477,11 @@ class Gemm(OnnxOpConverter): if not transB: inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) inputs[0] = _op.nn.batch_flatten(inputs[0]) - out = _op.nn.dense(_expr.const(alpha) * inputs[0], - inputs[1], units=channels) + + if alpha != 1.0: + inputs[0] *= _expr.const(alpha) + out = _op.nn.dense(inputs[0], inputs[1], units=channels) + # skip (beta * C) if zero C_array = params[inputs[2].name_hint].asnumpy() if (beta == 0.0) or np.array_equal(C_array, np.array([0])):