beta = inputs[3]
alpha = inputs[4]
- if not isinstance(alpha, _expr.Expr):
+ if not isinstance(alpha, _expr.Expr) and alpha != 1:
alpha = _create_typed_const(alpha, data_type)
data *= alpha
- if not isinstance(beta, _expr.Expr):
+ if not isinstance(beta, _expr.Expr) and beta != 1:
beta = _create_typed_const(beta, data_type)
weight *= beta
sys.setrecursionlimit(10000)
+def list_ops(expr):
+ class OpLister(tvm.relay.ExprVisitor):
+ def visit_op(self, expr):
+ if expr not in self.node_set:
+ self.node_list.append(expr)
+ return super().visit_op(expr)
+ def list_nodes(self, expr):
+ self.node_set = {}
+ self.node_list = []
+ self.visit(expr)
+ return self.node_list
+ return OpLister().list_nodes(expr)
def assert_shapes_match(tru, est):
if tru.shape != est.shape:
verify_model(Dense1().float().eval(), input_data=input_data)
verify_model(Dense2().float().eval(), input_data=input_data)
+ trace = torch.jit.trace(Dense1(), [input_data])
+ mod, params = relay.frontend.from_pytorch(
+ trace,
+ [('input', input_shape)],
+ )
+ assert not any([op.name == "multiply" for op in list_ops(mod['main'])])
+
def test_forward_dropout():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]