Allow int/float cast to bool (#13391)
authorDavid Riazati <davidriazati@fb.com>
Thu, 27 Dec 2018 23:58:32 +0000 (15:58 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 28 Dec 2018 00:01:08 +0000 (16:01 -0800)
Summary:
This PR adds explicit `bool()` casts to match Python semantics

`bool(1) = True`
`bool(0) = False`
`bool(0.0) = False`
`bool(0.1) = True`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13391

Differential Revision: D12871213

Pulled By: driazati

fbshipit-source-id: 773a48b2647973138efe854abe725d647f1d727d

test/test_jit.py
torch/_jit_internal.py
torch/_six.py
torch/csrc/jit/register_prim_ops.cpp

index 2f7a9b9..4657d80 100644 (file)
@@ -11,7 +11,7 @@ from torch.autograd import Variable, Function
 from torch.autograd.function import traceable
 from torch.testing import assert_allclose
 from torch.onnx import OperatorExportTypes
-from torch._six import inf, PY2
+from torch._six import inf, PY2, builtins
 from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
     skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
     freeze_rng_state, set_rng_seed
@@ -4215,35 +4215,31 @@ a")
             throwsAnd(t)
 
     def test_type_cast(self):
-        def test_int_to_float():
-            b = float(2)
-            return b + 1.0
-        self.checkScript(test_int_to_float, ())
-
-        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
-            @torch.jit.script
-            def test_int_to_bool():
-                return bool(5)
+        template = dedent('''
+        def cast(v):
+            # type: ({from_type}) -> {to_type}
+            return {to_type}(v)
+        ''')
 
-        def test_float_to_int():
-            b = int(5.0)
-            return b + 1
-        self.checkScript(test_float_to_int, ())
+        def check_cast(from_type, to_type, value, raises=False):
+            code = template.format(from_type=from_type, to_type=to_type)
+            expected = getattr(builtins, to_type)(value)
+            if raises:
+                with self.assertRaisesRegex(RuntimeError, "Cannot cast"):
+                    cu = torch.jit.CompilationUnit(code)
+            else:
+                self.checkScript(code, (value,), name='cast', outputs=expected)
 
-        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
-            @torch.jit.script
-            def test_float_to_bool():
-                return bool(5.0)
+        check_cast('int', 'float', 1)
+        check_cast('int', 'bool', 1)
+        check_cast('int', 'bool', 0)
 
-        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
-            @torch.jit.script
-            def test_bool_to_float():
-                return float(True)
+        check_cast('float', 'int', 1.)
+        check_cast('float', 'bool', 1.)
+        check_cast('float', 'bool', 0.)
 
-        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
-            @torch.jit.script
-            def test_bool_to_int():
-                return int(True)
+        check_cast('bool', 'int', True)
+        check_cast('bool', 'float', True)
 
     def test_multiple_assignment(self):
         def outer_func(x):
index 5594d41..154336c 100644 (file)
@@ -6,10 +6,7 @@ circular dependency problems
 
 import weakref
 import inspect
-try:
-    import builtins  # PY3
-except Exception:
-    import __builtin__ as builtins  # PY2
+from torch._six import builtins
 
 # Tracks standalone weak script functions
 _compiled_weak_fns = weakref.WeakKeyDictionary()
index 3a4c2ad..ad50cf2 100644 (file)
@@ -130,3 +130,8 @@ if PY2:
 elif PY3:
     def get_function_from_type(cls, name):
         return getattr(cls, name, None)
+
+if PY2:
+    import __builtin__ as builtins
+elif PY3:
+    import builtins
index b4df638..c1af45e 100644 (file)
@@ -70,6 +70,7 @@ RegisterOperators reg({
             return 0;
           };
         }),
+
     Operator(
         "prim::Bool(Tensor a) -> bool",
         [](const Node* node) -> Operation {
@@ -81,6 +82,26 @@ RegisterOperators reg({
           };
         }),
     Operator(
+        "prim::Bool(int a) -> bool",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            int64_t i;
+            pop(stack, i);
+            push(stack, (bool) i);
+            return 0;
+          };
+        }),
+    Operator(
+        "prim::Bool(float a) -> bool",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            double d;
+            pop(stack, d);
+            push(stack, (bool) d);
+            return 0;
+          };
+        }),
+    Operator(
         "prim::Int(Tensor a) -> int",
         [](const Node* node) -> Operation {
           return [](Stack& stack) {
@@ -164,6 +185,26 @@ RegisterOperators reg({
           };
         }),
     Operator(
+        "prim::Float(bool a) -> float",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            bool b;
+            pop(stack, b);
+            push(stack, (float) b);
+            return 0;
+          };
+        }),
+    Operator(
+        "prim::Int(bool a) -> int",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            bool b;
+            pop(stack, b);
+            push(stack, (int) b);
+            return 0;
+          };
+        }),
+    Operator(
         "prim::Float(str a) -> float",
         [](const Node* node) -> Operation {
           return [](Stack& stack) {