From 12558b6f725e4341f3e5c905ee158a50d919df1b Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Sun, 14 Jun 2020 06:39:20 +0300 Subject: [PATCH] [ONNX] Skip multiply with 1.0f constant for GEMM import (#5800) * [ONNX] Skip ADD inside Gemm op when vector is zero * [ONNX] Skip multiply with 1.0f constant for GEMM import --- python/tvm/relay/frontend/onnx.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index dabe55f..05a067d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -477,8 +477,11 @@ class Gemm(OnnxOpConverter): if not transB: inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) inputs[0] = _op.nn.batch_flatten(inputs[0]) - out = _op.nn.dense(_expr.const(alpha) * inputs[0], - inputs[1], units=channels) + + if alpha != 1.0: + inputs[0] *= _expr.const(alpha) + out = _op.nn.dense(inputs[0], inputs[1], units=channels) + # skip (beta * C) if zero C_array = params[inputs[2].name_hint].asnumpy() if (beta == 0.0) or np.array_equal(C_array, np.array([0])): -- 2.7.4