Preserve types during empty container assignment (#58911)
authorAnsley Ussery <ansley@fb.com>
Fri, 10 Sep 2021 23:18:33 +0000 (16:18 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 23:49:21 +0000 (16:49 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58911

Stack from [ghstack](https://github.com/ezyang/ghstack):
* __->__ #58911

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D30785623

Pulled By: ansley

fbshipit-source-id: 4e05d6369318974290fea02ad2bc148293c25090

test/jit/test_list_dict.py
test/jit/test_typing.py
test/jit/test_union.py
test/test_jit.py
test/test_ops.py
torch/csrc/jit/frontend/ir_emitter.cpp
torch/csrc/jit/ir/alias_analysis.cpp
torch/testing/_internal/jit_utils.py

index 10f5e87..ab514ad 100644 (file)
@@ -244,12 +244,9 @@ class TestList(JitTestCase):
         self.checkScript(fn, ())
 
     def test_dict_keyword_with_mismatched_annotations(self):
-        err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\
-                  "match the types of the actual dict items"
-        err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\
-                  "match the type of an actual key type `str`"
-        highlight_msg = "dict([(\"foo\", 1), (\"bar\", 2), (\"baz\", 3"
-        with self.assertRaisesRegexWithHighlight(RuntimeError, err_msg, highlight_msg):
+        err_msg = r"is annotated with type Dict\[int, str\] but is " \
+                  r"being assigned to a value of type Dict\[str, int\]"
+        with self.assertRaisesRegex(RuntimeError, err_msg):
             @torch.jit.script
             def fn():
                 x: Dict[int, str] = dict([("foo", 1), ("bar", 2), ("baz", 3)])    # noqa: C406
@@ -1328,7 +1325,7 @@ class TestList(JitTestCase):
             x = torch._C.ListType(None)
 
     def test_list_unification_hint(self):
-        with self.assertRaisesRegex(RuntimeError, "Expected a List type hint"):
+        with self.assertRaisesRegex(RuntimeError, "Expected an annotation of type List"):
             @torch.jit.script
             def x():
                 b : int = [2, 3]
index 125197c..6b3def0 100644 (file)
@@ -168,11 +168,12 @@ class TestTyping(JitTestCase):
             l1 = [1, 2, "foo", 3]
             l2 = ["foo", "bar", "baz", "qux"]
             d: Dict[int, str] = {k : v for k, v in zip(l1, l2)}
-            return l
+            return d
 
-        with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
-                                    r" `Dict\[int, str\]` did not match"
-                                    " the type of an actual key type"):
+        with self.assertRaisesRegex(RuntimeError, "Dicts may only "
+                                    "contain homogeneous keys, but the "
+                                    "type of the first generated key "
+                                    r"was Union\[int, str\]"):
             torch.jit.script(fn)
 
     def test_dict_type_refinement_annotation_value_mismatch(self):
@@ -180,12 +181,12 @@ class TestTyping(JitTestCase):
             l1 = ["foo", "bar", "baz", "qux"]
             l2 = [1, 2, "foo", 3]
             d: Dict[str, int] = {k : v for k, v in zip(l1, l2)}
-            return l
+            return d
 
-        with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
-                                    r" `Dict\[str, int\]` did not match"
-                                    " the type of an actual value "
-                                    "type"):
+        with self.assertRaisesRegex(RuntimeError, "annotated with type "
+                                    r"Dict\[str, int\] but is being "
+                                    "assigned to a value of type "
+                                    r"Dict\[str, Union\[int, str\]\]"):
             torch.jit.script(fn)
 
     def test_dict_invalid_annotations(self):
index df909a6..fb53d53 100644 (file)
@@ -5,6 +5,7 @@ import sys
 import torch
 from torch.testing import FileCheck
 from enum import Enum
+from textwrap import dedent
 from typing import Dict, List, Optional, Tuple, Union
 
 # Make the helper files in test/ importable
@@ -655,3 +656,272 @@ class TestUnion(JitTestCase):
 
         self.checkScript(fn, (1,))
         self.checkScript(fn, (8,))
+
+    def _assert_passes(self, template: str, ann: str, lhs: str):
+        code = template.format(ann=ann, lhs=lhs)
+        self.checkScript(code, (), name="fn")
+
+    def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
+        code = template.format(ann=ann, lhs=lhs)
+        with self.assertRaisesRegex(RuntimeError, msg):
+            cu = torch.jit.CompilationUnit(code, _frames_up=1)
+            string_frontend = getattr(cu, "fn")    # noqa: B009
+
+    def test_union_with_list_assignment(self):
+        template = dedent('''
+            def fn():
+                x: {ann} = {lhs}
+                if torch.jit.isinstance(x, List[torch.Tensor]):
+                    x.append(torch.tensor(3))
+                return x
+        ''')
+
+        lhs = {"list_literal_empty" : "[]",
+
+               "list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]",
+
+               "list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]",
+
+               "list_literal_of_mixed" : "[torch.arange(5), 1]",
+
+               "list_comprehension_of_tensor" :
+               "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
+
+               "list_comprehension_of_str" :
+               "[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]",
+
+               "list_comprehension_of_mixed" :
+               "[torch.add(1, x) for x in [torch.arange(5), 1]]"}
+
+        """
+        Union[List[str], List[torch.Tensor]]
+        """
+        self._assert_raises(template,
+                            "Union[List[str], List[torch.Tensor]]",
+                            lhs["list_literal_empty"],
+                            "there are multiple possible List type "
+                            "candidates in the Union annotation")
+
+        self._assert_passes(template,
+                            "Union[List[str], List[torch.Tensor]]",
+                            lhs["list_literal_of_tensor"])
+
+        self._assert_passes(template,
+                            "Union[List[str], List[torch.Tensor]]",
+                            lhs["list_literal_of_str"])
+
+        self._assert_raises(template,
+                            "Union[List[str], List[torch.Tensor]]",
+                            lhs["list_literal_of_mixed"],
+                            "none of those list types can hold the "
+                            "types of the given list elements")
+
+        self._assert_passes(template,
+                            "Union[List[str], List[torch.Tensor]]",
+                            lhs["list_comprehension_of_tensor"])
+
+        self._assert_passes(template,
+                            "Union[List[str], List[torch.Tensor]]",
+                            lhs["list_comprehension_of_str"])
+
+        # TODO: Support mixed list comprehensions
+        self._assert_raises(template,
+                            "Union[List[str], List[torch.Tensor]]",
+                            lhs["list_comprehension_of_mixed"],
+                            "Arguments for call are not valid")
+
+        """
+        Union[int, torch.Tensor]
+        """
+        self._assert_raises(template,
+                            "Union[int, torch.Tensor]",
+                            lhs["list_literal_empty"],
+                            "Expected an Union type annotation with an "
+                            "inner List type")
+
+        self._assert_raises(template, "Union[int, torch.Tensor]",
+                            lhs["list_literal_of_tensor"],
+                            "Expected an Union type annotation with an "
+                            "inner List type")
+
+        self._assert_raises(template, "Union[int, torch.Tensor]",
+                            lhs["list_comprehension_of_tensor"],
+                            "Expected an Union type annotation with an "
+                            "inner List type")
+
+        """
+        Union[List[torch.Tensor], int]
+        """
+        self._assert_passes(template,
+                            "Union[List[torch.Tensor], int]",
+                            lhs["list_literal_empty"])
+
+        self._assert_passes(template,
+                            "Union[List[torch.Tensor], int]",
+                            lhs["list_literal_of_tensor"])
+
+        self._assert_raises(template, "Union[List[torch.Tensor], int]",
+                            lhs["list_literal_of_str"],
+                            r"List type annotation `List\[Tensor\]` did "
+                            "not match the types of the given list "
+                            "elements")
+
+        self._assert_raises(template, "Union[List[torch.Tensor], int]",
+                            lhs["list_literal_of_mixed"],
+                            r"List type annotation `List\[Tensor\]` did "
+                            "not match the types of the given list "
+                            "elements")
+
+        self._assert_passes(template,
+                            "Union[List[torch.Tensor], int]",
+                            lhs["list_comprehension_of_tensor"])
+
+        self._assert_raises(template,
+                            "Union[List[torch.Tensor], int]",
+                            lhs["list_comprehension_of_str"],
+                            r"List type annotation `List\[Tensor\]` did "
+                            "not match the types of the given list "
+                            "elements")
+
+        # TODO: Support mixed list comprehensions
+        self._assert_raises(template,
+                            "Union[List[torch.Tensor], int]",
+                            lhs["list_comprehension_of_mixed"],
+                            "Arguments for call are not valid")
+
+    def test_union_with_dict_assignment(self):
+        template = dedent('''
+            def fn():
+                x: {ann} = {lhs}
+                if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
+                    x["foo"] = torch.tensor(3)
+                return x
+        ''')
+
+        lhs = {"dict_literal_empty" : "{}",
+
+               "dict_literal_of_str_tensor" :
+               "{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}",
+
+               "dict_literal_of_str_int" :
+               "{\"foo\" : 1, \"bar\" : 2}",
+
+               "dict_literal_of_mixed" :
+               "{\"foo\" : torch.arange(3), \"bar\" : 2}",
+
+               "dict_comprehension_of_str_tensor" :
+               "{x : torch.add(y, 1) for x, y in \
+                    zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}",
+
+               "dict_comprehension_of_str_int" :
+               "{x : torch.add(y, 1) for x, y in \
+                    zip([\"foo\", \"bar\"], [1, 2]}",
+
+               "dict_comprehension_of_mixed" :
+               "{x : torch.add(y, 1) for x, y in \
+                    zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",
+
+               "dict_keyword" :
+               "dict(foo=torch.arange(3), baz=torch.arange(5))"}
+
+        """
+        Union[Dict[str, torch.Tensor], Dict[str, int]]
+        """
+        self._assert_raises(template,
+                            "Union[List[str], List[torch.Tensor]]",
+                            lhs["dict_literal_empty"],
+                            "Expected an Union type annotation with an "
+                            "inner Dict type")
+
+        self._assert_passes(template,
+                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+                            lhs["dict_literal_of_str_tensor"])
+
+        self._assert_passes(template,
+                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+                            lhs["dict_literal_of_str_int"])
+
+        self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+                            lhs["dict_literal_of_mixed"],
+                            "none of those types can hold the types "
+                            "of the given dict elements")
+
+        # TODO: String frontend does not support tuple unpacking
+        # https://github.com/pytorch/pytorch/issues/64096
+        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+        #              lhs["dict_comprehension_of_str_tensor"])
+
+        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+        #              lhs["dict_comprehension_of_str_int"])
+
+        # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+        #              lhs["dict_comprehension_of_mixed"],
+        #              "foobar")
+
+        self._assert_passes(template,
+                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+                            lhs["dict_keyword"])
+
+        """
+        Union[int, torch.Tensor]
+        """
+        self._assert_raises(template,
+                            "Union[int, torch.Tensor]",
+                            lhs["dict_literal_empty"],
+                            "Expected an Union type annotation with "
+                            "an inner Dict type")
+
+        self._assert_raises(template,
+                            "Union[int, torch.Tensor]",
+                            lhs["dict_literal_of_str_tensor"],
+                            "Expected an Union type annotation with "
+                            "an inner Dict type")
+
+        # See above--string frontend does not support tuple unpacking
+        # self._assert_raises(template, "Union[int, torch.Tensor]",
+        #              lhs["dict_comprehension_of_tensor"],
+        #              "foobar")
+
+        """
+        Union[Dict[str, torch.Tensor], int]
+        """
+        self._assert_passes(template,
+                            "Union[Dict[str, torch.Tensor], int]",
+                            lhs["dict_literal_empty"])
+
+        self._assert_passes(template,
+                            "Union[Dict[str, torch.Tensor], int]",
+                            lhs["dict_literal_of_str_tensor"])
+
+        self._assert_raises(template,
+                            "Union[Dict[str, torch.Tensor], int]",
+                            lhs["dict_literal_of_str_int"],
+                            r"Type hint for dict was Dict\[str, Tensor\]"
+                            ", but the value at index 0 has type int, "
+                            "which is not a valid subtype of Tensor")
+
+        self._assert_raises(template,
+                            "Union[Dict[str, torch.Tensor], int]",
+                            lhs["dict_literal_of_mixed"],
+                            r"Type hint for dict was Dict\[str, Tensor\]"
+                            ", but the value at index 1 has type int, "
+                            "which is not a valid subtype of Tensor")
+
+        # See above--string frontend does not support tuple unpacking
+        # self._assert_passes(template,
+        #                    "Union[Dict[str, torch.Tensor], int]",
+        #                    lhs["dict_comprehension_of_str_tensor"])
+
+        # self._assert_raises(template,
+        #                    "Union[Dict[str, torch.Tensor], int]",
+        #                    lhs["dict_comprehension_of_str_int"],
+        #                    "foobar")
+
+        # self._assert_raises(template,
+        #                    "Union[Dict[str, torch.Tensor], int]",
+        #                    lhs["dict_comprehension_of_mixed"],
+        #                    "foobar")
+
+        self._assert_passes(template,
+                            "Union[Dict[str, torch.Tensor], int]",
+                            lhs["dict_keyword"])
index eb61fdb..a6589ad 100644 (file)
@@ -11527,7 +11527,8 @@ dedent """
             out = torch.jit.annotate(int, [x for x in [1, 2, 3]])  # noqa: C416
             return out
 
-        with self.assertRaisesRegex(Exception, "Expected list type annotation"):
+        with self.assertRaisesRegex(Exception, "Expected an annotation"
+                                    " of type List"):
             torch.jit.script(bad_type_annotation)
 
     def test_list_comprehension_variable_write(self):
index 90e52bb..4ee0cbc 100644 (file)
@@ -1,6 +1,5 @@
 from collections.abc import Sequence
 from functools import partial, wraps
-import unittest
 import warnings
 
 import torch
@@ -685,7 +684,6 @@ class TestJit(JitCommonTestCase):
     #   and runtimes (eager, traced, scripted).
     # TODO WARNING: inplace x {traced, scripted} not currently tested
     @_variant_ops(op_db)
-    @unittest.skipIf(True, "Temporarily skipping while landing Union PR stack")
     def test_variant_consistency_jit(self, device, dtype, op):
         _requires_grad = op.supports_autograd and (dtype.is_floating_point or
                                                    op.supports_complex_autograd(torch.device(device).type))
index dd29f1e..4c066c2 100644 (file)
@@ -660,6 +660,7 @@ struct to_ir {
     // to SSA while part of their original graph, and closures are ready to
     // be inlined into forked closures
     ConvertToSSA(graph);
+
     // convert loops with an iter and body condition specified to
     // python-recognize while loops. we do this so they can be exported,
     // and run the pass early to avoid jitter. Like conversion to SSA,
@@ -1318,25 +1319,59 @@ struct to_ir {
     const auto loc = lc.range();
     const auto targets_list = List<Expr>::create(lc.range(), {lc.target()});
     const auto itrs = List<Expr>::create(lc.range(), {lc.iter()});
+
     // If there is no type hint, and this is emitted over an iterable that is
     // unrolled and of length 0, then we emit a List of tensors
     Value* list_value = graph->insertNode(graph->create(prim::ListConstruct, 1))
                             ->output()
                             ->setType(ListType::ofTensors());
-    bool type_set = false;
-    if (type_hint) {
-      if (!type_hint->cast<ListType>()) {
-        throw ErrorReport(loc)
-            << "Expected list type annotation for list comprehension"
-               ", found "
-            << type_hint->repr_str();
+
+    // See notes on logic in `emitListLiteral`
+
+    TypePtr refined_type_hint = type_hint;
+    TypePtr annotated_union_type =
+        type_hint && type_hint->kind() == UnionType::Kind ? type_hint : nullptr;
+
+    std::vector<TypePtr> all_candidates = {};
+
+    if (refined_type_hint) {
+      // If necessary/possible, refine `refined_type_hint` to a ListType
+      if (auto union_type_hint = refined_type_hint->cast<UnionType>()) {
+        std::vector<TypePtr> list_types;
+        std::copy_if(
+            union_type_hint->containedTypes().begin(),
+            union_type_hint->containedTypes().end(),
+            std::back_inserter(list_types),
+            [&](TypePtr type_ptr) {
+              return type_ptr->kind() == ListType::Kind;
+            });
+        if (list_types.empty()) {
+          throw ErrorReport(lc) << "Expected an Union type annotation "
+                                << "with an inner List type, but got "
+                                << refined_type_hint->repr_str();
+        } else if (list_types.size() == 1) {
+          refined_type_hint = list_types[0];
+        } else {
+          all_candidates = std::move(list_types);
+        }
+      } else if (
+          auto optional_type_hint = refined_type_hint->cast<OptionalType>()) {
+        refined_type_hint = optional_type_hint->getElementType();
+      }
+
+      if (all_candidates.empty()) {
+        if (refined_type_hint->kind() == ListType::Kind) {
+          list_value->setType(refined_type_hint);
+        } else {
+          throw ErrorReport(lc) << "Expected an annotation of type "
+                                << "List, but got " << type_hint->repr_str();
+        }
       }
-      list_value->setType(type_hint);
-      type_set = true;
     }
 
-    // comprehension introduces its own scope. no variable assigned
-    // leaks into the rest of the graph
+    bool seen_first_elem = false;
+
+    // A list comprehension introduces its own scope
     Node* n =
         graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0));
     auto* comprehension_block = n->addBlock();
@@ -1349,41 +1384,94 @@ struct to_ir {
       // be set to `Tensor`. We don't want to unify this default type
       // with the actual elements in the list, so let the type begin as
       // the first element in the list
-      if (!type_set) {
+      if (!seen_first_elem) {
         list_value->setType(ListType::create(out->type()));
-        type_set = true;
+        seen_first_elem = true;
       }
 
-      ListTypePtr lt = list_value->type()->expect<ListType>();
+      const auto elem_type_hint =
+          refined_type_hint && refined_type_hint->kind() == ListType::Kind
+          ? refined_type_hint->cast<ListType>()->getElementType()
+          : nullptr;
 
-      const TypePtr element_type_hint =
-          type_hint ? type_hint->expect<ListType>()->getElementType() : nullptr;
-
-      auto unified = unifyTypes(
-          lt->getElementType(),
+      c10::optional<TypePtr> unified_elem_type = unifyTypes(
+          list_value->type()->expect<ListType>()->getElementType(),
           out->type(),
           /*default_to_union=*/true,
-          element_type_hint);
+          elem_type_hint);
 
-      if (lt->getElementType() != AnyType::get() &&
-          *unified == AnyType::get()) {
+      if (!type_hint && (*unified_elem_type)->kind() == UnionType::Kind) {
         TORCH_WARN(
             "List consists of heterogeneous types, which means",
-            " that it has been typed as `List[Any]`. To use "
-            "any of the values in the List, it will be "
-            "necessary to add an `assert isinstance` statement "
-            "before first use to trigger type refinement. The first ",
-            "non-matching element was typed as ",
+            " that it has been typed as containing ",
+            (*unified_elem_type)->repr_str(),
+            ". To use any of the "
+            "values in this List, it will be necessary to add an "
+            "`assert isinstance` statement before first use to trigger "
+            "type refinement. The first non-matching element was typed",
+            " as ",
             out->type()->repr_str(),
-            ", while the elements before it "
-            "were ",
-            lt->getElementType()->repr_str(),
+            ", while the elements "
+            " before it were ",
+            list_value->type()
+                ->expect<ListType>()
+                ->getElementType()
+                ->repr_str(),
             "\n",
             lc.range().str());
       }
 
-      if (!type_hint) {
-        list_value->setType(ListType::create(*unified));
+      if (all_candidates.empty() && refined_type_hint &&
+          !(*unified_elem_type)
+               ->isSubtypeOf(
+                   refined_type_hint->expect<ListType>()->getElementType())) {
+        throw ErrorReport(lc)
+            << "List type annotation `" << refined_type_hint->repr_str()
+            << "` did not match the types of the given list elements,"
+            << " which were unified to " << (*unified_elem_type)->repr_str();
+      }
+
+      if (!all_candidates.empty()) {
+        TypePtr greatest_elem_type = nullptr;
+        std::for_each(
+            all_candidates.begin(),
+            all_candidates.end(),
+            [&](TypePtr candidate) {
+              auto candidate_elem_type =
+                  candidate->expect<ListType>()->getElementType();
+              if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) {
+                if (!greatest_elem_type) {
+                  greatest_elem_type = candidate_elem_type;
+                } else {
+                  greatest_elem_type =
+                      *(unifyTypes(greatest_elem_type, candidate_elem_type));
+                }
+              }
+            });
+        if (!greatest_elem_type) {
+          std::stringstream vector_repr;
+          for (size_t i = 0; i < all_candidates.size(); ++i) {
+            if (i > 0 && all_candidates.size() > 2) {
+              vector_repr << ", ";
+            }
+            if (i != 0 && i == all_candidates.size() - 1) {
+              vector_repr << " or ";
+            }
+            vector_repr << all_candidates[i]->repr_str();
+          }
+          throw ErrorReport(lc)
+              << "Union type annotation `" << type_hint->repr_str()
+              << "` can hold " << vector_repr.str() << ", but none of "
+              << "those types match the types of the given list "
+              << "elements, which were unified to "
+              << (*unified_elem_type)->repr_str();
+        } else {
+          refined_type_hint = greatest_elem_type;
+        }
+      }
+
+      if (!refined_type_hint) {
+        list_value->setType(ListType::create(*unified_elem_type));
       }
 
       NamedValue self = NamedValue(loc, "self", list_value);
@@ -1402,20 +1490,55 @@ struct to_ir {
 
     Value* dict_value =
         graph->insertNode(graph->create(prim::DictConstruct, 1))->output();
-    // Set the default type to be Dict[Str, Tensor]
+
+    // Set the default type to be Dict[str, Tensor]
     dict_value->setType(DictType::create(StringType::get(), TensorType::get()));
-    bool type_set = false;
-    if (type_hint) {
-      if (!type_hint->cast<DictType>()) {
-        throw ErrorReport(loc)
-            << "Expected Dict type annotation for dict comprehension"
-               ", found "
-            << type_hint->repr_str();
+
+    TypePtr refined_type_hint = nullptr;
+    TypePtr annotated_union_type =
+        type_hint && type_hint->kind() == UnionType::Kind ? type_hint : nullptr;
+
+    std::vector<TypePtr> all_candidates = {};
+
+    // See notes on logic in `emitListLiteral`
+    if (refined_type_hint) {
+      // If necessary/possible, make `type_hint` a DictType
+      if (auto union_type_hint = refined_type_hint->cast<UnionType>()) {
+        std::vector<TypePtr> dict_types;
+        std::copy_if(
+            union_type_hint->containedTypes().begin(),
+            union_type_hint->containedTypes().end(),
+            std::back_inserter(dict_types),
+            [&](TypePtr type_ptr) {
+              return type_ptr->kind() == DictType::Kind;
+            });
+        if (dict_types.empty()) {
+          throw ErrorReport(dc) << "Expected an Union type annotation "
+                                << "with an inner Dict type, but got "
+                                << refined_type_hint->repr_str();
+        } else if (dict_types.size() == 1) {
+          refined_type_hint = dict_types[0];
+        } else {
+          all_candidates = std::move(dict_types);
+        }
+      } else if (
+          auto optional_type_hint = refined_type_hint->cast<OptionalType>()) {
+        refined_type_hint = optional_type_hint->getElementType();
+      }
+
+      if (all_candidates.empty()) {
+        if (refined_type_hint->kind() == DictType::Kind) {
+          dict_value->setType(refined_type_hint);
+        } else {
+          throw ErrorReport(dc) << "Expected an annotation of type "
+                                << "Dict, but got " << type_hint->repr_str();
+        }
       }
-      dict_value->setType(type_hint);
-      type_set = true;
     }
 
+    TypePtr first_generated_key_type = nullptr;
+    TypePtr first_generated_value_type = nullptr;
+
     // A dict comprehension introduces its own scope. No variable assigned
     // may leak into the rest of the graph
     Node* n =
@@ -1427,10 +1550,33 @@ struct to_ir {
       auto k = emitExpr(dc.key());
       auto v = emitExpr(dc.value());
 
-      // Make sure that any key and value types are subtypes of the
-      // annotatated key/value types
-      if (type_hint) {
-        DictTypePtr dict_type_hint = type_hint->expect<DictType>();
+      // If we didn't have a type annotation, the type of the dict would
+      // be set to `(str, Tensor)`. We don't want to unify this default
+      // type with the actual elements in the dict, so let the type
+      // begin as the first element in the dict
+      if (k->type()->kind() == UnionType::Kind) {
+        throw ErrorReport(dc)
+            << "Dicts may only contain homogeneous keys, but the type of "
+            << "the first generated key was " << k->type()->repr_str();
+      } else if (
+          first_generated_key_type && first_generated_key_type != k->type()) {
+        // Values can be heterogenous, so we only need to check that the
+        // key types are all the same
+        throw ErrorReport(dc)
+            << "Dicts may only contain homogeneous keys. Expected "
+            << "dict comprehension to generate type "
+            << first_generated_key_type->repr_str() << ", but got "
+            << k->type()->repr_str();
+      } else {
+        dict_value->setType(DictType::create(k->type(), v->type()));
+        first_generated_key_type = k->type();
+        first_generated_value_type = v->type();
+      }
+
+      // If we had any annotation OTHER THAN a Union that can hold more
+      // than one type of Dict
+      if (refined_type_hint && all_candidates.empty()) {
+        DictTypePtr dict_type_hint = refined_type_hint->expect<DictType>();
 
         std::stringstream ss;
         std::stringstream err;
@@ -1463,49 +1609,85 @@ struct to_ir {
         }
       }
 
-      // If we didn't have a type annotation, the type of the dict would
-      // be set to `(str, Tensor)`. We don't want to unify this default
-      // type with the actual elements in the dict, so let the type
-      // begin as the first element in the dict
-      if (!type_set) {
-        dict_value->setType(DictType::create(k->type(), v->type()));
-        type_set = true;
-      }
-
-      DictTypePtr dt = dict_value->type()->expect<DictType>();
-
       const TypePtr value_type_hint =
-          type_hint ? type_hint->expect<DictType>()->getKeyType() : nullptr;
+          refined_type_hint && refined_type_hint->kind() == DictType::Kind
+          ? refined_type_hint->expect<DictType>()->getValueType()
+          : nullptr;
 
-      c10::optional<TypePtr> unified = unifyTypes(
-          dt->getValueType(),
+      c10::optional<TypePtr> unified_value_type = unifyTypes(
+          first_generated_value_type,
           v->type(),
           /*default_to_union=*/true,
           value_type_hint);
 
-      // Warn the user if we inferred the type of the values to be `Any`
-      // even though the annotation was something else
-      if (dt->getValueType() != AnyType::get() && *unified == AnyType::get()) {
+      if (!type_hint && (*unified_value_type)->kind() == UnionType::Kind) {
         TORCH_WARN(
-            "Dict consists of heterogeneous types, which means",
-            " that it has been typed as `Dict[str, Any]`. To use "
-            "any of the values in the Dict, it will be "
-            "necessary to add an `assert isinstance` statement "
-            "before first use to trigger type refinement. The first ",
-            "non-matching element was typed as ",
+            "Dict values consist of heterogeneous types, which means",
+            " that they have been typed as being ",
+            (*unified_value_type)->repr_str(),
+            ". To use any of the "
+            "values in this dict, it will be necessary to add an "
+            "`assert isinstance` statement before first use to trigger "
+            "type refinement. The first non-matching element was typed",
+            " as ",
             v->type()->repr_str(),
-            ", while the elements before it "
-            "were ",
-            dt->getValueType()->repr_str(),
+            ", while the elements "
+            " before it were ",
+            first_generated_value_type->repr_str(),
             "\n",
             dc.range().str());
       }
 
-      // We only want to set `dict_value` if we don't have a type hint
-      // to allow for the case that `*unified` is a subtype of
-      // the value type given by `type_hint`
-      if (!type_hint) {
-        dict_value->setType(DictType::create(k->type(), *unified));
+      if (type_hint && !all_candidates.empty()) {
+        auto known_key_type = k->type();
+        auto known_value_type = *unified_value_type;
+
+        TypePtr candidate_key_type = nullptr;
+        TypePtr candidate_value_type = nullptr;
+        TypePtr candidate = nullptr;
+
+        for (const auto& current_candidate : all_candidates) {
+          auto current_key_type =
+              current_candidate->expect<DictType>()->getKeyType();
+          auto current_value_type =
+              current_candidate->expect<DictType>()->getValueType();
+          if (known_key_type->isSubtypeOf(current_key_type) &&
+              known_value_type->isSubtypeOf(current_value_type)) {
+            if (!candidate ||
+                (candidate_key_type->isSubtypeOf(current_key_type) &&
+                 candidate_value_type->isSubtypeOf(current_value_type))) {
+              candidate_key_type = current_key_type;
+              candidate_value_type = current_value_type;
+              candidate = current_candidate;
+            }
+          }
+        }
+
+        if (!candidate) {
+          std::stringstream vector_repr;
+          for (size_t i = 0; i < all_candidates.size(); ++i) {
+            if (i > 0 && all_candidates.size() > 2) {
+              vector_repr << ", ";
+            }
+            if (i != 0 && i == all_candidates.size() - 1) {
+              vector_repr << " or ";
+            }
+            vector_repr << all_candidates[i]->repr_str();
+          }
+          throw ErrorReport(dc)
+              << "Union type annotation `" << type_hint->repr_str()
+              << "` can hold " << vector_repr.str() << ", but none of "
+              << "those list types can hold the types of the given dict"
+              << " elements, which were unified to " << candidate->repr_str();
+        } else {
+          refined_type_hint = candidate;
+        }
+      }
+
+      if (!refined_type_hint) {
+        dict_value->setType(DictType::create(k->type(), *unified_value_type));
+      } else {
+        dict_value->setType(type_hint);
       }
 
       NamedValue self = NamedValue(loc, "self", dict_value);
@@ -1516,6 +1698,14 @@ struct to_ir {
     };
     emitFor(targets_list, itrs, loc, emit_body);
     popFrame();
+
+    if (annotated_union_type) {
+      Node* n =
+          graph->insertNode(graph->create(prim::unchecked_cast, {dict_value}));
+      n->output()->setType(std::move(annotated_union_type));
+      dict_value = n->output();
+    }
+
     return dict_value;
   }
 
@@ -3797,22 +3987,80 @@ struct to_ir {
             ->call(tree->range(), method, named_values, {}, 0));
   }
 
-  Value* emitListLiteral(ListLiteral ll, TypePtr type_hint) {
+  Value* emitListLiteral(ListLiteral ll, const TypePtr& type_hint) {
     auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
 
     // Determine the element type of the list. If we have a type hint
     // of `List[T]`, use `T`. If the list is non-empty, find the
     // greatest common supertype of all the list elements (defaulting to
     // `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]`
-    TypePtr elem_type = TensorType::get();
+    TypePtr inferred_elem_type = TensorType::get();
+
+    TypePtr refined_type_hint = type_hint;
+
+    // If `type_hint` is a Union, we're going to change it to be
+    // the type of the rhs List, so we need to store the original
+    // UnionType for later. `nullptr` means that we don't need to emit
+    // an `unchecked_cast` node (either because we don't have a type
+    // hint or because the type hint wasn't a Union)
+    TypePtr annotated_union_type =
+        refined_type_hint && refined_type_hint->kind() == UnionType::Kind
+        ? refined_type_hint
+        : nullptr;
+
+    // This is used for error reporting in the case that we have a Union
+    // annotation that contains multiple Lists. We need to determine the
+    // actual type based on the rhs values
+    std::vector<TypePtr> all_candidates = {};
+
+    // Basic `type_hint` check here for better error reporting. We also
+    // see if we can narrow down the actual type if the given type hint
+    // is a Union
+    if (refined_type_hint) {
+      // If necessary/possible, make `type_hint` a ListType
+      if (auto union_type_hint = refined_type_hint->cast<UnionType>()) {
+        std::vector<TypePtr> list_types;
+        std::copy_if(
+            union_type_hint->containedTypes().begin(),
+            union_type_hint->containedTypes().end(),
+            std::back_inserter(list_types),
+            [&](TypePtr type_ptr) {
+              return type_ptr->kind() == ListType::Kind;
+            });
+        if (list_types.empty()) {
+          throw ErrorReport(ll) << "Expected an Union type annotation "
+                                << "with an inner List type, but got "
+                                << refined_type_hint->repr_str();
+        } else if (list_types.size() > 1) {
+          if (values.empty()) {
+            throw ErrorReport(ll)
+                << "Cannot assign an empty list to a "
+                << "variable annotated to be type "
+                << refined_type_hint->repr_str()
+                << " because there are multiple possible List "
+                << "type candidates in the Union annotation";
+          } else {
+            all_candidates = std::move(list_types);
+          }
+        } else {
+          refined_type_hint = list_types[0];
+        }
+      } else if (
+          auto optional_type_hint = refined_type_hint->cast<OptionalType>()) {
+        refined_type_hint = optional_type_hint->getElementType();
+      }
 
-    if (type_hint) {
-      if (type_hint->kind() == TypeKind::ListType) {
-        elem_type = type_hint->expectRef<ListType>().getElementType();
-      } else {
-        // If the type hint was not `List[T]`, throw an error
-        throw ErrorReport(ll) << "Expected a List type hint but instead got "
-                              << type_hint->repr_str();
+      // If we had any annotation OTHER THAN a Union that can hold more
+      // than one type of List
+      if (all_candidates.empty()) {
+        if (refined_type_hint->kind() == ListType::Kind) {
+          auto list_type_hint = refined_type_hint->cast<ListType>();
+          inferred_elem_type = list_type_hint->getElementType();
+        } else {
+          throw ErrorReport(ll)
+              << "Expected an annotation of type "
+              << "List, but got " << refined_type_hint->repr_str();
+        }
       }
     }
 
@@ -3821,45 +4069,100 @@ struct to_ir {
 
       std::stringstream nowhere; // never used
 
-      const TypePtr element_type_hint =
-          type_hint ? type_hint->expect<ListType>()->getElementType() : nullptr;
+      // We don't want to use `elem_type` as the final argument to
+      // `unifyTypeList` because there's a chance that `elem_type` is
+      // the Tensor default
+      const auto elem_type_hint =
+          refined_type_hint && refined_type_hint->kind() == ListType::Kind
+          ? refined_type_hint->cast<ListType>()->getElementType()
+          : nullptr;
 
-      c10::optional<TypePtr> unified = unifyTypeList(
-          types, nowhere, /*default_to_union=*/true, element_type_hint);
+      c10::optional<TypePtr> unified_elem_type = unifyTypeList(
+          types, nowhere, /*default_to_union=*/true, elem_type_hint);
 
-      if (!type_hint && *unified == AnyType::get()) {
+      if (!refined_type_hint &&
+          (*unified_elem_type)->kind() == UnionType::Kind) {
         TORCH_WARN(
             "List consists of heterogeneous types, which means",
-            " that it has been typed as `List[Any]`. To use "
-            "any of the values in the List, it will be "
-            "necessary to add an `assert isinstance` statement "
-            "before first use to trigger type refinement. \n",
+            " that it has been typed as containing ",
+            (*unified_elem_type)->repr_str(),
+            ". To use any of the "
+            "values in this List, it will be necessary to add an "
+            "`assert isinstance` statement before first use to trigger "
+            "type refinement.\n",
             ll.range().str());
       }
 
-      if (type_hint && !(*unified)->isSubtypeOf(elem_type)) {
+      if (all_candidates.empty() && refined_type_hint &&
+          !(*unified_elem_type)->isSubtypeOf(inferred_elem_type)) {
         throw ErrorReport(ll)
-            << "List type annotation `" << type_hint->repr_str()
+            << "List type annotation `" << refined_type_hint->repr_str()
             << "` did not match the types of the given list elements,"
-            << " which were unified to " << (*unified)->repr_str();
+            << " which were unified to " << (*unified_elem_type)->repr_str();
+      }
+
+      if (!all_candidates.empty()) {
+        TypePtr greatest_elem_type = nullptr;
+        std::for_each(
+            all_candidates.begin(),
+            all_candidates.end(),
+            [&](TypePtr candidate) {
+              auto candidate_elem_type =
+                  candidate->expect<ListType>()->getElementType();
+              if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) {
+                if (!greatest_elem_type) {
+                  greatest_elem_type = candidate_elem_type;
+                } else {
+                  greatest_elem_type =
+                      *(unifyTypes(greatest_elem_type, candidate_elem_type));
+                }
+              }
+            });
+        if (!greatest_elem_type) {
+          std::stringstream vector_repr;
+          for (size_t i = 0; i < all_candidates.size(); ++i) {
+            if (i > 0 && all_candidates.size() > 2) {
+              vector_repr << ", ";
+            }
+            if (i != 0 && i == all_candidates.size() - 1) {
+              vector_repr << " or ";
+            }
+            vector_repr << all_candidates[i]->repr_str();
+          }
+          throw ErrorReport(ll)
+              << "Union type annotation `" << type_hint->repr_str()
+              << "` can hold " << vector_repr.str() << ", but none of "
+              << "those list types can hold the types of the given list"
+              << " elements, which were unified to "
+              << (*unified_elem_type)->repr_str();
+        } else {
+          refined_type_hint = ListType::create(greatest_elem_type);
+          inferred_elem_type =
+              refined_type_hint->expect<ListType>()->getElementType();
+        }
       }
 
       // We only want to set `elem_type` if we don't have a type hint
       // to allow for the case that `*unified` is a subtype of
       // `type_hint`
-      if (!type_hint) {
-        elem_type = *unified;
+      if (!refined_type_hint) {
+        inferred_elem_type = *unified_elem_type;
       }
     }
 
-    Value* result =
-        graph->insertNode(graph->createList(elem_type, values))->output();
-    return result;
+    Node* result =
+        graph->insertNode(graph->createList(inferred_elem_type, values));
+    if (annotated_union_type) {
+      Node* n = graph->insertNode(
+          graph->create(prim::unchecked_cast, {result->output()}));
+      n->output()->setType(std::move(annotated_union_type));
+      result = n;
+    }
+
+    return result->output();
   }
 
-  Value* emitSimpleExpr(
-      const TreeRef& tree,
-      const TypePtr& type_hint = nullptr) {
+  Value* emitSimpleExpr(const TreeRef& tree, TypePtr type_hint = nullptr) {
     switch (tree->kind()) {
       case TK_FLOOR_DIV:
       case '@': {
@@ -3961,10 +4264,70 @@ struct to_ir {
         TypePtr key_type = nullptr;
         TypePtr value_type = nullptr;
 
-        if (type_hint && type_hint->kind() == TypeKind::DictType) {
-          auto dict_type = type_hint->expect<DictType>();
-          key_type = dict_type->getKeyType();
-          value_type = dict_type->getValueType();
+        // See notes on logic in `emitListLiteral`
+
+        TypePtr refined_type_hint = type_hint;
+
+        TypePtr annotated_union_type =
+            refined_type_hint && refined_type_hint->kind() == UnionType::Kind
+            ? refined_type_hint
+            : nullptr;
+
+        std::vector<TypePtr> all_candidates = {};
+
+        if (refined_type_hint) {
+          if (auto union_type_hint = refined_type_hint->cast<UnionType>()) {
+            std::vector<TypePtr> dict_types;
+            std::copy_if(
+                union_type_hint->containedTypes().begin(),
+                union_type_hint->containedTypes().end(),
+                std::back_inserter(dict_types),
+                [&](TypePtr type_ptr) {
+                  return type_ptr->kind() == DictType::Kind;
+                });
+            if (dict_types.empty()) {
+              throw ErrorReport(dl) << "Expected an Union type annotation "
+                                    << "with an inner Dict type, but got "
+                                    << type_hint->repr_str();
+            } else if (dict_types.size() > 1) {
+              if (values.empty()) {
+                throw ErrorReport(dl)
+                    << "Cannot assign an empty dict to a "
+                    << "variable annotated to be type " << type_hint->repr_str()
+                    << " because there are multiple possible Dict "
+                    << "type candidates in the Union annotation";
+              } else {
+                all_candidates = std::move(dict_types);
+              }
+            } else {
+              refined_type_hint = dict_types[0];
+            }
+          } else if (
+              auto optional_type_hint =
+                  refined_type_hint->cast<OptionalType>()) {
+            refined_type_hint = optional_type_hint->getElementType();
+          }
+
+          if (all_candidates.empty()) {
+            if (auto dict_type_hint = refined_type_hint->cast<DictType>()) {
+              auto dict_type = refined_type_hint->expect<DictType>();
+              key_type = dict_type->getKeyType();
+              value_type = dict_type->getValueType();
+            } else if (refined_type_hint == AnyType::get()) {
+              // @ansley: Clean up later
+              if (keys.empty()) {
+                key_type = StringType::get();
+                value_type = TensorType::get();
+              } else {
+                key_type = keys.at(0)->type();
+                value_type = values.at(0)->type();
+              }
+            } else {
+              throw ErrorReport(dl)
+                  << "Expected an annotation of type "
+                  << "Dict, but got " << type_hint->repr_str();
+            }
+          }
         } else if (keys.empty()) {
           key_type = StringType::get();
           value_type = TensorType::get();
@@ -3972,34 +4335,96 @@ struct to_ir {
           key_type = keys.at(0)->type();
           value_type = values.at(0)->type();
         }
-        AT_ASSERT(key_type != nullptr && value_type != nullptr);
-
-        for (const auto i : c10::irange(keys.size())) {
-          std::stringstream ss;
-          if (!keys[i]->type()->isSubtypeOfExt(key_type, &ss)) {
-            throw ErrorReport(key_trees[i])
-                << "Dict keys must contain "
-                << "only a single type. Expected: " << key_type->repr_str()
-                << " but found " << keys[i]->type()->repr_str() << " instead.\n"
-                << ss.str();
+
+        AT_ASSERT(
+            !all_candidates.empty() ||
+            (key_type != nullptr && value_type != nullptr));
+
+        if (!keys.empty()) {
+          auto key_types = fmap(keys, [](const Value* v) { return v->type(); });
+
+          std::stringstream nowhere; // never used
+
+          TypePtr first_key_type = key_types[0];
+
+          for (const auto i : c10::irange(keys.size())) {
+            std::stringstream ss;
+            if (!keys[i]->type()->isSubtypeOfExt(first_key_type, &ss) &&
+                !first_key_type->isSubtypeOfExt(keys[i]->type(), &ss)) {
+              throw ErrorReport(key_trees[i])
+                  << "Dict keys must contain "
+                  << "only a single type. Expected: "
+                  << first_key_type->repr_str() << " but found "
+                  << keys[i]->type()->repr_str() << " instead.\n"
+                  << ss.str();
+            }
           }
         }
 
         if (!values.empty()) {
-          auto types = fmap(values, [](const Value* v) { return v->type(); });
+          auto value_types =
+              fmap(values, [](const Value* v) { return v->type(); });
 
           std::stringstream nowhere; // never used
 
-          const TypePtr value_type_hint =
-              type_hint ? type_hint->expect<DictType>()->getKeyType() : nullptr;
-
-          c10::optional<TypePtr> unified = unifyTypeList(
-              types,
+          c10::optional<TypePtr> unified_value_type = unifyTypeList(
+              value_types,
               /*why_not=*/nowhere,
               /*default_to_union=*/true,
-              value_type_hint);
+              value_type);
+
+          if (refined_type_hint && !all_candidates.empty()) {
+            auto known_key_type = keys[0]->type();
+            auto known_value_type = *unified_value_type;
+
+            TypePtr candidate_key_type = nullptr;
+            TypePtr candidate_value_type = nullptr;
+            TypePtr candidate = nullptr;
+
+            for (const auto& current_candidate : all_candidates) {
+              auto current_key_type =
+                  current_candidate->expect<DictType>()->getKeyType();
+              auto current_value_type =
+                  current_candidate->expect<DictType>()->getValueType();
+              if (known_key_type->isSubtypeOf(current_key_type) &&
+                  known_value_type->isSubtypeOf(current_value_type)) {
+                if (!candidate ||
+                    (candidate_key_type->isSubtypeOf(current_key_type) &&
+                     candidate_value_type->isSubtypeOf(current_value_type))) {
+                  candidate_key_type = current_key_type;
+                  candidate_value_type = current_value_type;
+                  candidate = current_candidate;
+                }
+              }
+            }
 
-          if (!type_hint && *unified == AnyType::get()) {
+            if (!candidate) {
+              std::stringstream vector_repr;
+              for (size_t i = 0; i < all_candidates.size(); ++i) {
+                if (i > 0 && all_candidates.size() > 2) {
+                  vector_repr << ", ";
+                }
+                if (i != 0 && i == all_candidates.size() - 1) {
+                  vector_repr << " or ";
+                }
+                vector_repr << all_candidates[i]->repr_str();
+              }
+              throw ErrorReport(dl)
+                  << "Union type annotation `" << refined_type_hint->repr_str()
+                  << "` can hold " << vector_repr.str() << ", but none of "
+                  << "those types can hold the types of the given dict"
+                  << " elements, which were unified to Dict["
+                  << known_key_type->repr_str() << ", "
+                  << known_value_type->repr_str() << "]";
+            } else {
+              key_type = candidate_key_type;
+              value_type = candidate_value_type;
+              refined_type_hint = candidate;
+            }
+          }
+
+          if (!refined_type_hint &&
+              (*unified_value_type)->kind() == UnionType::Kind) {
             TORCH_WARN(
                 "Dict values consist of heterogeneous types, which "
                 "means that they have been typed as `Any`. To use "
@@ -4009,20 +4434,20 @@ struct to_ir {
                 dl.range().str());
           }
 
-          if (type_hint) {
+          if (refined_type_hint) {
             TypePtr value_type_hint =
-                type_hint->expect<DictType>()->getValueType();
-            for (const auto i : c10::irange(types.size())) {
+                refined_type_hint->expect<DictType>()->getValueType();
+            for (const auto i : c10::irange(value_types.size())) {
               TORCH_CHECK(
-                  types[i]->isSubtypeOf(value_type_hint),
+                  value_types[i]->isSubtypeOf(value_type_hint),
                   "Type "
-                  "hint for dict was",
-                  type_hint->repr_str(),
-                  "but the value ",
+                  "hint for dict was ",
+                  refined_type_hint->repr_str(),
+                  "but the value ",
                   "at index ",
                   i,
                   " has type ",
-                  types[i]->repr_str(),
+                  value_types[i]->repr_str(),
                   ", which is not a valid"
                   " subtype of ",
                   value_type_hint->repr_str());
@@ -4032,14 +4457,21 @@ struct to_ir {
           // We only want to set `value_type` if we don't have a type
           // hint to allow for the case that `*unified` is a subtype of
           // the value type given by `type_hint`
-          if (!type_hint) {
-            value_type = *unified;
+          if (!refined_type_hint) {
+            value_type = *unified_value_type;
           }
         }
 
-        return graph
-            ->insertNode(graph->createDict(key_type, value_type, keys, values))
-            ->output();
+        Node* result = graph->insertNode(
+            graph->createDict(key_type, value_type, keys, values));
+        if (annotated_union_type) {
+          Node* n = graph->insertNode(
+              graph->create(prim::unchecked_cast, {result->output()}));
+          n->output()->setType(std::move(annotated_union_type));
+          result = n;
+        }
+
+        return result->output();
       } break;
       case TK_LIST_COMP: {
         auto lc = ListComp(tree);
index 03afbdd..d4c219d 100644 (file)
@@ -1129,16 +1129,6 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) {
   // immutable. `Any` is mutable but can point to an immutable type
   // through refinement
   if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) {
-    bool expected_kind = false;
-    for (auto kind : {from->type()->kind(), to->type()->kind()}) {
-      expected_kind = expected_kind ||
-          (kind == TypeKind::OptionalType || kind == TypeKind::FutureType ||
-           kind == TypeKind::TupleType ||
-           kind == TypeKind::UnionType) // immutable type containers
-          || kind == TypeKind::AnyType;
-    }
-    TORCH_INTERNAL_ASSERT(
-        expected_kind, from->type()->str(), to->type()->str());
     return;
   }
   // both immutable
index 4c521a8..29dac32 100644 (file)
@@ -408,19 +408,12 @@ class JitTestCase(JitCommonTestCase):
                     cu = torch.jit.CompilationUnit(source, _frames_up=frames_up)
                     string_frontend = getattr(cu, script.__name__)
 
-                with self.assertRaisesRegex(exception, regex):
-                    string_frontend(*inputs)
-                # optimized run
                 string_frontend(*inputs)
 
             # Python AST frontend
             if not isinstance(script, str):
                 with self.assertRaisesRegex(exception, regex):
                     ge = torch.jit.script(python_fn)
-                    # profiling run
-                    with self.assertRaisesRegex(exception, regex):
-                        ge(*inputs)
-                    # optimized run
                     ge(*inputs)
 
     def checkBailouts(self, model, inputs, expected):