return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type);
}
-
-c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
- //cases that t1 == t2, or t1 is a type refinement of t2 and vice versa
+c10::optional<TypePtr> tryEitherIsTheSuperType(const TypePtr& t1, const TypePtr& t2) {
if (t1->isSubtypeOf(t2)) {
return t2;
} else if (t2->isSubtypeOf(t1)) {
return t1;
+ } else {
+ return c10::nullopt;
+ }
+}
+
+c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
+ //cases that t1 == t2, or t1 is a type refinement of t2 and vice versa
+ if (auto maybe_supertype = tryEitherIsTheSuperType(t1, t2)) {
+ return *maybe_supertype;
}
// NB: we do not return NumberType because there is not currently enough
//types which contain other types
if (t1->cast<ListType>() && t2->cast<ListType>()) {
- auto unified_type = unifyTypes(t1->cast<ListType>()->getElementType(), t2->cast<ListType>()->getElementType());
- if (unified_type) {
- return static_cast<TypePtr>(ListType::create(*unified_type));
- } else {
- return c10::nullopt;
- }
+ // because we have runtime specializations of lists, e.g. int[] = std::vector<int64_t>
+ // int?[] = std::vector<IValue> we don't allow type coercion,
+ // since t1 & t2 may have different runtime representations.
+
+ // allow Lists of different tensor types
+ auto unshaped_t1 = unshapedType(t1);
+ auto unshaped_t2 = unshapedType(t2);
+ return tryEitherIsTheSuperType(unshaped_t1, unshaped_t2);
} else if(t1->cast<TupleType>() && t2->cast<TupleType>()) {
auto tuple1 = t1->cast<TupleType>();
auto tuple2 = t2->cast<TupleType>();
tensor_unifying.graph.propagate_shapes((a, b, c), False)
self.assertExpected(canonical(tensor_unifying.graph))
+ def test_list_unify(self):
+ # allowing a unififed int?[] would cause a runtime error b/c
+ # the index operation expects int?[] to be a generic list,
+ # but in the true branch the IValue will be a int list
+ with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"):
+ @torch.jit.script
+ def list_optional_fails(x):
+ # type: (bool) -> Optional[int]
+ if x:
+ y = [1]
+ else:
+ y = [None]
+ return y[0]
+
+ @torch.jit.script
+ def list_tensors(x):
+ # type: (bool) -> Tuple[Tensor, List[Tensor]]
+ if x:
+ a = torch.zeros([1, 1])
+ y = [a]
+ else:
+ a = torch.zeros([1, 2])
+ y = [a]
+ return a, y
+
+ self.run_pass('constant_propagation', list_tensors.graph)
+ m = torch.jit.ScriptModule()
+ m._create_method_from_graph("forward", list_tensors.graph)
+ # testing that tensor type of lists is unified
+ self.getExportImportCopy(m)
+
def test_type_annotations_repeated_list(self):
@torch.jit.script
def float_fn(x, y):