From: Elias Ellison Date: Thu, 4 Apr 2019 00:09:37 +0000 (-0700) Subject: Allow ints, floats, and tensors in conditional (#18755) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~432 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b80a4fa201d4a7755bfb82b7fbd502297cef41db;p=platform%2Fupstream%2Fpytorch.git Allow ints, floats, and tensors in conditional (#18755) Summary: Per our offline discussion, allow Tensors, ints, and floats to be casted to be bool when used in a conditional Fix for https://github.com/pytorch/pytorch/issues/18381 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18755 Reviewed By: driazati Differential Revision: D14752476 Pulled By: eellison fbshipit-source-id: 149960c92afcf7e4cc4997bccc57f4e911118ff1 --- diff --git a/docs/source/jit.rst b/docs/source/jit.rst index ad968a0..2d7ead8 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -478,6 +478,9 @@ If Statements else: r = 3 * a +In addition to bools, floats, ints, and Tensors can be used in a conditional +and will be implicitly casted to a boolean. + While Loops :: diff --git a/test/test_jit.py b/test/test_jit.py index 79cfba3..18bf956 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4526,13 +4526,48 @@ a") self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1) - def test_explicit_bool_cast(self): - with self.assertRaisesRegex(RuntimeError, "expected a boolean"): + def test_conditional_casting(self): + def test_bool_cast_tensor(x): + if x: + return 1 + else: + return 0 + + for make_one_dim in [True, False]: + for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]: + inp_val = [inp_val] if make_one_dim else inp_val + self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),)) + + self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception, + "bool value of Tensor with more than one value") + + def test_cast_int(x): + # type: (int) -> int + if x: + return 1 + else: + return 0 + self.checkScript(test_cast_int, (1,)) + self.checkScript(test_cast_int, (0,)) + self.checkScript(test_cast_int, (-1,)) + + def test_cast_float(x): + # type: (float) -> int + if x: + return 1 + else: + return 0 + self.checkScript(test_cast_float, (1.,)) + self.checkScript(test_cast_float, (0.,)) + self.checkScript(test_cast_float, (-1.,)) + + with self.assertRaisesRegex(RuntimeError, "expected a bool, int, float, or Tensor"): @torch.jit.script - def test_bool_cast(a): - if a: - return a + 2 - return a + 1 + def test_bad_conditional(x): + if (1, 2): + return + else: + return 0 def test_while_nonexistent_value(self): with self.assertRaisesRegex(RuntimeError, "undefined value x"): diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 5b944b8..688ca0e 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -8,9 +8,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -139,7 +139,7 @@ RegisterOperators reg( [](Stack& stack) { at::Tensor a; pop(stack, a); - push(stack, a.item() != 0); + push(stack, a.is_nonzero()); return 0; }), Operator( @@ -889,46 +889,47 @@ RegisterOperators reg( userObj->setSlot(slot, std::move(v)); return 0; }; - }) - }); - -RegisterOperators logging_operators({ - Operator("prim::AddStatValue(str key, int val) -> ()", [](Stack& stack) { - auto val = pop(stack).toInt(); - auto key = pop(stack).toString(); - - auto schema = parseSchema("prim::AddStatValue(str key, int val) -> ()"); - // TODO: remove this custom tracing code once the custom op bugfix lands - if (jit::tracer::isTracing()) { - const auto& graph = tracer::getTracingState()->graph; - Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0); - tracer::recordSourceLocation(node); - node->addInput(insertConstant(*graph, key)); - tracer::addInputs(node, "val", val); - graph->insertNode(node); - } - torch::jit::logging::getLogger()->addStatValue(*key, val); - return 0; - }), - Operator("prim::TimePoint() -> int", [](Stack& stack) { - auto schema = parseSchema("prim::TimePoint() -> int"); - Node* node = nullptr; - // TODO: remove this custom tracing code once the custom op bugfix lands - if (jit::tracer::isTracing()) { - const auto& graph = tracer::getTracingState()->graph; - Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0); - tracer::recordSourceLocation(node); - graph->insertNode(node); - } - auto output = autograd::profiler::getTime(); - push(stack, output); - if (jit::tracer::isTracing()) { - jit::tracer::addOutput(node, output); - } - return 0; - }) -}); + })}); +RegisterOperators logging_operators( + {Operator( + "prim::AddStatValue(str key, int val) -> ()", + [](Stack& stack) { + auto val = pop(stack).toInt(); + auto key = pop(stack).toString(); + + auto schema = + parseSchema("prim::AddStatValue(str key, int val) -> ()"); + // TODO: remove this custom tracing code once the custom op bugfix + // lands + if (jit::tracer::isTracing()) { + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0); + tracer::recordSourceLocation(node); + node->addInput(insertConstant(*graph, key)); + tracer::addInputs(node, "val", val); + graph->insertNode(node); + } + torch::jit::logging::getLogger()->addStatValue(*key, val); + return 0; + }), + Operator("prim::TimePoint() -> int", [](Stack& stack) { + auto schema = parseSchema("prim::TimePoint() -> int"); + Node* node = nullptr; + // TODO: remove this custom tracing code once the custom op bugfix lands + if (jit::tracer::isTracing()) { + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0); + tracer::recordSourceLocation(node); + graph->insertNode(node); + } + auto output = autograd::profiler::getTime(); + push(stack, output); + if (jit::tracer::isTracing()) { + jit::tracer::addOutput(node, output); + } + return 0; + })}); // define implementations for primitive number ops #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \ @@ -1577,7 +1578,7 @@ int dictGetDefault(Stack& stack) { return 0; } -template +template int hashValue(Stack& stack) { auto value = pop(stack); auto hash = std::hash()(value.to()); @@ -1727,7 +1728,7 @@ RegisterOperators reg2({ #undef CREATE_MUTABLE_LIST_OPS #define CREATE_LIST_OPS(decl_type, c_type) \ - Operator("aten::len(" decl_type "[] a) -> int", listLen>), \ + Operator("aten::len(" decl_type "[] a) -> int", listLen>), \ Operator( \ "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \ "[]", \ @@ -1966,28 +1967,27 @@ RegisterOperators reg2({ push(stack, t); return 0; }), -#define CREATE_DICT_OPS(key_type) \ - Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \ - Operator( \ - "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \ - dictKeys), \ - Operator( \ - "aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues), \ - Operator( \ - "prim::DictIndex(Dict(" key_type ", t) self, " key_type \ - " key) -> t(*)", \ - dictIndex), \ - Operator( \ - "aten::get(Dict(" key_type ", t) self, " key_type \ - " key) -> t(*)?", \ - dictGet), \ - Operator( \ - "aten::get(Dict(" key_type ", t) self, " key_type \ - " key, t default_value) -> t(*)", \ - dictGetDefault), \ - Operator( \ - "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \ - " idx, t v) -> ()", \ +#define CREATE_DICT_OPS(key_type) \ + Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \ + Operator( \ + "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \ + dictKeys), \ + Operator( \ + "aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues), \ + Operator( \ + "prim::DictIndex(Dict(" key_type ", t) self, " key_type \ + " key) -> t(*)", \ + dictIndex), \ + Operator( \ + "aten::get(Dict(" key_type ", t) self, " key_type " key) -> t(*)?", \ + dictGet), \ + Operator( \ + "aten::get(Dict(" key_type ", t) self, " key_type \ + " key, t default_value) -> t(*)", \ + dictGetDefault), \ + Operator( \ + "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \ + " idx, t v) -> ()", \ dictSetItem) CREATE_DICT_OPS("str"), @@ -1995,7 +1995,6 @@ RegisterOperators reg2({ CREATE_DICT_OPS("float"), #undef CREATE_DICT_OPS - Operator("aten::hash(str t) -> int", hashValue), Operator("aten::hash(int t) -> int", hashValue), Operator("aten::hash(float t) -> int", hashValue), diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index b52b995..dd10f16 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -405,7 +405,6 @@ struct Environment { {"abs", std::make_shared(prim::abs, at::nullopt)}, {"list", std::make_shared(aten::list, at::nullopt)}, {"ord", std::make_shared(aten::ord, at::nullopt)}, - {"rangelist", std::make_shared(prim::rangelist, at::nullopt)}, {"rangelist", std::make_shared(prim::rangelist, at::nullopt)}, }; @@ -1125,14 +1124,23 @@ struct to_ir { Value* emitCond(const Expr& cond) { Value* v = emitExpr(cond); if (!v->type()->isSubtypeOf(BoolType::get())) { - ErrorReport error(cond); - error << "expected a boolean expression for condition but found " + Value* cast_v = emitBuiltinCall( + cond.get()->range(), + *v->owningGraph(), + prim::Bool, + c10::nullopt, + {v}, + {}, + /*required*/ false); + if (cast_v == nullptr) { + ErrorReport error(cond); + error + << "expected a bool, int, float, or Tensor expression for condition but found " << v->type()->str(); - if (v->type()->isSubtypeOf(TensorType::get())) { - error << ", to use a tensor in a boolean" - << " expression, explicitly cast it with `bool()`"; + throw error; + } else { + v = cast_v; } - throw error; } return v; } @@ -2726,10 +2734,9 @@ struct to_ir { {}, true); } else { - throw ErrorReport(loc) - << "Indexing only supported on List, Dict, " - "Tensor, Tuple, and str but got type '" - << gatherable->type()->python_str() << "'"; + throw ErrorReport(loc) << "Indexing only supported on List, Dict, " + "Tensor, Tuple, and str but got type '" + << gatherable->type()->python_str() << "'"; } } };