Don't multiply by constant 1 uselessly in dense (#5911)
authorThomas Viehmann <tv.code@beamnet.de>
Wed, 24 Jun 2020 11:49:43 +0000 (13:49 +0200)
committerGitHub <noreply@github.com>
Wed, 24 Jun 2020 11:49:43 +0000 (20:49 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 9237303..84b0907 100644 (file)
@@ -995,11 +995,11 @@ def _dense():
         beta = inputs[3]
         alpha = inputs[4]
 
-        if not isinstance(alpha, _expr.Expr):
+        if not isinstance(alpha, _expr.Expr) and alpha != 1:
             alpha = _create_typed_const(alpha, data_type)
             data *= alpha
 
-        if not isinstance(beta, _expr.Expr):
+        if not isinstance(beta, _expr.Expr) and beta != 1:
             beta = _create_typed_const(beta, data_type)
             weight *= beta
 
index 12d1260..0694fa5 100644 (file)
@@ -33,6 +33,18 @@ from tvm.relay.testing.config import ctx_list
 
 sys.setrecursionlimit(10000)
 
+def list_ops(expr):
+    class OpLister(tvm.relay.ExprVisitor):
+        def visit_op(self, expr):
+            if expr not in self.node_set:
+                self.node_list.append(expr)
+            return super().visit_op(expr)
+        def list_nodes(self, expr):
+            self.node_set = {}
+            self.node_list = []
+            self.visit(expr)
+            return self.node_list
+    return OpLister().list_nodes(expr)
 
 def assert_shapes_match(tru, est):
     if tru.shape != est.shape:
@@ -1047,6 +1059,13 @@ def test_forward_dense():
     verify_model(Dense1().float().eval(), input_data=input_data)
     verify_model(Dense2().float().eval(), input_data=input_data)
 
+    trace = torch.jit.trace(Dense1(), [input_data])
+    mod, params = relay.frontend.from_pytorch(
+        trace,
+        [('input', input_shape)],
+    )
+    assert not any([op.name == "multiply" for op in list_ops(mod['main'])])
+
 def test_forward_dropout():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]