return Instruction::AllocTensor(last_register, dltype, NewRegister());
}
- void EmitInvokePrimitive(const Function& func, std::vector<Index> args_registers,
+ void EmitInvokePrimitive(const Function& func,
+ const std::vector<Index>& args_registers,
const Type& ret_type) {
+ std::vector<Index> unpacked_arg_regs;
std::vector<Instruction> allocs;
- size_t return_num = 0;
+
+ // Arity calculation must flatten tuples.
+ size_t arity = 0;
+ CHECK_EQ(func->params.size(), args_registers.size());
+ for (size_t i = 0; i < func->params.size(); i++) {
+ auto ty = func->params[i]->checked_type();
+ if (ty.as<TensorTypeNode>()) {
+ unpacked_arg_regs.push_back(args_registers[i]);
+ arity += 1;
+ } else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
+ for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
+ const auto& field = tuple_ty->fields[f];
+ CHECK(field.as<TensorTypeNode>())
+ << "only supports non-nested tuples currently "
+ << "found " << field;
+ auto dst = NewRegister();
+ Emit(Instruction::GetField(args_registers[i], f, dst));
+ unpacked_arg_regs.push_back(dst);
+ }
+ arity += tuple_ty->fields.size();
+ } else {
+ LOG(FATAL) << "unsupported parameter type " << ty;
+ }
+ }
+
+ size_t return_val_count = 0;
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
// Allocate space for the return tensor.
auto alloc = AllocTensorFromType(ttype);
allocs.push_back(alloc);
- return_num = 1;
+ return_val_count = 1;
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
std::vector<Index> fields_registers;
allocs.push_back(AllocTensorFromType(f_type));
fields_registers.push_back(allocs.back().dst);
}
- return_num = ttype->fields.size();
+ return_val_count = ttype->fields.size();
} else {
LOG(FATAL) << "Unsupported return value type";
}
+ arity += return_val_count;
for (auto& alloc : allocs) {
Emit(alloc);
- args_registers.push_back(alloc.dst);
+ unpacked_arg_regs.push_back(alloc.dst);
}
// Next generate the invoke instruction.
op_index = seen_funcs[cfunc->funcs[0]];
}
- // If Tensor, 1
- // If Tuple, size of tuple
- size_t arity = func->params.size() + return_num;
- Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers));
- if (return_num > 1) {
+ Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
+
+ if (return_val_count > 1) {
// return value is a tuple, we need to create a tuple
std::vector<Index> fields_registers;
- for (size_t i = func->params.size(); i < arity; ++i) {
- fields_registers.push_back(args_registers[i]);
+ for (size_t i = arity - return_val_count; i < arity; ++i) {
+ fields_registers.push_back(unpacked_arg_regs[i]);
}
- Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister()));
+ Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
}
}
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
+def test_split_no_fuse():
+ x = relay.var('x', shape=(12,))
+ y = relay.split(x, 3, axis=0).astuple()
+ z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
+ z = relay.annotation.stop_fusion(z)
+ f = relay.Function([x], z)
+ x_data = np.random.rand(12,).astype('float32')
+ res = veval(f, x_data)
+ tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
+
+
def test_id():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
test_tuple_second()
test_let_scalar()
test_let_tensor()
+ test_split()
+ test_split_no_fuse()
# TODO(@jroesch): restore when match is supported
# test_list_constructor()
test_closure()