From: Ansley Ussery Date: Fri, 10 Sep 2021 23:18:33 +0000 (-0700) Subject: Preserve types during empty container assignment (#58911) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~292 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c60075d4b59ca6518c2181d3a97983e900d2a8b7;p=platform%2Fupstream%2Fpytorch.git Preserve types during empty container assignment (#58911) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58911 Stack from [ghstack](https://github.com/ezyang/ghstack): * __->__ #58911 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D30785623 Pulled By: ansley fbshipit-source-id: 4e05d6369318974290fea02ad2bc148293c25090 --- diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 10f5e87..ab514ad 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -244,12 +244,9 @@ class TestList(JitTestCase): self.checkScript(fn, ()) def test_dict_keyword_with_mismatched_annotations(self): - err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\ - "match the types of the actual dict items" - err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\ - "match the type of an actual key type `str`" - highlight_msg = "dict([(\"foo\", 1), (\"bar\", 2), (\"baz\", 3" - with self.assertRaisesRegexWithHighlight(RuntimeError, err_msg, highlight_msg): + err_msg = r"is annotated with type Dict\[int, str\] but is " \ + r"being assigned to a value of type Dict\[str, int\]" + with self.assertRaisesRegex(RuntimeError, err_msg): @torch.jit.script def fn(): x: Dict[int, str] = dict([("foo", 1), ("bar", 2), ("baz", 3)]) # noqa: C406 @@ -1328,7 +1325,7 @@ class TestList(JitTestCase): x = torch._C.ListType(None) def test_list_unification_hint(self): - with self.assertRaisesRegex(RuntimeError, "Expected a List type hint"): + with self.assertRaisesRegex(RuntimeError, "Expected an annotation of type List"): @torch.jit.script def x(): b : int = [2, 3] diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py index 125197c..6b3def0 100644 --- a/test/jit/test_typing.py +++ b/test/jit/test_typing.py @@ -168,11 +168,12 @@ class TestTyping(JitTestCase): l1 = [1, 2, "foo", 3] l2 = ["foo", "bar", "baz", "qux"] d: Dict[int, str] = {k : v for k, v in zip(l1, l2)} - return l + return d - with self.assertRaisesRegex(RuntimeError, "Dict type annotation" - r" `Dict\[int, str\]` did not match" - " the type of an actual key type"): + with self.assertRaisesRegex(RuntimeError, "Dicts may only " + "contain homogeneous keys, but the " + "type of the first generated key " + r"was Union\[int, str\]"): torch.jit.script(fn) def test_dict_type_refinement_annotation_value_mismatch(self): @@ -180,12 +181,12 @@ class TestTyping(JitTestCase): l1 = ["foo", "bar", "baz", "qux"] l2 = [1, 2, "foo", 3] d: Dict[str, int] = {k : v for k, v in zip(l1, l2)} - return l + return d - with self.assertRaisesRegex(RuntimeError, "Dict type annotation" - r" `Dict\[str, int\]` did not match" - " the type of an actual value " - "type"): + with self.assertRaisesRegex(RuntimeError, "annotated with type " + r"Dict\[str, int\] but is being " + "assigned to a value of type " + r"Dict\[str, Union\[int, str\]\]"): torch.jit.script(fn) def test_dict_invalid_annotations(self): diff --git a/test/jit/test_union.py b/test/jit/test_union.py index df909a6..fb53d53 100644 --- a/test/jit/test_union.py +++ b/test/jit/test_union.py @@ -5,6 +5,7 @@ import sys import torch from torch.testing import FileCheck from enum import Enum +from textwrap import dedent from typing import Dict, List, Optional, Tuple, Union # Make the helper files in test/ importable @@ -655,3 +656,272 @@ class TestUnion(JitTestCase): self.checkScript(fn, (1,)) self.checkScript(fn, (8,)) + + def _assert_passes(self, template: str, ann: str, lhs: str): + code = template.format(ann=ann, lhs=lhs) + self.checkScript(code, (), name="fn") + + def _assert_raises(self, template: str, ann: str, lhs: str, msg: str): + code = template.format(ann=ann, lhs=lhs) + with self.assertRaisesRegex(RuntimeError, msg): + cu = torch.jit.CompilationUnit(code, _frames_up=1) + string_frontend = getattr(cu, "fn") # noqa: B009 + + def test_union_with_list_assignment(self): + template = dedent(''' + def fn(): + x: {ann} = {lhs} + if torch.jit.isinstance(x, List[torch.Tensor]): + x.append(torch.tensor(3)) + return x + ''') + + lhs = {"list_literal_empty" : "[]", + + "list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]", + + "list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]", + + "list_literal_of_mixed" : "[torch.arange(5), 1]", + + "list_comprehension_of_tensor" : + "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]", + + "list_comprehension_of_str" : + "[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]", + + "list_comprehension_of_mixed" : + "[torch.add(1, x) for x in [torch.arange(5), 1]]"} + + """ + Union[List[str], List[torch.Tensor]] + """ + self._assert_raises(template, + "Union[List[str], List[torch.Tensor]]", + lhs["list_literal_empty"], + "there are multiple possible List type " + "candidates in the Union annotation") + + self._assert_passes(template, + "Union[List[str], List[torch.Tensor]]", + lhs["list_literal_of_tensor"]) + + self._assert_passes(template, + "Union[List[str], List[torch.Tensor]]", + lhs["list_literal_of_str"]) + + self._assert_raises(template, + "Union[List[str], List[torch.Tensor]]", + lhs["list_literal_of_mixed"], + "none of those list types can hold the " + "types of the given list elements") + + self._assert_passes(template, + "Union[List[str], List[torch.Tensor]]", + lhs["list_comprehension_of_tensor"]) + + self._assert_passes(template, + "Union[List[str], List[torch.Tensor]]", + lhs["list_comprehension_of_str"]) + + # TODO: Support mixed list comprehensions + self._assert_raises(template, + "Union[List[str], List[torch.Tensor]]", + lhs["list_comprehension_of_mixed"], + "Arguments for call are not valid") + + """ + Union[int, torch.Tensor] + """ + self._assert_raises(template, + "Union[int, torch.Tensor]", + lhs["list_literal_empty"], + "Expected an Union type annotation with an " + "inner List type") + + self._assert_raises(template, "Union[int, torch.Tensor]", + lhs["list_literal_of_tensor"], + "Expected an Union type annotation with an " + "inner List type") + + self._assert_raises(template, "Union[int, torch.Tensor]", + lhs["list_comprehension_of_tensor"], + "Expected an Union type annotation with an " + "inner List type") + + """ + Union[List[torch.Tensor], int] + """ + self._assert_passes(template, + "Union[List[torch.Tensor], int]", + lhs["list_literal_empty"]) + + self._assert_passes(template, + "Union[List[torch.Tensor], int]", + lhs["list_literal_of_tensor"]) + + self._assert_raises(template, "Union[List[torch.Tensor], int]", + lhs["list_literal_of_str"], + r"List type annotation `List\[Tensor\]` did " + "not match the types of the given list " + "elements") + + self._assert_raises(template, "Union[List[torch.Tensor], int]", + lhs["list_literal_of_mixed"], + r"List type annotation `List\[Tensor\]` did " + "not match the types of the given list " + "elements") + + self._assert_passes(template, + "Union[List[torch.Tensor], int]", + lhs["list_comprehension_of_tensor"]) + + self._assert_raises(template, + "Union[List[torch.Tensor], int]", + lhs["list_comprehension_of_str"], + r"List type annotation `List\[Tensor\]` did " + "not match the types of the given list " + "elements") + + # TODO: Support mixed list comprehensions + self._assert_raises(template, + "Union[List[torch.Tensor], int]", + lhs["list_comprehension_of_mixed"], + "Arguments for call are not valid") + + def test_union_with_dict_assignment(self): + template = dedent(''' + def fn(): + x: {ann} = {lhs} + if torch.jit.isinstance(x, Dict[str, torch.Tensor]): + x["foo"] = torch.tensor(3) + return x + ''') + + lhs = {"dict_literal_empty" : "{}", + + "dict_literal_of_str_tensor" : + "{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}", + + "dict_literal_of_str_int" : + "{\"foo\" : 1, \"bar\" : 2}", + + "dict_literal_of_mixed" : + "{\"foo\" : torch.arange(3), \"bar\" : 2}", + + "dict_comprehension_of_str_tensor" : + "{x : torch.add(y, 1) for x, y in \ + zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}", + + "dict_comprehension_of_str_int" : + "{x : torch.add(y, 1) for x, y in \ + zip([\"foo\", \"bar\"], [1, 2]}", + + "dict_comprehension_of_mixed" : + "{x : torch.add(y, 1) for x, y in \ + zip([\"foo\", \"bar\"], [torch.arange(3), 2])}", + + "dict_keyword" : + "dict(foo=torch.arange(3), baz=torch.arange(5))"} + + """ + Union[Dict[str, torch.Tensor], Dict[str, int]] + """ + self._assert_raises(template, + "Union[List[str], List[torch.Tensor]]", + lhs["dict_literal_empty"], + "Expected an Union type annotation with an " + "inner Dict type") + + self._assert_passes(template, + "Union[Dict[str, torch.Tensor], Dict[str, int]]", + lhs["dict_literal_of_str_tensor"]) + + self._assert_passes(template, + "Union[Dict[str, torch.Tensor], Dict[str, int]]", + lhs["dict_literal_of_str_int"]) + + self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", + lhs["dict_literal_of_mixed"], + "none of those types can hold the types " + "of the given dict elements") + + # TODO: String frontend does not support tuple unpacking + # https://github.com/pytorch/pytorch/issues/64096 + # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", + # lhs["dict_comprehension_of_str_tensor"]) + + # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", + # lhs["dict_comprehension_of_str_int"]) + + # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", + # lhs["dict_comprehension_of_mixed"], + # "foobar") + + self._assert_passes(template, + "Union[Dict[str, torch.Tensor], Dict[str, int]]", + lhs["dict_keyword"]) + + """ + Union[int, torch.Tensor] + """ + self._assert_raises(template, + "Union[int, torch.Tensor]", + lhs["dict_literal_empty"], + "Expected an Union type annotation with " + "an inner Dict type") + + self._assert_raises(template, + "Union[int, torch.Tensor]", + lhs["dict_literal_of_str_tensor"], + "Expected an Union type annotation with " + "an inner Dict type") + + # See above--string frontend does not support tuple unpacking + # self._assert_raises(template, "Union[int, torch.Tensor]", + # lhs["dict_comprehension_of_tensor"], + # "foobar") + + """ + Union[Dict[str, torch.Tensor], int] + """ + self._assert_passes(template, + "Union[Dict[str, torch.Tensor], int]", + lhs["dict_literal_empty"]) + + self._assert_passes(template, + "Union[Dict[str, torch.Tensor], int]", + lhs["dict_literal_of_str_tensor"]) + + self._assert_raises(template, + "Union[Dict[str, torch.Tensor], int]", + lhs["dict_literal_of_str_int"], + r"Type hint for dict was Dict\[str, Tensor\]" + ", but the value at index 0 has type int, " + "which is not a valid subtype of Tensor") + + self._assert_raises(template, + "Union[Dict[str, torch.Tensor], int]", + lhs["dict_literal_of_mixed"], + r"Type hint for dict was Dict\[str, Tensor\]" + ", but the value at index 1 has type int, " + "which is not a valid subtype of Tensor") + + # See above--string frontend does not support tuple unpacking + # self._assert_passes(template, + # "Union[Dict[str, torch.Tensor], int]", + # lhs["dict_comprehension_of_str_tensor"]) + + # self._assert_raises(template, + # "Union[Dict[str, torch.Tensor], int]", + # lhs["dict_comprehension_of_str_int"], + # "foobar") + + # self._assert_raises(template, + # "Union[Dict[str, torch.Tensor], int]", + # lhs["dict_comprehension_of_mixed"], + # "foobar") + + self._assert_passes(template, + "Union[Dict[str, torch.Tensor], int]", + lhs["dict_keyword"]) diff --git a/test/test_jit.py b/test/test_jit.py index eb61fdb..a6589ad 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -11527,7 +11527,8 @@ dedent """ out = torch.jit.annotate(int, [x for x in [1, 2, 3]]) # noqa: C416 return out - with self.assertRaisesRegex(Exception, "Expected list type annotation"): + with self.assertRaisesRegex(Exception, "Expected an annotation" + " of type List"): torch.jit.script(bad_type_annotation) def test_list_comprehension_variable_write(self): diff --git a/test/test_ops.py b/test/test_ops.py index 90e52bb..4ee0cbc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,6 +1,5 @@ from collections.abc import Sequence from functools import partial, wraps -import unittest import warnings import torch @@ -685,7 +684,6 @@ class TestJit(JitCommonTestCase): # and runtimes (eager, traced, scripted). # TODO WARNING: inplace x {traced, scripted} not currently tested @_variant_ops(op_db) - @unittest.skipIf(True, "Temporarily skipping while landing Union PR stack") def test_variant_consistency_jit(self, device, dtype, op): _requires_grad = op.supports_autograd and (dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)) diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index dd29f1e..4c066c2 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -660,6 +660,7 @@ struct to_ir { // to SSA while part of their original graph, and closures are ready to // be inlined into forked closures ConvertToSSA(graph); + // convert loops with an iter and body condition specified to // python-recognize while loops. we do this so they can be exported, // and run the pass early to avoid jitter. Like conversion to SSA, @@ -1318,25 +1319,59 @@ struct to_ir { const auto loc = lc.range(); const auto targets_list = List::create(lc.range(), {lc.target()}); const auto itrs = List::create(lc.range(), {lc.iter()}); + // If there is no type hint, and this is emitted over an iterable that is // unrolled and of length 0, then we emit a List of tensors Value* list_value = graph->insertNode(graph->create(prim::ListConstruct, 1)) ->output() ->setType(ListType::ofTensors()); - bool type_set = false; - if (type_hint) { - if (!type_hint->cast()) { - throw ErrorReport(loc) - << "Expected list type annotation for list comprehension" - ", found " - << type_hint->repr_str(); + + // See notes on logic in `emitListLiteral` + + TypePtr refined_type_hint = type_hint; + TypePtr annotated_union_type = + type_hint && type_hint->kind() == UnionType::Kind ? type_hint : nullptr; + + std::vector all_candidates = {}; + + if (refined_type_hint) { + // If necessary/possible, refine `refined_type_hint` to a ListType + if (auto union_type_hint = refined_type_hint->cast()) { + std::vector list_types; + std::copy_if( + union_type_hint->containedTypes().begin(), + union_type_hint->containedTypes().end(), + std::back_inserter(list_types), + [&](TypePtr type_ptr) { + return type_ptr->kind() == ListType::Kind; + }); + if (list_types.empty()) { + throw ErrorReport(lc) << "Expected an Union type annotation " + << "with an inner List type, but got " + << refined_type_hint->repr_str(); + } else if (list_types.size() == 1) { + refined_type_hint = list_types[0]; + } else { + all_candidates = std::move(list_types); + } + } else if ( + auto optional_type_hint = refined_type_hint->cast()) { + refined_type_hint = optional_type_hint->getElementType(); + } + + if (all_candidates.empty()) { + if (refined_type_hint->kind() == ListType::Kind) { + list_value->setType(refined_type_hint); + } else { + throw ErrorReport(lc) << "Expected an annotation of type " + << "List, but got " << type_hint->repr_str(); + } } - list_value->setType(type_hint); - type_set = true; } - // comprehension introduces its own scope. no variable assigned - // leaks into the rest of the graph + bool seen_first_elem = false; + + // A list comprehension introduces its own scope Node* n = graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0)); auto* comprehension_block = n->addBlock(); @@ -1349,41 +1384,94 @@ struct to_ir { // be set to `Tensor`. We don't want to unify this default type // with the actual elements in the list, so let the type begin as // the first element in the list - if (!type_set) { + if (!seen_first_elem) { list_value->setType(ListType::create(out->type())); - type_set = true; + seen_first_elem = true; } - ListTypePtr lt = list_value->type()->expect(); + const auto elem_type_hint = + refined_type_hint && refined_type_hint->kind() == ListType::Kind + ? refined_type_hint->cast()->getElementType() + : nullptr; - const TypePtr element_type_hint = - type_hint ? type_hint->expect()->getElementType() : nullptr; - - auto unified = unifyTypes( - lt->getElementType(), + c10::optional unified_elem_type = unifyTypes( + list_value->type()->expect()->getElementType(), out->type(), /*default_to_union=*/true, - element_type_hint); + elem_type_hint); - if (lt->getElementType() != AnyType::get() && - *unified == AnyType::get()) { + if (!type_hint && (*unified_elem_type)->kind() == UnionType::Kind) { TORCH_WARN( "List consists of heterogeneous types, which means", - " that it has been typed as `List[Any]`. To use " - "any of the values in the List, it will be " - "necessary to add an `assert isinstance` statement " - "before first use to trigger type refinement. The first ", - "non-matching element was typed as ", + " that it has been typed as containing ", + (*unified_elem_type)->repr_str(), + ". To use any of the " + "values in this List, it will be necessary to add an " + "`assert isinstance` statement before first use to trigger " + "type refinement. The first non-matching element was typed", + " as ", out->type()->repr_str(), - ", while the elements before it " - "were ", - lt->getElementType()->repr_str(), + ", while the elements " + " before it were ", + list_value->type() + ->expect() + ->getElementType() + ->repr_str(), "\n", lc.range().str()); } - if (!type_hint) { - list_value->setType(ListType::create(*unified)); + if (all_candidates.empty() && refined_type_hint && + !(*unified_elem_type) + ->isSubtypeOf( + refined_type_hint->expect()->getElementType())) { + throw ErrorReport(lc) + << "List type annotation `" << refined_type_hint->repr_str() + << "` did not match the types of the given list elements," + << " which were unified to " << (*unified_elem_type)->repr_str(); + } + + if (!all_candidates.empty()) { + TypePtr greatest_elem_type = nullptr; + std::for_each( + all_candidates.begin(), + all_candidates.end(), + [&](TypePtr candidate) { + auto candidate_elem_type = + candidate->expect()->getElementType(); + if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) { + if (!greatest_elem_type) { + greatest_elem_type = candidate_elem_type; + } else { + greatest_elem_type = + *(unifyTypes(greatest_elem_type, candidate_elem_type)); + } + } + }); + if (!greatest_elem_type) { + std::stringstream vector_repr; + for (size_t i = 0; i < all_candidates.size(); ++i) { + if (i > 0 && all_candidates.size() > 2) { + vector_repr << ", "; + } + if (i != 0 && i == all_candidates.size() - 1) { + vector_repr << " or "; + } + vector_repr << all_candidates[i]->repr_str(); + } + throw ErrorReport(lc) + << "Union type annotation `" << type_hint->repr_str() + << "` can hold " << vector_repr.str() << ", but none of " + << "those types match the types of the given list " + << "elements, which were unified to " + << (*unified_elem_type)->repr_str(); + } else { + refined_type_hint = greatest_elem_type; + } + } + + if (!refined_type_hint) { + list_value->setType(ListType::create(*unified_elem_type)); } NamedValue self = NamedValue(loc, "self", list_value); @@ -1402,20 +1490,55 @@ struct to_ir { Value* dict_value = graph->insertNode(graph->create(prim::DictConstruct, 1))->output(); - // Set the default type to be Dict[Str, Tensor] + + // Set the default type to be Dict[str, Tensor] dict_value->setType(DictType::create(StringType::get(), TensorType::get())); - bool type_set = false; - if (type_hint) { - if (!type_hint->cast()) { - throw ErrorReport(loc) - << "Expected Dict type annotation for dict comprehension" - ", found " - << type_hint->repr_str(); + + TypePtr refined_type_hint = nullptr; + TypePtr annotated_union_type = + type_hint && type_hint->kind() == UnionType::Kind ? type_hint : nullptr; + + std::vector all_candidates = {}; + + // See notes on logic in `emitListLiteral` + if (refined_type_hint) { + // If necessary/possible, make `type_hint` a DictType + if (auto union_type_hint = refined_type_hint->cast()) { + std::vector dict_types; + std::copy_if( + union_type_hint->containedTypes().begin(), + union_type_hint->containedTypes().end(), + std::back_inserter(dict_types), + [&](TypePtr type_ptr) { + return type_ptr->kind() == DictType::Kind; + }); + if (dict_types.empty()) { + throw ErrorReport(dc) << "Expected an Union type annotation " + << "with an inner Dict type, but got " + << refined_type_hint->repr_str(); + } else if (dict_types.size() == 1) { + refined_type_hint = dict_types[0]; + } else { + all_candidates = std::move(dict_types); + } + } else if ( + auto optional_type_hint = refined_type_hint->cast()) { + refined_type_hint = optional_type_hint->getElementType(); + } + + if (all_candidates.empty()) { + if (refined_type_hint->kind() == DictType::Kind) { + dict_value->setType(refined_type_hint); + } else { + throw ErrorReport(dc) << "Expected an annotation of type " + << "Dict, but got " << type_hint->repr_str(); + } } - dict_value->setType(type_hint); - type_set = true; } + TypePtr first_generated_key_type = nullptr; + TypePtr first_generated_value_type = nullptr; + // A dict comprehension introduces its own scope. No variable assigned // may leak into the rest of the graph Node* n = @@ -1427,10 +1550,33 @@ struct to_ir { auto k = emitExpr(dc.key()); auto v = emitExpr(dc.value()); - // Make sure that any key and value types are subtypes of the - // annotatated key/value types - if (type_hint) { - DictTypePtr dict_type_hint = type_hint->expect(); + // If we didn't have a type annotation, the type of the dict would + // be set to `(str, Tensor)`. We don't want to unify this default + // type with the actual elements in the dict, so let the type + // begin as the first element in the dict + if (k->type()->kind() == UnionType::Kind) { + throw ErrorReport(dc) + << "Dicts may only contain homogeneous keys, but the type of " + << "the first generated key was " << k->type()->repr_str(); + } else if ( + first_generated_key_type && first_generated_key_type != k->type()) { + // Values can be heterogenous, so we only need to check that the + // key types are all the same + throw ErrorReport(dc) + << "Dicts may only contain homogeneous keys. Expected " + << "dict comprehension to generate type " + << first_generated_key_type->repr_str() << ", but got " + << k->type()->repr_str(); + } else { + dict_value->setType(DictType::create(k->type(), v->type())); + first_generated_key_type = k->type(); + first_generated_value_type = v->type(); + } + + // If we had any annotation OTHER THAN a Union that can hold more + // than one type of Dict + if (refined_type_hint && all_candidates.empty()) { + DictTypePtr dict_type_hint = refined_type_hint->expect(); std::stringstream ss; std::stringstream err; @@ -1463,49 +1609,85 @@ struct to_ir { } } - // If we didn't have a type annotation, the type of the dict would - // be set to `(str, Tensor)`. We don't want to unify this default - // type with the actual elements in the dict, so let the type - // begin as the first element in the dict - if (!type_set) { - dict_value->setType(DictType::create(k->type(), v->type())); - type_set = true; - } - - DictTypePtr dt = dict_value->type()->expect(); - const TypePtr value_type_hint = - type_hint ? type_hint->expect()->getKeyType() : nullptr; + refined_type_hint && refined_type_hint->kind() == DictType::Kind + ? refined_type_hint->expect()->getValueType() + : nullptr; - c10::optional unified = unifyTypes( - dt->getValueType(), + c10::optional unified_value_type = unifyTypes( + first_generated_value_type, v->type(), /*default_to_union=*/true, value_type_hint); - // Warn the user if we inferred the type of the values to be `Any` - // even though the annotation was something else - if (dt->getValueType() != AnyType::get() && *unified == AnyType::get()) { + if (!type_hint && (*unified_value_type)->kind() == UnionType::Kind) { TORCH_WARN( - "Dict consists of heterogeneous types, which means", - " that it has been typed as `Dict[str, Any]`. To use " - "any of the values in the Dict, it will be " - "necessary to add an `assert isinstance` statement " - "before first use to trigger type refinement. The first ", - "non-matching element was typed as ", + "Dict values consist of heterogeneous types, which means", + " that they have been typed as being ", + (*unified_value_type)->repr_str(), + ". To use any of the " + "values in this dict, it will be necessary to add an " + "`assert isinstance` statement before first use to trigger " + "type refinement. The first non-matching element was typed", + " as ", v->type()->repr_str(), - ", while the elements before it " - "were ", - dt->getValueType()->repr_str(), + ", while the elements " + " before it were ", + first_generated_value_type->repr_str(), "\n", dc.range().str()); } - // We only want to set `dict_value` if we don't have a type hint - // to allow for the case that `*unified` is a subtype of - // the value type given by `type_hint` - if (!type_hint) { - dict_value->setType(DictType::create(k->type(), *unified)); + if (type_hint && !all_candidates.empty()) { + auto known_key_type = k->type(); + auto known_value_type = *unified_value_type; + + TypePtr candidate_key_type = nullptr; + TypePtr candidate_value_type = nullptr; + TypePtr candidate = nullptr; + + for (const auto& current_candidate : all_candidates) { + auto current_key_type = + current_candidate->expect()->getKeyType(); + auto current_value_type = + current_candidate->expect()->getValueType(); + if (known_key_type->isSubtypeOf(current_key_type) && + known_value_type->isSubtypeOf(current_value_type)) { + if (!candidate || + (candidate_key_type->isSubtypeOf(current_key_type) && + candidate_value_type->isSubtypeOf(current_value_type))) { + candidate_key_type = current_key_type; + candidate_value_type = current_value_type; + candidate = current_candidate; + } + } + } + + if (!candidate) { + std::stringstream vector_repr; + for (size_t i = 0; i < all_candidates.size(); ++i) { + if (i > 0 && all_candidates.size() > 2) { + vector_repr << ", "; + } + if (i != 0 && i == all_candidates.size() - 1) { + vector_repr << " or "; + } + vector_repr << all_candidates[i]->repr_str(); + } + throw ErrorReport(dc) + << "Union type annotation `" << type_hint->repr_str() + << "` can hold " << vector_repr.str() << ", but none of " + << "those list types can hold the types of the given dict" + << " elements, which were unified to " << candidate->repr_str(); + } else { + refined_type_hint = candidate; + } + } + + if (!refined_type_hint) { + dict_value->setType(DictType::create(k->type(), *unified_value_type)); + } else { + dict_value->setType(type_hint); } NamedValue self = NamedValue(loc, "self", dict_value); @@ -1516,6 +1698,14 @@ struct to_ir { }; emitFor(targets_list, itrs, loc, emit_body); popFrame(); + + if (annotated_union_type) { + Node* n = + graph->insertNode(graph->create(prim::unchecked_cast, {dict_value})); + n->output()->setType(std::move(annotated_union_type)); + dict_value = n->output(); + } + return dict_value; } @@ -3797,22 +3987,80 @@ struct to_ir { ->call(tree->range(), method, named_values, {}, 0)); } - Value* emitListLiteral(ListLiteral ll, TypePtr type_hint) { + Value* emitListLiteral(ListLiteral ll, const TypePtr& type_hint) { auto values = getValues(ll.inputs(), /*maybe_unpack=*/true); // Determine the element type of the list. If we have a type hint // of `List[T]`, use `T`. If the list is non-empty, find the // greatest common supertype of all the list elements (defaulting to // `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]` - TypePtr elem_type = TensorType::get(); + TypePtr inferred_elem_type = TensorType::get(); + + TypePtr refined_type_hint = type_hint; + + // If `type_hint` is a Union, we're going to change it to be + // the type of the rhs List, so we need to store the original + // UnionType for later. `nullptr` means that we don't need to emit + // an `unchecked_cast` node (either because we don't have a type + // hint or because the type hint wasn't a Union) + TypePtr annotated_union_type = + refined_type_hint && refined_type_hint->kind() == UnionType::Kind + ? refined_type_hint + : nullptr; + + // This is used for error reporting in the case that we have a Union + // annotation that contains multiple Lists. We need to determine the + // actual type based on the rhs values + std::vector all_candidates = {}; + + // Basic `type_hint` check here for better error reporting. We also + // see if we can narrow down the actual type if the given type hint + // is a Union + if (refined_type_hint) { + // If necessary/possible, make `type_hint` a ListType + if (auto union_type_hint = refined_type_hint->cast()) { + std::vector list_types; + std::copy_if( + union_type_hint->containedTypes().begin(), + union_type_hint->containedTypes().end(), + std::back_inserter(list_types), + [&](TypePtr type_ptr) { + return type_ptr->kind() == ListType::Kind; + }); + if (list_types.empty()) { + throw ErrorReport(ll) << "Expected an Union type annotation " + << "with an inner List type, but got " + << refined_type_hint->repr_str(); + } else if (list_types.size() > 1) { + if (values.empty()) { + throw ErrorReport(ll) + << "Cannot assign an empty list to a " + << "variable annotated to be type " + << refined_type_hint->repr_str() + << " because there are multiple possible List " + << "type candidates in the Union annotation"; + } else { + all_candidates = std::move(list_types); + } + } else { + refined_type_hint = list_types[0]; + } + } else if ( + auto optional_type_hint = refined_type_hint->cast()) { + refined_type_hint = optional_type_hint->getElementType(); + } - if (type_hint) { - if (type_hint->kind() == TypeKind::ListType) { - elem_type = type_hint->expectRef().getElementType(); - } else { - // If the type hint was not `List[T]`, throw an error - throw ErrorReport(ll) << "Expected a List type hint but instead got " - << type_hint->repr_str(); + // If we had any annotation OTHER THAN a Union that can hold more + // than one type of List + if (all_candidates.empty()) { + if (refined_type_hint->kind() == ListType::Kind) { + auto list_type_hint = refined_type_hint->cast(); + inferred_elem_type = list_type_hint->getElementType(); + } else { + throw ErrorReport(ll) + << "Expected an annotation of type " + << "List, but got " << refined_type_hint->repr_str(); + } } } @@ -3821,45 +4069,100 @@ struct to_ir { std::stringstream nowhere; // never used - const TypePtr element_type_hint = - type_hint ? type_hint->expect()->getElementType() : nullptr; + // We don't want to use `elem_type` as the final argument to + // `unifyTypeList` because there's a chance that `elem_type` is + // the Tensor default + const auto elem_type_hint = + refined_type_hint && refined_type_hint->kind() == ListType::Kind + ? refined_type_hint->cast()->getElementType() + : nullptr; - c10::optional unified = unifyTypeList( - types, nowhere, /*default_to_union=*/true, element_type_hint); + c10::optional unified_elem_type = unifyTypeList( + types, nowhere, /*default_to_union=*/true, elem_type_hint); - if (!type_hint && *unified == AnyType::get()) { + if (!refined_type_hint && + (*unified_elem_type)->kind() == UnionType::Kind) { TORCH_WARN( "List consists of heterogeneous types, which means", - " that it has been typed as `List[Any]`. To use " - "any of the values in the List, it will be " - "necessary to add an `assert isinstance` statement " - "before first use to trigger type refinement. \n", + " that it has been typed as containing ", + (*unified_elem_type)->repr_str(), + ". To use any of the " + "values in this List, it will be necessary to add an " + "`assert isinstance` statement before first use to trigger " + "type refinement.\n", ll.range().str()); } - if (type_hint && !(*unified)->isSubtypeOf(elem_type)) { + if (all_candidates.empty() && refined_type_hint && + !(*unified_elem_type)->isSubtypeOf(inferred_elem_type)) { throw ErrorReport(ll) - << "List type annotation `" << type_hint->repr_str() + << "List type annotation `" << refined_type_hint->repr_str() << "` did not match the types of the given list elements," - << " which were unified to " << (*unified)->repr_str(); + << " which were unified to " << (*unified_elem_type)->repr_str(); + } + + if (!all_candidates.empty()) { + TypePtr greatest_elem_type = nullptr; + std::for_each( + all_candidates.begin(), + all_candidates.end(), + [&](TypePtr candidate) { + auto candidate_elem_type = + candidate->expect()->getElementType(); + if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) { + if (!greatest_elem_type) { + greatest_elem_type = candidate_elem_type; + } else { + greatest_elem_type = + *(unifyTypes(greatest_elem_type, candidate_elem_type)); + } + } + }); + if (!greatest_elem_type) { + std::stringstream vector_repr; + for (size_t i = 0; i < all_candidates.size(); ++i) { + if (i > 0 && all_candidates.size() > 2) { + vector_repr << ", "; + } + if (i != 0 && i == all_candidates.size() - 1) { + vector_repr << " or "; + } + vector_repr << all_candidates[i]->repr_str(); + } + throw ErrorReport(ll) + << "Union type annotation `" << type_hint->repr_str() + << "` can hold " << vector_repr.str() << ", but none of " + << "those list types can hold the types of the given list" + << " elements, which were unified to " + << (*unified_elem_type)->repr_str(); + } else { + refined_type_hint = ListType::create(greatest_elem_type); + inferred_elem_type = + refined_type_hint->expect()->getElementType(); + } } // We only want to set `elem_type` if we don't have a type hint // to allow for the case that `*unified` is a subtype of // `type_hint` - if (!type_hint) { - elem_type = *unified; + if (!refined_type_hint) { + inferred_elem_type = *unified_elem_type; } } - Value* result = - graph->insertNode(graph->createList(elem_type, values))->output(); - return result; + Node* result = + graph->insertNode(graph->createList(inferred_elem_type, values)); + if (annotated_union_type) { + Node* n = graph->insertNode( + graph->create(prim::unchecked_cast, {result->output()})); + n->output()->setType(std::move(annotated_union_type)); + result = n; + } + + return result->output(); } - Value* emitSimpleExpr( - const TreeRef& tree, - const TypePtr& type_hint = nullptr) { + Value* emitSimpleExpr(const TreeRef& tree, TypePtr type_hint = nullptr) { switch (tree->kind()) { case TK_FLOOR_DIV: case '@': { @@ -3961,10 +4264,70 @@ struct to_ir { TypePtr key_type = nullptr; TypePtr value_type = nullptr; - if (type_hint && type_hint->kind() == TypeKind::DictType) { - auto dict_type = type_hint->expect(); - key_type = dict_type->getKeyType(); - value_type = dict_type->getValueType(); + // See notes on logic in `emitListLiteral` + + TypePtr refined_type_hint = type_hint; + + TypePtr annotated_union_type = + refined_type_hint && refined_type_hint->kind() == UnionType::Kind + ? refined_type_hint + : nullptr; + + std::vector all_candidates = {}; + + if (refined_type_hint) { + if (auto union_type_hint = refined_type_hint->cast()) { + std::vector dict_types; + std::copy_if( + union_type_hint->containedTypes().begin(), + union_type_hint->containedTypes().end(), + std::back_inserter(dict_types), + [&](TypePtr type_ptr) { + return type_ptr->kind() == DictType::Kind; + }); + if (dict_types.empty()) { + throw ErrorReport(dl) << "Expected an Union type annotation " + << "with an inner Dict type, but got " + << type_hint->repr_str(); + } else if (dict_types.size() > 1) { + if (values.empty()) { + throw ErrorReport(dl) + << "Cannot assign an empty dict to a " + << "variable annotated to be type " << type_hint->repr_str() + << " because there are multiple possible Dict " + << "type candidates in the Union annotation"; + } else { + all_candidates = std::move(dict_types); + } + } else { + refined_type_hint = dict_types[0]; + } + } else if ( + auto optional_type_hint = + refined_type_hint->cast()) { + refined_type_hint = optional_type_hint->getElementType(); + } + + if (all_candidates.empty()) { + if (auto dict_type_hint = refined_type_hint->cast()) { + auto dict_type = refined_type_hint->expect(); + key_type = dict_type->getKeyType(); + value_type = dict_type->getValueType(); + } else if (refined_type_hint == AnyType::get()) { + // @ansley: Clean up later + if (keys.empty()) { + key_type = StringType::get(); + value_type = TensorType::get(); + } else { + key_type = keys.at(0)->type(); + value_type = values.at(0)->type(); + } + } else { + throw ErrorReport(dl) + << "Expected an annotation of type " + << "Dict, but got " << type_hint->repr_str(); + } + } } else if (keys.empty()) { key_type = StringType::get(); value_type = TensorType::get(); @@ -3972,34 +4335,96 @@ struct to_ir { key_type = keys.at(0)->type(); value_type = values.at(0)->type(); } - AT_ASSERT(key_type != nullptr && value_type != nullptr); - - for (const auto i : c10::irange(keys.size())) { - std::stringstream ss; - if (!keys[i]->type()->isSubtypeOfExt(key_type, &ss)) { - throw ErrorReport(key_trees[i]) - << "Dict keys must contain " - << "only a single type. Expected: " << key_type->repr_str() - << " but found " << keys[i]->type()->repr_str() << " instead.\n" - << ss.str(); + + AT_ASSERT( + !all_candidates.empty() || + (key_type != nullptr && value_type != nullptr)); + + if (!keys.empty()) { + auto key_types = fmap(keys, [](const Value* v) { return v->type(); }); + + std::stringstream nowhere; // never used + + TypePtr first_key_type = key_types[0]; + + for (const auto i : c10::irange(keys.size())) { + std::stringstream ss; + if (!keys[i]->type()->isSubtypeOfExt(first_key_type, &ss) && + !first_key_type->isSubtypeOfExt(keys[i]->type(), &ss)) { + throw ErrorReport(key_trees[i]) + << "Dict keys must contain " + << "only a single type. Expected: " + << first_key_type->repr_str() << " but found " + << keys[i]->type()->repr_str() << " instead.\n" + << ss.str(); + } } } if (!values.empty()) { - auto types = fmap(values, [](const Value* v) { return v->type(); }); + auto value_types = + fmap(values, [](const Value* v) { return v->type(); }); std::stringstream nowhere; // never used - const TypePtr value_type_hint = - type_hint ? type_hint->expect()->getKeyType() : nullptr; - - c10::optional unified = unifyTypeList( - types, + c10::optional unified_value_type = unifyTypeList( + value_types, /*why_not=*/nowhere, /*default_to_union=*/true, - value_type_hint); + value_type); + + if (refined_type_hint && !all_candidates.empty()) { + auto known_key_type = keys[0]->type(); + auto known_value_type = *unified_value_type; + + TypePtr candidate_key_type = nullptr; + TypePtr candidate_value_type = nullptr; + TypePtr candidate = nullptr; + + for (const auto& current_candidate : all_candidates) { + auto current_key_type = + current_candidate->expect()->getKeyType(); + auto current_value_type = + current_candidate->expect()->getValueType(); + if (known_key_type->isSubtypeOf(current_key_type) && + known_value_type->isSubtypeOf(current_value_type)) { + if (!candidate || + (candidate_key_type->isSubtypeOf(current_key_type) && + candidate_value_type->isSubtypeOf(current_value_type))) { + candidate_key_type = current_key_type; + candidate_value_type = current_value_type; + candidate = current_candidate; + } + } + } - if (!type_hint && *unified == AnyType::get()) { + if (!candidate) { + std::stringstream vector_repr; + for (size_t i = 0; i < all_candidates.size(); ++i) { + if (i > 0 && all_candidates.size() > 2) { + vector_repr << ", "; + } + if (i != 0 && i == all_candidates.size() - 1) { + vector_repr << " or "; + } + vector_repr << all_candidates[i]->repr_str(); + } + throw ErrorReport(dl) + << "Union type annotation `" << refined_type_hint->repr_str() + << "` can hold " << vector_repr.str() << ", but none of " + << "those types can hold the types of the given dict" + << " elements, which were unified to Dict[" + << known_key_type->repr_str() << ", " + << known_value_type->repr_str() << "]"; + } else { + key_type = candidate_key_type; + value_type = candidate_value_type; + refined_type_hint = candidate; + } + } + + if (!refined_type_hint && + (*unified_value_type)->kind() == UnionType::Kind) { TORCH_WARN( "Dict values consist of heterogeneous types, which " "means that they have been typed as `Any`. To use " @@ -4009,20 +4434,20 @@ struct to_ir { dl.range().str()); } - if (type_hint) { + if (refined_type_hint) { TypePtr value_type_hint = - type_hint->expect()->getValueType(); - for (const auto i : c10::irange(types.size())) { + refined_type_hint->expect()->getValueType(); + for (const auto i : c10::irange(value_types.size())) { TORCH_CHECK( - types[i]->isSubtypeOf(value_type_hint), + value_types[i]->isSubtypeOf(value_type_hint), "Type " - "hint for dict was", - type_hint->repr_str(), - "but the value ", + "hint for dict was ", + refined_type_hint->repr_str(), + ", but the value ", "at index ", i, " has type ", - types[i]->repr_str(), + value_types[i]->repr_str(), ", which is not a valid" " subtype of ", value_type_hint->repr_str()); @@ -4032,14 +4457,21 @@ struct to_ir { // We only want to set `value_type` if we don't have a type // hint to allow for the case that `*unified` is a subtype of // the value type given by `type_hint` - if (!type_hint) { - value_type = *unified; + if (!refined_type_hint) { + value_type = *unified_value_type; } } - return graph - ->insertNode(graph->createDict(key_type, value_type, keys, values)) - ->output(); + Node* result = graph->insertNode( + graph->createDict(key_type, value_type, keys, values)); + if (annotated_union_type) { + Node* n = graph->insertNode( + graph->create(prim::unchecked_cast, {result->output()})); + n->output()->setType(std::move(annotated_union_type)); + result = n; + } + + return result->output(); } break; case TK_LIST_COMP: { auto lc = ListComp(tree); diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 03afbdd..d4c219d 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -1129,16 +1129,6 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) { // immutable. `Any` is mutable but can point to an immutable type // through refinement if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) { - bool expected_kind = false; - for (auto kind : {from->type()->kind(), to->type()->kind()}) { - expected_kind = expected_kind || - (kind == TypeKind::OptionalType || kind == TypeKind::FutureType || - kind == TypeKind::TupleType || - kind == TypeKind::UnionType) // immutable type containers - || kind == TypeKind::AnyType; - } - TORCH_INTERNAL_ASSERT( - expected_kind, from->type()->str(), to->type()->str()); return; } // both immutable diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 4c521a8..29dac32 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -408,19 +408,12 @@ class JitTestCase(JitCommonTestCase): cu = torch.jit.CompilationUnit(source, _frames_up=frames_up) string_frontend = getattr(cu, script.__name__) - with self.assertRaisesRegex(exception, regex): - string_frontend(*inputs) - # optimized run string_frontend(*inputs) # Python AST frontend if not isinstance(script, str): with self.assertRaisesRegex(exception, regex): ge = torch.jit.script(python_fn) - # profiling run - with self.assertRaisesRegex(exception, regex): - ge(*inputs) - # optimized run ge(*inputs) def checkBailouts(self, model, inputs, expected):