Allow ints, floats, and tensors in conditional (#18755)
authorElias Ellison <eellison@fb.com>
Thu, 4 Apr 2019 00:09:37 +0000 (17:09 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 00:12:17 +0000 (17:12 -0700)
Summary:
Per our offline discussion, allow Tensors, ints, and floats to be casted to be bool when used in a conditional

Fix for https://github.com/pytorch/pytorch/issues/18381
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18755

Reviewed By: driazati

Differential Revision: D14752476

Pulled By: eellison

fbshipit-source-id: 149960c92afcf7e4cc4997bccc57f4e911118ff1

docs/source/jit.rst
test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp

index ad968a0..2d7ead8 100644 (file)
@@ -478,6 +478,9 @@ If Statements
         else:
             r = 3 * a
 
+In addition to bools, floats, ints, and Tensors can be used in a conditional
+and will be implicitly casted to a boolean.
+
 While Loops
 
   ::
index 79cfba3..18bf956 100644 (file)
@@ -4526,13 +4526,48 @@ a")
 
         self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
 
-    def test_explicit_bool_cast(self):
-        with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
+    def test_conditional_casting(self):
+        def test_bool_cast_tensor(x):
+            if x:
+                return 1
+            else:
+                return 0
+
+        for make_one_dim in [True, False]:
+            for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]:
+                inp_val = [inp_val] if make_one_dim else inp_val
+                self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),))
+
+        self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception,
+                                    "bool value of Tensor with more than one value")
+
+        def test_cast_int(x):
+            # type: (int) -> int
+            if x:
+                return 1
+            else:
+                return 0
+        self.checkScript(test_cast_int, (1,))
+        self.checkScript(test_cast_int, (0,))
+        self.checkScript(test_cast_int, (-1,))
+
+        def test_cast_float(x):
+            # type: (float) -> int
+            if x:
+                return 1
+            else:
+                return 0
+        self.checkScript(test_cast_float, (1.,))
+        self.checkScript(test_cast_float, (0.,))
+        self.checkScript(test_cast_float, (-1.,))
+
+        with self.assertRaisesRegex(RuntimeError, "expected a bool, int, float, or Tensor"):
             @torch.jit.script
-            def test_bool_cast(a):
-                if a:
-                    return a + 2
-                return a + 1
+            def test_bad_conditional(x):
+                if (1, 2):
+                    return
+                else:
+                    return 0
 
     def test_while_nonexistent_value(self):
         with self.assertRaisesRegex(RuntimeError, "undefined value x"):
index 5b944b8..688ca0e 100644 (file)
@@ -8,9 +8,9 @@
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/script/logging.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/jit_exception.h>
+#include <torch/csrc/jit/script/logging.h>
 
 #include <ATen/ExpandUtils.h>
 #include <ATen/WrapDimUtils.h>
@@ -139,7 +139,7 @@ RegisterOperators reg(
          [](Stack& stack) {
            at::Tensor a;
            pop(stack, a);
-           push(stack, a.item<int64_t>() != 0);
+           push(stack, a.is_nonzero());
            return 0;
          }),
      Operator(
@@ -889,46 +889,47 @@ RegisterOperators reg(
          userObj->setSlot(slot, std::move(v));
          return 0;
        };
-     })
-     });
-
-RegisterOperators logging_operators({
-    Operator("prim::AddStatValue(str key, int val) -> ()", [](Stack& stack) {
-          auto val = pop(stack).toInt();
-          auto key = pop(stack).toString();
-
-          auto schema = parseSchema("prim::AddStatValue(str key, int val) -> ()");
-          // TODO: remove this custom tracing code once the custom op bugfix lands
-          if (jit::tracer::isTracing()) {
-            const auto& graph = tracer::getTracingState()->graph;
-            Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
-            tracer::recordSourceLocation(node);
-            node->addInput(insertConstant(*graph, key));
-            tracer::addInputs(node, "val", val);
-            graph->insertNode(node);
-          }
-          torch::jit::logging::getLogger()->addStatValue(*key, val);
-          return 0;
-    }),
-    Operator("prim::TimePoint() -> int", [](Stack& stack) {
-        auto schema = parseSchema("prim::TimePoint() -> int");
-        Node* node = nullptr;
-        // TODO: remove this custom tracing code once the custom op bugfix lands
-        if (jit::tracer::isTracing()) {
-            const auto& graph = tracer::getTracingState()->graph;
-            Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
-            tracer::recordSourceLocation(node);
-            graph->insertNode(node);
-        }
-        auto output = autograd::profiler::getTime();
-        push(stack, output);
-        if (jit::tracer::isTracing()) {
-          jit::tracer::addOutput(node, output);
-        }
-        return 0;
-    })
-});
+     })});
 
+RegisterOperators logging_operators(
+    {Operator(
+         "prim::AddStatValue(str key, int val) -> ()",
+         [](Stack& stack) {
+           auto val = pop(stack).toInt();
+           auto key = pop(stack).toString();
+
+           auto schema =
+               parseSchema("prim::AddStatValue(str key, int val) -> ()");
+           // TODO: remove this custom tracing code once the custom op bugfix
+           // lands
+           if (jit::tracer::isTracing()) {
+             const auto& graph = tracer::getTracingState()->graph;
+             Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
+             tracer::recordSourceLocation(node);
+             node->addInput(insertConstant(*graph, key));
+             tracer::addInputs(node, "val", val);
+             graph->insertNode(node);
+           }
+           torch::jit::logging::getLogger()->addStatValue(*key, val);
+           return 0;
+         }),
+     Operator("prim::TimePoint() -> int", [](Stack& stack) {
+       auto schema = parseSchema("prim::TimePoint() -> int");
+       Node* node = nullptr;
+       // TODO: remove this custom tracing code once the custom op bugfix lands
+       if (jit::tracer::isTracing()) {
+         const auto& graph = tracer::getTracingState()->graph;
+         Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
+         tracer::recordSourceLocation(node);
+         graph->insertNode(node);
+       }
+       auto output = autograd::profiler::getTime();
+       push(stack, output);
+       if (jit::tracer::isTracing()) {
+         jit::tracer::addOutput(node, output);
+       }
+       return 0;
+     })});
 
 // define implementations for primitive number ops
 #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
@@ -1577,7 +1578,7 @@ int dictGetDefault(Stack& stack) {
   return 0;
 }
 
-template<typename T>
+template <typename T>
 int hashValue(Stack& stack) {
   auto value = pop(stack);
   auto hash = std::hash<T>()(value.to<T>());
@@ -1727,7 +1728,7 @@ RegisterOperators reg2({
 #undef CREATE_MUTABLE_LIST_OPS
 
 #define CREATE_LIST_OPS(decl_type, c_type)                                          \
-      Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>),     \
+  Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>),         \
       Operator(                                                                     \
           "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type           \
           "[]",                                                                     \
@@ -1966,28 +1967,27 @@ RegisterOperators reg2({
           push(stack, t);
           return 0;
         }),
-#define CREATE_DICT_OPS(key_type)                                            \
-  Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen),         \
-      Operator(                                                              \
-          "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)",     \
-          dictKeys),                                                         \
-      Operator(                                                              \
-          "aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues), \
-      Operator(                                                              \
-          "prim::DictIndex(Dict(" key_type ", t) self, " key_type            \
-          " key) -> t(*)",                                                   \
-          dictIndex),                                                        \
-      Operator(                                                              \
-          "aten::get(Dict(" key_type ", t) self, " key_type                  \
-          " key) -> t(*)?",                                                   \
-          dictGet),                                                          \
-      Operator(                                                              \
-          "aten::get(Dict(" key_type ", t) self, " key_type                  \
-          " key, t default_value) -> t(*)",                                  \
-          dictGetDefault),                                                   \
-      Operator(                                                              \
-          "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type           \
-          " idx, t v) -> ()",                                                \
+#define CREATE_DICT_OPS(key_type)                                             \
+  Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen),          \
+      Operator(                                                               \
+          "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)",      \
+          dictKeys),                                                          \
+      Operator(                                                               \
+          "aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues),  \
+      Operator(                                                               \
+          "prim::DictIndex(Dict(" key_type ", t) self, " key_type             \
+          " key) -> t(*)",                                                    \
+          dictIndex),                                                         \
+      Operator(                                                               \
+          "aten::get(Dict(" key_type ", t) self, " key_type " key) -> t(*)?", \
+          dictGet),                                                           \
+      Operator(                                                               \
+          "aten::get(Dict(" key_type ", t) self, " key_type                   \
+          " key, t default_value) -> t(*)",                                   \
+          dictGetDefault),                                                    \
+      Operator(                                                               \
+          "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type            \
+          " idx, t v) -> ()",                                                 \
           dictSetItem)
 
     CREATE_DICT_OPS("str"),
@@ -1995,7 +1995,6 @@ RegisterOperators reg2({
     CREATE_DICT_OPS("float"),
 #undef CREATE_DICT_OPS
 
-
     Operator("aten::hash(str t) -> int", hashValue<std::string>),
     Operator("aten::hash(int t) -> int", hashValue<int>),
     Operator("aten::hash(float t) -> int", hashValue<double>),
index b52b995..dd10f16 100644 (file)
@@ -405,7 +405,6 @@ struct Environment {
           {"abs", std::make_shared<BuiltinFunction>(prim::abs, at::nullopt)},
           {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
           {"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
-          {"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
           {"rangelist",
            std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
       };
@@ -1125,14 +1124,23 @@ struct to_ir {
   Value* emitCond(const Expr& cond) {
     Value* v = emitExpr(cond);
     if (!v->type()->isSubtypeOf(BoolType::get())) {
-      ErrorReport error(cond);
-      error << "expected a boolean expression for condition but found "
+      Value* cast_v = emitBuiltinCall(
+          cond.get()->range(),
+          *v->owningGraph(),
+          prim::Bool,
+          c10::nullopt,
+          {v},
+          {},
+          /*required*/ false);
+      if (cast_v == nullptr) {
+        ErrorReport error(cond);
+        error
+            << "expected a bool, int, float, or Tensor expression for condition but found "
             << v->type()->str();
-      if (v->type()->isSubtypeOf(TensorType::get())) {
-        error << ", to use a tensor in a boolean"
-              << " expression, explicitly cast it with `bool()`";
+        throw error;
+      } else {
+        v = cast_v;
       }
-      throw error;
     }
     return v;
   }
@@ -2726,10 +2734,9 @@ struct to_ir {
           {},
           true);
     } else {
-      throw ErrorReport(loc)
-          << "Indexing only supported on List, Dict, "
-             "Tensor, Tuple, and str but got type '"
-          << gatherable->type()->python_str() << "'";
+      throw ErrorReport(loc) << "Indexing only supported on List, Dict, "
+                                "Tensor, Tuple, and str but got type '"
+                             << gatherable->type()->python_str() << "'";
     }
   }
 };