From d4f6befc93938e5462b186f76405b544e50fab60 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 18 Jan 2019 11:17:34 -0800 Subject: [PATCH] Add implicit optional unwrapping (#15587) Summary: Add support for type inference for optional type refinement. If a conditional is of the form "x is None" or "x is not None", or is a boolean expression containing multiple none checks, the proper type refinements are inserted in each branch. For example: if optional_tensor is not None and len(optional_tensor) < 2: # optional_tensor is a Tensor if optional_tensor1 is not None and optional_tensor2 is not None: # both optional_tensor1 and optional_tensor2 are Tensors TODO: - not run an op for unchecked unwrap optional in the interpreter - potentially refine types to prim::None (omitted for now to simply things & because it's not an actual use cause). Pull Request resolved: https://github.com/pytorch/pytorch/pull/15587 Differential Revision: D13733810 Pulled By: eellison fbshipit-source-id: 57c32be9f5a09ab5542ba0144a6059b96de23d7a --- aten/src/ATen/core/interned_strings.h | 1 + .../TestScript.test_if_is_none_dispatch.expect | 13 +- test/test_jit.py | 101 ++++++++++ torch/csrc/jit/passes/constant_propagation.cpp | 1 + torch/csrc/jit/register_prim_ops.cpp | 4 + torch/csrc/jit/script/compiler.cpp | 222 ++++++++++++++++++++- 6 files changed, 326 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 5e9144b..049f5ad 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -88,6 +88,7 @@ namespace c10 { _(aten, index_put_) \ _(aten, device) \ _(aten, len) \ + _(prim, unchecked_unwrap_optional)\ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ diff --git a/test/expect/TestScript.test_if_is_none_dispatch.expect b/test/expect/TestScript.test_if_is_none_dispatch.expect index bc15fd3f..64a30c4 100644 --- a/test/expect/TestScript.test_if_is_none_dispatch.expect +++ b/test/expect/TestScript.test_if_is_none_dispatch.expect @@ -6,17 +6,18 @@ graph(%input : Tensor %5 : int = prim::Constant[value=4]() %x.1 : Tensor = aten::add(%input, %4, %3) %7 : bool = aten::__isnot__(%opt.1, %2) - %opt : Tensor?, %x.3 : Tensor = prim::If(%7) + %opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7) block0() { - %opt.2 : Tensor = aten::_unwrap_optional(%opt.1) - %x.2 : Tensor = aten::add(%opt.2, %x.1, %3) - -> (%opt.2, %x.2) + %opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1) + %opt.3 : Tensor = aten::_unwrap_optional(%opt.2) + %x.2 : Tensor = aten::add(%opt.3, %x.1, %3) + -> (%opt.3, %x.2) } block1() { -> (%opt.1, %x.1) } - %12 : bool = aten::__is__(%opt, %2) - %x : Tensor = prim::If(%12) + %13 : bool = aten::__is__(%opt.4, %2) + %x : Tensor = prim::If(%13) block0() { %x.4 : Tensor = aten::add(%x.3, %5, %3) -> (%x.4) diff --git a/test/test_jit.py b/test/test_jit.py index 4f549d6..b26521f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4123,6 +4123,107 @@ a") return a + b ''') + def test_optional_refinement(self): + @torch.jit.script + def test_if_none_assignment(x): + # type: (Optional[int]) -> int + if x is None: + x = 1 + return x + 1 + + self.assertEqual(test_if_none_assignment(1), 2) + + @torch.jit.script + def test_ternary(x): + # type: (Optional[int]) -> int + x = x if x is not None else 2 + return x + + @torch.jit.script + def test_not_none(x): + # type: (Optional[int]) -> None + if x is not None: + print(x + 1) + + @torch.jit.script + def test_and(x, y): + # type: (Optional[int], Optional[int]) -> None + if x is not None and y is not None: + print(x + y) + + @torch.jit.script + def test_not(x, y): + # type: (Optional[int], Optional[int]) -> None + if not (x is not None and y is not None): + pass + else: + print(x + y) + + @torch.jit.script + def test_bool_expression(x): + # type: (Optional[int]) -> None + if x is not None and x < 2: + print(x + 1) + + @torch.jit.script + def test_nested_bool_expression(x, y): + # type: (Optional[int], Optional[int]) -> int + if x is not None and x < 2 and y is not None: + x = x + y + else: + x = 5 + return x + 2 + + @torch.jit.script + def test_or(x, y): + # type: (Optional[int], Optional[int]) -> None + if y is None or x is None: + pass + else: + print(x + y) + + # backwards compatibility + @torch.jit.script + def test_manual_unwrap_opt(x): + # type: (Optional[int]) -> int + if x is None: + x = 1 + else: + x = torch.jit._unwrap_optional(x) + return x + + with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): + @torch.jit.script + def or_error(x, y): + # type: (Optional[int], Optional[int]) -> int + if x is None or y is None: + print(x + y) + + with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): + @torch.jit.script + def and_error(x, y): + # type: (Optional[int], Optional[int]) -> int + if x is None and y is None: + pass + else: + print(x + y) + + with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): + @torch.jit.script + def named_var(x): + # type: (Optional[int]) -> None + x_none = x is not None + if x_none: + print(x + 1) + + with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): + @torch.jit.script + def named_var_and(x, y): + # type: (Optional[int], Optional[int]) -> None + x_none = x is not None + if y is not None and x_none: + print(x + y) + def test_while_write_outer_then_read(self): def func(a, b): while bool(a < 10): diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index 2665982..3951eb5 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -19,6 +19,7 @@ std::unordered_set skip_list = { prim::Loop, // TODO: handle Loop prim::Constant, prim::Undefined, + prim::unchecked_unwrap_optional, //TODO remove prim::None, // it is already a constant and propagating it will lose // important type information about which Optional type it is // TODO (zach): we should consider skipping tensor factories in the cases diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index e46fce7..1831a70 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -799,6 +799,10 @@ RegisterOperators reg({ return 0; }; }), + // This op can be removed in preprocessing before being run in the interpreter + // (but is currently not removed), even when it is removed it needs to remain + // a registered op so that constant prop can run. + Operator("prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)", noop), Operator( prim::fork, [](const Node* node) { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index a67fc2a..681bee7 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -5,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -29,6 +29,115 @@ using ValueTable = std::unordered_map; using AttributeMap = std::unordered_map; using ListAttributeMap = std::unordered_map>; +using TypeAndRange = std::pair; + +// Holds mappings from a variable name to a refined type for that variable +// E.g if x is not None is true than we can refine x from type t? to t. +struct Refinements { + // using ordered map for deterministic graph output + std::map mappings_; + + void setRefinement(const std::string& name, TypeAndRange mapping) { + mappings_[name] = std::move(mapping); + } + + c10::optional getRefinement(const std::string& name) const { + const auto& maybe_mapping = mappings_.find(name); + if (maybe_mapping == mappings_.end()) { + return c10::nullopt; + } + return maybe_mapping->second; + } + + // return the intersection of the values to type mappings between this + // types can be unified + void intersectRefinements(const Refinements& other) { + Refinements ret; + for (const auto& name_mapping : mappings_) { + const auto& name = name_mapping.first; + const auto& mapping = name_mapping.second; + if (auto other_mapping = other.getRefinement(name_mapping.first)) { + const auto maybe_unified_type = + unifyTypes(mapping.first, other_mapping->first); + if (maybe_unified_type) { + ret.setRefinement( + name, TypeAndRange(*maybe_unified_type, mapping.second)); + } + } + } + mappings_ = std::move(ret.mappings_); + } + + // return the union of the values to type mappings in a and b whose + // types can be unified + void unionRefinements(const Refinements& other) { + Refinements ret; + for (const auto& name_mapping : mappings_) { + const auto& name = name_mapping.first; + const auto& mapping = name_mapping.second; + TypePtr t_1 = mapping.first; + if (auto other_mapping = other.getRefinement(name_mapping.first)) { + TypePtr t_2 = other_mapping->first; + c10::optional maybe_unified_type = c10::nullopt; + if (t_1->isSubtypeOf(t_2)) { + maybe_unified_type = t_1; + } else if (t_2->isSubtypeOf(t_1)) { + maybe_unified_type = t_2; + } + if (maybe_unified_type) { + ret.setRefinement( + name, TypeAndRange(*maybe_unified_type, mapping.second)); + } + } else { + ret.setRefinement(name, mapping); + } + } + + for (auto& name_mapping : other.mappings_) { + if (!getRefinement(name_mapping.first)) { + ret.setRefinement(name_mapping.first, name_mapping.second); + } + } + + mappings_ = std::move(ret.mappings_); + } +}; + +// When a comparison like x is None is made, we associate type refinements +// with its true value and its false value. If a boolean that has refinements +// associated with it is used in a conditional of an if statememt, the true and +// false refinements are inserted into the corresponding blocks + +struct BoolInfo { + BoolInfo(Refinements true_refinements, Refinements false_refinements) + : true_refinements_(std::move(true_refinements)), + false_refinements_(std::move(false_refinements)){}; + BoolInfo() = default; + + Refinements true_refinements_; + Refinements false_refinements_; + + BoolInfo* mergeOr(const BoolInfo& other) { + // if the result of an OR is true, either a & b could have been true, + // so we take the intersection of a.true_refinements & b.true_refinements. + // if the result is false, both a and b had to be false, + // so we take their union. + true_refinements_.intersectRefinements(other.true_refinements_); + false_refinements_.unionRefinements(other.false_refinements_); + return this; + } + + BoolInfo* mergeAnd(const BoolInfo& other) { + // if the result of an AND is true, both a & b had to be true, + // so we take the union of a.true_refinements and b.true_refinements. + // if the result is false, either a or b could have been false, + // so we take their intersection. + true_refinements_.unionRefinements(other.true_refinements_); + false_refinements_.intersectRefinements(other.false_refinements_); + return this; + } +}; + static Value* asSimple(const SugaredValuePtr& value) { if (SimpleValue* sv = dynamic_cast(value.get())) { return sv->getValue(); @@ -817,9 +926,11 @@ struct to_ir { std::shared_ptr emitSingleIfBranch( Block* b, - const List& branch) { + const List& branch, + const Refinements& refinements) { pushFrame(b); WithInsertPoint guard(b); + insertRefinements(refinements); emitStatements(branch); return popFrame(); } @@ -830,23 +941,65 @@ struct to_ir { } Value* emitTernaryIf(const TernaryIf& expr) { + const auto& bool_info = findRefinements(expr.cond()); Value* cond_value = emitCond(expr.cond()); - auto true_expr = [&] { return emitExpr(expr.true_expr()); }; - auto false_expr = [&] { return emitExpr(expr.false_expr()); }; + auto true_expr = [&] { + insertRefinements(bool_info.true_refinements_); + return emitExpr(expr.true_expr()); + }; + auto false_expr = [&] { + insertRefinements(bool_info.false_refinements_); + return emitExpr(expr.false_expr()); + }; return emitIfExpr(expr.range(), cond_value, true_expr, false_expr); } + // Insert subtyping refinements + void insertRefinements(const Refinements& ref) { + for (const auto& name_mappings : ref.mappings_) { + const std::string& name = name_mappings.first; + auto type = name_mappings.second.first; + const auto& range = *name_mappings.second.second; + Value* v = environment_stack->getVar(name, range); + if (type != NoneType::get()) { + Value* output = graph->insert(prim::unchecked_unwrap_optional, {v}); + environment_stack->setVar(range, name, output); + } + // todo @eellison - revisit inserting Nones when None subtypes Optional + } + } + Value* emitShortCircuitIf( const SourceRange& loc, const TreeRef& first_expr, const TreeRef& second_expr, bool is_or) { + const auto first_bool_info = findRefinements(first_expr); Value* first_value = emitCond(Expr(first_expr)); - auto get_first_expr = [first_value] { return first_value; }; - auto get_second_expr = [&] { return emitCond(Expr(second_expr)); }; + const Refinements* first_expr_refinements; + const Refinements* second_expr_refinements; + // if it's an OR the first expr is emitted in the true branch + // and the second expr in the false branch, if it's an AND the opposite + if (is_or) { + first_expr_refinements = &first_bool_info.true_refinements_; + second_expr_refinements = &first_bool_info.false_refinements_; + } else { + first_expr_refinements = &first_bool_info.false_refinements_; + second_expr_refinements = &first_bool_info.true_refinements_; + } + + auto get_first_expr = [&] { + insertRefinements(*first_expr_refinements); + return first_value; + }; + + auto get_second_expr = [&] { + insertRefinements(*second_expr_refinements); + return emitCond(Expr(second_expr)); + }; - // if this is an OR, eval second expression if first expr is False. + // if this is an OR, eval second expression if first expr is False // If this is an AND, eval second expression if first expr is True if (is_or) { return emitIfExpr(loc, first_value, get_first_expr, get_second_expr); @@ -910,12 +1063,15 @@ struct to_ir { void emitIfElseBlocks(Value* cond_value, const If& stmt) { Node* n = graph->insertNode(create(prim::If, stmt.range(), 0)); n->addInput(cond_value); + const auto bool_info = findRefinements(stmt.cond()); auto* true_block = n->addBlock(); auto* false_block = n->addBlock(); // Emit both blocks once to get the union of all mutated values - auto save_true = emitSingleIfBranch(true_block, stmt.trueBranch()); - auto save_false = emitSingleIfBranch(false_block, stmt.falseBranch()); + auto save_true = emitSingleIfBranch( + true_block, stmt.trueBranch(), bool_info.true_refinements_); + auto save_false = emitSingleIfBranch( + false_block, stmt.falseBranch(), bool_info.false_refinements_); // In python, every variable assigned in an if statement escapes // the scope of the if statement (all variables are scoped to the function). @@ -1039,6 +1195,7 @@ struct to_ir { // emit the whole If stmt as usual, finish emitCond first auto lhs_range = cond_op.lhs().get()->range(); auto rhs_range = cond_op.rhs().get()->range(); + auto kind = getNodeKind(cond.kind(), cond.get()->trees().size()); Value* cond_value = emitBuiltinCall( cond.get()->range(), @@ -1820,6 +1977,51 @@ struct to_ir { } } + BoolInfo findRefinements(const TreeRef& tree) { + switch (tree->kind()) { + case TK_IS: + case TK_ISNOT: { + const auto& inputs = tree->trees(); + if (inputs.at(0)->kind() == TK_VAR && inputs.at(1)->kind() == TK_NONE) { + const std::string& var_name = Var(inputs[0]).name().name(); + Refinements true_info, false_info; + auto type = + environment_stack->getVar(var_name, inputs[0]->range())->type(); + if (auto opt_type = type->cast()) { + false_info.setRefinement( + var_name, + TypeAndRange(opt_type->getElementType(), &tree->range())); + true_info.setRefinement( + var_name, TypeAndRange(NoneType::get(), &tree->range())); + } + if (tree->kind() == TK_IS) { + return BoolInfo(true_info, false_info); + } else { + return BoolInfo(false_info, true_info); + } + } + } break; + case TK_NOT: { + const auto& inputs = tree->trees(); + auto bool_info = findRefinements(inputs[0]); + return BoolInfo( + bool_info.false_refinements_, bool_info.true_refinements_); + } + case TK_OR: + case TK_AND: { + const auto& inputs = tree->trees(); + auto first = findRefinements(inputs[0]); + auto second = findRefinements(inputs[1]); + if (tree->kind() == TK_OR) { + return *first.mergeOr(second); + } else { + return *first.mergeAnd(second); + } + } + } + return BoolInfo(); + } + Value* emitExpr(const Expr& tree, const TypePtr& type_hint = nullptr) { return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method); } @@ -2024,7 +2226,7 @@ struct to_ir { elem_type = values.at(0)->type(); } for (auto v : values) { - if (*v->type() != *elem_type) { + if (*v->type() != *elem_type) { throw ErrorReport(tree) << "Lists must contain only a single type, expected: " << *elem_type << " but found " << *v->type() << " instead"; -- 2.7.4