fix aliasing on unwrap optional (#15748)
authorElias Ellison <eellison@fb.com>
Thu, 10 Jan 2019 20:49:54 +0000 (12:49 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 10 Jan 2019 20:52:53 +0000 (12:52 -0800)
Summary:
Fix for https://github.com/pytorch/pytorch/issues/15604
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15748

Differential Revision: D13583632

Pulled By: eellison

fbshipit-source-id: 9655ee010494179e17e34f3047363477dad15fb1

torch/csrc/jit/operator.cpp
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/register_prim_ops.cpp

index 723e7c1..50c0bbb 100644 (file)
@@ -72,7 +72,8 @@ struct SchemaParser {
   TreeRef parseIdent() {
     return String::create(L.expect(TK_IDENT).text());
   }
-  TypePtr parseBaseType() {
+  using TypeAndAlias = std::pair<TypePtr, c10::optional<AliasInfo>>;
+  TypeAndAlias parseBaseType() {
     static std::unordered_map<std::string, TypePtr> type_map = {
         {"Generator", GeneratorType::get()},
         {"ScalarType", IntType::get()},
@@ -91,11 +92,11 @@ struct SchemaParser {
       if (text.size() > 0 && islower(text[0])) {
         // lower case identifiers that are not otherwise valid types
         // are treated as type variables
-        return VarType::create(text);
+        return TypeAndAlias(VarType::create(text), parseAliasAnnotation());
       }
       throw ErrorReport(tok.range) << "unknown type specifier";
     }
-    return it->second;
+    return TypeAndAlias(it->second, c10::nullopt);
   }
   // Examples:
   // Tensor(a) // Tensor is in set a
@@ -160,7 +161,9 @@ struct SchemaParser {
       value = DynamicType::get();
       alias_info = parseAliasAnnotation();
     } else {
-      value = parseBaseType();
+      auto value_alias = parseBaseType();
+      value = value_alias.first;
+      alias_info = value_alias.second;
     }
     while (true) {
       if (L.cur().kind == '[' && L.lookahead().kind == ']') {
index 895e68a..384fc2b 100644 (file)
@@ -9,6 +9,7 @@ bool shouldAnnotate(const TypePtr& type) {
   return type->isSubtypeOf(DynamicType::get()) ||
       type->kind() == TypeKind::ListType ||
       type->kind() == TypeKind::TupleType ||
+      type->kind() == TypeKind::VarType ||
       (type->kind() == TypeKind::OptionalType &&
        shouldAnnotate(type->cast<OptionalType>()->getElementType()));
 }
@@ -366,6 +367,11 @@ void AliasDb::analyze(Node* node) {
       continue;
     }
 
+    // If this type cannot alias, continue. Can occur with a VarType schema
+    if (!shouldAnnotate(actualValue)) {
+      continue;
+    }
+
     // We don't support composite types for alias analysis yet.
     JIT_ASSERT(formal->containedTypes().size() == 0);
     // TODO neither unions nor wildcards make sense on an input. We should
@@ -401,6 +407,11 @@ void AliasDb::analyze(Node* node) {
       continue;
     }
 
+    // If this type cannot alias, continue. Can occur with a VarType schema
+    if (!shouldAnnotate(actual)) {
+      continue;
+    }
+
     // We don't support composite types for alias analysis yet.
     JIT_ASSERT(formal->containedTypes().size() == 0);
 
index 00c5d91..ee2ac40 100644 (file)
@@ -697,7 +697,7 @@ RegisterOperators reg({
           }
         }),
     Operator(
-        "aten::_unwrap_optional(t? optional) -> t",
+        "aten::_unwrap_optional(t(a)? optional) -> t(a)",
         [](const Node* node) -> Operation {
           return [=](Stack& stack) {
             auto val = pop(stack);