Remove remaining IR Expect files (#17886)
authorElias Ellison <eellison@fb.com>
Tue, 12 Mar 2019 00:23:27 +0000 (17:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Mar 2019 00:32:19 +0000 (17:32 -0700)
Summary:
Last batch of IR expect files removed. Includes some removal of expect files that are no longer used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17886

Differential Revision: D14414435

Pulled By: eellison

fbshipit-source-id: 0bfd7ce66ac2f72a57f15f45ebd60b95e80b6c16

19 files changed:
test/cpp/jit/gtest.cpp
test/cpp/jit/no-gtest.cpp
test/cpp/jit/test_constant_pooling.h [new file with mode: 0644]
test/cpp/jit/test_irparser.h
test/cpp/jit/test_misc.h
test/cpp/jit/test_netdef_converter.h
test/expect/TestJit.test_cpp_cuda.expect [deleted file]
test/expect/TestJit.test_python_ir.expect [deleted file]
test/expect/TestScript.test_constant_pooling.expect [deleted file]
test/expect/TestScript.test_if_for_in_range.expect [deleted file]
test/expect/TestScript.test_loop_unroll_unused_counter.expect [deleted file]
test/expect/TestScript.test_math_tensor_number.expect [deleted file]
test/expect/TestScript.test_string_cu.expect [deleted file]
test/test_jit.py
torch/csrc/jit/init.cpp
torch/csrc/jit/irparser.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/testing/file_check.cpp
torch/csrc/jit/testing/file_check.h

index 8f6c2a8..0732729 100644 (file)
@@ -1,6 +1,7 @@
 #include <gtest/gtest.h>
 
 #include <test/cpp/jit/test_alias_analysis.h>
+#include <test/cpp/jit/test_constant_pooling.h>
 #include <test/cpp/jit/test_irparser.h>
 #include <test/cpp/jit/test_misc.h>
 #include <test/cpp/jit/test_netdef_converter.h>
@@ -38,6 +39,7 @@ JIT_TEST(WriteTracking)
 JIT_TEST(Wildcards)
 JIT_TEST(MemoryDAG)
 JIT_TEST(IRParser)
+JIT_TEST(ConstantPooling)
 
 JIT_TEST(NetDefConverter)
 
index 00b6892..a2c950f 100644 (file)
@@ -1,4 +1,5 @@
 #include <test/cpp/jit/test_alias_analysis.h>
+#include <test/cpp/jit/test_constant_pooling.h>
 #include <test/cpp/jit/test_irparser.h>
 #include <test/cpp/jit/test_misc.h>
 #include <test/cpp/jit/test_netdef_converter.h>
@@ -8,20 +9,19 @@
 
 namespace torch {
 namespace jit {
-std::string runJITCPPTests() {
-  std::stringstream out;
+void runJITCPPTests() {
   testNoneSchemaMatch();
   testAutogradProfiler();
   testADFormulas();
   testArgumentSpec();
   testAttributes();
-  testBlocks(out);
+  testBlocks();
   testCodeTemplate();
   testControlFlow();
-  testCreateAutodiffSubgraphs(out);
+  testCreateAutodiffSubgraphs();
   testCustomOperators();
-  testDifferentiate(out);
-  testDifferentiateWithRequiresGrad(out);
+  testDifferentiate();
+  testDifferentiateWithRequiresGrad();
   testDynamicDAG();
   testEvalModeForLoadedModule();
   testFromQualString();
@@ -42,9 +42,9 @@ std::string runJITCPPTests() {
   testWriteTracking();
   testWildcards();
   testMemoryDAG();
-  testNetDefConverter(out);
-  testIRParser(out);
-  return out.str();
+  testNetDefConverter();
+  testIRParser();
+  testConstantPooling();
 }
 
 } // namespace jit
diff --git a/test/cpp/jit/test_constant_pooling.h b/test/cpp/jit/test_constant_pooling.h
new file mode 100644 (file)
index 0000000..c77df2f
--- /dev/null
@@ -0,0 +1,87 @@
+#pragma once
+
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/irparser.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
+#include <torch/csrc/jit/testing/file_check.h>
+#include "test/cpp/jit/test_base.h"
+
+#include <sstream>
+#include <string>
+
+namespace torch {
+namespace jit {
+
+void testConstantPooling() {
+  {
+    auto graph = std::make_shared<Graph>();
+    script::parseIR(
+        R"IR(
+graph():
+  %8 : int = prim::Constant[value=1]()
+  %10 : int = prim::Constant[value=1]()
+  return (%8, %10)
+  )IR",
+        &*graph);
+    ConstantPooling(graph);
+    testing::FileCheck()
+        .check_count("prim::Constant", 1, /*exactly*/ true)
+        ->run(*graph);
+  }
+  {
+    auto graph = std::make_shared<Graph>();
+    script::parseIR(
+        R"IR(
+graph(%cond : Tensor):
+  %a : string = prim::Constant[value="bcd"]()
+  %3 : bool = prim::Bool(%cond)
+  %b : string = prim::If(%3)
+    block0():
+      %b.1 : string = prim::Constant[value="abc"]()
+      -> (%b.1)
+    block1():
+      %b.2 : string = prim::Constant[value="abc"]()
+      -> (%b.2)
+  %7 : (string, string) = prim::TupleConstruct(%a, %b)
+  return (%7)
+  )IR",
+        &*graph);
+    ConstantPooling(graph);
+    testing::FileCheck()
+        .check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
+        ->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
+        ->run(*graph);
+  }
+  {
+    auto graph = std::make_shared<Graph>();
+    script::parseIR(
+        R"IR(
+graph():
+  %2 : int = prim::Constant[value=2]()
+  %1 : int = prim::Constant[value=1]()
+  %5 : int? = prim::Constant()
+  %7 : Device? = prim::Constant()
+  %10 : int = prim::Constant[value=6]()
+  %3 : int[] = prim::ListConstruct(%1, %2)
+  %x : Tensor = aten::tensor(%3, %5, %7)
+  %y : Tensor = aten::tensor(%3, %10, %7)
+  %9 : int[] = prim::ListConstruct(%1, %2)
+  %z : Tensor = aten::tensor(%9, %10, %7)
+  %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
+  return (%14)
+  )IR",
+        &*graph);
+    // three tensors created - two different devices among the three
+    // don't have good support for parsing tensor constants
+    ConstantPropagation(graph);
+    ConstantPooling(graph);
+    testing::FileCheck()
+        .check_count("Float(2) = prim::Constant", 1, /*exactly*/ true)
+        ->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true)
+        ->run(*graph);
+  }
+}
+
+} // namespace jit
+} // namespace torch
index b12da1f..c8c3eea 100644 (file)
@@ -39,7 +39,7 @@ static void checkRoundtrip(const std::string& s) {
   AT_ASSERT(original == parsed);
 }
 
-void testIRParser(std::ostream& out = std::cout) {
+void testIRParser() {
   {
     auto graph = std::make_shared<Graph>();
     script::parseIR(
index 52077ee..cdf5642 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "test/cpp/jit/test_base.h"
 
+#include <torch/csrc/jit/passes/canonicalize.h>
 #include "ATen/core/interned_strings.h"
 #include "torch/csrc/autograd/generated/variable_factories.h"
 #include "torch/csrc/autograd/variable.h"
@@ -34,6 +35,7 @@
 #include "torch/csrc/autograd/engine.h"
 #include "torch/csrc/autograd/variable.h"
 
+#include <torch/csrc/jit/testing/file_check.h>
 #include "ATen/core/ivalue.h"
 #include "torch/csrc/jit/graph_executor.h"
 #include "torch/csrc/jit/script/compiler.h"
@@ -831,7 +833,7 @@ void testADFormulas() {
   }
 }
 
-void testDifferentiate(std::ostream& out = std::cout) {
+void testDifferentiate() {
   auto graph = std::make_shared<Graph>();
   at::ScalarType s = at::ScalarType::Float;
   auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
@@ -852,13 +854,19 @@ void testDifferentiate(std::ostream& out = std::cout) {
   ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs);
   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
-  out << "testDifferentiate\n";
-  out << *grad_spec.f;
-  out << *grad_spec.df;
-  out << "\n";
+  testing::FileCheck()
+      .check_count("aten::mul", 2)
+      ->check("aten::size")
+      ->check("aten::add")
+      ->run(*grad_spec.f);
+  testing::FileCheck()
+      .check("prim::GradOf[name=\"aten::add\"]")
+      ->check_count("prim::GradOf[name=\"aten::mul\"]", 2)
+      ->check_count("AutogradAdd", 2)
+      ->run(*grad_spec.df);
 }
 
-void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
+void testDifferentiateWithRequiresGrad() {
   // Build up a fake graph
   auto graph = std::make_shared<Graph>();
   auto a = SymbolicVariable::asNewInput(*graph);
@@ -884,10 +892,17 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
   ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
-  out << "testDifferentiateWithRequiresGrad\n";
-  out << *grad_spec.f;
-  out << *grad_spec.df;
-  out << "\n";
+  testing::FileCheck()
+      .check("aten::mul")
+      ->check_count("aten::add", 2)
+      ->check("aten::mul")
+      ->check("aten::size")
+      ->check("aten::add")
+      ->run(*grad_spec.f);
+
+  testing::FileCheck()
+      .check_count("prim::GradOf[name=\"aten::mul\"]", 1, /*exactly*/ true)
+      ->run(*grad_spec.df);
 }
 
 void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
@@ -937,11 +952,18 @@ void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
   ASSERT_EQ(second_key, expected_key);
 }
 
-void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) {
+void testCreateAutodiffSubgraphs() {
   auto graph = build_lstm();
   CreateAutodiffSubgraphs(graph, /*threshold=*/2);
-  out << "testCreateAutodiffSubgraphs\n";
-  out << *graph << "\n";
+  // all of the ops are within the DifferentiableGraph
+  testing::FileCheck()
+      .check_not("aten::mm")
+      ->check_not("aten::sigmoid")
+      ->check_not("aten::tanh")
+      ->check_not("aten::mul")
+      ->check("DifferentiableGraph")
+      ->check_next("return")
+      ->run(*graph);
 }
 
 void testSubgraphUtils() {
@@ -1088,11 +1110,13 @@ void testGraphExecutor() {
 }
 
 void testBlocks(std::ostream& out = std::cout) {
-  Graph g;
-  auto a = Var::asNewInput(g, "a");
-  auto b = Var::asNewInput(g, "b");
+  auto g = std::make_shared<Graph>();
+  // auto g = *graph;
+  auto a = Var::asNewInput(*g, "a");
+  auto b = Var::asNewInput(*g, "b");
   auto c = a + b;
-  auto r = g.appendNode(g.create(prim::If, {Var::asNewInput(g, "c").value()}));
+  auto r =
+      g->appendNode(g->create(prim::If, {Var::asNewInput(*g, "c").value()}));
   auto then_block = r->addBlock();
   auto else_block = r->addBlock();
   {
@@ -1106,15 +1130,32 @@ void testBlocks(std::ostream& out = std::cout) {
     auto e = d + c;
     else_block->registerOutput(e.value());
   }
-  g.registerOutput((Var(r->output()) + c).value());
-  g.lint();
-  out << "testBlocks\n" << g << "\n";
+  g->registerOutput((Var(r->output()) + c).value());
+  g->lint();
+  testing::FileCheck()
+      .check("add")
+      ->check("prim::If")
+      ->check("block0")
+      ->check("aten::add")
+      ->check("block1")
+      ->check_count("aten::add", 3)
+      ->run(*g);
   r->eraseBlock(0);
-  out << g << "\n";
-  g.lint();
+  testing::FileCheck()
+      .check("add")
+      ->check("prim::If")
+      ->check("block0")
+      ->check_not("block")
+      ->run(*g);
+  g->lint();
   // test recursive copy of blocks works
-  auto g2 = g.copy();
-  out << *g2 << "\n";
+  auto g2 = g->copy();
+  testing::FileCheck()
+      .check("add")
+      ->check("prim::If")
+      ->check("block0")
+      ->check_not("block")
+      ->run(*g2);
 }
 
 const auto cf_examples = R"JIT(
@@ -1508,9 +1549,7 @@ void testSchemaParser() {
         Symbol::fromQualString("alias::b"),
     };
     const auto expectedAfter = std::unordered_set<Symbol>{
-        Symbol::fromQualString("alias::b"),
-        Symbol::fromQualString("alias::c")
-    };
+        Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
     ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
     ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
     ASSERT_FALSE(containedAliasInfo.isWrite());
@@ -1849,7 +1888,6 @@ void testNoneSchemaMatch() {
   // checking that constant propagation ran wo/failure
   AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
 }
-
 } // namespace
 } // namespace jit
 } // namespace torch
index 3dba9b8..0ac7d92 100644 (file)
@@ -24,7 +24,7 @@ static caffe2::OperatorDef createOperator(
   return op;
 }
 
-void testNetDefConverter(std::ostream& out = std::cout) {
+void testNetDefConverter() {
   {
     // Check a simple net conversion back and forth.
 
diff --git a/test/expect/TestJit.test_cpp_cuda.expect b/test/expect/TestJit.test_cpp_cuda.expect
deleted file mode 100644 (file)
index e15ab3c..0000000
+++ /dev/null
@@ -1,169 +0,0 @@
-testBlocks
-graph(%a : Tensor,
-      %b : Tensor,
-      %c : Tensor):
-  %2 : int = prim::Constant[value=1]()
-  %3 : Tensor = aten::add(%a, %b, %2)
-  %5 : Tensor = prim::If(%c)
-    block0():
-      %6 : int = prim::Constant[value=1]()
-      %7 : Tensor = aten::add(%3, %3, %6)
-      -> (%7)
-    block1():
-      %8 : int = prim::Constant[value=1]()
-      %9 : Tensor = aten::add(%b, %3, %8)
-      %10 : int = prim::Constant[value=1]()
-      %11 : Tensor = aten::add(%9, %3, %10)
-      -> (%11)
-  %12 : int = prim::Constant[value=1]()
-  %13 : Tensor = aten::add(%5, %3, %12)
-  return (%13)
-
-graph(%a : Tensor,
-      %b : Tensor,
-      %c : Tensor):
-  %2 : int = prim::Constant[value=1]()
-  %3 : Tensor = aten::add(%a, %b, %2)
-  %5 : Tensor = prim::If(%c)
-    block0():
-      %8 : int = prim::Constant[value=1]()
-      %9 : Tensor = aten::add(%b, %3, %8)
-      %10 : int = prim::Constant[value=1]()
-      %11 : Tensor = aten::add(%9, %3, %10)
-      -> (%11)
-  %12 : int = prim::Constant[value=1]()
-  %13 : Tensor = aten::add(%5, %3, %12)
-  return (%13)
-
-graph(%a : Tensor,
-      %b : Tensor,
-      %c : Tensor):
-  %3 : int = prim::Constant[value=1]()
-  %4 : Tensor = aten::add(%a, %b, %3)
-  %5 : Tensor = prim::If(%c)
-    block0():
-      %6 : int = prim::Constant[value=1]()
-      %7 : Tensor = aten::add(%b, %4, %6)
-      %8 : int = prim::Constant[value=1]()
-      %9 : Tensor = aten::add(%7, %4, %8)
-      -> (%9)
-  %10 : int = prim::Constant[value=1]()
-  %11 : Tensor = aten::add(%5, %4, %10)
-  return (%11)
-
-testCreateAutodiffSubgraphs
-graph(%0 : Tensor,
-      %1 : Tensor,
-      %2 : Tensor,
-      %3 : Tensor,
-      %4 : Tensor):
-  %7 : int = prim::Constant[value=1]()
-  %23 : Tensor, %24 : Tensor = prim::DifferentiableGraph_0(%2, %1, %4, %0, %3)
-  return (%23, %24)
-with prim::DifferentiableGraph_0 = graph(%13 : Tensor,
-      %32 : Tensor,
-      %33 : Tensor,
-      %35 : Tensor,
-      %36 : Tensor):
-  %37 : Tensor = aten::mm(%35, %36)
-  %34 : Tensor = aten::mm(%32, %33)
-  %30 : int = prim::Constant[value=1]()
-  %31 : Tensor = aten::add(%37, %34, %30)
-  %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor = prim::ConstantChunk[chunks=4, dim=1](%31)
-  %22 : Tensor = aten::sigmoid(%24)
-  %20 : Tensor = aten::sigmoid(%27)
-  %18 : Tensor = aten::tanh(%26)
-  %16 : Tensor = aten::sigmoid(%25)
-  %14 : Tensor = aten::mul(%16, %13)
-  %11 : Tensor = aten::mul(%22, %18)
-  %8 : Tensor = aten::add(%14, %11, %30)
-  %4 : Tensor = aten::tanh(%8)
-  %2 : Tensor = aten::mul(%20, %4)
-  return (%2, %8)
-
-testDifferentiate
-graph(%0 : Float(2, 3, 4),
-      %1 : Float(2, 3, 4)):
-  %2 : Float(2, 3, 4) = aten::mul(%0, %1)
-  %3 : Float(2, 3, 4) = aten::mul(%2, %0)
-  %4 : int = prim::Constant[value=1]()
-  %7 : int[] = aten::size(%3)
-  %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
-  return (%5, %2, %7)
-graph(%0 : Float(2, 3, 4),
-      %1 : Float(2, 3, 4),
-      %2 : Float(2, 3, 4),
-      %3 : Float(2, 3, 4),
-      %4 : Float(2, 3, 4),
-      %5 : int[]):
-  %7 : int = prim::Constant[value=1]()
-  %6 : int[] = aten::size(%3)
-  %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0)
-    block0():
-      %10 : Tensor = aten::_grad_sum_to_size(%0, %5)
-      %11 : Float(2, 3, 4) = aten::mul(%0, %7)
-      %12 : Tensor = aten::_grad_sum_to_size(%11, %6)
-      -> (%10, %12)
-  %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8)
-    block0():
-      %15 : Tensor = aten::mul(%8, %2)
-      %16 : int[] = aten::size(%4)
-      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%15, %16)
-      %18 : Tensor = aten::mul(%8, %4)
-      %19 : int[] = aten::size(%2)
-      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%18, %19)
-      -> (%grad_self.1, %grad_other.1)
-  %21 : Tensor = prim::AutogradAdd(%1, %grad_self.2)
-  %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21)
-    block0():
-      %24 : Tensor = aten::mul(%21, %3)
-      %25 : int[] = aten::size(%2)
-      %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %25)
-      %27 : Tensor = aten::mul(%21, %2)
-      %28 : int[] = aten::size(%3)
-      %grad_other.3 : Tensor = aten::_grad_sum_to_size(%27, %28)
-      -> (%grad_self.3, %grad_other.3)
-  %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
-  %31 : Tensor = prim::AutogradAdd(%9, %grad_other)
-  return (%30, %31)
-
-testDifferentiateWithRequiresGrad
-graph(%0 : Float(*),
-      %1 : Float(*)):
-  %2 : Float(*) = aten::mul(%1, %1)
-  %3 : int = prim::Constant[value=1]()
-  %4 : Float(*) = aten::add(%2, %1, %3)
-  %6 : Float(*) = aten::add(%4, %0, %3)
-  %7 : Float(*) = aten::mul(%6, %0)
-  %11 : int[] = aten::size(%7)
-  %9 : Float(*) = aten::add(%7, %1, %3)
-  return (%4, %9, %6, %11)
-graph(%0 : Float(*),
-      %1 : Float(*),
-      %2 : Float(*),
-      %3 : Float(*),
-      %4 : int[]):
-  %6 : int = prim::Constant[value=1]()
-  %5 : int[] = aten::size(%2)
-  %7 : Tensor = prim::GradOf[name="aten::add"](%0)
-    block0():
-      %8 : Tensor = aten::_grad_sum_to_size(%0, %4)
-      -> (%8)
-  %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7)
-    block0():
-      %11 : Tensor = aten::mul(%7, %2)
-      %12 : int[] = aten::size(%3)
-      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%11, %12)
-      %14 : Tensor = aten::mul(%7, %3)
-      %15 : int[] = aten::size(%2)
-      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %15)
-      -> (%grad_self.1, %grad_other.1)
-  %17 : Tensor = prim::AutogradAdd(%1, %grad_self)
-  %18 : Tensor = prim::GradOf[name="aten::add"](%17)
-    block0():
-      %19 : Tensor = aten::mul(%17, %6)
-      %20 : Tensor = aten::_grad_sum_to_size(%19, %5)
-      -> (%20)
-  %21 : Tensor = prim::AutogradAdd(%grad_other, %18)
-  return (%21)
-
diff --git a/test/expect/TestJit.test_python_ir.expect b/test/expect/TestJit.test_python_ir.expect
deleted file mode 100644 (file)
index b7fa4c1..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-graph(%0 : Tensor,
-      %1 : Tensor):
-  %2 : int = prim::Constant[value=1]()
-  %3 : Double(1) = aten::add(%0, %1, %2)
-  %4 : Double(1) = aten::mul(%0, %3)
-  %5 : Double(1) = aten::tanh(%4)
-  %6 : Double(1) = aten::sigmoid(%5)
-  %7 : Tensor = prim::TensorTest[a= 1  1  1  1 [ Variable[CPUDoubleType]{2,2} ]]()
-  return (%6)
diff --git a/test/expect/TestScript.test_constant_pooling.expect b/test/expect/TestScript.test_constant_pooling.expect
deleted file mode 100644 (file)
index 9bfb2c0..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-graph(%cond : Tensor):
-  %1 : int[] = prim::Constant[value=[1]]()
-  %2 : int[] = prim::Constant[value=[0]]()
-  %3 : int = prim::Constant[value=3]()
-  %4 : Float(2) = prim::Constant[value= 4  4 [ Variable[CPUFloatType]{2} ]]()
-  %5 : Float(2) = prim::Constant[value= 1  1 [ Variable[CPUFloatType]{2} ]]()
-  %c.1 : int = prim::Constant[value=0]()
-  %a : int = prim::Constant[value=1]()
-  %d : string = prim::Constant[value="abc"]()
-  %e : string = prim::Constant[value="bcd"]()
-  %10 : int = prim::Constant[value=6]()
-  %11 : Device = prim::Constant[value="cpu"]()
-  %12 : bool = prim::Bool(%cond)
-  %c : int, %y : Tensor = prim::If(%12)
-    block0():
-      -> (%3, %4)
-    block1():
-      %y.2 : Tensor = aten::rand(%2, %10, %c.1, %11)
-      %16 : bool = prim::Bool(%cond)
-      %y.4 : Tensor = prim::If(%16)
-        block0():
-          %y.3 : Tensor = aten::rand(%1, %10, %c.1, %11)
-          -> (%y.3)
-        block1():
-          -> (%y.2)
-       = prim::Print(%d, %e, %d, %5, %y.4, %5)
-      -> (%c.1, %y.4)
-  %19 : (int, int, int, Tensor, Tensor) = prim::TupleConstruct(%a, %3, %c, %5, %y)
-  return (%19)
diff --git a/test/expect/TestScript.test_if_for_in_range.expect b/test/expect/TestScript.test_if_for_in_range.expect
deleted file mode 100644 (file)
index a32202c..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-graph(%a.1 : Dynamic,
-      %b.1 : Dynamic):
-  %d.1 : Long() = prim::Constant[value={3}]()
-  %3 : Long() = prim::Constant[value={20}]()
-  %4 : Byte() = prim::Constant[value={1}]()
-  %a : Dynamic, %d : Long(), %b : Dynamic = prim::Loop(%3, %4, %a.1, %d.1, %b.1)
-    block0(%_ : Dynamic, %6 : Dynamic, %10 : Long(), %14 : Dynamic):
-      %7 : Long() = prim::Constant[value={10}]()
-      %8 : Dynamic = aten::gt(%6, %7)
-      %a.3 : Dynamic, %b.3 : Dynamic, %d.3 : Long() = prim::If(%8)
-        block0():
-          %9 : Long() = prim::Constant[value={3}]()
-          %a.2 : Dynamic = aten::add[alpha={1}](%9, %10)
-          -> (%a.2, %14, %10)
-        block1():
-          %12 : Long() = prim::Constant[value={3}]()
-          %b.2 : Dynamic = aten::add[alpha={1}](%12, %10)
-          %d.2 : Long() = prim::Constant[value={4}]()
-          -> (%6, %b.2, %d.2)
-      %20 : Byte() = prim::Constant[value={1}]()
-      -> (%20, %a.3, %d.3, %b.3)
-  return (%d)
diff --git a/test/expect/TestScript.test_loop_unroll_unused_counter.expect b/test/expect/TestScript.test_loop_unroll_unused_counter.expect
deleted file mode 100644 (file)
index 4f66287..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-graph(%x : Tensor):
-  %1 : bool = prim::Constant[value=1]()
-  %y.1 : int = prim::Constant[value=0]()
-  %3 : int = prim::Constant[value=1]()
-  %4 : int = prim::Int(%x)
-  %5 : int = prim::Constant[value=8]()
-  %6 : int = aten::__round_to_zero_floordiv(%4, %5)
-  %7 : int = prim::Constant[value=8]()
-  %8 : int = aten::mul(%6, %7)
-  %9 : int = aten::sub(%4, %8)
-  %y.3 : int = prim::Loop(%6, %1, %y.1)
-    block0(%11 : int, %12 : int):
-      %y.12 : int = aten::add(%12, %3)
-      %y.5 : int = aten::add(%y.12, %3)
-      %y.6 : int = aten::add(%y.5, %3)
-      %y.7 : int = aten::add(%y.6, %3)
-      %y.8 : int = aten::add(%y.7, %3)
-      %y.9 : int = aten::add(%y.8, %3)
-      %y.10 : int = aten::add(%y.9, %3)
-      %y.11 : int = aten::add(%y.10, %3)
-      -> (%1, %y.11)
-  %y : int = prim::Loop(%9, %1, %y.3)
-    block0(%22 : int, %23 : int):
-      %y.4 : int = aten::add(%23, %3)
-      -> (%1, %y.4)
-  return (%y)
diff --git a/test/expect/TestScript.test_math_tensor_number.expect b/test/expect/TestScript.test_math_tensor_number.expect
deleted file mode 100644 (file)
index 13555dd..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor):
-  %1 : int = prim::Constant[value=1]()
-  %2 : int = prim::Constant[value=7]()
-  %3 : Tensor = aten::add(%x, %2, %1)
-  return (%3)
diff --git a/test/expect/TestScript.test_string_cu.expect b/test/expect/TestScript.test_string_cu.expect
deleted file mode 100644 (file)
index 0604991..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-graph(%a : Tensor):
-  %3 : string = prim::Constant[value="aa"]()
-  %1 : string = prim::Constant[value="a\n\tb\n"]()
-  %2 : int = prim::Constant[value=2]()
-   = prim::Print(%a, %1, %2, %3)
-  return (%a)
index 47c408f..2bb4a27 100644 (file)
@@ -1322,9 +1322,7 @@ class TestJit(JitTestCase):
     def test_cpp_cuda(self):
         from cpp.jit import tests_setup
         tests_setup.setup()
-        # rather than rebuild assertExpected in cpp,
-        # just glob all the cpp outputs into one file for now
-        self.assertExpected(torch._C._jit_run_cpp_tests())
+        torch._C._jit_run_cpp_tests()
         tests_setup.shutdown()
 
     def test_batchnorm(self):
@@ -3255,7 +3253,7 @@ class TestScript(JitTestCase):
 a")
                 return a
         ''')
-        self.assertExpected(str(cu.foo.graph))
+        FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph))
 
     def test_string_ops(self):
         def foo():
@@ -3646,33 +3644,6 @@ a")
         y2 = torch.sum(x, dim=0)
         self.assertEqual(y, y2)
 
-    def test_constant_pooling(self):
-        def func(cond):
-            a = 1
-            b = 4
-            c = 0
-            d = "abc"
-            e = "bcd"
-            f = "abc"
-            x = torch.ones([2])
-            y = x * 4
-            z = torch.ones([2])
-            if bool(cond):
-                c = b - a
-            else:
-                y = torch.rand(0)
-                if bool(cond):
-                    y = torch.rand(1)
-                print(d, e, f, x, y, z)
-            b = b - a
-            return a, b, c, x, y
-
-        self.checkScript(func, torch.tensor([1]))
-        graph = torch.jit.script(func).graph
-        self.run_pass('constant_propagation', graph)
-        self.run_pass('constant_pooling', graph)
-        self.assertExpectedGraph(graph)
-
     def test_constant_pooling_none(self):
         @torch.jit.script
         def typed_nones(a=None, b=None, c=None):
index ad0e0d9..2fcaa68 100644 (file)
@@ -80,11 +80,11 @@ bool loadPythonClasses() {
 } // anonymous namespace
 
 #if defined(_WIN32)
-std::string runJITCPPTests() {
+void runJITCPPTests() {
   AT_ERROR("JIT tests not yet supported on Windows");
 }
 #else
-std::string runJITCPPTests();
+void runJITCPPTests();
 #endif
 
 void initJITBindings(PyObject* module) {
index 0b64bc1..4af5c80 100644 (file)
@@ -79,13 +79,8 @@ void parseIR(const std::string& str, torch::jit::Graph* graph) {
 }
 
 VarWithType IRParser::parseVarWithType() {
-  L.expect('%');
   VarWithType r;
-  if (L.cur().kind == TK_IDENT) {
-    r.name = L.expect(TK_IDENT).text();
-  } else {
-    r.name = L.expect(TK_NUMBER).text();
-  }
+  r.name = parseVar();
   r.type = TensorType::get();
   if (L.nextIf(':')) {
     auto type_alias = type_parser.parseType();
@@ -98,9 +93,16 @@ VarWithType IRParser::parseVarWithType() {
 std::string IRParser::parseVar() {
   L.expect('%');
   if (L.cur().kind == TK_IDENT) {
-    return L.expect(TK_IDENT).text();
+    auto name = L.expect(TK_IDENT).text();
+    if (L.cur().kind == TK_NUMBER) {
+      auto suffix = L.expect(TK_NUMBER).text();
+      AT_ASSERT(suffix[0] == '.');
+      name += suffix;
+    }
+    return name;
+  } else {
+    return L.expect(TK_NUMBER).text();
   }
-  return L.expect(TK_NUMBER).text();
 }
 
 void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
index 5425219..baa1bd2 100644 (file)
@@ -269,7 +269,6 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
   std::shared_ptr<Module> module_;
 };
 
-
 // defines how modules/methods behave inside the script subset.
 // for now this does not have any interaction with python.
 // in the future, we will add the ability to resolve `self.foo` to python
@@ -771,8 +770,7 @@ void initJitScriptBindings(PyObject* module) {
               auto& p = parameters[i];
               py::tuple r(2);
               result[i] = std::make_tuple(
-                  p.key(),
-                  autograd::as_variable_ref(p->slot()->toTensor()));
+                  p.key(), autograd::as_variable_ref(p->slot()->toTensor()));
             }
             return result;
           })
@@ -786,9 +784,7 @@ void initJitScriptBindings(PyObject* module) {
               py::tuple r(3);
               IValue v = *buffer->slot();
               result[i] = std::make_tuple(
-                  buffer.key(),
-                  buffer->type,
-                  toPyObject(std::move(v)));
+                  buffer.key(), buffer->type, toPyObject(std::move(v)));
             }
             return result;
           })
@@ -932,8 +928,7 @@ void initJitScriptBindings(PyObject* module) {
              std::shared_ptr<Module> orig) {
             std::vector<IValue*> member_inputs;
             for (auto& p : params) {
-              NamedIValue* np =
-                  std::get<0>(p)->find_parameter(std::get<1>(p));
+              NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p));
               if (np == nullptr) {
                 np = std::get<0>(p)->find_buffer(std::get<1>(p));
               }
@@ -942,8 +937,7 @@ void initJitScriptBindings(PyObject* module) {
             }
 
             Method* orig_method = orig->find_method(name);
-            m->create_method(
-                name, orig_method->graph()->copy(), member_inputs);
+            m->create_method(name, orig_method->graph()->copy(), member_inputs);
           });
 
   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
@@ -1071,7 +1065,14 @@ void initJitScriptBindings(PyObject* module) {
           py::arg("str"),
           py::arg("count"),
           py::arg("exactly") = false)
-      .def("run", &testing::FileCheck::run);
+      .def(
+          "run",
+          [](testing::FileCheck& f, const std::string& str) {
+            return f.run(str);
+          })
+      .def("run", [](testing::FileCheck& f, const Graph& g) {
+        return f.run(g);
+      });
 }
 } // namespace script
 } // namespace jit
index d428699..3af8c5a 100644 (file)
@@ -18,7 +18,7 @@
 #include <sstream>
 #include <string>
 
-#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/testing/file_check.h>
 
 namespace torch {
@@ -303,6 +303,12 @@ void FileCheck::run(const std::string& test_file) {
   fcImpl->run(test_file);
 };
 
+void FileCheck::run(const Graph& graph) {
+  std::stringstream graph_str;
+  graph_str << graph;
+  fcImpl->run(graph_str.str());
+};
+
 FileCheck* FileCheck::check(const std::string& str) {
   fcImpl->addCheck(CHECK, str);
   return this;
index 20f987f..cf80575 100644 (file)
@@ -5,6 +5,9 @@
 
 namespace torch {
 namespace jit {
+
+struct Graph;
+
 namespace testing {
 
 struct FileCheckImpl;
@@ -17,6 +20,9 @@ struct FileCheck {
   // Run FileCheck against test string
   TORCH_API void run(const std::string& test_string);
 
+  // Run FileCheck against dump of graph IR
+  TORCH_API void run(const Graph& graph);
+
   // Checks that the string occurs, starting at the end of the most recent match
   TORCH_API FileCheck* check(const std::string& str);