Add implicit optional unwrapping (#15587)
authorElias Ellison <eellison@fb.com>
Fri, 18 Jan 2019 19:17:34 +0000 (11:17 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 18 Jan 2019 19:25:01 +0000 (11:25 -0800)
Summary:
Add support for type inference for optional type refinement.

If a conditional is of the form "x is None" or "x is not None", or is a boolean expression containing multiple none checks, the proper type refinements are inserted in each branch.

For example:
if optional_tensor is not None and len(optional_tensor) < 2:
# optional_tensor is a Tensor

if optional_tensor1 is not None and optional_tensor2 is not None:
# both optional_tensor1 and optional_tensor2 are Tensors

TODO:

- not run an op for unchecked unwrap optional in the interpreter

- potentially refine types to prim::None (omitted for now to simply things & because it's not an actual use cause).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15587

Differential Revision: D13733810

Pulled By: eellison

fbshipit-source-id: 57c32be9f5a09ab5542ba0144a6059b96de23d7a

aten/src/ATen/core/interned_strings.h
test/expect/TestScript.test_if_is_none_dispatch.expect
test/test_jit.py
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp

index 5e9144b..049f5ad 100644 (file)
@@ -88,6 +88,7 @@ namespace c10 {
   _(aten, index_put_)              \
   _(aten, device)                  \
   _(aten, len)                     \
+  _(prim, unchecked_unwrap_optional)\
   FORALL_ATEN_BASE_SYMBOLS(_)      \
   _(onnx, Add)                     \
   _(onnx, Concat)                  \
index bc15fd3..64a30c4 100644 (file)
@@ -6,17 +6,18 @@ graph(%input : Tensor
   %5 : int = prim::Constant[value=4]()
   %x.1 : Tensor = aten::add(%input, %4, %3)
   %7 : bool = aten::__isnot__(%opt.1, %2)
-  %opt : Tensor?, %x.3 : Tensor = prim::If(%7)
+  %opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7)
     block0() {
-      %opt.2 : Tensor = aten::_unwrap_optional(%opt.1)
-      %x.2 : Tensor = aten::add(%opt.2, %x.1, %3)
-      -> (%opt.2, %x.2)
+      %opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1)
+      %opt.3 : Tensor = aten::_unwrap_optional(%opt.2)
+      %x.2 : Tensor = aten::add(%opt.3, %x.1, %3)
+      -> (%opt.3, %x.2)
     }
     block1() {
       -> (%opt.1, %x.1)
     }
-  %12 : bool = aten::__is__(%opt, %2)
-  %x : Tensor = prim::If(%12)
+  %13 : bool = aten::__is__(%opt.4, %2)
+  %x : Tensor = prim::If(%13)
     block0() {
       %x.4 : Tensor = aten::add(%x.3, %5, %3)
       -> (%x.4)
index 4f549d6..b26521f 100644 (file)
@@ -4123,6 +4123,107 @@ a")
                 return a + b
             ''')
 
+    def test_optional_refinement(self):
+        @torch.jit.script
+        def test_if_none_assignment(x):
+            # type: (Optional[int]) -> int
+            if x is None:
+                x = 1
+            return x + 1
+
+        self.assertEqual(test_if_none_assignment(1), 2)
+
+        @torch.jit.script
+        def test_ternary(x):
+            # type: (Optional[int]) -> int
+            x = x if x is not None else 2
+            return x
+
+        @torch.jit.script
+        def test_not_none(x):
+            # type: (Optional[int]) -> None
+            if x is not None:
+                print(x + 1)
+
+        @torch.jit.script
+        def test_and(x, y):
+            # type: (Optional[int], Optional[int]) -> None
+            if x is not None and y is not None:
+                print(x + y)
+
+        @torch.jit.script
+        def test_not(x, y):
+            # type: (Optional[int], Optional[int]) -> None
+            if not (x is not None and y is not None):
+                pass
+            else:
+                print(x + y)
+
+        @torch.jit.script
+        def test_bool_expression(x):
+            # type: (Optional[int]) -> None
+            if x is not None and x < 2:
+                print(x + 1)
+
+        @torch.jit.script
+        def test_nested_bool_expression(x, y):
+            # type: (Optional[int], Optional[int]) -> int
+            if x is not None and x < 2 and y is not None:
+                x = x + y
+            else:
+                x = 5
+            return x + 2
+
+        @torch.jit.script
+        def test_or(x, y):
+            # type: (Optional[int], Optional[int]) -> None
+            if y is None or x is None:
+                pass
+            else:
+                print(x + y)
+
+        # backwards compatibility
+        @torch.jit.script
+        def test_manual_unwrap_opt(x):
+            # type: (Optional[int]) -> int
+            if x is None:
+                x = 1
+            else:
+                x = torch.jit._unwrap_optional(x)
+            return x
+
+        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
+            @torch.jit.script
+            def or_error(x, y):
+                # type: (Optional[int], Optional[int]) -> int
+                if x is None or y is None:
+                    print(x + y)
+
+        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
+            @torch.jit.script
+            def and_error(x, y):
+                # type: (Optional[int], Optional[int]) -> int
+                if x is None and y is None:
+                    pass
+                else:
+                    print(x + y)
+
+        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
+            @torch.jit.script
+            def named_var(x):
+                # type: (Optional[int]) -> None
+                x_none = x is not None
+                if x_none:
+                    print(x + 1)
+
+        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
+            @torch.jit.script
+            def named_var_and(x, y):
+                # type: (Optional[int], Optional[int]) -> None
+                x_none = x is not None
+                if y is not None and x_none:
+                    print(x + y)
+
     def test_while_write_outer_then_read(self):
         def func(a, b):
             while bool(a < 10):
index 2665982..3951eb5 100644 (file)
@@ -19,6 +19,7 @@ std::unordered_set<Symbol> skip_list = {
     prim::Loop, // TODO: handle Loop
     prim::Constant,
     prim::Undefined,
+    prim::unchecked_unwrap_optional, //TODO remove
     prim::None, // it is already a constant and propagating it will lose
                 // important type information about which Optional type it is
     // TODO (zach): we should consider skipping tensor factories in the cases
index e46fce7..1831a70 100644 (file)
@@ -799,6 +799,10 @@ RegisterOperators reg({
             return 0;
           };
         }),
+    // This op can be removed in preprocessing before being run in the interpreter
+    // (but is currently not removed), even when it is removed it needs to remain
+    // a registered op so that constant prop can run.
+    Operator("prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)", noop),
     Operator(
         prim::fork,
         [](const Node* node) {
index a67fc2a..681bee7 100644 (file)
@@ -1,3 +1,4 @@
+#include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/interpreter.h>
@@ -5,7 +6,6 @@
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/lower_tuples.h>
-#include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/script/final_returns.h>
 #include <torch/csrc/jit/script/parser.h>
 #include <torch/csrc/jit/script/schema_matching.h>
@@ -29,6 +29,115 @@ using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
 using AttributeMap = std::unordered_map<std::string, Const>;
 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
 
+using TypeAndRange = std::pair<TypePtr, const SourceRange*>;
+
+// Holds mappings from a variable name to a refined type for that variable
+// E.g if x is not None is true than we can refine x from type t? to t.
+struct Refinements {
+  // using ordered map for deterministic graph output
+  std::map<std::string, TypeAndRange> mappings_;
+
+  void setRefinement(const std::string& name, TypeAndRange mapping) {
+    mappings_[name] = std::move(mapping);
+  }
+
+  c10::optional<TypeAndRange> getRefinement(const std::string& name) const {
+    const auto& maybe_mapping = mappings_.find(name);
+    if (maybe_mapping == mappings_.end()) {
+      return c10::nullopt;
+    }
+    return maybe_mapping->second;
+  }
+
+  // return the intersection of the values to type mappings between this
+  // types can be unified
+  void intersectRefinements(const Refinements& other) {
+    Refinements ret;
+    for (const auto& name_mapping : mappings_) {
+      const auto& name = name_mapping.first;
+      const auto& mapping = name_mapping.second;
+      if (auto other_mapping = other.getRefinement(name_mapping.first)) {
+        const auto maybe_unified_type =
+            unifyTypes(mapping.first, other_mapping->first);
+        if (maybe_unified_type) {
+          ret.setRefinement(
+              name, TypeAndRange(*maybe_unified_type, mapping.second));
+        }
+      }
+    }
+    mappings_ = std::move(ret.mappings_);
+  }
+
+  // return the union of the values to type mappings in a and b whose
+  // types can be unified
+  void unionRefinements(const Refinements& other) {
+    Refinements ret;
+    for (const auto& name_mapping : mappings_) {
+      const auto& name = name_mapping.first;
+      const auto& mapping = name_mapping.second;
+      TypePtr t_1 = mapping.first;
+      if (auto other_mapping = other.getRefinement(name_mapping.first)) {
+        TypePtr t_2 = other_mapping->first;
+        c10::optional<TypePtr> maybe_unified_type = c10::nullopt;
+        if (t_1->isSubtypeOf(t_2)) {
+          maybe_unified_type = t_1;
+        } else if (t_2->isSubtypeOf(t_1)) {
+          maybe_unified_type = t_2;
+        }
+        if (maybe_unified_type) {
+          ret.setRefinement(
+              name, TypeAndRange(*maybe_unified_type, mapping.second));
+        }
+      } else {
+        ret.setRefinement(name, mapping);
+      }
+    }
+
+    for (auto& name_mapping : other.mappings_) {
+      if (!getRefinement(name_mapping.first)) {
+        ret.setRefinement(name_mapping.first, name_mapping.second);
+      }
+    }
+
+    mappings_ = std::move(ret.mappings_);
+  }
+};
+
+// When a comparison like x is None is made, we associate type refinements
+// with its true value and its false value. If a boolean that has refinements
+// associated with it is used in a conditional of an if statememt, the true and
+// false refinements are inserted into the corresponding blocks
+
+struct BoolInfo {
+  BoolInfo(Refinements true_refinements, Refinements false_refinements)
+      : true_refinements_(std::move(true_refinements)),
+        false_refinements_(std::move(false_refinements)){};
+  BoolInfo() = default;
+
+  Refinements true_refinements_;
+  Refinements false_refinements_;
+
+  BoolInfo* mergeOr(const BoolInfo& other) {
+    // if the result of an OR is true, either a & b could have been true,
+    // so we take the intersection of a.true_refinements & b.true_refinements.
+    // if the result is false, both a and b had to be false,
+    // so we take their union.
+    true_refinements_.intersectRefinements(other.true_refinements_);
+    false_refinements_.unionRefinements(other.false_refinements_);
+    return this;
+  }
+
+  BoolInfo* mergeAnd(const BoolInfo& other) {
+    // if the result of an AND is true, both a & b had to be true,
+    // so we take the union of a.true_refinements and b.true_refinements.
+    // if the result is false, either a or b could have been false,
+    // so we take their intersection.
+    true_refinements_.unionRefinements(other.true_refinements_);
+    false_refinements_.intersectRefinements(other.false_refinements_);
+    return this;
+  }
+};
+
 static Value* asSimple(const SugaredValuePtr& value) {
   if (SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
     return sv->getValue();
@@ -817,9 +926,11 @@ struct to_ir {
 
   std::shared_ptr<Environment> emitSingleIfBranch(
       Block* b,
-      const List<Stmt>& branch) {
+      const List<Stmt>& branch,
+      const Refinements& refinements) {
     pushFrame(b);
     WithInsertPoint guard(b);
+    insertRefinements(refinements);
     emitStatements(branch);
     return popFrame();
   }
@@ -830,23 +941,65 @@ struct to_ir {
   }
 
   Value* emitTernaryIf(const TernaryIf& expr) {
+    const auto& bool_info = findRefinements(expr.cond());
     Value* cond_value = emitCond(expr.cond());
-    auto true_expr = [&] { return emitExpr(expr.true_expr()); };
-    auto false_expr = [&] { return emitExpr(expr.false_expr()); };
+    auto true_expr = [&] {
+      insertRefinements(bool_info.true_refinements_);
+      return emitExpr(expr.true_expr());
+    };
+    auto false_expr = [&] {
+      insertRefinements(bool_info.false_refinements_);
+      return emitExpr(expr.false_expr());
+    };
     return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
   }
 
+  // Insert subtyping refinements
+  void insertRefinements(const Refinements& ref) {
+    for (const auto& name_mappings : ref.mappings_) {
+      const std::string& name = name_mappings.first;
+      auto type = name_mappings.second.first;
+      const auto& range = *name_mappings.second.second;
+      Value* v = environment_stack->getVar(name, range);
+      if (type != NoneType::get()) {
+        Value* output = graph->insert(prim::unchecked_unwrap_optional, {v});
+        environment_stack->setVar(range, name, output);
+      }
+      // todo @eellison - revisit inserting Nones when None subtypes Optional
+    }
+  }
+
   Value* emitShortCircuitIf(
       const SourceRange& loc,
       const TreeRef& first_expr,
       const TreeRef& second_expr,
       bool is_or) {
+    const auto first_bool_info = findRefinements(first_expr);
     Value* first_value = emitCond(Expr(first_expr));
 
-    auto get_first_expr = [first_value] { return first_value; };
-    auto get_second_expr = [&] { return emitCond(Expr(second_expr)); };
+    const Refinements* first_expr_refinements;
+    const Refinements* second_expr_refinements;
+    // if it's an OR the first expr is emitted in the true branch
+    // and the second expr in the false branch, if it's an AND the opposite
+    if (is_or) {
+      first_expr_refinements = &first_bool_info.true_refinements_;
+      second_expr_refinements = &first_bool_info.false_refinements_;
+    } else {
+      first_expr_refinements = &first_bool_info.false_refinements_;
+      second_expr_refinements = &first_bool_info.true_refinements_;
+    }
+
+    auto get_first_expr = [&] {
+      insertRefinements(*first_expr_refinements);
+      return first_value;
+    };
+
+    auto get_second_expr = [&] {
+      insertRefinements(*second_expr_refinements);
+      return emitCond(Expr(second_expr));
+    };
 
-    // if this is an OR, eval second expression if first expr is False.
+    // if this is an OR, eval second expression if first expr is False
     // If this is an AND, eval second expression if first expr is True
     if (is_or) {
       return emitIfExpr(loc, first_value, get_first_expr, get_second_expr);
@@ -910,12 +1063,15 @@ struct to_ir {
   void emitIfElseBlocks(Value* cond_value, const If& stmt) {
     Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
     n->addInput(cond_value);
+    const auto bool_info = findRefinements(stmt.cond());
     auto* true_block = n->addBlock();
     auto* false_block = n->addBlock();
 
     // Emit both blocks once to get the union of all mutated values
-    auto save_true = emitSingleIfBranch(true_block, stmt.trueBranch());
-    auto save_false = emitSingleIfBranch(false_block, stmt.falseBranch());
+    auto save_true = emitSingleIfBranch(
+        true_block, stmt.trueBranch(), bool_info.true_refinements_);
+    auto save_false = emitSingleIfBranch(
+        false_block, stmt.falseBranch(), bool_info.false_refinements_);
 
     // In python, every variable assigned in an if statement escapes
     // the scope of the if statement (all variables are scoped to the function).
@@ -1039,6 +1195,7 @@ struct to_ir {
       // emit the whole If stmt as usual, finish emitCond first
       auto lhs_range = cond_op.lhs().get()->range();
       auto rhs_range = cond_op.rhs().get()->range();
+
       auto kind = getNodeKind(cond.kind(), cond.get()->trees().size());
       Value* cond_value = emitBuiltinCall(
           cond.get()->range(),
@@ -1820,6 +1977,51 @@ struct to_ir {
     }
   }
 
+  BoolInfo findRefinements(const TreeRef& tree) {
+    switch (tree->kind()) {
+      case TK_IS:
+      case TK_ISNOT: {
+        const auto& inputs = tree->trees();
+        if (inputs.at(0)->kind() == TK_VAR && inputs.at(1)->kind() == TK_NONE) {
+          const std::string& var_name = Var(inputs[0]).name().name();
+          Refinements true_info, false_info;
+          auto type =
+              environment_stack->getVar(var_name, inputs[0]->range())->type();
+          if (auto opt_type = type->cast<OptionalType>()) {
+            false_info.setRefinement(
+                var_name,
+                TypeAndRange(opt_type->getElementType(), &tree->range()));
+            true_info.setRefinement(
+                var_name, TypeAndRange(NoneType::get(), &tree->range()));
+          }
+          if (tree->kind() == TK_IS) {
+            return BoolInfo(true_info, false_info);
+          } else {
+            return BoolInfo(false_info, true_info);
+          }
+        }
+      } break;
+      case TK_NOT: {
+        const auto& inputs = tree->trees();
+        auto bool_info = findRefinements(inputs[0]);
+        return BoolInfo(
+            bool_info.false_refinements_, bool_info.true_refinements_);
+      }
+      case TK_OR:
+      case TK_AND: {
+        const auto& inputs = tree->trees();
+        auto first = findRefinements(inputs[0]);
+        auto second = findRefinements(inputs[1]);
+        if (tree->kind() == TK_OR) {
+          return *first.mergeOr(second);
+        } else {
+          return *first.mergeAnd(second);
+        }
+      }
+    }
+    return BoolInfo();
+  }
+
   Value* emitExpr(const Expr& tree, const TypePtr& type_hint = nullptr) {
     return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
   }
@@ -2024,7 +2226,7 @@ struct to_ir {
           elem_type = values.at(0)->type();
         }
         for (auto v : values) {
-          if (*v->type() != *elem_type)  {
+          if (*v->type() != *elem_type) {
             throw ErrorReport(tree)
                 << "Lists must contain only a single type, expected: "
                 << *elem_type << " but found " << *v->type() << " instead";