// Checks if input_ivalue is a subvalue of type.
bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
+ if (ivalue.isTuple()) {
+ const auto& ivalue_elem = ivalue.toTuple()->elements();
+ auto tuple_type = type->cast<TupleType>();
+ if (!tuple_type || tuple_type->elements().size() != ivalue_elem.size()) {
+ return false;
+ }
+ auto type_elem = tuple_type->elements();
+ bool is_subvalue = true;
+ for (size_t i = 0; i < type_elem.size() && is_subvalue; ++i) {
+ is_subvalue = isSubvalueOf(ivalue_elem[i], type_elem[i]);
+ }
+ return is_subvalue;
+ }
if (ivalue.isGenericList()) {
auto list_type = type->cast<ListType>();
if (!list_type) {
auto output = module->run_method("dict_op", dict, std::string("hello"));
ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
}
+
+TEST(TorchScriptTest, TestTupleArgMatching) {
+ auto module = torch::jit::compile(R"JIT(
+ def tuple_op(a: Tuple[List[int]]):
+ return a
+ )JIT");
+
+ std::vector<int64_t> int_list = {1};
+ auto tuple_generic_list = torch::jit::Tuple::create({ int_list });
+
+ // doesn't fail on arg matching
+ module->run_method("tuple_op", tuple_generic_list);
+
+}