+++ /dev/null
-graph() {
- %0 : float = prim::Constant[value=5]()
- %1 : int = prim::Constant[value=1]()
- %b : int = prim::FloatToInt(%0)
- %3 : int = aten::add(%b, %1)
- return (%3);
-}
+++ /dev/null
-graph() {
- %0 : int = prim::Constant[value=2]()
- %1 : float = prim::Constant[value=1]()
- %b : float = prim::IntToFloat(%0)
- %3 : float = aten::add(%b, %1)
- return (%3);
-}
+++ /dev/null
-graph(%x : Tensor) {
- %1 : int = prim::Constant[value=55]()
- %2 : int = prim::Constant[value=199]()
- %3 : int = prim::Constant[value=1]()
- %4 : int = aten::add(%1, %2)
- %5 : Tensor = ^python_op_in_weak_module()(%x)
- %6 : Tensor = aten::add(%5, %4, %3)
- return (%6);
-}
+++ /dev/null
-graph(%x : Tensor) {
- %1 : int = prim::Constant[value=357]()
- %2 : int = prim::Constant[value=55]()
- %3 : int = prim::Constant[value=199]()
- %4 : int = prim::Constant[value=2]()
- %5 : int = prim::Constant[value=1]()
- %y : Tensor = aten::mul(%x, %4)
- %7 : Tensor = aten::add(%y, %5, %5)
- %8 : int = aten::add(%2, %3)
- %9 : Tensor = ^python_op_in_weak_module()(%y)
- %10 : Tensor = aten::add(%9, %8, %5)
- %11 : Tensor = aten::add(%7, %10, %5)
- %12 : Tensor = aten::add(%y, %1, %5)
- %13 : Tensor = ^python_op_in_strong_module()(%y)
- %14 : Tensor = aten::add(%12, %13, %5)
- %15 : Tensor = aten::add(%11, %14, %5)
- return (%15);
-}
+++ /dev/null
-graph(%x : Tensor
- %1 : Tensor
- %2 : Tensor
- %3 : Tensor
- %4 : Tensor) {
- %5 : int = prim::Constant[value=2]()
- %6 : int = prim::Constant[value=1]()
- %7 : int = prim::Constant[value=3]()
- %8 : int = prim::Constant[value=27]()
- %9 : Tensor = aten::mul(%x, %x)
- %10 : Tensor = aten::add(%9, %7, %6)
- %11 : int = aten::dim(%x)
- %12 : bool = aten::eq(%11, %5)
- %13 : bool = prim::If(%12)
- block0() {
- %14 : None = prim::None()
- %15 : bool = aten::__isnot__(%2, %14)
- -> (%15)
- }
- block1() {
- -> (%12)
- }
- %ret.2 : Tensor = prim::If(%13)
- block0() {
- %17 : Tensor = aten::_unwrap_optional(%2)
- %18 : Tensor = aten::t(%1)
- %ret.3 : Tensor = aten::addmm(%17, %x, %18, %6, %6)
- -> (%ret.3)
- }
- block1() {
- %20 : Tensor = aten::t(%1)
- %output.3 : Tensor = aten::matmul(%x, %20)
- %22 : None = prim::None()
- %23 : bool = aten::__isnot__(%2, %22)
- %output.4 : Tensor = prim::If(%23)
- block0() {
- %25 : Tensor = aten::_unwrap_optional(%2)
- %output.5 : Tensor = aten::add_(%output.3, %25, %6)
- -> (%output.5)
- }
- block1() {
- -> (%output.3)
- }
- -> (%output.4)
- }
- %27 : Tensor = aten::add(%10, %ret.2, %6)
- %28 : Tensor = aten::add(%x, %27, %6)
- %29 : Tensor = aten::add(%x, %8, %6)
- %30 : Tensor = aten::add(%28, %29, %6)
- %31 : int = aten::dim(%x)
- %32 : bool = aten::eq(%31, %5)
- %33 : bool = prim::If(%32)
- block0() {
- %34 : None = prim::None()
- %35 : bool = aten::__isnot__(%4, %34)
- -> (%35)
- }
- block1() {
- -> (%32)
- }
- %ret : Tensor = prim::If(%33)
- block0() {
- %37 : Tensor = aten::_unwrap_optional(%4)
- %38 : Tensor = aten::t(%3)
- %ret.1 : Tensor = aten::addmm(%37, %x, %38, %6, %6)
- -> (%ret.1)
- }
- block1() {
- %40 : Tensor = aten::t(%3)
- %output.1 : Tensor = aten::matmul(%x, %40)
- %42 : None = prim::None()
- %43 : bool = aten::__isnot__(%4, %42)
- %output : Tensor = prim::If(%43)
- block0() {
- %45 : Tensor = aten::_unwrap_optional(%4)
- %output.2 : Tensor = aten::add_(%output.1, %45, %6)
- -> (%output.2)
- }
- block1() {
- -> (%output.1)
- }
- -> (%output)
- }
- %47 : Tensor = aten::add(%30, %ret, %6)
- %48 : Tensor = aten::add(%x, %47, %6)
- return (%48);
-}
+++ /dev/null
-graph(%x : Tensor
- %1 : Tensor
- %2 : Tensor
- %3 : Tensor
- %4 : Tensor
- %5 : Tensor
- %6 : Tensor) {
- %7 : int = prim::Constant[value=1]()
- %8 : int = prim::Constant[value=2]()
- %9 : int = aten::dim(%x)
- %10 : bool = aten::eq(%9, %8)
- %11 : bool = prim::If(%10)
- block0() {
- %12 : None = prim::None()
- %13 : bool = aten::__isnot__(%2, %12)
- -> (%13)
- }
- block1() {
- -> (%10)
- }
- %ret.2 : Tensor = prim::If(%11)
- block0() {
- %15 : Tensor = aten::_unwrap_optional(%2)
- %16 : Tensor = aten::t(%1)
- %ret.3 : Tensor = aten::addmm(%15, %x, %16, %7, %7)
- -> (%ret.3)
- }
- block1() {
- %18 : Tensor = aten::t(%1)
- %output.3 : Tensor = aten::matmul(%x, %18)
- %20 : None = prim::None()
- %21 : bool = aten::__isnot__(%2, %20)
- %output.4 : Tensor = prim::If(%21)
- block0() {
- %23 : Tensor = aten::_unwrap_optional(%2)
- %output.5 : Tensor = aten::add_(%output.3, %23, %7)
- -> (%output.5)
- }
- block1() {
- -> (%output.3)
- }
- -> (%output.4)
- }
- %25 : Tensor = aten::add(%ret.2, %3, %7)
- %26 : Tensor = aten::add(%x, %25, %7)
- %27 : int = aten::dim(%x)
- %28 : bool = aten::eq(%27, %8)
- %29 : bool = prim::If(%28)
- block0() {
- %30 : None = prim::None()
- %31 : bool = aten::__isnot__(%2, %30)
- -> (%31)
- }
- block1() {
- -> (%28)
- }
- %ret.4 : Tensor = prim::If(%29)
- block0() {
- %33 : Tensor = aten::_unwrap_optional(%2)
- %34 : Tensor = aten::t(%1)
- %ret.5 : Tensor = aten::addmm(%33, %x, %34, %7, %7)
- -> (%ret.5)
- }
- block1() {
- %36 : Tensor = aten::t(%1)
- %output.6 : Tensor = aten::matmul(%x, %36)
- %38 : None = prim::None()
- %39 : bool = aten::__isnot__(%2, %38)
- %output.7 : Tensor = prim::If(%39)
- block0() {
- %41 : Tensor = aten::_unwrap_optional(%2)
- %output.8 : Tensor = aten::add_(%output.6, %41, %7)
- -> (%output.8)
- }
- block1() {
- -> (%output.6)
- }
- -> (%output.7)
- }
- %43 : Tensor = aten::add(%ret.4, %3, %7)
- %44 : Tensor = aten::add(%26, %43, %7)
- %45 : int = aten::dim(%x)
- %46 : bool = aten::eq(%45, %8)
- %47 : bool = prim::If(%46)
- block0() {
- %48 : None = prim::None()
- %49 : bool = aten::__isnot__(%5, %48)
- -> (%49)
- }
- block1() {
- -> (%46)
- }
- %ret : Tensor = prim::If(%47)
- block0() {
- %51 : Tensor = aten::_unwrap_optional(%5)
- %52 : Tensor = aten::t(%4)
- %ret.1 : Tensor = aten::addmm(%51, %x, %52, %7, %7)
- -> (%ret.1)
- }
- block1() {
- %54 : Tensor = aten::t(%4)
- %output.1 : Tensor = aten::matmul(%x, %54)
- %56 : None = prim::None()
- %57 : bool = aten::__isnot__(%5, %56)
- %output : Tensor = prim::If(%57)
- block0() {
- %59 : Tensor = aten::_unwrap_optional(%5)
- %output.2 : Tensor = aten::add_(%output.1, %59, %7)
- -> (%output.2)
- }
- block1() {
- -> (%output.1)
- }
- -> (%output)
- }
- %61 : Tensor = aten::add(%ret, %6, %7)
- %62 : Tensor = aten::add(%44, %61, %7)
- return (%62);
-}
+++ /dev/null
-graph(%x.1 : Tensor) {
- %1 : int = prim::Constant[value=5]()
- %2 : int = prim::Constant[value=6]()
- %3 : int = prim::Constant[value=4]()
- %4 : int = prim::Constant[value=3]()
- %5 : int = prim::Constant[value=2]()
- %6 : int = prim::Constant[value=1]()
- %x.3 : Tensor = ^not_a_script_fn()(%x.1)
- %8 : Tensor = aten::norm(%x.3, %5)
- %9 : Tensor = aten::gt(%8, %5)
- %10 : bool = prim::TensorToBool(%9)
- %x.4 : Tensor = prim::If(%10)
- block0() {
- %x.2 : Tensor = aten::add(%x.3, %4, %6)
- -> (%x.2)
- }
- block1() {
- -> (%x.3)
- }
- %13 : Tensor = aten::add(%x.4, %3, %6)
- %14 : Tensor = ^not_a_script_fn()(%x.4)
- %15 : Tensor = aten::add(%14, %x.4, %6)
- %16 : Tensor = aten::add(%x.4, %6, %6)
- %17 : Tensor = aten::add(%15, %16, %6)
- %x : Tensor = aten::add(%13, %17, %6)
- %19 : Tensor = aten::add(%x, %1, %6)
- %20 : Tensor = aten::add(%x, %2, %6)
- %21 : Tensor = ^not_a_script_fn()(%x)
- %22 : Tensor = aten::add(%20, %21, %6)
- %23 : Tensor = aten::add(%19, %22, %6)
- %24 : Tensor = aten::add(%x, %2, %6)
- %25 : Tensor = ^not_a_script_fn()(%x)
- %26 : Tensor = aten::add(%24, %25, %6)
- %27 : Tensor = aten::add(%23, %26, %6)
- return (%27);
-}
throwsAnd(t)
def test_type_cast(self):
- @torch.jit.script
def test_int_to_float():
b = float(2)
return b + 1.0
+ self.checkScript(test_int_to_float, ())
with self.assertRaisesRegex(RuntimeError, "Cannot cast type"):
@torch.jit.script
def test_int_to_bool():
return bool(5)
- @torch.jit.script
def test_float_to_int():
b = int(5.0)
return b + 1
+ self.checkScript(test_float_to_int, ())
with self.assertRaisesRegex(RuntimeError, "Cannot cast type"):
@torch.jit.script
def test_bool_to_int():
return int(True)
- self.assertExpectedGraph(test_int_to_float.graph, "test_int_to_float")
- self.assertExpectedGraph(test_float_to_int.graph, "test_float_to_int")
-
def test_multiple_assignment(self):
def outer_func(x):
return x * 2, x + 2
x = strong_script_fn(x)
return weak_script_fn(x)
- scripted = torch.jit.script(fn)
-
input = torch.randn(3, 4, 5)
- self.assertExpectedGraph(scripted.graph)
- self.assertEqual(scripted(input), fn(input))
+ self.checkScript(fn, (input,))
def test_python_op_exception(self):
def python_op(x):
python_result = weak_mod(x)
strong_mod = Passthrough()
script_result = strong_mod(x)
+
self.assertEqual(python_result, expected_result)
self.assertEqual(script_result, expected_result)
- self.assertExpectedGraph(strong_mod.graph, "basic")
class Strong(torch.jit.ScriptModule):
def __init__(self):
script_result2 = strong_mod2(x)
self.assertEqual(script_result, expected_result)
self.assertEqual(script_result, script_result2)
- self.assertExpectedGraph(strong_mod.graph, "scope_test")
def test_weak_module_parameters_and_buffers(self):
weights = torch.randn(10, 10)
return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
strong_mod = Strong()
- self.assertExpectedGraph(strong_mod.graph)
# Run same calculation as module
inp = torch.ones(10)
expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
self.assertEqual(strong_mod(inp), expected_result)
+ self.assertExportImportModule(strong_mod, (inp,))
def test_weak_module_nested(self):
@torch._jit_internal.weak_module
return x + self.weak(x)
strong_mod = Strong()
- self.assertExpectedGraph(strong_mod.graph)
inp = torch.randn(10)
result = strong_mod(inp)
expected_result = inp + (inp + inp * inp + inp + 27) + 3 \