fix list type unification (#17424)
authorElias Ellison <eellison@fb.com>
Mon, 25 Feb 2019 21:26:49 +0000 (13:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Feb 2019 21:34:50 +0000 (13:34 -0800)
Summary:
Previously we were unifying the types of lists across if block outputs. This now fails with Optional subtyping because two types which can be unified have different runtime representations.
```

            torch.jit.script
            def list_optional_fails(x):
                # type: (bool) -> Optional[int]
                if x:
                    y = [1]
                else:
                    y = [None]
                return y[0]

```
the indexing op will expect y to be a generic list, but it will find an intlist.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17424

Differential Revision: D14210903

Pulled By: eellison

fbshipit-source-id: 4b8b26ba2e7e5bebf617e40316475f91e9109cc2

aten/src/ATen/core/type.cpp
test/test_jit.py

index ba8a9cc..e29345c 100644 (file)
@@ -199,13 +199,20 @@ bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
   return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type);
 }
 
-
-c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
-  //cases that t1 == t2, or t1 is a type refinement of t2 and vice versa
+c10::optional<TypePtr> tryEitherIsTheSuperType(const TypePtr& t1, const TypePtr& t2) {
   if (t1->isSubtypeOf(t2)) {
     return t2;
   } else if (t2->isSubtypeOf(t1)) {
     return t1;
+  } else {
+    return c10::nullopt;
+  }
+}
+
+c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
+  //cases that t1 == t2, or t1 is a type refinement of t2 and vice versa
+  if (auto maybe_supertype = tryEitherIsTheSuperType(t1, t2)) {
+    return *maybe_supertype;
   }
 
   // NB: we do not return NumberType because there is not currently enough
@@ -224,12 +231,14 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
 
   //types which contain other types
   if (t1->cast<ListType>() && t2->cast<ListType>()) {
-    auto unified_type = unifyTypes(t1->cast<ListType>()->getElementType(), t2->cast<ListType>()->getElementType());
-    if (unified_type) {
-      return static_cast<TypePtr>(ListType::create(*unified_type));
-    } else {
-      return c10::nullopt;
-    }
+    // because we have runtime specializations of lists, e.g. int[] = std::vector<int64_t>
+    // int?[] = std::vector<IValue>  we don't allow type coercion,
+    // since t1 & t2 may have different runtime representations.
+
+    // allow Lists of different tensor types
+    auto unshaped_t1 = unshapedType(t1);
+    auto unshaped_t2 = unshapedType(t2);
+    return tryEitherIsTheSuperType(unshaped_t1, unshaped_t2);
   } else if(t1->cast<TupleType>() && t2->cast<TupleType>()) {
     auto tuple1 = t1->cast<TupleType>();
     auto tuple2 = t2->cast<TupleType>();
index a0c793a..566613a 100644 (file)
@@ -6643,6 +6643,37 @@ a")
         tensor_unifying.graph.propagate_shapes((a, b, c), False)
         self.assertExpected(canonical(tensor_unifying.graph))
 
+    def test_list_unify(self):
+        # allowing a unififed int?[] would cause a runtime error b/c
+        # the index operation expects int?[] to be a generic list,
+        # but in the true branch the IValue will be a int list
+        with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"):
+            @torch.jit.script
+            def list_optional_fails(x):
+                # type: (bool) -> Optional[int]
+                if x:
+                    y = [1]
+                else:
+                    y = [None]
+                return y[0]
+
+        @torch.jit.script
+        def list_tensors(x):
+            # type: (bool) -> Tuple[Tensor, List[Tensor]]
+            if x:
+                a = torch.zeros([1, 1])
+                y = [a]
+            else:
+                a = torch.zeros([1, 2])
+                y = [a]
+            return a, y
+
+        self.run_pass('constant_propagation', list_tensors.graph)
+        m = torch.jit.ScriptModule()
+        m._create_method_from_graph("forward", list_tensors.graph)
+        # testing that tensor type of lists is unified
+        self.getExportImportCopy(m)
+
     def test_type_annotations_repeated_list(self):
         @torch.jit.script
         def float_fn(x, y):