from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
-from torch._six import inf, PY2
+from torch._six import inf, PY2, builtins
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state, set_rng_seed
throwsAnd(t)
def test_type_cast(self):
- def test_int_to_float():
- b = float(2)
- return b + 1.0
- self.checkScript(test_int_to_float, ())
-
- with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
- @torch.jit.script
- def test_int_to_bool():
- return bool(5)
+ template = dedent('''
+ def cast(v):
+ # type: ({from_type}) -> {to_type}
+ return {to_type}(v)
+ ''')
- def test_float_to_int():
- b = int(5.0)
- return b + 1
- self.checkScript(test_float_to_int, ())
+ def check_cast(from_type, to_type, value, raises=False):
+ code = template.format(from_type=from_type, to_type=to_type)
+ expected = getattr(builtins, to_type)(value)
+ if raises:
+ with self.assertRaisesRegex(RuntimeError, "Cannot cast"):
+ cu = torch.jit.CompilationUnit(code)
+ else:
+ self.checkScript(code, (value,), name='cast', outputs=expected)
- with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
- @torch.jit.script
- def test_float_to_bool():
- return bool(5.0)
+ check_cast('int', 'float', 1)
+ check_cast('int', 'bool', 1)
+ check_cast('int', 'bool', 0)
- with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
- @torch.jit.script
- def test_bool_to_float():
- return float(True)
+ check_cast('float', 'int', 1.)
+ check_cast('float', 'bool', 1.)
+ check_cast('float', 'bool', 0.)
- with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
- @torch.jit.script
- def test_bool_to_int():
- return int(True)
+ check_cast('bool', 'int', True)
+ check_cast('bool', 'float', True)
def test_multiple_assignment(self):
def outer_func(x):
return 0;
};
}),
+
Operator(
"prim::Bool(Tensor a) -> bool",
[](const Node* node) -> Operation {
};
}),
Operator(
+ "prim::Bool(int a) -> bool",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ int64_t i;
+ pop(stack, i);
+ push(stack, (bool) i);
+ return 0;
+ };
+ }),
+ Operator(
+ "prim::Bool(float a) -> bool",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ double d;
+ pop(stack, d);
+ push(stack, (bool) d);
+ return 0;
+ };
+ }),
+ Operator(
"prim::Int(Tensor a) -> int",
[](const Node* node) -> Operation {
return [](Stack& stack) {
};
}),
Operator(
+ "prim::Float(bool a) -> float",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ bool b;
+ pop(stack, b);
+ push(stack, (float) b);
+ return 0;
+ };
+ }),
+ Operator(
+ "prim::Int(bool a) -> int",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ bool b;
+ pop(stack, b);
+ push(stack, (int) b);
+ return 0;
+ };
+ }),
+ Operator(
"prim::Float(str a) -> float",
[](const Node* node) -> Operation {
return [](Stack& stack) {