Remove Expect Files from python / tracing / script interop
authorElias Ellison <eellison@fb.com>
Tue, 5 Mar 2019 06:38:41 +0000 (22:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 5 Mar 2019 07:04:54 +0000 (23:04 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17622

Differential Revision: D14308307

Pulled By: eellison

fbshipit-source-id: bda249d38ac2570000a12b0ca328c26233ecefe8

30 files changed:
test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect [deleted file]
test/expect/TestJit.test_nested_inplace.expect [deleted file]
test/expect/TestJit.test_recursive_cse.expect [deleted file]
test/expect/TestJit.test_shared_param.expect [deleted file]
test/expect/TestScript.test_augmented_assign.expect [deleted file]
test/expect/TestScript.test_call_python_fn_from_script_fn.expect [deleted file]
test/expect/TestScript.test_call_python_fn_from_script_module.expect [deleted file]
test/expect/TestScript.test_call_python_fn_from_traced_module.expect [deleted file]
test/expect/TestScript.test_call_python_mod_from_script_fn.expect [deleted file]
test/expect/TestScript.test_call_python_mod_from_script_module.expect [deleted file]
test/expect/TestScript.test_call_python_mod_from_traced_module.expect [deleted file]
test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect [deleted file]
test/expect/TestScript.test_call_script_fn_from_script_fn.expect [deleted file]
test/expect/TestScript.test_call_script_fn_from_script_module.expect [deleted file]
test/expect/TestScript.test_call_script_fn_from_traced_module.expect [deleted file]
test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect [deleted file]
test/expect/TestScript.test_call_script_mod_from_script_fn.expect [deleted file]
test/expect/TestScript.test_call_script_mod_from_script_module.expect [deleted file]
test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect [deleted file]
test/expect/TestScript.test_call_script_module_from_traced_module.expect [deleted file]
test/expect/TestScript.test_call_traced_fn_from_script_fn.expect [deleted file]
test/expect/TestScript.test_call_traced_fn_from_traced_module.expect [deleted file]
test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect [deleted file]
test/expect/TestScript.test_call_traced_mod_from_script_fn.expect [deleted file]
test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect [deleted file]
test/expect/TestScript.test_call_traced_module_from_traced_module.expect [deleted file]
test/expect/TestScript.test_call_tracing_fn_from_script_module.expect [deleted file]
test/expect/TestScript.test_call_tracing_mod_from_script_module.expect [deleted file]
test/test_jit.py
torch/csrc/jit/testing/file_check.cpp

diff --git a/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect b/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect
deleted file mode 100644 (file)
index 490b53f..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-graph(%x : Dynamic):
-  %1 : Dynamic = aten::relu(%x)
-  return (%1)
diff --git a/test/expect/TestJit.test_nested_inplace.expect b/test/expect/TestJit.test_nested_inplace.expect
deleted file mode 100644 (file)
index 6803ff1..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Double(2, 2)):
-  %1 : int = prim::Constant[value=0]()
-  %2 : int = prim::Constant[value=0]()
-  %3 : Double(2, 2) = aten::threshold_(%x, %1, %2)
-  return (%3)
diff --git a/test/expect/TestJit.test_recursive_cse.expect b/test/expect/TestJit.test_recursive_cse.expect
deleted file mode 100644 (file)
index c117774..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-graph(%x : Tensor,
-      %y : Tensor):
-  %2 : int = prim::Constant[value=1]()
-  %3 : Tensor = aten::add(%x, %y, %2)
-  %4 : Tensor = aten::gt(%3, %x)
-  %5 : bool = prim::Bool(%4)
-  %z : Tensor = prim::If(%5)
-    block0():
-      -> (%3)
-    block1():
-      -> (%x)
-  return (%z)
diff --git a/test/expect/TestJit.test_shared_param.expect b/test/expect/TestJit.test_shared_param.expect
deleted file mode 100644 (file)
index c8b2976..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-graph(%0 : Double(2, 2),
-      %1 : Double(2, 2)):
-  %2 : Double(2, 2) = aten::mul(%0, %1), scope: MyModule
-  %3 : int = prim::Constant[value=1](), scope: MyModule
-  %4 : Double(2, 2) = aten::add(%2, %1, %3), scope: MyModule
-  return (%4)
diff --git a/test/expect/TestScript.test_augmented_assign.expect b/test/expect/TestScript.test_augmented_assign.expect
deleted file mode 100644 (file)
index e5d047d..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-graph(%a.1 : Tensor,
-      %b : Tensor):
-  %2 : int = prim::Constant[value=1]()
-  %a.2 : Tensor = aten::add_(%a.1, %b, %2)
-  %a.3 : Tensor = aten::sub_(%a.2, %b, %2)
-  %a.4 : Tensor = aten::div_(%a.3, %b)
-  %a : Tensor = aten::mul_(%a.4, %b)
-  %7 : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
-  return (%7)
diff --git a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect b/test/expect/TestScript.test_call_python_fn_from_script_fn.expect
deleted file mode 100644 (file)
index 9f7bc73..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor):
-  %1 : int = prim::Constant[value=1]()
-  %2 : Tensor = ^python_fn()(%x)
-  %3 : Tensor = aten::add(%2, %1, %1)
-  return (%3)
diff --git a/test/expect/TestScript.test_call_python_fn_from_script_module.expect b/test/expect/TestScript.test_call_python_fn_from_script_module.expect
deleted file mode 100644 (file)
index ec5349f..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor,
-      %1 : Tensor):
-  %2 : Tensor = aten::mm(%x, %1)
-  %3 : Tensor = ^python_fn()(%2)
-  return (%3)
diff --git a/test/expect/TestScript.test_call_python_fn_from_traced_module.expect b/test/expect/TestScript.test_call_python_fn_from_traced_module.expect
deleted file mode 100644 (file)
index 503fe26..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Double(3, 4),
-      %1 : Double(4, 3)):
-  %2 : Double(3, 4) = aten::neg(%x), scope: TracedModule
-  %3 : Double(3, 3) = aten::mm(%2, %1), scope: TracedModule
-  return (%3)
diff --git a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect b/test/expect/TestScript.test_call_python_mod_from_script_fn.expect
deleted file mode 100644 (file)
index 7140db6..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor):
-  %2 : int = prim::Constant[value=1]()
-  %1 : Tensor = ^<python_value>()(%x)
-  %4 : Tensor = aten::add(%1, %2, %2)
-  return (%4)
diff --git a/test/expect/TestScript.test_call_python_mod_from_script_module.expect b/test/expect/TestScript.test_call_python_mod_from_script_module.expect
deleted file mode 100644 (file)
index 2512a45..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor,
-      %1 : Tensor):
-  %2 : Tensor = aten::mm(%x, %1)
-  %3 : Tensor = ^<python_value>()(%2)
-  return (%3)
diff --git a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect
deleted file mode 100644 (file)
index fe0cd06..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-graph(%x.1 : Double(3, 4),
-      %1 : Double(4, 5),
-      %2 : Double(5, 7)):
-  %x : Double(3, 5) = aten::mm(%x.1, %1), scope: TracedModule
-  %4 : Double(3, 7) = aten::mm(%x, %2), scope: TracedModule/PythonModule[mod]
-  %5 : Double() = prim::Constant[value={1}](), scope: TracedModule
-  %6 : int = prim::Constant[value=1](), scope: TracedModule
-  %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
-  return (%7)
diff --git a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect
deleted file mode 100644 (file)
index f66d55b..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-graph(%x : Double(3, 4)):
-  %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: PythonMod
-  %2 : Double(3, 3) = aten::mm(%x, %1), scope: PythonMod
-  %3 : Double() = prim::Constant[value={1}]()
-  %4 : int = prim::Constant[value=1]()
-  %5 : Double(3, 3) = aten::add(%2, %3, %4)
-  return (%5)
diff --git a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect b/test/expect/TestScript.test_call_script_fn_from_script_fn.expect
deleted file mode 100644 (file)
index 51ea22e..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor):
-  %1 : int = prim::Constant[value=1]()
-  %2 : Tensor = aten::neg(%x)
-  %3 : Tensor = aten::add(%2, %1, %1)
-  return (%3)
diff --git a/test/expect/TestScript.test_call_script_fn_from_script_module.expect b/test/expect/TestScript.test_call_script_fn_from_script_module.expect
deleted file mode 100644 (file)
index df05460..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor,
-      %1 : Tensor):
-  %2 : Tensor = aten::mm(%x, %1)
-  %3 : Tensor = aten::neg(%2)
-  return (%3)
diff --git a/test/expect/TestScript.test_call_script_fn_from_traced_module.expect b/test/expect/TestScript.test_call_script_fn_from_traced_module.expect
deleted file mode 100644 (file)
index 03d0d35..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Double(3, 4),
-      %1 : Double(4, 5)):
-  %2 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule
-  %3 : Double(*, *) = aten::neg(%2), scope: TracedModule/ScriptModule
-  return (%3)
diff --git a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect
deleted file mode 100644 (file)
index a9d071f..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-graph(%0 : Double(3, 4)):
-  %1 : Double(*, *) = aten::neg(%0), scope: ScriptModule
-  %2 : Long() = prim::Constant[value={1}]()
-  %3 : int = prim::Constant[value=1]()
-  %4 : Double(3, 4) = aten::add(%1, %2, %3)
-  return (%4)
diff --git a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect
deleted file mode 100644 (file)
index 2afa03c..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-graph(%x : Tensor):
-  %1 : int = prim::Constant[value=3]()
-  %2 : int = prim::Constant[value=4]()
-  %3 : int = prim::Constant[value=6]()
-  %4 : int = prim::Constant[value=0]()
-  %5 : Device = prim::Constant[value="cpu"]()
-  %6 : int = prim::Constant[value=1]()
-  %7 : int[] = prim::ListConstruct(%2, %1)
-  %8 : Tensor = aten::zeros(%7, %3, %4, %5)
-  %9 : Tensor = aten::mm(%x, %8)
-  %10 : Tensor = aten::add(%9, %6, %6)
-  return (%10)
diff --git a/test/expect/TestScript.test_call_script_mod_from_script_module.expect b/test/expect/TestScript.test_call_script_mod_from_script_module.expect
deleted file mode 100644 (file)
index 73d8984..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-graph(%x : Tensor,
-      %1 : Tensor,
-      %2 : Tensor):
-  %3 : Tensor = aten::mm(%x, %1)
-  %4 : Tensor = aten::mm(%3, %2)
-  return (%4)
diff --git a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect
deleted file mode 100644 (file)
index f66186a..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-graph(%0 : Double(3, 4)):
-  %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: ScriptMod
-  %2 : Double(*, *) = aten::mm(%0, %1), scope: ScriptMod
-  %3 : Double() = prim::Constant[value={1}]()
-  %4 : int = prim::Constant[value=1]()
-  %5 : Double(3, 3) = aten::add(%2, %3, %4)
-  return (%5)
diff --git a/test/expect/TestScript.test_call_script_module_from_traced_module.expect b/test/expect/TestScript.test_call_script_module_from_traced_module.expect
deleted file mode 100644 (file)
index 184b0ff..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-graph(%x : Double(3, 4),
-      %1 : Double(4, 5),
-      %2 : Double(5, 7)):
-  %3 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule
-  %4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule/ScriptMod[mod]
-  %5 : Double() = prim::Constant[value={1}](), scope: TracedModule
-  %6 : int = prim::Constant[value=1](), scope: TracedModule
-  %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
-  return (%7)
diff --git a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect
deleted file mode 100644 (file)
index 7b68f79..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor):
-  %2 : int = prim::Constant[value=1]()
-  %1 : Double(3, 4) = aten::neg(%x)
-  %4 : Tensor = aten::add(%1, %2, %2)
-  return (%4)
diff --git a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect b/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect
deleted file mode 100644 (file)
index 9d7399d..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Double(3, 4),
-      %1 : Double(4, 5)):
-  %2 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule
-  %3 : Double(*, *) = aten::neg(%2), scope: TracedModule/traced_fn
-  return (%3)
diff --git a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect
deleted file mode 100644 (file)
index 91de7c2..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-graph(%0 : Double(3, 4)):
-  %1 : Double(*, *) = aten::neg(%0), scope: traced_fn1
-  %2 : Long() = prim::Constant[value={1}]()
-  %3 : int = prim::Constant[value=1]()
-  %4 : Double(3, 4) = aten::add(%1, %2, %3)
-  return (%4)
diff --git a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect
deleted file mode 100644 (file)
index 0f44f74..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-graph(%x : Tensor):
-  %6 : Device = prim::Constant[value="cpu"](), scope: TracedModule
-  %5 : int = prim::Constant[value=0](), scope: TracedModule
-  %4 : int = prim::Constant[value=7](), scope: TracedModule
-  %2 : int = prim::Constant[value=3](), scope: TracedModule
-  %1 : int = prim::Constant[value=4](), scope: TracedModule
-  %9 : int = prim::Constant[value=1]()
-  %3 : int[] = prim::ListConstruct(%1, %2), scope: TracedModule
-  %7 : Double(4, 3) = aten::zeros(%3, %4, %5, %6), scope: TracedModule
-  %8 : Double(3, 3) = aten::mm(%x, %7), scope: TracedModule
-  %11 : Tensor = aten::add(%8, %9, %9)
-  return (%11)
diff --git a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect
deleted file mode 100644 (file)
index a114654..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-graph(%0 : Double(3, 4)):
-  %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: TracedModule[TracedModule]
-  %2 : Double(*, *) = aten::mm(%0, %1), scope: TracedModule
-  %3 : Double() = prim::Constant[value={1}]()
-  %4 : int = prim::Constant[value=1]()
-  %5 : Double(3, 3) = aten::add(%2, %3, %4)
-  return (%5)
diff --git a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect
deleted file mode 100644 (file)
index de442fc..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-graph(%x : Double(3, 4),
-      %1 : Double(4, 5),
-      %2 : Double(5, 7)):
-  %3 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule
-  %4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule1
-  %5 : Double() = prim::Constant[value={1}](), scope: TracedModule
-  %6 : int = prim::Constant[value=1](), scope: TracedModule
-  %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
-  return (%7)
diff --git a/test/expect/TestScript.test_call_tracing_fn_from_script_module.expect b/test/expect/TestScript.test_call_tracing_fn_from_script_module.expect
deleted file mode 100644 (file)
index c60ec63..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor,
-      %1 : Tensor):
-  %2 : Tensor = aten::mm(%x, %1)
-  %3 : Double(3, 3) = aten::neg(%2)
-  return (%3)
diff --git a/test/expect/TestScript.test_call_tracing_mod_from_script_module.expect b/test/expect/TestScript.test_call_tracing_mod_from_script_module.expect
deleted file mode 100644 (file)
index 5786878..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-graph(%x : Tensor,
-      %1 : Tensor,
-      %3 : Tensor):
-  %2 : Tensor = aten::mm(%x, %1)
-  %4 : Double(3, 5) = aten::mm(%2, %3), scope: TracedMod
-  return (%4)
index 320af77..d9e9966 100644 (file)
@@ -8370,7 +8370,8 @@ a")
 
         # Note: the parameter self.param from the Python module is inlined
         # into the graph
-        self.assertExpected(canonical(traced_fn.graph))
+        self.assertTrue(len(list(traced_fn.graph.inputs())) == 1)
+        FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph))
 
     def test_call_traced_fn_from_tracing_fn(self):
         @_trace(torch.rand(3, 4))
@@ -8381,7 +8382,8 @@ a")
         def traced_fn(x):
             return traced_fn1(x) + 1
 
-        self.assertExpected(canonical(traced_fn.graph))
+        FileCheck().check("aten::neg").check_same("scope: traced_fn1").check("aten::add") \
+            .run(str(traced_fn.graph))
 
     def test_call_traced_mod_from_tracing_fn(self):
         class TracedModule(torch.nn.Module):
@@ -8400,7 +8402,8 @@ a")
 
         # Note: the parameter self.param from the Python module is inlined
         # into the graph
-        self.assertExpected(canonical(traced_fn.graph))
+        FileCheck().check("prim::Constant[value=<Tensor>]").check("aten::mm") \
+            .check("aten::add").run(str(traced_fn.graph))
 
     def test_call_script_fn_from_tracing_fn(self):
         @torch.jit.script
@@ -8411,25 +8414,30 @@ a")
         def traced_fn(x):
             return script_fn(x) + 1
 
-        self.assertExpected(canonical(traced_fn.graph))
+        FileCheck().check("aten::neg").check("aten::add").run(str(traced_fn.graph))
 
     def test_call_script_mod_from_tracing_fn(self):
-        class ScriptMod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
+        with self.disableModuleHook():
+            class ScriptMod(torch.jit.ScriptModule):
+                def __init__(self):
+                    super(ScriptMod, self).__init__()
+                    self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return torch.mm(x, self.param)
+                @torch.jit.script_method
+                def forward(self, x):
+                    for _i in range(4):
+                        x += self.param
+                    return x
 
-        sm = ScriptMod()
+            sm = ScriptMod()
 
-        @_trace(torch.rand(3, 4))
-        def traced_fn(x):
-            return sm(x) + 1.0
+            @_trace(torch.rand(3, 4))
+            def traced_fn(x):
+                return sm(x) + 1.0
 
-        self.assertExpected(canonical(traced_fn.graph))
+            # parameter turns into constant and loop is perserved
+            FileCheck().check("prim::Constant[value=<Tensor>]").check("Loop") \
+                .run(str(traced_fn.graph))
 
     def test_call_python_fn_from_traced_module(self):
         def python_fn(x):
@@ -8448,7 +8456,8 @@ a")
         # Note: parameter self.param from the traced module should appear as
         # an input to the graph and the neg op from the Python function should
         # be properly inlined
-        self.assertExpected(canonical(tm.graph))
+        self.assertTrue(len(list(tm.graph.inputs())) == 2)
+        FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph))
 
     def test_call_python_mod_from_traced_module(self):
         class PythonModule(torch.nn.Module):
@@ -8472,7 +8481,9 @@ a")
 
         # Note: the parameters from both modules should appear in the flattened
         # inputs of the graph. All ops from both modules should be inlined.
-        self.assertExpected(canonical(tm.graph))
+        self.assertTrue(len(list(tm.graph.inputs())) == 3)
+        FileCheck().check_not("value=<Tensor>").check_count("aten::mm", 2).check("aten::add") \
+            .run(str(tm.graph))
 
     def test_call_traced_fn_from_traced_module(self):
         @_trace(torch.rand(3, 4))
@@ -8489,7 +8500,9 @@ a")
 
         tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
         # Note: neg op from the traced function should be properly inlined
-        self.assertExpected(canonical(tm.graph))
+        FileCheck().check("aten::mm").check_same("scope: TracedModule") \
+            .check_next("aten::neg").check("scope: TracedModule/traced_fn") \
+            .run(str(tm.graph))
 
     def test_trace_hierarchy(self):
         # Test that we preserve the module hierarchy for a ScriptModule
@@ -8605,7 +8618,8 @@ a")
 
         # Note: the parameters from both modules should appear in the flattened
         # inputs of the graph. All ops from both modules should be inlined.
-        self.assertExpected(canonical(tm.graph))
+        self.assertTrue(len(list(tm.graph.inputs())) == 3)
+        FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph))
 
     def test_call_script_fn_from_traced_module(self):
         @torch.jit.script
@@ -8622,7 +8636,7 @@ a")
 
         tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
         # Note: neg op from the script function should be properly inlined
-        self.assertExpected(canonical(tm.graph))
+        FileCheck().check("aten::mm").check("aten::neg").run(str(tm.graph))
 
     def test_call_script_module_from_traced_module(self):
         class ScriptMod(torch.jit.ScriptModule):
@@ -8647,7 +8661,8 @@ a")
 
         # Note: the parameters from both modules should appear in the flattened
         # inputs of the graph. All ops from both modules should be inlined.
-        self.assertExpected(canonical(tm.graph))
+        self.assertTrue(len(list(tm.graph.inputs())) == 3)
+        FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph))
 
     def test_call_python_fn_from_script_fn(self):
         def python_fn(x):
@@ -8659,7 +8674,9 @@ a")
 
         # Note: the call to python_fn appears as `^python_fn()` and is called
         # as a PythonOp in the interpreter
-        self.assertExpected(canonical(script_fn.graph))
+        a = torch.tensor(1)
+        self.assertEqual(script_fn(a), torch.tensor(0))
+        FileCheck().check("python_fn").run(str(script_fn.graph))
 
     def test_call_python_mod_from_script_fn(self):
         class PythonModule(torch.nn.Module):
@@ -8678,7 +8695,7 @@ a")
 
         # Note: call to pm(x) appears as ^<python_value>() in the trace.
         # Parameters are NOT inlined.
-        self.assertExpected(str(script_fn.graph))
+        FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph))
 
     def test_call_traced_fn_from_script_fn(self):
         @_trace(torch.rand(3, 4))
@@ -8691,7 +8708,7 @@ a")
 
         # Note: the neg op from traced_fn should be properly inlined into the
         # script function's graph
-        self.assertExpected(str(script_fn.graph))
+        FileCheck().check("aten::neg").check("aten::add").run(str(script_fn.graph))
 
     def test_call_traced_mod_from_script_fn(self):
         class TracedModule(torch.nn.Module):
@@ -8707,7 +8724,8 @@ a")
         def script_fn(x):
             return tm(x) + 1
 
-        self.assertExpected(str(script_fn.graph))
+        FileCheck().check("aten::zeros").check_same("scope: TracedModule").check("aten::mm") \
+            .check("aten::add").run(str(script_fn.graph))
 
     def test_call_script_fn_from_script_fn(self):
         @torch.jit.script
@@ -8720,7 +8738,7 @@ a")
 
         # Note: the neg op from script_fn1 should be properly inlined into the
         # graph of script_fn
-        self.assertExpected(canonical(script_fn.graph))
+        FileCheck().check("aten::neg").run(str(script_fn.graph))
 
     def test_call_script_mod_from_script_fn(self):
         class ScriptMod(torch.jit.ScriptModule):
@@ -8737,7 +8755,7 @@ a")
         def script_fn(x):
             return sm(x) + 1
 
-        self.assertExpected(canonical(script_fn.graph))
+        FileCheck().check("zeros").check("aten::mm").check("add").run(str(script_fn.graph))
 
     def test_call_python_fn_from_script_module(self):
         def python_fn(x):
@@ -8753,7 +8771,8 @@ a")
                 return python_fn(torch.mm(x, self.param))
 
         sm = ScriptMod()
-        self.assertExpected(str(sm.__getattr__('forward').graph))
+        FileCheck().check("aten::mm").check("python_fn") \
+            .run(str(sm.__getattr__('forward').graph))
 
     def test_call_python_mod_from_script_module(self):
         class PythonMod(torch.nn.Module):
@@ -8777,7 +8796,7 @@ a")
         sm = ScriptMod()
         # Note: the call into PythonMod appears as ^<python_value>(). Parameters
         # are NOT inlined
-        self.assertExpected(str(sm.graph))
+        FileCheck().check("aten::mm").check("python_value").run(str(sm.graph))
 
     def test_call_tracing_fn_from_script_module(self):
         @_trace(torch.rand(3, 3))
@@ -8794,7 +8813,7 @@ a")
                 return traced_fn(torch.mm(x, self.param))
 
         sm = ScriptMod()
-        self.assertExpected(str(sm.__getattr__('forward').graph))
+        FileCheck().check("aten::mm").check("aten::neg").run(str(sm.__getattr__('forward').graph))
 
     def test_call_tracing_mod_from_script_module(self):
         class TracedMod(torch.nn.Module):
@@ -8819,7 +8838,8 @@ a")
         # Note: the parameters from both modules should appear in the flattened
         # input list to the graph. The mm op from TracedMod should be properly
         # inlined
-        self.assertExpected(str(sm.graph))
+        self.assertTrue(len(list(sm.graph.inputs())) == 3)
+        FileCheck().check("aten::mm").check("aten::mm").run(str(sm.graph))
 
     def test_call_script_fn_from_script_module(self):
         @torch.jit.script
@@ -8836,7 +8856,8 @@ a")
                 return script_fn(torch.mm(x, self.param))
 
         sm = ScriptMod()
-        self.assertExpected(canonical(sm.__getattr__('forward').graph))
+        graph = (sm.__getattr__('forward').graph)
+        FileCheck().check("aten::mm").check("aten::neg").run(str(graph))
 
     def test_call_script_mod_from_script_module(self):
         class ScriptMod1(torch.jit.ScriptModule):
@@ -8862,7 +8883,8 @@ a")
         # Note: the parameters from both modules should appear in the flattened
         # input list to the graph. The mm op from ScriptMod1 should be properly
         # inlined
-        self.assertExpected(canonical(sm.graph))
+        # 3 % values in graph input lists, two mms in body
+        FileCheck().check_count('%', 3).check(":").check_count("mm", 2).run(str(sm.graph))
 
     def test_module_with_params_called_fails(self):
         with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful "
@@ -9981,7 +10003,7 @@ a")
             a /= b
             a *= b
             return a, b
-        self.checkScript(foo, (torch.rand(3), torch.rand(3)), check_expected=True)
+        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
 
     def test_pass(self):
         def foo(x):
index 8025e5f..d428699 100644 (file)
@@ -156,6 +156,8 @@ struct FileCheckImpl {
 
   bool has_run = false;
 
+  friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
+
  private:
   void doCheckNot(
       const std::vector<Check>& nots,
@@ -281,9 +283,18 @@ struct FileCheckImpl {
 
 FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};
 
+std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) {
+  out << "FileCheck checks:\n";
+  for (const Check& c : fc.checks) {
+    out << "\t" << c << "\n";
+  }
+  return out;
+};
+
 FileCheck::~FileCheck() {
   if (!fcImpl->has_run) {
     std::cout << "You have not run this instance of FileCheck!\n";
+    std::cout << *fcImpl;
   }
   fcImpl.reset();
 };