From b0b7541ca4ddfa2f5967e17dac01f5bf471b8566 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 25 Feb 2019 13:26:49 -0800 Subject: [PATCH] fix list type unification (#17424) Summary: Previously we were unifying the types of lists across if block outputs. This now fails with Optional subtyping because two types which can be unified have different runtime representations. ``` torch.jit.script def list_optional_fails(x): # type: (bool) -> Optional[int] if x: y = [1] else: y = [None] return y[0] ``` the indexing op will expect y to be a generic list, but it will find an intlist. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17424 Differential Revision: D14210903 Pulled By: eellison fbshipit-source-id: 4b8b26ba2e7e5bebf617e40316475f91e9109cc2 --- aten/src/ATen/core/type.cpp | 27 ++++++++++++++++++--------- test/test_jit.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index ba8a9cc..e29345c 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -199,13 +199,20 @@ bool isSubvalueOf(const IValue& ivalue, TypePtr type) { return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type); } - -c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { - //cases that t1 == t2, or t1 is a type refinement of t2 and vice versa +c10::optional tryEitherIsTheSuperType(const TypePtr& t1, const TypePtr& t2) { if (t1->isSubtypeOf(t2)) { return t2; } else if (t2->isSubtypeOf(t1)) { return t1; + } else { + return c10::nullopt; + } +} + +c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { + //cases that t1 == t2, or t1 is a type refinement of t2 and vice versa + if (auto maybe_supertype = tryEitherIsTheSuperType(t1, t2)) { + return *maybe_supertype; } // NB: we do not return NumberType because there is not currently enough @@ -224,12 +231,14 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { //types which contain other types if (t1->cast() && t2->cast()) { - auto unified_type = unifyTypes(t1->cast()->getElementType(), t2->cast()->getElementType()); - if (unified_type) { - return static_cast(ListType::create(*unified_type)); - } else { - return c10::nullopt; - } + // because we have runtime specializations of lists, e.g. int[] = std::vector + // int?[] = std::vector we don't allow type coercion, + // since t1 & t2 may have different runtime representations. + + // allow Lists of different tensor types + auto unshaped_t1 = unshapedType(t1); + auto unshaped_t2 = unshapedType(t2); + return tryEitherIsTheSuperType(unshaped_t1, unshaped_t2); } else if(t1->cast() && t2->cast()) { auto tuple1 = t1->cast(); auto tuple2 = t2->cast(); diff --git a/test/test_jit.py b/test/test_jit.py index a0c793a..566613a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6643,6 +6643,37 @@ a") tensor_unifying.graph.propagate_shapes((a, b, c), False) self.assertExpected(canonical(tensor_unifying.graph)) + def test_list_unify(self): + # allowing a unififed int?[] would cause a runtime error b/c + # the index operation expects int?[] to be a generic list, + # but in the true branch the IValue will be a int list + with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"): + @torch.jit.script + def list_optional_fails(x): + # type: (bool) -> Optional[int] + if x: + y = [1] + else: + y = [None] + return y[0] + + @torch.jit.script + def list_tensors(x): + # type: (bool) -> Tuple[Tensor, List[Tensor]] + if x: + a = torch.zeros([1, 1]) + y = [a] + else: + a = torch.zeros([1, 2]) + y = [a] + return a, y + + self.run_pass('constant_propagation', list_tensors.graph) + m = torch.jit.ScriptModule() + m._create_method_from_graph("forward", list_tensors.graph) + # testing that tensor type of lists is unified + self.getExportImportCopy(m) + def test_type_annotations_repeated_list(self): @torch.jit.script def float_fn(x, y): -- 2.7.4