return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(field) for field in arg.fields])
+ elif isinstance(arg, tuple):
+ return Tuple([_arg_to_ast(field) for field in arg])
elif isinstance(arg, RefValue):
return RefCreate(_arg_to_ast(arg.value))
elif isinstance(arg, ConstructorValue):
tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
tuple_value.fields[i].asnumpy())
+def test_tuple_passing():
+ x = relay.var('x', type_annotation=relay.ty.TupleType([
+ relay.ty.TensorType((), 'int64'),
+ relay.ty.TensorType((), 'int64')]))
+
+ fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
+ mod = relay.Module({})
+ gv = relay.GlobalVar('fn')
+ mod[gv] = fn
+ mod.entry_func = gv
+ mod[gv] = relay.ir_pass.infer_type(mod[gv], mod=mod)
+
+ ctx = tvm.cpu()
+ target = tvm.target.create('llvm')
+ exec = relay.create_executor(mod=mod, ctx=ctx, target=target)
+ f = exec.evaluate(gv)
+ # First use a Python tuple.
+ out = f((10, 8))
+ tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
+ # Second use a tuple value.
+ value_tuple = TupleValue(
+ TensorValue(np.array(11)),
+ TensorValue(np.array(12)))
+ out = f(value_tuple)
+ tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
if __name__ == "__main__":
test_id()
test_tensor_value()
test_tuple_value()
test_tuple_getitem()
- test_function_taking_adt_ref_tuple()
\ No newline at end of file
+ test_function_taking_adt_ref_tuple()
+ test_tuple_passing()