import torch
from torch.testing import FileCheck
from enum import Enum
+from textwrap import dedent
from typing import Dict, List, Optional, Tuple, Union
# Make the helper files in test/ importable
self.checkScript(fn, (1,))
self.checkScript(fn, (8,))
+
+ def _assert_passes(self, template: str, ann: str, lhs: str):
+ code = template.format(ann=ann, lhs=lhs)
+ self.checkScript(code, (), name="fn")
+
+ def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
+ code = template.format(ann=ann, lhs=lhs)
+ with self.assertRaisesRegex(RuntimeError, msg):
+ cu = torch.jit.CompilationUnit(code, _frames_up=1)
+ string_frontend = getattr(cu, "fn") # noqa: B009
+
+ def test_union_with_list_assignment(self):
+ template = dedent('''
+ def fn():
+ x: {ann} = {lhs}
+ if torch.jit.isinstance(x, List[torch.Tensor]):
+ x.append(torch.tensor(3))
+ return x
+ ''')
+
+ lhs = {"list_literal_empty" : "[]",
+
+ "list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]",
+
+ "list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]",
+
+ "list_literal_of_mixed" : "[torch.arange(5), 1]",
+
+ "list_comprehension_of_tensor" :
+ "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
+
+ "list_comprehension_of_str" :
+ "[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]",
+
+ "list_comprehension_of_mixed" :
+ "[torch.add(1, x) for x in [torch.arange(5), 1]]"}
+
+ """
+ Union[List[str], List[torch.Tensor]]
+ """
+ self._assert_raises(template,
+ "Union[List[str], List[torch.Tensor]]",
+ lhs["list_literal_empty"],
+ "there are multiple possible List type "
+ "candidates in the Union annotation")
+
+ self._assert_passes(template,
+ "Union[List[str], List[torch.Tensor]]",
+ lhs["list_literal_of_tensor"])
+
+ self._assert_passes(template,
+ "Union[List[str], List[torch.Tensor]]",
+ lhs["list_literal_of_str"])
+
+ self._assert_raises(template,
+ "Union[List[str], List[torch.Tensor]]",
+ lhs["list_literal_of_mixed"],
+ "none of those list types can hold the "
+ "types of the given list elements")
+
+ self._assert_passes(template,
+ "Union[List[str], List[torch.Tensor]]",
+ lhs["list_comprehension_of_tensor"])
+
+ self._assert_passes(template,
+ "Union[List[str], List[torch.Tensor]]",
+ lhs["list_comprehension_of_str"])
+
+ # TODO: Support mixed list comprehensions
+ self._assert_raises(template,
+ "Union[List[str], List[torch.Tensor]]",
+ lhs["list_comprehension_of_mixed"],
+ "Arguments for call are not valid")
+
+ """
+ Union[int, torch.Tensor]
+ """
+ self._assert_raises(template,
+ "Union[int, torch.Tensor]",
+ lhs["list_literal_empty"],
+ "Expected an Union type annotation with an "
+ "inner List type")
+
+ self._assert_raises(template, "Union[int, torch.Tensor]",
+ lhs["list_literal_of_tensor"],
+ "Expected an Union type annotation with an "
+ "inner List type")
+
+ self._assert_raises(template, "Union[int, torch.Tensor]",
+ lhs["list_comprehension_of_tensor"],
+ "Expected an Union type annotation with an "
+ "inner List type")
+
+ """
+ Union[List[torch.Tensor], int]
+ """
+ self._assert_passes(template,
+ "Union[List[torch.Tensor], int]",
+ lhs["list_literal_empty"])
+
+ self._assert_passes(template,
+ "Union[List[torch.Tensor], int]",
+ lhs["list_literal_of_tensor"])
+
+ self._assert_raises(template, "Union[List[torch.Tensor], int]",
+ lhs["list_literal_of_str"],
+ r"List type annotation `List\[Tensor\]` did "
+ "not match the types of the given list "
+ "elements")
+
+ self._assert_raises(template, "Union[List[torch.Tensor], int]",
+ lhs["list_literal_of_mixed"],
+ r"List type annotation `List\[Tensor\]` did "
+ "not match the types of the given list "
+ "elements")
+
+ self._assert_passes(template,
+ "Union[List[torch.Tensor], int]",
+ lhs["list_comprehension_of_tensor"])
+
+ self._assert_raises(template,
+ "Union[List[torch.Tensor], int]",
+ lhs["list_comprehension_of_str"],
+ r"List type annotation `List\[Tensor\]` did "
+ "not match the types of the given list "
+ "elements")
+
+ # TODO: Support mixed list comprehensions
+ self._assert_raises(template,
+ "Union[List[torch.Tensor], int]",
+ lhs["list_comprehension_of_mixed"],
+ "Arguments for call are not valid")
+
+ def test_union_with_dict_assignment(self):
+ template = dedent('''
+ def fn():
+ x: {ann} = {lhs}
+ if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
+ x["foo"] = torch.tensor(3)
+ return x
+ ''')
+
+ lhs = {"dict_literal_empty" : "{}",
+
+ "dict_literal_of_str_tensor" :
+ "{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}",
+
+ "dict_literal_of_str_int" :
+ "{\"foo\" : 1, \"bar\" : 2}",
+
+ "dict_literal_of_mixed" :
+ "{\"foo\" : torch.arange(3), \"bar\" : 2}",
+
+ "dict_comprehension_of_str_tensor" :
+ "{x : torch.add(y, 1) for x, y in \
+ zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}",
+
+ "dict_comprehension_of_str_int" :
+ "{x : torch.add(y, 1) for x, y in \
+ zip([\"foo\", \"bar\"], [1, 2]}",
+
+ "dict_comprehension_of_mixed" :
+ "{x : torch.add(y, 1) for x, y in \
+ zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",
+
+ "dict_keyword" :
+ "dict(foo=torch.arange(3), baz=torch.arange(5))"}
+
+ """
+ Union[Dict[str, torch.Tensor], Dict[str, int]]
+ """
+ self._assert_raises(template,
+ "Union[List[str], List[torch.Tensor]]",
+ lhs["dict_literal_empty"],
+ "Expected an Union type annotation with an "
+ "inner Dict type")
+
+ self._assert_passes(template,
+ "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+ lhs["dict_literal_of_str_tensor"])
+
+ self._assert_passes(template,
+ "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+ lhs["dict_literal_of_str_int"])
+
+ self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+ lhs["dict_literal_of_mixed"],
+ "none of those types can hold the types "
+ "of the given dict elements")
+
+ # TODO: String frontend does not support tuple unpacking
+ # https://github.com/pytorch/pytorch/issues/64096
+ # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+ # lhs["dict_comprehension_of_str_tensor"])
+
+ # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+ # lhs["dict_comprehension_of_str_int"])
+
+ # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+ # lhs["dict_comprehension_of_mixed"],
+ # "foobar")
+
+ self._assert_passes(template,
+ "Union[Dict[str, torch.Tensor], Dict[str, int]]",
+ lhs["dict_keyword"])
+
+ """
+ Union[int, torch.Tensor]
+ """
+ self._assert_raises(template,
+ "Union[int, torch.Tensor]",
+ lhs["dict_literal_empty"],
+ "Expected an Union type annotation with "
+ "an inner Dict type")
+
+ self._assert_raises(template,
+ "Union[int, torch.Tensor]",
+ lhs["dict_literal_of_str_tensor"],
+ "Expected an Union type annotation with "
+ "an inner Dict type")
+
+ # See above--string frontend does not support tuple unpacking
+ # self._assert_raises(template, "Union[int, torch.Tensor]",
+ # lhs["dict_comprehension_of_tensor"],
+ # "foobar")
+
+ """
+ Union[Dict[str, torch.Tensor], int]
+ """
+ self._assert_passes(template,
+ "Union[Dict[str, torch.Tensor], int]",
+ lhs["dict_literal_empty"])
+
+ self._assert_passes(template,
+ "Union[Dict[str, torch.Tensor], int]",
+ lhs["dict_literal_of_str_tensor"])
+
+ self._assert_raises(template,
+ "Union[Dict[str, torch.Tensor], int]",
+ lhs["dict_literal_of_str_int"],
+ r"Type hint for dict was Dict\[str, Tensor\]"
+ ", but the value at index 0 has type int, "
+ "which is not a valid subtype of Tensor")
+
+ self._assert_raises(template,
+ "Union[Dict[str, torch.Tensor], int]",
+ lhs["dict_literal_of_mixed"],
+ r"Type hint for dict was Dict\[str, Tensor\]"
+ ", but the value at index 1 has type int, "
+ "which is not a valid subtype of Tensor")
+
+ # See above--string frontend does not support tuple unpacking
+ # self._assert_passes(template,
+ # "Union[Dict[str, torch.Tensor], int]",
+ # lhs["dict_comprehension_of_str_tensor"])
+
+ # self._assert_raises(template,
+ # "Union[Dict[str, torch.Tensor], int]",
+ # lhs["dict_comprehension_of_str_int"],
+ # "foobar")
+
+ # self._assert_raises(template,
+ # "Union[Dict[str, torch.Tensor], int]",
+ # lhs["dict_comprehension_of_mixed"],
+ # "foobar")
+
+ self._assert_passes(template,
+ "Union[Dict[str, torch.Tensor], int]",
+ lhs["dict_keyword"])
// to SSA while part of their original graph, and closures are ready to
// be inlined into forked closures
ConvertToSSA(graph);
+
// convert loops with an iter and body condition specified to
// python-recognize while loops. we do this so they can be exported,
// and run the pass early to avoid jitter. Like conversion to SSA,
const auto loc = lc.range();
const auto targets_list = List<Expr>::create(lc.range(), {lc.target()});
const auto itrs = List<Expr>::create(lc.range(), {lc.iter()});
+
// If there is no type hint, and this is emitted over an iterable that is
// unrolled and of length 0, then we emit a List of tensors
Value* list_value = graph->insertNode(graph->create(prim::ListConstruct, 1))
->output()
->setType(ListType::ofTensors());
- bool type_set = false;
- if (type_hint) {
- if (!type_hint->cast<ListType>()) {
- throw ErrorReport(loc)
- << "Expected list type annotation for list comprehension"
- ", found "
- << type_hint->repr_str();
+
+ // See notes on logic in `emitListLiteral`
+
+ TypePtr refined_type_hint = type_hint;
+ TypePtr annotated_union_type =
+ type_hint && type_hint->kind() == UnionType::Kind ? type_hint : nullptr;
+
+ std::vector<TypePtr> all_candidates = {};
+
+ if (refined_type_hint) {
+ // If necessary/possible, refine `refined_type_hint` to a ListType
+ if (auto union_type_hint = refined_type_hint->cast<UnionType>()) {
+ std::vector<TypePtr> list_types;
+ std::copy_if(
+ union_type_hint->containedTypes().begin(),
+ union_type_hint->containedTypes().end(),
+ std::back_inserter(list_types),
+ [&](TypePtr type_ptr) {
+ return type_ptr->kind() == ListType::Kind;
+ });
+ if (list_types.empty()) {
+ throw ErrorReport(lc) << "Expected an Union type annotation "
+ << "with an inner List type, but got "
+ << refined_type_hint->repr_str();
+ } else if (list_types.size() == 1) {
+ refined_type_hint = list_types[0];
+ } else {
+ all_candidates = std::move(list_types);
+ }
+ } else if (
+ auto optional_type_hint = refined_type_hint->cast<OptionalType>()) {
+ refined_type_hint = optional_type_hint->getElementType();
+ }
+
+ if (all_candidates.empty()) {
+ if (refined_type_hint->kind() == ListType::Kind) {
+ list_value->setType(refined_type_hint);
+ } else {
+ throw ErrorReport(lc) << "Expected an annotation of type "
+ << "List, but got " << type_hint->repr_str();
+ }
}
- list_value->setType(type_hint);
- type_set = true;
}
- // comprehension introduces its own scope. no variable assigned
- // leaks into the rest of the graph
+ bool seen_first_elem = false;
+
+ // A list comprehension introduces its own scope
Node* n =
graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0));
auto* comprehension_block = n->addBlock();
// be set to `Tensor`. We don't want to unify this default type
// with the actual elements in the list, so let the type begin as
// the first element in the list
- if (!type_set) {
+ if (!seen_first_elem) {
list_value->setType(ListType::create(out->type()));
- type_set = true;
+ seen_first_elem = true;
}
- ListTypePtr lt = list_value->type()->expect<ListType>();
+ const auto elem_type_hint =
+ refined_type_hint && refined_type_hint->kind() == ListType::Kind
+ ? refined_type_hint->cast<ListType>()->getElementType()
+ : nullptr;
- const TypePtr element_type_hint =
- type_hint ? type_hint->expect<ListType>()->getElementType() : nullptr;
-
- auto unified = unifyTypes(
- lt->getElementType(),
+ c10::optional<TypePtr> unified_elem_type = unifyTypes(
+ list_value->type()->expect<ListType>()->getElementType(),
out->type(),
/*default_to_union=*/true,
- element_type_hint);
+ elem_type_hint);
- if (lt->getElementType() != AnyType::get() &&
- *unified == AnyType::get()) {
+ if (!type_hint && (*unified_elem_type)->kind() == UnionType::Kind) {
TORCH_WARN(
"List consists of heterogeneous types, which means",
- " that it has been typed as `List[Any]`. To use "
- "any of the values in the List, it will be "
- "necessary to add an `assert isinstance` statement "
- "before first use to trigger type refinement. The first ",
- "non-matching element was typed as ",
+ " that it has been typed as containing ",
+ (*unified_elem_type)->repr_str(),
+ ". To use any of the "
+ "values in this List, it will be necessary to add an "
+ "`assert isinstance` statement before first use to trigger "
+ "type refinement. The first non-matching element was typed",
+ " as ",
out->type()->repr_str(),
- ", while the elements before it "
- "were ",
- lt->getElementType()->repr_str(),
+ ", while the elements "
+ " before it were ",
+ list_value->type()
+ ->expect<ListType>()
+ ->getElementType()
+ ->repr_str(),
"\n",
lc.range().str());
}
- if (!type_hint) {
- list_value->setType(ListType::create(*unified));
+ if (all_candidates.empty() && refined_type_hint &&
+ !(*unified_elem_type)
+ ->isSubtypeOf(
+ refined_type_hint->expect<ListType>()->getElementType())) {
+ throw ErrorReport(lc)
+ << "List type annotation `" << refined_type_hint->repr_str()
+ << "` did not match the types of the given list elements,"
+ << " which were unified to " << (*unified_elem_type)->repr_str();
+ }
+
+ if (!all_candidates.empty()) {
+ TypePtr greatest_elem_type = nullptr;
+ std::for_each(
+ all_candidates.begin(),
+ all_candidates.end(),
+ [&](TypePtr candidate) {
+ auto candidate_elem_type =
+ candidate->expect<ListType>()->getElementType();
+ if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) {
+ if (!greatest_elem_type) {
+ greatest_elem_type = candidate_elem_type;
+ } else {
+ greatest_elem_type =
+ *(unifyTypes(greatest_elem_type, candidate_elem_type));
+ }
+ }
+ });
+ if (!greatest_elem_type) {
+ std::stringstream vector_repr;
+ for (size_t i = 0; i < all_candidates.size(); ++i) {
+ if (i > 0 && all_candidates.size() > 2) {
+ vector_repr << ", ";
+ }
+ if (i != 0 && i == all_candidates.size() - 1) {
+ vector_repr << " or ";
+ }
+ vector_repr << all_candidates[i]->repr_str();
+ }
+ throw ErrorReport(lc)
+ << "Union type annotation `" << type_hint->repr_str()
+ << "` can hold " << vector_repr.str() << ", but none of "
+ << "those types match the types of the given list "
+ << "elements, which were unified to "
+ << (*unified_elem_type)->repr_str();
+ } else {
+ refined_type_hint = greatest_elem_type;
+ }
+ }
+
+ if (!refined_type_hint) {
+ list_value->setType(ListType::create(*unified_elem_type));
}
NamedValue self = NamedValue(loc, "self", list_value);
Value* dict_value =
graph->insertNode(graph->create(prim::DictConstruct, 1))->output();
- // Set the default type to be Dict[Str, Tensor]
+
+ // Set the default type to be Dict[str, Tensor]
dict_value->setType(DictType::create(StringType::get(), TensorType::get()));
- bool type_set = false;
- if (type_hint) {
- if (!type_hint->cast<DictType>()) {
- throw ErrorReport(loc)
- << "Expected Dict type annotation for dict comprehension"
- ", found "
- << type_hint->repr_str();
+
+ TypePtr refined_type_hint = nullptr;
+ TypePtr annotated_union_type =
+ type_hint && type_hint->kind() == UnionType::Kind ? type_hint : nullptr;
+
+ std::vector<TypePtr> all_candidates = {};
+
+ // See notes on logic in `emitListLiteral`
+ if (refined_type_hint) {
+ // If necessary/possible, make `type_hint` a DictType
+ if (auto union_type_hint = refined_type_hint->cast<UnionType>()) {
+ std::vector<TypePtr> dict_types;
+ std::copy_if(
+ union_type_hint->containedTypes().begin(),
+ union_type_hint->containedTypes().end(),
+ std::back_inserter(dict_types),
+ [&](TypePtr type_ptr) {
+ return type_ptr->kind() == DictType::Kind;
+ });
+ if (dict_types.empty()) {
+ throw ErrorReport(dc) << "Expected an Union type annotation "
+ << "with an inner Dict type, but got "
+ << refined_type_hint->repr_str();
+ } else if (dict_types.size() == 1) {
+ refined_type_hint = dict_types[0];
+ } else {
+ all_candidates = std::move(dict_types);
+ }
+ } else if (
+ auto optional_type_hint = refined_type_hint->cast<OptionalType>()) {
+ refined_type_hint = optional_type_hint->getElementType();
+ }
+
+ if (all_candidates.empty()) {
+ if (refined_type_hint->kind() == DictType::Kind) {
+ dict_value->setType(refined_type_hint);
+ } else {
+ throw ErrorReport(dc) << "Expected an annotation of type "
+ << "Dict, but got " << type_hint->repr_str();
+ }
}
- dict_value->setType(type_hint);
- type_set = true;
}
+ TypePtr first_generated_key_type = nullptr;
+ TypePtr first_generated_value_type = nullptr;
+
// A dict comprehension introduces its own scope. No variable assigned
// may leak into the rest of the graph
Node* n =
auto k = emitExpr(dc.key());
auto v = emitExpr(dc.value());
- // Make sure that any key and value types are subtypes of the
- // annotatated key/value types
- if (type_hint) {
- DictTypePtr dict_type_hint = type_hint->expect<DictType>();
+ // If we didn't have a type annotation, the type of the dict would
+ // be set to `(str, Tensor)`. We don't want to unify this default
+ // type with the actual elements in the dict, so let the type
+ // begin as the first element in the dict
+ if (k->type()->kind() == UnionType::Kind) {
+ throw ErrorReport(dc)
+ << "Dicts may only contain homogeneous keys, but the type of "
+ << "the first generated key was " << k->type()->repr_str();
+ } else if (
+ first_generated_key_type && first_generated_key_type != k->type()) {
+ // Values can be heterogenous, so we only need to check that the
+ // key types are all the same
+ throw ErrorReport(dc)
+ << "Dicts may only contain homogeneous keys. Expected "
+ << "dict comprehension to generate type "
+ << first_generated_key_type->repr_str() << ", but got "
+ << k->type()->repr_str();
+ } else {
+ dict_value->setType(DictType::create(k->type(), v->type()));
+ first_generated_key_type = k->type();
+ first_generated_value_type = v->type();
+ }
+
+ // If we had any annotation OTHER THAN a Union that can hold more
+ // than one type of Dict
+ if (refined_type_hint && all_candidates.empty()) {
+ DictTypePtr dict_type_hint = refined_type_hint->expect<DictType>();
std::stringstream ss;
std::stringstream err;
}
}
- // If we didn't have a type annotation, the type of the dict would
- // be set to `(str, Tensor)`. We don't want to unify this default
- // type with the actual elements in the dict, so let the type
- // begin as the first element in the dict
- if (!type_set) {
- dict_value->setType(DictType::create(k->type(), v->type()));
- type_set = true;
- }
-
- DictTypePtr dt = dict_value->type()->expect<DictType>();
-
const TypePtr value_type_hint =
- type_hint ? type_hint->expect<DictType>()->getKeyType() : nullptr;
+ refined_type_hint && refined_type_hint->kind() == DictType::Kind
+ ? refined_type_hint->expect<DictType>()->getValueType()
+ : nullptr;
- c10::optional<TypePtr> unified = unifyTypes(
- dt->getValueType(),
+ c10::optional<TypePtr> unified_value_type = unifyTypes(
+ first_generated_value_type,
v->type(),
/*default_to_union=*/true,
value_type_hint);
- // Warn the user if we inferred the type of the values to be `Any`
- // even though the annotation was something else
- if (dt->getValueType() != AnyType::get() && *unified == AnyType::get()) {
+ if (!type_hint && (*unified_value_type)->kind() == UnionType::Kind) {
TORCH_WARN(
- "Dict consists of heterogeneous types, which means",
- " that it has been typed as `Dict[str, Any]`. To use "
- "any of the values in the Dict, it will be "
- "necessary to add an `assert isinstance` statement "
- "before first use to trigger type refinement. The first ",
- "non-matching element was typed as ",
+ "Dict values consist of heterogeneous types, which means",
+ " that they have been typed as being ",
+ (*unified_value_type)->repr_str(),
+ ". To use any of the "
+ "values in this dict, it will be necessary to add an "
+ "`assert isinstance` statement before first use to trigger "
+ "type refinement. The first non-matching element was typed",
+ " as ",
v->type()->repr_str(),
- ", while the elements before it "
- "were ",
- dt->getValueType()->repr_str(),
+ ", while the elements "
+ " before it were ",
+ first_generated_value_type->repr_str(),
"\n",
dc.range().str());
}
- // We only want to set `dict_value` if we don't have a type hint
- // to allow for the case that `*unified` is a subtype of
- // the value type given by `type_hint`
- if (!type_hint) {
- dict_value->setType(DictType::create(k->type(), *unified));
+ if (type_hint && !all_candidates.empty()) {
+ auto known_key_type = k->type();
+ auto known_value_type = *unified_value_type;
+
+ TypePtr candidate_key_type = nullptr;
+ TypePtr candidate_value_type = nullptr;
+ TypePtr candidate = nullptr;
+
+ for (const auto& current_candidate : all_candidates) {
+ auto current_key_type =
+ current_candidate->expect<DictType>()->getKeyType();
+ auto current_value_type =
+ current_candidate->expect<DictType>()->getValueType();
+ if (known_key_type->isSubtypeOf(current_key_type) &&
+ known_value_type->isSubtypeOf(current_value_type)) {
+ if (!candidate ||
+ (candidate_key_type->isSubtypeOf(current_key_type) &&
+ candidate_value_type->isSubtypeOf(current_value_type))) {
+ candidate_key_type = current_key_type;
+ candidate_value_type = current_value_type;
+ candidate = current_candidate;
+ }
+ }
+ }
+
+ if (!candidate) {
+ std::stringstream vector_repr;
+ for (size_t i = 0; i < all_candidates.size(); ++i) {
+ if (i > 0 && all_candidates.size() > 2) {
+ vector_repr << ", ";
+ }
+ if (i != 0 && i == all_candidates.size() - 1) {
+ vector_repr << " or ";
+ }
+ vector_repr << all_candidates[i]->repr_str();
+ }
+ throw ErrorReport(dc)
+ << "Union type annotation `" << type_hint->repr_str()
+ << "` can hold " << vector_repr.str() << ", but none of "
+ << "those list types can hold the types of the given dict"
+ << " elements, which were unified to " << candidate->repr_str();
+ } else {
+ refined_type_hint = candidate;
+ }
+ }
+
+ if (!refined_type_hint) {
+ dict_value->setType(DictType::create(k->type(), *unified_value_type));
+ } else {
+ dict_value->setType(type_hint);
}
NamedValue self = NamedValue(loc, "self", dict_value);
};
emitFor(targets_list, itrs, loc, emit_body);
popFrame();
+
+ if (annotated_union_type) {
+ Node* n =
+ graph->insertNode(graph->create(prim::unchecked_cast, {dict_value}));
+ n->output()->setType(std::move(annotated_union_type));
+ dict_value = n->output();
+ }
+
return dict_value;
}
->call(tree->range(), method, named_values, {}, 0));
}
- Value* emitListLiteral(ListLiteral ll, TypePtr type_hint) {
+ Value* emitListLiteral(ListLiteral ll, const TypePtr& type_hint) {
auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
// Determine the element type of the list. If we have a type hint
// of `List[T]`, use `T`. If the list is non-empty, find the
// greatest common supertype of all the list elements (defaulting to
// `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]`
- TypePtr elem_type = TensorType::get();
+ TypePtr inferred_elem_type = TensorType::get();
+
+ TypePtr refined_type_hint = type_hint;
+
+ // If `type_hint` is a Union, we're going to change it to be
+ // the type of the rhs List, so we need to store the original
+ // UnionType for later. `nullptr` means that we don't need to emit
+ // an `unchecked_cast` node (either because we don't have a type
+ // hint or because the type hint wasn't a Union)
+ TypePtr annotated_union_type =
+ refined_type_hint && refined_type_hint->kind() == UnionType::Kind
+ ? refined_type_hint
+ : nullptr;
+
+ // This is used for error reporting in the case that we have a Union
+ // annotation that contains multiple Lists. We need to determine the
+ // actual type based on the rhs values
+ std::vector<TypePtr> all_candidates = {};
+
+ // Basic `type_hint` check here for better error reporting. We also
+ // see if we can narrow down the actual type if the given type hint
+ // is a Union
+ if (refined_type_hint) {
+ // If necessary/possible, make `type_hint` a ListType
+ if (auto union_type_hint = refined_type_hint->cast<UnionType>()) {
+ std::vector<TypePtr> list_types;
+ std::copy_if(
+ union_type_hint->containedTypes().begin(),
+ union_type_hint->containedTypes().end(),
+ std::back_inserter(list_types),
+ [&](TypePtr type_ptr) {
+ return type_ptr->kind() == ListType::Kind;
+ });
+ if (list_types.empty()) {
+ throw ErrorReport(ll) << "Expected an Union type annotation "
+ << "with an inner List type, but got "
+ << refined_type_hint->repr_str();
+ } else if (list_types.size() > 1) {
+ if (values.empty()) {
+ throw ErrorReport(ll)
+ << "Cannot assign an empty list to a "
+ << "variable annotated to be type "
+ << refined_type_hint->repr_str()
+ << " because there are multiple possible List "
+ << "type candidates in the Union annotation";
+ } else {
+ all_candidates = std::move(list_types);
+ }
+ } else {
+ refined_type_hint = list_types[0];
+ }
+ } else if (
+ auto optional_type_hint = refined_type_hint->cast<OptionalType>()) {
+ refined_type_hint = optional_type_hint->getElementType();
+ }
- if (type_hint) {
- if (type_hint->kind() == TypeKind::ListType) {
- elem_type = type_hint->expectRef<ListType>().getElementType();
- } else {
- // If the type hint was not `List[T]`, throw an error
- throw ErrorReport(ll) << "Expected a List type hint but instead got "
- << type_hint->repr_str();
+ // If we had any annotation OTHER THAN a Union that can hold more
+ // than one type of List
+ if (all_candidates.empty()) {
+ if (refined_type_hint->kind() == ListType::Kind) {
+ auto list_type_hint = refined_type_hint->cast<ListType>();
+ inferred_elem_type = list_type_hint->getElementType();
+ } else {
+ throw ErrorReport(ll)
+ << "Expected an annotation of type "
+ << "List, but got " << refined_type_hint->repr_str();
+ }
}
}
std::stringstream nowhere; // never used
- const TypePtr element_type_hint =
- type_hint ? type_hint->expect<ListType>()->getElementType() : nullptr;
+ // We don't want to use `elem_type` as the final argument to
+ // `unifyTypeList` because there's a chance that `elem_type` is
+ // the Tensor default
+ const auto elem_type_hint =
+ refined_type_hint && refined_type_hint->kind() == ListType::Kind
+ ? refined_type_hint->cast<ListType>()->getElementType()
+ : nullptr;
- c10::optional<TypePtr> unified = unifyTypeList(
- types, nowhere, /*default_to_union=*/true, element_type_hint);
+ c10::optional<TypePtr> unified_elem_type = unifyTypeList(
+ types, nowhere, /*default_to_union=*/true, elem_type_hint);
- if (!type_hint && *unified == AnyType::get()) {
+ if (!refined_type_hint &&
+ (*unified_elem_type)->kind() == UnionType::Kind) {
TORCH_WARN(
"List consists of heterogeneous types, which means",
- " that it has been typed as `List[Any]`. To use "
- "any of the values in the List, it will be "
- "necessary to add an `assert isinstance` statement "
- "before first use to trigger type refinement. \n",
+ " that it has been typed as containing ",
+ (*unified_elem_type)->repr_str(),
+ ". To use any of the "
+ "values in this List, it will be necessary to add an "
+ "`assert isinstance` statement before first use to trigger "
+ "type refinement.\n",
ll.range().str());
}
- if (type_hint && !(*unified)->isSubtypeOf(elem_type)) {
+ if (all_candidates.empty() && refined_type_hint &&
+ !(*unified_elem_type)->isSubtypeOf(inferred_elem_type)) {
throw ErrorReport(ll)
- << "List type annotation `" << type_hint->repr_str()
+ << "List type annotation `" << refined_type_hint->repr_str()
<< "` did not match the types of the given list elements,"
- << " which were unified to " << (*unified)->repr_str();
+ << " which were unified to " << (*unified_elem_type)->repr_str();
+ }
+
+ if (!all_candidates.empty()) {
+ TypePtr greatest_elem_type = nullptr;
+ std::for_each(
+ all_candidates.begin(),
+ all_candidates.end(),
+ [&](TypePtr candidate) {
+ auto candidate_elem_type =
+ candidate->expect<ListType>()->getElementType();
+ if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) {
+ if (!greatest_elem_type) {
+ greatest_elem_type = candidate_elem_type;
+ } else {
+ greatest_elem_type =
+ *(unifyTypes(greatest_elem_type, candidate_elem_type));
+ }
+ }
+ });
+ if (!greatest_elem_type) {
+ std::stringstream vector_repr;
+ for (size_t i = 0; i < all_candidates.size(); ++i) {
+ if (i > 0 && all_candidates.size() > 2) {
+ vector_repr << ", ";
+ }
+ if (i != 0 && i == all_candidates.size() - 1) {
+ vector_repr << " or ";
+ }
+ vector_repr << all_candidates[i]->repr_str();
+ }
+ throw ErrorReport(ll)
+ << "Union type annotation `" << type_hint->repr_str()
+ << "` can hold " << vector_repr.str() << ", but none of "
+ << "those list types can hold the types of the given list"
+ << " elements, which were unified to "
+ << (*unified_elem_type)->repr_str();
+ } else {
+ refined_type_hint = ListType::create(greatest_elem_type);
+ inferred_elem_type =
+ refined_type_hint->expect<ListType>()->getElementType();
+ }
}
// We only want to set `elem_type` if we don't have a type hint
// to allow for the case that `*unified` is a subtype of
// `type_hint`
- if (!type_hint) {
- elem_type = *unified;
+ if (!refined_type_hint) {
+ inferred_elem_type = *unified_elem_type;
}
}
- Value* result =
- graph->insertNode(graph->createList(elem_type, values))->output();
- return result;
+ Node* result =
+ graph->insertNode(graph->createList(inferred_elem_type, values));
+ if (annotated_union_type) {
+ Node* n = graph->insertNode(
+ graph->create(prim::unchecked_cast, {result->output()}));
+ n->output()->setType(std::move(annotated_union_type));
+ result = n;
+ }
+
+ return result->output();
}
- Value* emitSimpleExpr(
- const TreeRef& tree,
- const TypePtr& type_hint = nullptr) {
+ Value* emitSimpleExpr(const TreeRef& tree, TypePtr type_hint = nullptr) {
switch (tree->kind()) {
case TK_FLOOR_DIV:
case '@': {
TypePtr key_type = nullptr;
TypePtr value_type = nullptr;
- if (type_hint && type_hint->kind() == TypeKind::DictType) {
- auto dict_type = type_hint->expect<DictType>();
- key_type = dict_type->getKeyType();
- value_type = dict_type->getValueType();
+ // See notes on logic in `emitListLiteral`
+
+ TypePtr refined_type_hint = type_hint;
+
+ TypePtr annotated_union_type =
+ refined_type_hint && refined_type_hint->kind() == UnionType::Kind
+ ? refined_type_hint
+ : nullptr;
+
+ std::vector<TypePtr> all_candidates = {};
+
+ if (refined_type_hint) {
+ if (auto union_type_hint = refined_type_hint->cast<UnionType>()) {
+ std::vector<TypePtr> dict_types;
+ std::copy_if(
+ union_type_hint->containedTypes().begin(),
+ union_type_hint->containedTypes().end(),
+ std::back_inserter(dict_types),
+ [&](TypePtr type_ptr) {
+ return type_ptr->kind() == DictType::Kind;
+ });
+ if (dict_types.empty()) {
+ throw ErrorReport(dl) << "Expected an Union type annotation "
+ << "with an inner Dict type, but got "
+ << type_hint->repr_str();
+ } else if (dict_types.size() > 1) {
+ if (values.empty()) {
+ throw ErrorReport(dl)
+ << "Cannot assign an empty dict to a "
+ << "variable annotated to be type " << type_hint->repr_str()
+ << " because there are multiple possible Dict "
+ << "type candidates in the Union annotation";
+ } else {
+ all_candidates = std::move(dict_types);
+ }
+ } else {
+ refined_type_hint = dict_types[0];
+ }
+ } else if (
+ auto optional_type_hint =
+ refined_type_hint->cast<OptionalType>()) {
+ refined_type_hint = optional_type_hint->getElementType();
+ }
+
+ if (all_candidates.empty()) {
+ if (auto dict_type_hint = refined_type_hint->cast<DictType>()) {
+ auto dict_type = refined_type_hint->expect<DictType>();
+ key_type = dict_type->getKeyType();
+ value_type = dict_type->getValueType();
+ } else if (refined_type_hint == AnyType::get()) {
+ // @ansley: Clean up later
+ if (keys.empty()) {
+ key_type = StringType::get();
+ value_type = TensorType::get();
+ } else {
+ key_type = keys.at(0)->type();
+ value_type = values.at(0)->type();
+ }
+ } else {
+ throw ErrorReport(dl)
+ << "Expected an annotation of type "
+ << "Dict, but got " << type_hint->repr_str();
+ }
+ }
} else if (keys.empty()) {
key_type = StringType::get();
value_type = TensorType::get();
key_type = keys.at(0)->type();
value_type = values.at(0)->type();
}
- AT_ASSERT(key_type != nullptr && value_type != nullptr);
-
- for (const auto i : c10::irange(keys.size())) {
- std::stringstream ss;
- if (!keys[i]->type()->isSubtypeOfExt(key_type, &ss)) {
- throw ErrorReport(key_trees[i])
- << "Dict keys must contain "
- << "only a single type. Expected: " << key_type->repr_str()
- << " but found " << keys[i]->type()->repr_str() << " instead.\n"
- << ss.str();
+
+ AT_ASSERT(
+ !all_candidates.empty() ||
+ (key_type != nullptr && value_type != nullptr));
+
+ if (!keys.empty()) {
+ auto key_types = fmap(keys, [](const Value* v) { return v->type(); });
+
+ std::stringstream nowhere; // never used
+
+ TypePtr first_key_type = key_types[0];
+
+ for (const auto i : c10::irange(keys.size())) {
+ std::stringstream ss;
+ if (!keys[i]->type()->isSubtypeOfExt(first_key_type, &ss) &&
+ !first_key_type->isSubtypeOfExt(keys[i]->type(), &ss)) {
+ throw ErrorReport(key_trees[i])
+ << "Dict keys must contain "
+ << "only a single type. Expected: "
+ << first_key_type->repr_str() << " but found "
+ << keys[i]->type()->repr_str() << " instead.\n"
+ << ss.str();
+ }
}
}
if (!values.empty()) {
- auto types = fmap(values, [](const Value* v) { return v->type(); });
+ auto value_types =
+ fmap(values, [](const Value* v) { return v->type(); });
std::stringstream nowhere; // never used
- const TypePtr value_type_hint =
- type_hint ? type_hint->expect<DictType>()->getKeyType() : nullptr;
-
- c10::optional<TypePtr> unified = unifyTypeList(
- types,
+ c10::optional<TypePtr> unified_value_type = unifyTypeList(
+ value_types,
/*why_not=*/nowhere,
/*default_to_union=*/true,
- value_type_hint);
+ value_type);
+
+ if (refined_type_hint && !all_candidates.empty()) {
+ auto known_key_type = keys[0]->type();
+ auto known_value_type = *unified_value_type;
+
+ TypePtr candidate_key_type = nullptr;
+ TypePtr candidate_value_type = nullptr;
+ TypePtr candidate = nullptr;
+
+ for (const auto& current_candidate : all_candidates) {
+ auto current_key_type =
+ current_candidate->expect<DictType>()->getKeyType();
+ auto current_value_type =
+ current_candidate->expect<DictType>()->getValueType();
+ if (known_key_type->isSubtypeOf(current_key_type) &&
+ known_value_type->isSubtypeOf(current_value_type)) {
+ if (!candidate ||
+ (candidate_key_type->isSubtypeOf(current_key_type) &&
+ candidate_value_type->isSubtypeOf(current_value_type))) {
+ candidate_key_type = current_key_type;
+ candidate_value_type = current_value_type;
+ candidate = current_candidate;
+ }
+ }
+ }
- if (!type_hint && *unified == AnyType::get()) {
+ if (!candidate) {
+ std::stringstream vector_repr;
+ for (size_t i = 0; i < all_candidates.size(); ++i) {
+ if (i > 0 && all_candidates.size() > 2) {
+ vector_repr << ", ";
+ }
+ if (i != 0 && i == all_candidates.size() - 1) {
+ vector_repr << " or ";
+ }
+ vector_repr << all_candidates[i]->repr_str();
+ }
+ throw ErrorReport(dl)
+ << "Union type annotation `" << refined_type_hint->repr_str()
+ << "` can hold " << vector_repr.str() << ", but none of "
+ << "those types can hold the types of the given dict"
+ << " elements, which were unified to Dict["
+ << known_key_type->repr_str() << ", "
+ << known_value_type->repr_str() << "]";
+ } else {
+ key_type = candidate_key_type;
+ value_type = candidate_value_type;
+ refined_type_hint = candidate;
+ }
+ }
+
+ if (!refined_type_hint &&
+ (*unified_value_type)->kind() == UnionType::Kind) {
TORCH_WARN(
"Dict values consist of heterogeneous types, which "
"means that they have been typed as `Any`. To use "
dl.range().str());
}
- if (type_hint) {
+ if (refined_type_hint) {
TypePtr value_type_hint =
- type_hint->expect<DictType>()->getValueType();
- for (const auto i : c10::irange(types.size())) {
+ refined_type_hint->expect<DictType>()->getValueType();
+ for (const auto i : c10::irange(value_types.size())) {
TORCH_CHECK(
- types[i]->isSubtypeOf(value_type_hint),
+ value_types[i]->isSubtypeOf(value_type_hint),
"Type "
- "hint for dict was",
- type_hint->repr_str(),
- "but the value ",
+ "hint for dict was ",
+ refined_type_hint->repr_str(),
+ ", but the value ",
"at index ",
i,
" has type ",
- types[i]->repr_str(),
+ value_types[i]->repr_str(),
", which is not a valid"
" subtype of ",
value_type_hint->repr_str());
// We only want to set `value_type` if we don't have a type
// hint to allow for the case that `*unified` is a subtype of
// the value type given by `type_hint`
- if (!type_hint) {
- value_type = *unified;
+ if (!refined_type_hint) {
+ value_type = *unified_value_type;
}
}
- return graph
- ->insertNode(graph->createDict(key_type, value_type, keys, values))
- ->output();
+ Node* result = graph->insertNode(
+ graph->createDict(key_type, value_type, keys, values));
+ if (annotated_union_type) {
+ Node* n = graph->insertNode(
+ graph->create(prim::unchecked_cast, {result->output()}));
+ n->output()->setType(std::move(annotated_union_type));
+ result = n;
+ }
+
+ return result->output();
} break;
case TK_LIST_COMP: {
auto lc = ListComp(tree);