[Relay][Backend] Fix interpreter argument conversion for tuples. (#3349)
authorJared Roesch <roeschinc@gmail.com>
Wed, 12 Jun 2019 16:58:15 +0000 (09:58 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 12 Jun 2019 16:58:15 +0000 (09:58 -0700)
* Support taking a tuple as an argument

* Add test

python/tvm/relay/backend/interpreter.py
tests/python/relay/test_backend_interpreter.py

index fc47f4e1b7c8607786d7ffd2ef5dce81fabbd527..593cf7cfbdf7f106fe2b7b2c780d135a63370aeb 100644 (file)
@@ -118,6 +118,8 @@ def _arg_to_ast(arg):
         return Constant(arg.data.copyto(nd.cpu(0)))
     elif isinstance(arg, TupleValue):
         return Tuple([_arg_to_ast(field) for field in arg.fields])
+    elif isinstance(arg, tuple):
+        return Tuple([_arg_to_ast(field) for field in arg])
     elif isinstance(arg, RefValue):
         return RefCreate(_arg_to_ast(arg.value))
     elif isinstance(arg, ConstructorValue):
index e8a99e14d741d617b95f8e9388ec452899b43519..1e5e2310e927f3b546076cf77e0b63e640ba97c4 100644 (file)
@@ -217,6 +217,31 @@ def test_function_taking_adt_ref_tuple():
         tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
                                     tuple_value.fields[i].asnumpy())
 
+def test_tuple_passing():
+    x = relay.var('x', type_annotation=relay.ty.TupleType([
+        relay.ty.TensorType((), 'int64'),
+        relay.ty.TensorType((), 'int64')]))
+
+    fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
+    mod = relay.Module({})
+    gv = relay.GlobalVar('fn')
+    mod[gv] = fn
+    mod.entry_func = gv
+    mod[gv] = relay.ir_pass.infer_type(mod[gv], mod=mod)
+
+    ctx = tvm.cpu()
+    target = tvm.target.create('llvm')
+    exec = relay.create_executor(mod=mod, ctx=ctx, target=target)
+    f = exec.evaluate(gv)
+    # First use a Python tuple.
+    out = f((10, 8))
+    tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
+    # Second use a tuple value.
+    value_tuple = TupleValue(
+        TensorValue(np.array(11)),
+        TensorValue(np.array(12)))
+    out = f(value_tuple)
+    tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
 
 if __name__ == "__main__":
     test_id()
@@ -231,4 +256,5 @@ if __name__ == "__main__":
     test_tensor_value()
     test_tuple_value()
     test_tuple_getitem()
-    test_function_taking_adt_ref_tuple()
\ No newline at end of file
+    test_function_taking_adt_ref_tuple()
+    test_tuple_passing()