auto tuple = TupleType(func_type->arg_types);
auto in_types = FlattenTupleType(tuple);
auto out_types = FlattenTupleType(func_type->ret_type);
+ Array<Integer> is_input;
+ for (size_t i = 0; i < func_type->arg_types.size(); ++i) {
+ auto const& aty = func_type->arg_types[i];
+ size_t num_types = 1;
+ if (aty.as<TupleTypeNode>()) {
+ num_types = FlattenTupleType(aty).size();
+ }
+ for (size_t j = 0; j < num_types; ++j) {
+ is_input.push_back(shape_func_attrs->is_input[i]);
+ }
+ }
Array<Type> shape_func_ins, shape_func_outs;
for (size_t i = 0; i < in_types.size(); i++) {
auto in_type = in_types[i];
- if (shape_func_attrs->is_input[i]) {
+ if (is_input[i]) {
shape_func_ins.push_back(in_type);
} else {
auto shape = RankShape(in_type->shape);
except Exception as e:
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
+def test_tuple_get_item():
+ mod = tvm.IRModule()
+ dtype = "float32"
+ static_data_shape = (9, 4)
+ data_shape = (relay.Any(), 4)
+ indices_or_sections = 2
+ axis = 1
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.split(data, indices_or_sections, axis)
+ y = relay.expr.TupleGetItem(y.astuple(), 0)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ ref_out_shape = (9, 2)
+ for kind in ["vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape))
+
+def test_mixed_input_type():
+ mod = tvm.IRModule()
+ dtype = "float32"
+ static_data_shape = (9, 4)
+ data_shape = (relay.Any(), 4)
+ tensor_type = relay.TensorType(data_shape, dtype)
+ tuple_type = relay.TupleType([tensor_type, tensor_type])
+ data0 = relay.var("d0", type_annotation=relay.TupleType([tuple_type, tensor_type]))
+ data1 = relay.var("d1", shape=(relay.Any(), 4), dtype=dtype)
+ data_tuple = relay.expr.TupleWrapper(data0, 2)
+ nested_data_tuple = relay.expr.TupleWrapper(data_tuple[0], 2)
+ y = nested_data_tuple[1] * data_tuple[1] + data1
+ mod["main"] = relay.Function([data0, data1], y)
+ data_np0 = np.random.uniform(size=static_data_shape).astype(dtype)
+ data_np1 = np.random.uniform(size=static_data_shape).astype(dtype)
+ ref_out_shape = (9, 4)
+ for kind in ["vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()([[data_np0, data_np0], data_np0], data_np1)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape))
+
if __name__ == "__main__":
test_any_full()
test_any_broadcast()
test_arange_with_dynamic_shape()
test_recursive_concat()
test_recursive_concat_with_wrong_annotation()
+ test_tuple_get_item()
+ test_mixed_input_type()
+