From 9173cd5a4d31abd8cf42fcae0341db8287a69074 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 10 Jan 2019 12:49:54 -0800 Subject: [PATCH] fix aliasing on unwrap optional (#15748) Summary: Fix for https://github.com/pytorch/pytorch/issues/15604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15748 Differential Revision: D13583632 Pulled By: eellison fbshipit-source-id: 9655ee010494179e17e34f3047363477dad15fb1 --- torch/csrc/jit/operator.cpp | 11 +++++++---- torch/csrc/jit/passes/alias_analysis.cpp | 11 +++++++++++ torch/csrc/jit/register_prim_ops.cpp | 2 +- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 723e7c1..50c0bbb 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -72,7 +72,8 @@ struct SchemaParser { TreeRef parseIdent() { return String::create(L.expect(TK_IDENT).text()); } - TypePtr parseBaseType() { + using TypeAndAlias = std::pair>; + TypeAndAlias parseBaseType() { static std::unordered_map type_map = { {"Generator", GeneratorType::get()}, {"ScalarType", IntType::get()}, @@ -91,11 +92,11 @@ struct SchemaParser { if (text.size() > 0 && islower(text[0])) { // lower case identifiers that are not otherwise valid types // are treated as type variables - return VarType::create(text); + return TypeAndAlias(VarType::create(text), parseAliasAnnotation()); } throw ErrorReport(tok.range) << "unknown type specifier"; } - return it->second; + return TypeAndAlias(it->second, c10::nullopt); } // Examples: // Tensor(a) // Tensor is in set a @@ -160,7 +161,9 @@ struct SchemaParser { value = DynamicType::get(); alias_info = parseAliasAnnotation(); } else { - value = parseBaseType(); + auto value_alias = parseBaseType(); + value = value_alias.first; + alias_info = value_alias.second; } while (true) { if (L.cur().kind == '[' && L.lookahead().kind == ']') { diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 895e68a..384fc2b 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -9,6 +9,7 @@ bool shouldAnnotate(const TypePtr& type) { return type->isSubtypeOf(DynamicType::get()) || type->kind() == TypeKind::ListType || type->kind() == TypeKind::TupleType || + type->kind() == TypeKind::VarType || (type->kind() == TypeKind::OptionalType && shouldAnnotate(type->cast()->getElementType())); } @@ -366,6 +367,11 @@ void AliasDb::analyze(Node* node) { continue; } + // If this type cannot alias, continue. Can occur with a VarType schema + if (!shouldAnnotate(actualValue)) { + continue; + } + // We don't support composite types for alias analysis yet. JIT_ASSERT(formal->containedTypes().size() == 0); // TODO neither unions nor wildcards make sense on an input. We should @@ -401,6 +407,11 @@ void AliasDb::analyze(Node* node) { continue; } + // If this type cannot alias, continue. Can occur with a VarType schema + if (!shouldAnnotate(actual)) { + continue; + } + // We don't support composite types for alias analysis yet. JIT_ASSERT(formal->containedTypes().size() == 0); diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 00c5d91..ee2ac40 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -697,7 +697,7 @@ RegisterOperators reg({ } }), Operator( - "aten::_unwrap_optional(t? optional) -> t", + "aten::_unwrap_optional(t(a)? optional) -> t(a)", [](const Node* node) -> Operation { return [=](Stack& stack) { auto val = pop(stack); -- 2.7.4