[Relay][VM] Fix code generation for packed functions + tuples (#3287)
authorJared Roesch <roeschinc@gmail.com>
Wed, 5 Jun 2019 16:28:52 +0000 (09:28 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 5 Jun 2019 16:28:52 +0000 (09:28 -0700)
src/relay/backend/vm/compiler.cc
tests/python/relay/test_vm.py

index 602e927..db98a9a 100644 (file)
@@ -334,15 +334,42 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
     return Instruction::AllocTensor(last_register, dltype, NewRegister());
   }
 
-  void EmitInvokePrimitive(const Function& func, std::vector<Index> args_registers,
+  void EmitInvokePrimitive(const Function& func,
+                           const std::vector<Index>& args_registers,
                            const Type& ret_type) {
+    std::vector<Index> unpacked_arg_regs;
     std::vector<Instruction> allocs;
-    size_t return_num = 0;
+
+    // Arity calculation must flatten tuples.
+    size_t arity = 0;
+    CHECK_EQ(func->params.size(), args_registers.size());
+    for (size_t i = 0; i < func->params.size(); i++) {
+      auto ty = func->params[i]->checked_type();
+      if (ty.as<TensorTypeNode>()) {
+        unpacked_arg_regs.push_back(args_registers[i]);
+        arity += 1;
+      } else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
+        for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
+          const auto& field = tuple_ty->fields[f];
+          CHECK(field.as<TensorTypeNode>())
+            << "only supports non-nested tuples currently "
+            << "found " << field;
+          auto dst =  NewRegister();
+          Emit(Instruction::GetField(args_registers[i], f, dst));
+          unpacked_arg_regs.push_back(dst);
+        }
+        arity += tuple_ty->fields.size();
+      } else {
+        LOG(FATAL) << "unsupported parameter type " << ty;
+      }
+    }
+
+    size_t return_val_count = 0;
     if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
       // Allocate space for the return tensor.
       auto alloc = AllocTensorFromType(ttype);
       allocs.push_back(alloc);
-      return_num = 1;
+      return_val_count = 1;
     } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
       std::vector<Index> fields_registers;
 
@@ -352,14 +379,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
         allocs.push_back(AllocTensorFromType(f_type));
         fields_registers.push_back(allocs.back().dst);
       }
-      return_num = ttype->fields.size();
+      return_val_count = ttype->fields.size();
     } else {
       LOG(FATAL) << "Unsupported return value type";
     }
 
+    arity += return_val_count;
     for (auto& alloc : allocs) {
       Emit(alloc);
-      args_registers.push_back(alloc.dst);
+      unpacked_arg_regs.push_back(alloc.dst);
     }
 
     // Next generate the invoke instruction.
@@ -378,17 +406,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
       op_index = seen_funcs[cfunc->funcs[0]];
     }
 
-    // If Tensor, 1
-    // If Tuple, size of tuple
-    size_t arity = func->params.size() + return_num;
-    Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers));
-    if (return_num > 1) {
+    Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
+
+    if (return_val_count > 1) {
       // return value is a tuple, we need to create a tuple
       std::vector<Index> fields_registers;
-      for (size_t i = func->params.size(); i < arity; ++i) {
-        fields_registers.push_back(args_registers[i]);
+      for (size_t i = arity - return_val_count; i < arity; ++i) {
+        fields_registers.push_back(unpacked_arg_regs[i]);
       }
-      Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister()));
+      Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
     }
   }
 
index bc99418..d727e77 100644 (file)
@@ -49,6 +49,17 @@ def test_split():
     res = veval(f, x_data)
     tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
 
+def test_split_no_fuse():
+    x = relay.var('x', shape=(12,))
+    y = relay.split(x, 3, axis=0).astuple()
+    z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
+    z = relay.annotation.stop_fusion(z)
+    f = relay.Function([x], z)
+    x_data = np.random.rand(12,).astype('float32')
+    res = veval(f, x_data)
+    tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
+
+
 def test_id():
     x = relay.var('x', shape=(10, 10))
     f = relay.Function([x], x)
@@ -259,6 +270,8 @@ if __name__ == "__main__":
     test_tuple_second()
     test_let_scalar()
     test_let_tensor()
+    test_split()
+    test_split_no_fuse()
     # TODO(@jroesch): restore when match is supported
     # test_list_constructor()
     test_closure()