From 7fc3aa8c49cda3433a2a04e7be504d3dfffbd426 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 4 Mar 2019 22:38:41 -0800 Subject: [PATCH] Remove Expect Files from python / tracing / script interop Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17622 Differential Revision: D14308307 Pulled By: eellison fbshipit-source-id: bda249d38ac2570000a12b0ca328c26233ecefe8 --- ...ors.test_script_graph_contains_custom_op.expect | 3 - test/expect/TestJit.test_nested_inplace.expect | 5 -- test/expect/TestJit.test_recursive_cse.expect | 12 --- test/expect/TestJit.test_shared_param.expect | 6 -- .../expect/TestScript.test_augmented_assign.expect | 9 --- ...cript.test_call_python_fn_from_script_fn.expect | 5 -- ...t.test_call_python_fn_from_script_module.expect | 5 -- ...t.test_call_python_fn_from_traced_module.expect | 5 -- ...ript.test_call_python_mod_from_script_fn.expect | 5 -- ....test_call_python_mod_from_script_module.expect | 5 -- ....test_call_python_mod_from_traced_module.expect | 9 --- ...ipt.test_call_python_mod_from_tracing_fn.expect | 7 -- ...cript.test_call_script_fn_from_script_fn.expect | 5 -- ...t.test_call_script_fn_from_script_module.expect | 5 -- ...t.test_call_script_fn_from_traced_module.expect | 5 -- ...ript.test_call_script_fn_from_tracing_fn.expect | 6 -- ...ript.test_call_script_mod_from_script_fn.expect | 12 --- ....test_call_script_mod_from_script_module.expect | 6 -- ...ipt.test_call_script_mod_from_tracing_fn.expect | 7 -- ...st_call_script_module_from_traced_module.expect | 9 --- ...cript.test_call_traced_fn_from_script_fn.expect | 5 -- ...t.test_call_traced_fn_from_traced_module.expect | 5 -- ...ript.test_call_traced_fn_from_tracing_fn.expect | 6 -- ...ript.test_call_traced_mod_from_script_fn.expect | 12 --- ...ipt.test_call_traced_mod_from_tracing_fn.expect | 7 -- ...st_call_traced_module_from_traced_module.expect | 9 --- ....test_call_tracing_fn_from_script_module.expect | 5 -- ...test_call_tracing_mod_from_script_module.expect | 6 -- test/test_jit.py | 92 ++++++++++++++-------- torch/csrc/jit/testing/file_check.cpp | 11 +++ 30 files changed, 68 insertions(+), 221 deletions(-) delete mode 100644 test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect delete mode 100644 test/expect/TestJit.test_nested_inplace.expect delete mode 100644 test/expect/TestJit.test_recursive_cse.expect delete mode 100644 test/expect/TestJit.test_shared_param.expect delete mode 100644 test/expect/TestScript.test_augmented_assign.expect delete mode 100644 test/expect/TestScript.test_call_python_fn_from_script_fn.expect delete mode 100644 test/expect/TestScript.test_call_python_fn_from_script_module.expect delete mode 100644 test/expect/TestScript.test_call_python_fn_from_traced_module.expect delete mode 100644 test/expect/TestScript.test_call_python_mod_from_script_fn.expect delete mode 100644 test/expect/TestScript.test_call_python_mod_from_script_module.expect delete mode 100644 test/expect/TestScript.test_call_python_mod_from_traced_module.expect delete mode 100644 test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect delete mode 100644 test/expect/TestScript.test_call_script_fn_from_script_fn.expect delete mode 100644 test/expect/TestScript.test_call_script_fn_from_script_module.expect delete mode 100644 test/expect/TestScript.test_call_script_fn_from_traced_module.expect delete mode 100644 test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect delete mode 100644 test/expect/TestScript.test_call_script_mod_from_script_fn.expect delete mode 100644 test/expect/TestScript.test_call_script_mod_from_script_module.expect delete mode 100644 test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect delete mode 100644 test/expect/TestScript.test_call_script_module_from_traced_module.expect delete mode 100644 test/expect/TestScript.test_call_traced_fn_from_script_fn.expect delete mode 100644 test/expect/TestScript.test_call_traced_fn_from_traced_module.expect delete mode 100644 test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect delete mode 100644 test/expect/TestScript.test_call_traced_mod_from_script_fn.expect delete mode 100644 test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect delete mode 100644 test/expect/TestScript.test_call_traced_module_from_traced_module.expect delete mode 100644 test/expect/TestScript.test_call_tracing_fn_from_script_module.expect delete mode 100644 test/expect/TestScript.test_call_tracing_mod_from_script_module.expect diff --git a/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect b/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect deleted file mode 100644 index 490b53f..0000000 --- a/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect +++ /dev/null @@ -1,3 +0,0 @@ -graph(%x : Dynamic): - %1 : Dynamic = aten::relu(%x) - return (%1) diff --git a/test/expect/TestJit.test_nested_inplace.expect b/test/expect/TestJit.test_nested_inplace.expect deleted file mode 100644 index 6803ff1..0000000 --- a/test/expect/TestJit.test_nested_inplace.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Double(2, 2)): - %1 : int = prim::Constant[value=0]() - %2 : int = prim::Constant[value=0]() - %3 : Double(2, 2) = aten::threshold_(%x, %1, %2) - return (%3) diff --git a/test/expect/TestJit.test_recursive_cse.expect b/test/expect/TestJit.test_recursive_cse.expect deleted file mode 100644 index c117774..0000000 --- a/test/expect/TestJit.test_recursive_cse.expect +++ /dev/null @@ -1,12 +0,0 @@ -graph(%x : Tensor, - %y : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : Tensor = aten::add(%x, %y, %2) - %4 : Tensor = aten::gt(%3, %x) - %5 : bool = prim::Bool(%4) - %z : Tensor = prim::If(%5) - block0(): - -> (%3) - block1(): - -> (%x) - return (%z) diff --git a/test/expect/TestJit.test_shared_param.expect b/test/expect/TestJit.test_shared_param.expect deleted file mode 100644 index c8b2976..0000000 --- a/test/expect/TestJit.test_shared_param.expect +++ /dev/null @@ -1,6 +0,0 @@ -graph(%0 : Double(2, 2), - %1 : Double(2, 2)): - %2 : Double(2, 2) = aten::mul(%0, %1), scope: MyModule - %3 : int = prim::Constant[value=1](), scope: MyModule - %4 : Double(2, 2) = aten::add(%2, %1, %3), scope: MyModule - return (%4) diff --git a/test/expect/TestScript.test_augmented_assign.expect b/test/expect/TestScript.test_augmented_assign.expect deleted file mode 100644 index e5d047d..0000000 --- a/test/expect/TestScript.test_augmented_assign.expect +++ /dev/null @@ -1,9 +0,0 @@ -graph(%a.1 : Tensor, - %b : Tensor): - %2 : int = prim::Constant[value=1]() - %a.2 : Tensor = aten::add_(%a.1, %b, %2) - %a.3 : Tensor = aten::sub_(%a.2, %b, %2) - %a.4 : Tensor = aten::div_(%a.3, %b) - %a : Tensor = aten::mul_(%a.4, %b) - %7 : (Tensor, Tensor) = prim::TupleConstruct(%a, %b) - return (%7) diff --git a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect b/test/expect/TestScript.test_call_python_fn_from_script_fn.expect deleted file mode 100644 index 9f7bc73..0000000 --- a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor): - %1 : int = prim::Constant[value=1]() - %2 : Tensor = ^python_fn()(%x) - %3 : Tensor = aten::add(%2, %1, %1) - return (%3) diff --git a/test/expect/TestScript.test_call_python_fn_from_script_module.expect b/test/expect/TestScript.test_call_python_fn_from_script_module.expect deleted file mode 100644 index ec5349f..0000000 --- a/test/expect/TestScript.test_call_python_fn_from_script_module.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor, - %1 : Tensor): - %2 : Tensor = aten::mm(%x, %1) - %3 : Tensor = ^python_fn()(%2) - return (%3) diff --git a/test/expect/TestScript.test_call_python_fn_from_traced_module.expect b/test/expect/TestScript.test_call_python_fn_from_traced_module.expect deleted file mode 100644 index 503fe26..0000000 --- a/test/expect/TestScript.test_call_python_fn_from_traced_module.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Double(3, 4), - %1 : Double(4, 3)): - %2 : Double(3, 4) = aten::neg(%x), scope: TracedModule - %3 : Double(3, 3) = aten::mm(%2, %1), scope: TracedModule - return (%3) diff --git a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect b/test/expect/TestScript.test_call_python_mod_from_script_fn.expect deleted file mode 100644 index 7140db6..0000000 --- a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor): - %2 : int = prim::Constant[value=1]() - %1 : Tensor = ^()(%x) - %4 : Tensor = aten::add(%1, %2, %2) - return (%4) diff --git a/test/expect/TestScript.test_call_python_mod_from_script_module.expect b/test/expect/TestScript.test_call_python_mod_from_script_module.expect deleted file mode 100644 index 2512a45..0000000 --- a/test/expect/TestScript.test_call_python_mod_from_script_module.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor, - %1 : Tensor): - %2 : Tensor = aten::mm(%x, %1) - %3 : Tensor = ^()(%2) - return (%3) diff --git a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect deleted file mode 100644 index fe0cd06..0000000 --- a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect +++ /dev/null @@ -1,9 +0,0 @@ -graph(%x.1 : Double(3, 4), - %1 : Double(4, 5), - %2 : Double(5, 7)): - %x : Double(3, 5) = aten::mm(%x.1, %1), scope: TracedModule - %4 : Double(3, 7) = aten::mm(%x, %2), scope: TracedModule/PythonModule[mod] - %5 : Double() = prim::Constant[value={1}](), scope: TracedModule - %6 : int = prim::Constant[value=1](), scope: TracedModule - %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule - return (%7) diff --git a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect deleted file mode 100644 index f66d55b..0000000 --- a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect +++ /dev/null @@ -1,7 +0,0 @@ -graph(%x : Double(3, 4)): - %1 : Double(4, 3) = prim::Constant[value=](), scope: PythonMod - %2 : Double(3, 3) = aten::mm(%x, %1), scope: PythonMod - %3 : Double() = prim::Constant[value={1}]() - %4 : int = prim::Constant[value=1]() - %5 : Double(3, 3) = aten::add(%2, %3, %4) - return (%5) diff --git a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect b/test/expect/TestScript.test_call_script_fn_from_script_fn.expect deleted file mode 100644 index 51ea22e..0000000 --- a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor): - %1 : int = prim::Constant[value=1]() - %2 : Tensor = aten::neg(%x) - %3 : Tensor = aten::add(%2, %1, %1) - return (%3) diff --git a/test/expect/TestScript.test_call_script_fn_from_script_module.expect b/test/expect/TestScript.test_call_script_fn_from_script_module.expect deleted file mode 100644 index df05460..0000000 --- a/test/expect/TestScript.test_call_script_fn_from_script_module.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor, - %1 : Tensor): - %2 : Tensor = aten::mm(%x, %1) - %3 : Tensor = aten::neg(%2) - return (%3) diff --git a/test/expect/TestScript.test_call_script_fn_from_traced_module.expect b/test/expect/TestScript.test_call_script_fn_from_traced_module.expect deleted file mode 100644 index 03d0d35..0000000 --- a/test/expect/TestScript.test_call_script_fn_from_traced_module.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Double(3, 4), - %1 : Double(4, 5)): - %2 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule - %3 : Double(*, *) = aten::neg(%2), scope: TracedModule/ScriptModule - return (%3) diff --git a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect deleted file mode 100644 index a9d071f..0000000 --- a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect +++ /dev/null @@ -1,6 +0,0 @@ -graph(%0 : Double(3, 4)): - %1 : Double(*, *) = aten::neg(%0), scope: ScriptModule - %2 : Long() = prim::Constant[value={1}]() - %3 : int = prim::Constant[value=1]() - %4 : Double(3, 4) = aten::add(%1, %2, %3) - return (%4) diff --git a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect deleted file mode 100644 index 2afa03c..0000000 --- a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect +++ /dev/null @@ -1,12 +0,0 @@ -graph(%x : Tensor): - %1 : int = prim::Constant[value=3]() - %2 : int = prim::Constant[value=4]() - %3 : int = prim::Constant[value=6]() - %4 : int = prim::Constant[value=0]() - %5 : Device = prim::Constant[value="cpu"]() - %6 : int = prim::Constant[value=1]() - %7 : int[] = prim::ListConstruct(%2, %1) - %8 : Tensor = aten::zeros(%7, %3, %4, %5) - %9 : Tensor = aten::mm(%x, %8) - %10 : Tensor = aten::add(%9, %6, %6) - return (%10) diff --git a/test/expect/TestScript.test_call_script_mod_from_script_module.expect b/test/expect/TestScript.test_call_script_mod_from_script_module.expect deleted file mode 100644 index 73d8984..0000000 --- a/test/expect/TestScript.test_call_script_mod_from_script_module.expect +++ /dev/null @@ -1,6 +0,0 @@ -graph(%x : Tensor, - %1 : Tensor, - %2 : Tensor): - %3 : Tensor = aten::mm(%x, %1) - %4 : Tensor = aten::mm(%3, %2) - return (%4) diff --git a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect deleted file mode 100644 index f66186a..0000000 --- a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect +++ /dev/null @@ -1,7 +0,0 @@ -graph(%0 : Double(3, 4)): - %1 : Double(4, 3) = prim::Constant[value=](), scope: ScriptMod - %2 : Double(*, *) = aten::mm(%0, %1), scope: ScriptMod - %3 : Double() = prim::Constant[value={1}]() - %4 : int = prim::Constant[value=1]() - %5 : Double(3, 3) = aten::add(%2, %3, %4) - return (%5) diff --git a/test/expect/TestScript.test_call_script_module_from_traced_module.expect b/test/expect/TestScript.test_call_script_module_from_traced_module.expect deleted file mode 100644 index 184b0ff..0000000 --- a/test/expect/TestScript.test_call_script_module_from_traced_module.expect +++ /dev/null @@ -1,9 +0,0 @@ -graph(%x : Double(3, 4), - %1 : Double(4, 5), - %2 : Double(5, 7)): - %3 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule - %4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule/ScriptMod[mod] - %5 : Double() = prim::Constant[value={1}](), scope: TracedModule - %6 : int = prim::Constant[value=1](), scope: TracedModule - %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule - return (%7) diff --git a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect deleted file mode 100644 index 7b68f79..0000000 --- a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor): - %2 : int = prim::Constant[value=1]() - %1 : Double(3, 4) = aten::neg(%x) - %4 : Tensor = aten::add(%1, %2, %2) - return (%4) diff --git a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect b/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect deleted file mode 100644 index 9d7399d..0000000 --- a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Double(3, 4), - %1 : Double(4, 5)): - %2 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule - %3 : Double(*, *) = aten::neg(%2), scope: TracedModule/traced_fn - return (%3) diff --git a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect deleted file mode 100644 index 91de7c2..0000000 --- a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect +++ /dev/null @@ -1,6 +0,0 @@ -graph(%0 : Double(3, 4)): - %1 : Double(*, *) = aten::neg(%0), scope: traced_fn1 - %2 : Long() = prim::Constant[value={1}]() - %3 : int = prim::Constant[value=1]() - %4 : Double(3, 4) = aten::add(%1, %2, %3) - return (%4) diff --git a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect deleted file mode 100644 index 0f44f74..0000000 --- a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect +++ /dev/null @@ -1,12 +0,0 @@ -graph(%x : Tensor): - %6 : Device = prim::Constant[value="cpu"](), scope: TracedModule - %5 : int = prim::Constant[value=0](), scope: TracedModule - %4 : int = prim::Constant[value=7](), scope: TracedModule - %2 : int = prim::Constant[value=3](), scope: TracedModule - %1 : int = prim::Constant[value=4](), scope: TracedModule - %9 : int = prim::Constant[value=1]() - %3 : int[] = prim::ListConstruct(%1, %2), scope: TracedModule - %7 : Double(4, 3) = aten::zeros(%3, %4, %5, %6), scope: TracedModule - %8 : Double(3, 3) = aten::mm(%x, %7), scope: TracedModule - %11 : Tensor = aten::add(%8, %9, %9) - return (%11) diff --git a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect deleted file mode 100644 index a114654..0000000 --- a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect +++ /dev/null @@ -1,7 +0,0 @@ -graph(%0 : Double(3, 4)): - %1 : Double(4, 3) = prim::Constant[value=](), scope: TracedModule[TracedModule] - %2 : Double(*, *) = aten::mm(%0, %1), scope: TracedModule - %3 : Double() = prim::Constant[value={1}]() - %4 : int = prim::Constant[value=1]() - %5 : Double(3, 3) = aten::add(%2, %3, %4) - return (%5) diff --git a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect deleted file mode 100644 index de442fc..0000000 --- a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect +++ /dev/null @@ -1,9 +0,0 @@ -graph(%x : Double(3, 4), - %1 : Double(4, 5), - %2 : Double(5, 7)): - %3 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule - %4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule1 - %5 : Double() = prim::Constant[value={1}](), scope: TracedModule - %6 : int = prim::Constant[value=1](), scope: TracedModule - %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule - return (%7) diff --git a/test/expect/TestScript.test_call_tracing_fn_from_script_module.expect b/test/expect/TestScript.test_call_tracing_fn_from_script_module.expect deleted file mode 100644 index c60ec63..0000000 --- a/test/expect/TestScript.test_call_tracing_fn_from_script_module.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor, - %1 : Tensor): - %2 : Tensor = aten::mm(%x, %1) - %3 : Double(3, 3) = aten::neg(%2) - return (%3) diff --git a/test/expect/TestScript.test_call_tracing_mod_from_script_module.expect b/test/expect/TestScript.test_call_tracing_mod_from_script_module.expect deleted file mode 100644 index 5786878..0000000 --- a/test/expect/TestScript.test_call_tracing_mod_from_script_module.expect +++ /dev/null @@ -1,6 +0,0 @@ -graph(%x : Tensor, - %1 : Tensor, - %3 : Tensor): - %2 : Tensor = aten::mm(%x, %1) - %4 : Double(3, 5) = aten::mm(%2, %3), scope: TracedMod - return (%4) diff --git a/test/test_jit.py b/test/test_jit.py index 320af77..d9e9966 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8370,7 +8370,8 @@ a") # Note: the parameter self.param from the Python module is inlined # into the graph - self.assertExpected(canonical(traced_fn.graph)) + self.assertTrue(len(list(traced_fn.graph.inputs())) == 1) + FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph)) def test_call_traced_fn_from_tracing_fn(self): @_trace(torch.rand(3, 4)) @@ -8381,7 +8382,8 @@ a") def traced_fn(x): return traced_fn1(x) + 1 - self.assertExpected(canonical(traced_fn.graph)) + FileCheck().check("aten::neg").check_same("scope: traced_fn1").check("aten::add") \ + .run(str(traced_fn.graph)) def test_call_traced_mod_from_tracing_fn(self): class TracedModule(torch.nn.Module): @@ -8400,7 +8402,8 @@ a") # Note: the parameter self.param from the Python module is inlined # into the graph - self.assertExpected(canonical(traced_fn.graph)) + FileCheck().check("prim::Constant[value=]").check("aten::mm") \ + .check("aten::add").run(str(traced_fn.graph)) def test_call_script_fn_from_tracing_fn(self): @torch.jit.script @@ -8411,25 +8414,30 @@ a") def traced_fn(x): return script_fn(x) + 1 - self.assertExpected(canonical(traced_fn.graph)) + FileCheck().check("aten::neg").check("aten::add").run(str(traced_fn.graph)) def test_call_script_mod_from_tracing_fn(self): - class ScriptMod(torch.jit.ScriptModule): - def __init__(self): - super(ScriptMod, self).__init__() - self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False) + with self.disableModuleHook(): + class ScriptMod(torch.jit.ScriptModule): + def __init__(self): + super(ScriptMod, self).__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False) - @torch.jit.script_method - def forward(self, x): - return torch.mm(x, self.param) + @torch.jit.script_method + def forward(self, x): + for _i in range(4): + x += self.param + return x - sm = ScriptMod() + sm = ScriptMod() - @_trace(torch.rand(3, 4)) - def traced_fn(x): - return sm(x) + 1.0 + @_trace(torch.rand(3, 4)) + def traced_fn(x): + return sm(x) + 1.0 - self.assertExpected(canonical(traced_fn.graph)) + # parameter turns into constant and loop is perserved + FileCheck().check("prim::Constant[value=]").check("Loop") \ + .run(str(traced_fn.graph)) def test_call_python_fn_from_traced_module(self): def python_fn(x): @@ -8448,7 +8456,8 @@ a") # Note: parameter self.param from the traced module should appear as # an input to the graph and the neg op from the Python function should # be properly inlined - self.assertExpected(canonical(tm.graph)) + self.assertTrue(len(list(tm.graph.inputs())) == 2) + FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph)) def test_call_python_mod_from_traced_module(self): class PythonModule(torch.nn.Module): @@ -8472,7 +8481,9 @@ a") # Note: the parameters from both modules should appear in the flattened # inputs of the graph. All ops from both modules should be inlined. - self.assertExpected(canonical(tm.graph)) + self.assertTrue(len(list(tm.graph.inputs())) == 3) + FileCheck().check_not("value=").check_count("aten::mm", 2).check("aten::add") \ + .run(str(tm.graph)) def test_call_traced_fn_from_traced_module(self): @_trace(torch.rand(3, 4)) @@ -8489,7 +8500,9 @@ a") tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) # Note: neg op from the traced function should be properly inlined - self.assertExpected(canonical(tm.graph)) + FileCheck().check("aten::mm").check_same("scope: TracedModule") \ + .check_next("aten::neg").check("scope: TracedModule/traced_fn") \ + .run(str(tm.graph)) def test_trace_hierarchy(self): # Test that we preserve the module hierarchy for a ScriptModule @@ -8605,7 +8618,8 @@ a") # Note: the parameters from both modules should appear in the flattened # inputs of the graph. All ops from both modules should be inlined. - self.assertExpected(canonical(tm.graph)) + self.assertTrue(len(list(tm.graph.inputs())) == 3) + FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph)) def test_call_script_fn_from_traced_module(self): @torch.jit.script @@ -8622,7 +8636,7 @@ a") tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) # Note: neg op from the script function should be properly inlined - self.assertExpected(canonical(tm.graph)) + FileCheck().check("aten::mm").check("aten::neg").run(str(tm.graph)) def test_call_script_module_from_traced_module(self): class ScriptMod(torch.jit.ScriptModule): @@ -8647,7 +8661,8 @@ a") # Note: the parameters from both modules should appear in the flattened # inputs of the graph. All ops from both modules should be inlined. - self.assertExpected(canonical(tm.graph)) + self.assertTrue(len(list(tm.graph.inputs())) == 3) + FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph)) def test_call_python_fn_from_script_fn(self): def python_fn(x): @@ -8659,7 +8674,9 @@ a") # Note: the call to python_fn appears as `^python_fn()` and is called # as a PythonOp in the interpreter - self.assertExpected(canonical(script_fn.graph)) + a = torch.tensor(1) + self.assertEqual(script_fn(a), torch.tensor(0)) + FileCheck().check("python_fn").run(str(script_fn.graph)) def test_call_python_mod_from_script_fn(self): class PythonModule(torch.nn.Module): @@ -8678,7 +8695,7 @@ a") # Note: call to pm(x) appears as ^() in the trace. # Parameters are NOT inlined. - self.assertExpected(str(script_fn.graph)) + FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph)) def test_call_traced_fn_from_script_fn(self): @_trace(torch.rand(3, 4)) @@ -8691,7 +8708,7 @@ a") # Note: the neg op from traced_fn should be properly inlined into the # script function's graph - self.assertExpected(str(script_fn.graph)) + FileCheck().check("aten::neg").check("aten::add").run(str(script_fn.graph)) def test_call_traced_mod_from_script_fn(self): class TracedModule(torch.nn.Module): @@ -8707,7 +8724,8 @@ a") def script_fn(x): return tm(x) + 1 - self.assertExpected(str(script_fn.graph)) + FileCheck().check("aten::zeros").check_same("scope: TracedModule").check("aten::mm") \ + .check("aten::add").run(str(script_fn.graph)) def test_call_script_fn_from_script_fn(self): @torch.jit.script @@ -8720,7 +8738,7 @@ a") # Note: the neg op from script_fn1 should be properly inlined into the # graph of script_fn - self.assertExpected(canonical(script_fn.graph)) + FileCheck().check("aten::neg").run(str(script_fn.graph)) def test_call_script_mod_from_script_fn(self): class ScriptMod(torch.jit.ScriptModule): @@ -8737,7 +8755,7 @@ a") def script_fn(x): return sm(x) + 1 - self.assertExpected(canonical(script_fn.graph)) + FileCheck().check("zeros").check("aten::mm").check("add").run(str(script_fn.graph)) def test_call_python_fn_from_script_module(self): def python_fn(x): @@ -8753,7 +8771,8 @@ a") return python_fn(torch.mm(x, self.param)) sm = ScriptMod() - self.assertExpected(str(sm.__getattr__('forward').graph)) + FileCheck().check("aten::mm").check("python_fn") \ + .run(str(sm.__getattr__('forward').graph)) def test_call_python_mod_from_script_module(self): class PythonMod(torch.nn.Module): @@ -8777,7 +8796,7 @@ a") sm = ScriptMod() # Note: the call into PythonMod appears as ^(). Parameters # are NOT inlined - self.assertExpected(str(sm.graph)) + FileCheck().check("aten::mm").check("python_value").run(str(sm.graph)) def test_call_tracing_fn_from_script_module(self): @_trace(torch.rand(3, 3)) @@ -8794,7 +8813,7 @@ a") return traced_fn(torch.mm(x, self.param)) sm = ScriptMod() - self.assertExpected(str(sm.__getattr__('forward').graph)) + FileCheck().check("aten::mm").check("aten::neg").run(str(sm.__getattr__('forward').graph)) def test_call_tracing_mod_from_script_module(self): class TracedMod(torch.nn.Module): @@ -8819,7 +8838,8 @@ a") # Note: the parameters from both modules should appear in the flattened # input list to the graph. The mm op from TracedMod should be properly # inlined - self.assertExpected(str(sm.graph)) + self.assertTrue(len(list(sm.graph.inputs())) == 3) + FileCheck().check("aten::mm").check("aten::mm").run(str(sm.graph)) def test_call_script_fn_from_script_module(self): @torch.jit.script @@ -8836,7 +8856,8 @@ a") return script_fn(torch.mm(x, self.param)) sm = ScriptMod() - self.assertExpected(canonical(sm.__getattr__('forward').graph)) + graph = (sm.__getattr__('forward').graph) + FileCheck().check("aten::mm").check("aten::neg").run(str(graph)) def test_call_script_mod_from_script_module(self): class ScriptMod1(torch.jit.ScriptModule): @@ -8862,7 +8883,8 @@ a") # Note: the parameters from both modules should appear in the flattened # input list to the graph. The mm op from ScriptMod1 should be properly # inlined - self.assertExpected(canonical(sm.graph)) + # 3 % values in graph input lists, two mms in body + FileCheck().check_count('%', 3).check(":").check_count("mm", 2).run(str(sm.graph)) def test_module_with_params_called_fails(self): with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful " @@ -9981,7 +10003,7 @@ a") a /= b a *= b return a, b - self.checkScript(foo, (torch.rand(3), torch.rand(3)), check_expected=True) + self.checkScript(foo, (torch.rand(3), torch.rand(3))) def test_pass(self): def foo(x): diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index 8025e5f..d428699 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -156,6 +156,8 @@ struct FileCheckImpl { bool has_run = false; + friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc); + private: void doCheckNot( const std::vector& nots, @@ -281,9 +283,18 @@ struct FileCheckImpl { FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){}; +std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) { + out << "FileCheck checks:\n"; + for (const Check& c : fc.checks) { + out << "\t" << c << "\n"; + } + return out; +}; + FileCheck::~FileCheck() { if (!fcImpl->has_run) { std::cout << "You have not run this instance of FileCheck!\n"; + std::cout << *fcImpl; } fcImpl.reset(); }; -- 2.7.4