+++ /dev/null
-graph(%x : Dynamic):
- %1 : Dynamic = aten::relu(%x)
- return (%1)
+++ /dev/null
-graph(%x : Double(2, 2)):
- %1 : int = prim::Constant[value=0]()
- %2 : int = prim::Constant[value=0]()
- %3 : Double(2, 2) = aten::threshold_(%x, %1, %2)
- return (%3)
+++ /dev/null
-graph(%x : Tensor,
- %y : Tensor):
- %2 : int = prim::Constant[value=1]()
- %3 : Tensor = aten::add(%x, %y, %2)
- %4 : Tensor = aten::gt(%3, %x)
- %5 : bool = prim::Bool(%4)
- %z : Tensor = prim::If(%5)
- block0():
- -> (%3)
- block1():
- -> (%x)
- return (%z)
+++ /dev/null
-graph(%0 : Double(2, 2),
- %1 : Double(2, 2)):
- %2 : Double(2, 2) = aten::mul(%0, %1), scope: MyModule
- %3 : int = prim::Constant[value=1](), scope: MyModule
- %4 : Double(2, 2) = aten::add(%2, %1, %3), scope: MyModule
- return (%4)
+++ /dev/null
-graph(%a.1 : Tensor,
- %b : Tensor):
- %2 : int = prim::Constant[value=1]()
- %a.2 : Tensor = aten::add_(%a.1, %b, %2)
- %a.3 : Tensor = aten::sub_(%a.2, %b, %2)
- %a.4 : Tensor = aten::div_(%a.3, %b)
- %a : Tensor = aten::mul_(%a.4, %b)
- %7 : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
- return (%7)
+++ /dev/null
-graph(%x : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : Tensor = ^python_fn()(%x)
- %3 : Tensor = aten::add(%2, %1, %1)
- return (%3)
+++ /dev/null
-graph(%x : Tensor,
- %1 : Tensor):
- %2 : Tensor = aten::mm(%x, %1)
- %3 : Tensor = ^python_fn()(%2)
- return (%3)
+++ /dev/null
-graph(%x : Double(3, 4),
- %1 : Double(4, 3)):
- %2 : Double(3, 4) = aten::neg(%x), scope: TracedModule
- %3 : Double(3, 3) = aten::mm(%2, %1), scope: TracedModule
- return (%3)
+++ /dev/null
-graph(%x : Tensor):
- %2 : int = prim::Constant[value=1]()
- %1 : Tensor = ^<python_value>()(%x)
- %4 : Tensor = aten::add(%1, %2, %2)
- return (%4)
+++ /dev/null
-graph(%x : Tensor,
- %1 : Tensor):
- %2 : Tensor = aten::mm(%x, %1)
- %3 : Tensor = ^<python_value>()(%2)
- return (%3)
+++ /dev/null
-graph(%x.1 : Double(3, 4),
- %1 : Double(4, 5),
- %2 : Double(5, 7)):
- %x : Double(3, 5) = aten::mm(%x.1, %1), scope: TracedModule
- %4 : Double(3, 7) = aten::mm(%x, %2), scope: TracedModule/PythonModule[mod]
- %5 : Double() = prim::Constant[value={1}](), scope: TracedModule
- %6 : int = prim::Constant[value=1](), scope: TracedModule
- %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
- return (%7)
+++ /dev/null
-graph(%x : Double(3, 4)):
- %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: PythonMod
- %2 : Double(3, 3) = aten::mm(%x, %1), scope: PythonMod
- %3 : Double() = prim::Constant[value={1}]()
- %4 : int = prim::Constant[value=1]()
- %5 : Double(3, 3) = aten::add(%2, %3, %4)
- return (%5)
+++ /dev/null
-graph(%x : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : Tensor = aten::neg(%x)
- %3 : Tensor = aten::add(%2, %1, %1)
- return (%3)
+++ /dev/null
-graph(%x : Tensor,
- %1 : Tensor):
- %2 : Tensor = aten::mm(%x, %1)
- %3 : Tensor = aten::neg(%2)
- return (%3)
+++ /dev/null
-graph(%x : Double(3, 4),
- %1 : Double(4, 5)):
- %2 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule
- %3 : Double(*, *) = aten::neg(%2), scope: TracedModule/ScriptModule
- return (%3)
+++ /dev/null
-graph(%0 : Double(3, 4)):
- %1 : Double(*, *) = aten::neg(%0), scope: ScriptModule
- %2 : Long() = prim::Constant[value={1}]()
- %3 : int = prim::Constant[value=1]()
- %4 : Double(3, 4) = aten::add(%1, %2, %3)
- return (%4)
+++ /dev/null
-graph(%x : Tensor):
- %1 : int = prim::Constant[value=3]()
- %2 : int = prim::Constant[value=4]()
- %3 : int = prim::Constant[value=6]()
- %4 : int = prim::Constant[value=0]()
- %5 : Device = prim::Constant[value="cpu"]()
- %6 : int = prim::Constant[value=1]()
- %7 : int[] = prim::ListConstruct(%2, %1)
- %8 : Tensor = aten::zeros(%7, %3, %4, %5)
- %9 : Tensor = aten::mm(%x, %8)
- %10 : Tensor = aten::add(%9, %6, %6)
- return (%10)
+++ /dev/null
-graph(%x : Tensor,
- %1 : Tensor,
- %2 : Tensor):
- %3 : Tensor = aten::mm(%x, %1)
- %4 : Tensor = aten::mm(%3, %2)
- return (%4)
+++ /dev/null
-graph(%0 : Double(3, 4)):
- %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: ScriptMod
- %2 : Double(*, *) = aten::mm(%0, %1), scope: ScriptMod
- %3 : Double() = prim::Constant[value={1}]()
- %4 : int = prim::Constant[value=1]()
- %5 : Double(3, 3) = aten::add(%2, %3, %4)
- return (%5)
+++ /dev/null
-graph(%x : Double(3, 4),
- %1 : Double(4, 5),
- %2 : Double(5, 7)):
- %3 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule
- %4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule/ScriptMod[mod]
- %5 : Double() = prim::Constant[value={1}](), scope: TracedModule
- %6 : int = prim::Constant[value=1](), scope: TracedModule
- %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
- return (%7)
+++ /dev/null
-graph(%x : Tensor):
- %2 : int = prim::Constant[value=1]()
- %1 : Double(3, 4) = aten::neg(%x)
- %4 : Tensor = aten::add(%1, %2, %2)
- return (%4)
+++ /dev/null
-graph(%x : Double(3, 4),
- %1 : Double(4, 5)):
- %2 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule
- %3 : Double(*, *) = aten::neg(%2), scope: TracedModule/traced_fn
- return (%3)
+++ /dev/null
-graph(%0 : Double(3, 4)):
- %1 : Double(*, *) = aten::neg(%0), scope: traced_fn1
- %2 : Long() = prim::Constant[value={1}]()
- %3 : int = prim::Constant[value=1]()
- %4 : Double(3, 4) = aten::add(%1, %2, %3)
- return (%4)
+++ /dev/null
-graph(%x : Tensor):
- %6 : Device = prim::Constant[value="cpu"](), scope: TracedModule
- %5 : int = prim::Constant[value=0](), scope: TracedModule
- %4 : int = prim::Constant[value=7](), scope: TracedModule
- %2 : int = prim::Constant[value=3](), scope: TracedModule
- %1 : int = prim::Constant[value=4](), scope: TracedModule
- %9 : int = prim::Constant[value=1]()
- %3 : int[] = prim::ListConstruct(%1, %2), scope: TracedModule
- %7 : Double(4, 3) = aten::zeros(%3, %4, %5, %6), scope: TracedModule
- %8 : Double(3, 3) = aten::mm(%x, %7), scope: TracedModule
- %11 : Tensor = aten::add(%8, %9, %9)
- return (%11)
+++ /dev/null
-graph(%0 : Double(3, 4)):
- %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: TracedModule[TracedModule]
- %2 : Double(*, *) = aten::mm(%0, %1), scope: TracedModule
- %3 : Double() = prim::Constant[value={1}]()
- %4 : int = prim::Constant[value=1]()
- %5 : Double(3, 3) = aten::add(%2, %3, %4)
- return (%5)
+++ /dev/null
-graph(%x : Double(3, 4),
- %1 : Double(4, 5),
- %2 : Double(5, 7)):
- %3 : Double(3, 5) = aten::mm(%x, %1), scope: TracedModule
- %4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule1
- %5 : Double() = prim::Constant[value={1}](), scope: TracedModule
- %6 : int = prim::Constant[value=1](), scope: TracedModule
- %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
- return (%7)
+++ /dev/null
-graph(%x : Tensor,
- %1 : Tensor):
- %2 : Tensor = aten::mm(%x, %1)
- %3 : Double(3, 3) = aten::neg(%2)
- return (%3)
+++ /dev/null
-graph(%x : Tensor,
- %1 : Tensor,
- %3 : Tensor):
- %2 : Tensor = aten::mm(%x, %1)
- %4 : Double(3, 5) = aten::mm(%2, %3), scope: TracedMod
- return (%4)
# Note: the parameter self.param from the Python module is inlined
# into the graph
- self.assertExpected(canonical(traced_fn.graph))
+ self.assertTrue(len(list(traced_fn.graph.inputs())) == 1)
+ FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph))
def test_call_traced_fn_from_tracing_fn(self):
@_trace(torch.rand(3, 4))
def traced_fn(x):
return traced_fn1(x) + 1
- self.assertExpected(canonical(traced_fn.graph))
+ FileCheck().check("aten::neg").check_same("scope: traced_fn1").check("aten::add") \
+ .run(str(traced_fn.graph))
def test_call_traced_mod_from_tracing_fn(self):
class TracedModule(torch.nn.Module):
# Note: the parameter self.param from the Python module is inlined
# into the graph
- self.assertExpected(canonical(traced_fn.graph))
+ FileCheck().check("prim::Constant[value=<Tensor>]").check("aten::mm") \
+ .check("aten::add").run(str(traced_fn.graph))
def test_call_script_fn_from_tracing_fn(self):
@torch.jit.script
def traced_fn(x):
return script_fn(x) + 1
- self.assertExpected(canonical(traced_fn.graph))
+ FileCheck().check("aten::neg").check("aten::add").run(str(traced_fn.graph))
def test_call_script_mod_from_tracing_fn(self):
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
+ with self.disableModuleHook():
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False)
- @torch.jit.script_method
- def forward(self, x):
- return torch.mm(x, self.param)
+ @torch.jit.script_method
+ def forward(self, x):
+ for _i in range(4):
+ x += self.param
+ return x
- sm = ScriptMod()
+ sm = ScriptMod()
- @_trace(torch.rand(3, 4))
- def traced_fn(x):
- return sm(x) + 1.0
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return sm(x) + 1.0
- self.assertExpected(canonical(traced_fn.graph))
+ # parameter turns into constant and loop is perserved
+ FileCheck().check("prim::Constant[value=<Tensor>]").check("Loop") \
+ .run(str(traced_fn.graph))
def test_call_python_fn_from_traced_module(self):
def python_fn(x):
# Note: parameter self.param from the traced module should appear as
# an input to the graph and the neg op from the Python function should
# be properly inlined
- self.assertExpected(canonical(tm.graph))
+ self.assertTrue(len(list(tm.graph.inputs())) == 2)
+ FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph))
def test_call_python_mod_from_traced_module(self):
class PythonModule(torch.nn.Module):
# Note: the parameters from both modules should appear in the flattened
# inputs of the graph. All ops from both modules should be inlined.
- self.assertExpected(canonical(tm.graph))
+ self.assertTrue(len(list(tm.graph.inputs())) == 3)
+ FileCheck().check_not("value=<Tensor>").check_count("aten::mm", 2).check("aten::add") \
+ .run(str(tm.graph))
def test_call_traced_fn_from_traced_module(self):
@_trace(torch.rand(3, 4))
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
# Note: neg op from the traced function should be properly inlined
- self.assertExpected(canonical(tm.graph))
+ FileCheck().check("aten::mm").check_same("scope: TracedModule") \
+ .check_next("aten::neg").check("scope: TracedModule/traced_fn") \
+ .run(str(tm.graph))
def test_trace_hierarchy(self):
# Test that we preserve the module hierarchy for a ScriptModule
# Note: the parameters from both modules should appear in the flattened
# inputs of the graph. All ops from both modules should be inlined.
- self.assertExpected(canonical(tm.graph))
+ self.assertTrue(len(list(tm.graph.inputs())) == 3)
+ FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph))
def test_call_script_fn_from_traced_module(self):
@torch.jit.script
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
# Note: neg op from the script function should be properly inlined
- self.assertExpected(canonical(tm.graph))
+ FileCheck().check("aten::mm").check("aten::neg").run(str(tm.graph))
def test_call_script_module_from_traced_module(self):
class ScriptMod(torch.jit.ScriptModule):
# Note: the parameters from both modules should appear in the flattened
# inputs of the graph. All ops from both modules should be inlined.
- self.assertExpected(canonical(tm.graph))
+ self.assertTrue(len(list(tm.graph.inputs())) == 3)
+ FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph))
def test_call_python_fn_from_script_fn(self):
def python_fn(x):
# Note: the call to python_fn appears as `^python_fn()` and is called
# as a PythonOp in the interpreter
- self.assertExpected(canonical(script_fn.graph))
+ a = torch.tensor(1)
+ self.assertEqual(script_fn(a), torch.tensor(0))
+ FileCheck().check("python_fn").run(str(script_fn.graph))
def test_call_python_mod_from_script_fn(self):
class PythonModule(torch.nn.Module):
# Note: call to pm(x) appears as ^<python_value>() in the trace.
# Parameters are NOT inlined.
- self.assertExpected(str(script_fn.graph))
+ FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph))
def test_call_traced_fn_from_script_fn(self):
@_trace(torch.rand(3, 4))
# Note: the neg op from traced_fn should be properly inlined into the
# script function's graph
- self.assertExpected(str(script_fn.graph))
+ FileCheck().check("aten::neg").check("aten::add").run(str(script_fn.graph))
def test_call_traced_mod_from_script_fn(self):
class TracedModule(torch.nn.Module):
def script_fn(x):
return tm(x) + 1
- self.assertExpected(str(script_fn.graph))
+ FileCheck().check("aten::zeros").check_same("scope: TracedModule").check("aten::mm") \
+ .check("aten::add").run(str(script_fn.graph))
def test_call_script_fn_from_script_fn(self):
@torch.jit.script
# Note: the neg op from script_fn1 should be properly inlined into the
# graph of script_fn
- self.assertExpected(canonical(script_fn.graph))
+ FileCheck().check("aten::neg").run(str(script_fn.graph))
def test_call_script_mod_from_script_fn(self):
class ScriptMod(torch.jit.ScriptModule):
def script_fn(x):
return sm(x) + 1
- self.assertExpected(canonical(script_fn.graph))
+ FileCheck().check("zeros").check("aten::mm").check("add").run(str(script_fn.graph))
def test_call_python_fn_from_script_module(self):
def python_fn(x):
return python_fn(torch.mm(x, self.param))
sm = ScriptMod()
- self.assertExpected(str(sm.__getattr__('forward').graph))
+ FileCheck().check("aten::mm").check("python_fn") \
+ .run(str(sm.__getattr__('forward').graph))
def test_call_python_mod_from_script_module(self):
class PythonMod(torch.nn.Module):
sm = ScriptMod()
# Note: the call into PythonMod appears as ^<python_value>(). Parameters
# are NOT inlined
- self.assertExpected(str(sm.graph))
+ FileCheck().check("aten::mm").check("python_value").run(str(sm.graph))
def test_call_tracing_fn_from_script_module(self):
@_trace(torch.rand(3, 3))
return traced_fn(torch.mm(x, self.param))
sm = ScriptMod()
- self.assertExpected(str(sm.__getattr__('forward').graph))
+ FileCheck().check("aten::mm").check("aten::neg").run(str(sm.__getattr__('forward').graph))
def test_call_tracing_mod_from_script_module(self):
class TracedMod(torch.nn.Module):
# Note: the parameters from both modules should appear in the flattened
# input list to the graph. The mm op from TracedMod should be properly
# inlined
- self.assertExpected(str(sm.graph))
+ self.assertTrue(len(list(sm.graph.inputs())) == 3)
+ FileCheck().check("aten::mm").check("aten::mm").run(str(sm.graph))
def test_call_script_fn_from_script_module(self):
@torch.jit.script
return script_fn(torch.mm(x, self.param))
sm = ScriptMod()
- self.assertExpected(canonical(sm.__getattr__('forward').graph))
+ graph = (sm.__getattr__('forward').graph)
+ FileCheck().check("aten::mm").check("aten::neg").run(str(graph))
def test_call_script_mod_from_script_module(self):
class ScriptMod1(torch.jit.ScriptModule):
# Note: the parameters from both modules should appear in the flattened
# input list to the graph. The mm op from ScriptMod1 should be properly
# inlined
- self.assertExpected(canonical(sm.graph))
+ # 3 % values in graph input lists, two mms in body
+ FileCheck().check_count('%', 3).check(":").check_count("mm", 2).run(str(sm.graph))
def test_module_with_params_called_fails(self):
with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful "
a /= b
a *= b
return a, b
- self.checkScript(foo, (torch.rand(3), torch.rand(3)), check_expected=True)
+ self.checkScript(foo, (torch.rand(3), torch.rand(3)))
def test_pass(self):
def foo(x):
bool has_run = false;
+ friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
+
private:
void doCheckNot(
const std::vector<Check>& nots,
FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};
+std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) {
+ out << "FileCheck checks:\n";
+ for (const Check& c : fc.checks) {
+ out << "\t" << c << "\n";
+ }
+ return out;
+};
+
FileCheck::~FileCheck() {
if (!fcImpl->has_run) {
std::cout << "You have not run this instance of FileCheck!\n";
+ std::cout << *fcImpl;
}
fcImpl.reset();
};