fix tuple matching (#17687)
authorElias Ellison <eellison@fb.com>
Wed, 6 Mar 2019 19:21:09 +0000 (11:21 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Mar 2019 19:25:36 +0000 (11:25 -0800)
Summary:
Check for Tuple Matching in isSubvalueOf, since they may contain container types that need to be recursed within isSubvalueOf

Fix for https://github.com/pytorch/pytorch/issues/17650
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17687

Differential Revision: D14324642

Pulled By: eellison

fbshipit-source-id: 7f1e019875286b2640a3b9c003d1635dda8cf543

aten/src/ATen/core/type.cpp
test/cpp/api/jit.cpp

index e8892db..634bbd9 100644 (file)
@@ -176,6 +176,19 @@ TypePtr attemptToRecoverType(const IValue& ivalue) {
 
 // Checks if input_ivalue is a subvalue of type.
 bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
+  if (ivalue.isTuple()) {
+    const auto& ivalue_elem = ivalue.toTuple()->elements();
+    auto tuple_type = type->cast<TupleType>();
+    if (!tuple_type || tuple_type->elements().size() != ivalue_elem.size()) {
+      return false;
+    }
+    auto type_elem = tuple_type->elements();
+    bool is_subvalue = true;
+    for (size_t i = 0; i < type_elem.size() && is_subvalue; ++i) {
+      is_subvalue = isSubvalueOf(ivalue_elem[i], type_elem[i]);
+    }
+    return is_subvalue;
+  }
   if (ivalue.isGenericList()) {
     auto list_type = type->cast<ListType>();
     if (!list_type) {
index a433578..ed44f7b 100644 (file)
@@ -98,3 +98,17 @@ TEST(TorchScriptTest, TestDictArgMatching) {
   auto output = module->run_method("dict_op", dict, std::string("hello"));
   ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
 }
+
+TEST(TorchScriptTest, TestTupleArgMatching) {
+  auto module = torch::jit::compile(R"JIT(
+      def tuple_op(a: Tuple[List[int]]):
+        return a
+    )JIT");
+
+  std::vector<int64_t> int_list = {1};
+  auto tuple_generic_list = torch::jit::Tuple::create({ int_list });
+
+  // doesn't fail on arg matching
+  module->run_method("tuple_op", tuple_generic_list);
+
+}