[ONNX] Skip multiply with 1.0f constant for GEMM import (#5800)
authorBalint Cristian <cristian.balint@gmail.com>
Sun, 14 Jun 2020 03:39:20 +0000 (06:39 +0300)
committerGitHub <noreply@github.com>
Sun, 14 Jun 2020 03:39:20 +0000 (12:39 +0900)
* [ONNX] Skip ADD inside Gemm op when vector is zero

* [ONNX] Skip multiply with 1.0f constant for GEMM import

python/tvm/relay/frontend/onnx.py

index dabe55f..05a067d 100644 (file)
@@ -477,8 +477,11 @@ class Gemm(OnnxOpConverter):
         if not transB:
             inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
         inputs[0] = _op.nn.batch_flatten(inputs[0])
-        out = _op.nn.dense(_expr.const(alpha) * inputs[0],
-                           inputs[1], units=channels)
+
+        if alpha != 1.0:
+            inputs[0] *= _expr.const(alpha)
+        out = _op.nn.dense(inputs[0], inputs[1], units=channels)
+
         # skip (beta * C) if zero
         C_array = params[inputs[2].name_hint].asnumpy()
         if (beta == 0.0) or np.array_equal(C_array, np.array([0])):