From 4d2f6f1bbeeb15b57d62365a6b3aa5e36a4d1f9a Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 12 Mar 2019 11:25:37 -0700 Subject: [PATCH] Remove remaining test jit expects redux (#17924) Summary: Trying to reland https://github.com/pytorch/pytorch/pull/17886 since it broke a build and I reverted it Pull Request resolved: https://github.com/pytorch/pytorch/pull/17924 Differential Revision: D14423842 Pulled By: eellison fbshipit-source-id: f219e786bd07f7da3b7f9e866981199f5ccf6318 --- test/cpp/jit/gtest.cpp | 2 + test/cpp/jit/no-gtest.cpp | 18 +-- test/cpp/jit/test_constant_pooling.h | 87 +++++++++++ test/cpp/jit/test_irparser.h | 2 +- test/cpp/jit/test_misc.h | 94 ++++++++---- test/cpp/jit/test_netdef_converter.h | 2 +- test/expect/TestJit.test_cpp_cuda.expect | 169 --------------------- test/expect/TestJit.test_python_ir.expect | 9 -- test/expect/TestScript.test_if_for_in_range.expect | 22 --- ...stScript.test_loop_unroll_unused_counter.expect | 26 ---- .../TestScript.test_math_tensor_number.expect | 5 - test/expect/TestScript.test_string_cu.expect | 6 - test/test_jit.py | 33 +--- torch/csrc/jit/init.cpp | 4 +- torch/csrc/jit/irparser.cpp | 18 ++- torch/csrc/jit/script/init.cpp | 23 +-- torch/csrc/jit/testing/file_check.cpp | 8 +- torch/csrc/jit/testing/file_check.h | 6 + 18 files changed, 205 insertions(+), 329 deletions(-) create mode 100644 test/cpp/jit/test_constant_pooling.h delete mode 100644 test/expect/TestJit.test_cpp_cuda.expect delete mode 100644 test/expect/TestJit.test_python_ir.expect delete mode 100644 test/expect/TestScript.test_if_for_in_range.expect delete mode 100644 test/expect/TestScript.test_loop_unroll_unused_counter.expect delete mode 100644 test/expect/TestScript.test_math_tensor_number.expect delete mode 100644 test/expect/TestScript.test_string_cu.expect diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp index 2c1edae..1186dd9 100644 --- a/test/cpp/jit/gtest.cpp +++ b/test/cpp/jit/gtest.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -40,6 +41,7 @@ JIT_TEST(WriteTracking) JIT_TEST(Wildcards) JIT_TEST(MemoryDAG) JIT_TEST(IRParser) +JIT_TEST(ConstantPooling) JIT_TEST(NetDefConverter) diff --git a/test/cpp/jit/no-gtest.cpp b/test/cpp/jit/no-gtest.cpp index 3c6ca37..845a38e 100644 --- a/test/cpp/jit/no-gtest.cpp +++ b/test/cpp/jit/no-gtest.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -10,20 +11,19 @@ using namespace torch::jit::script; namespace torch { namespace jit { -std::string runJITCPPTests() { - std::stringstream out; +void runJITCPPTests() { testNoneSchemaMatch(); testAutogradProfiler(); testADFormulas(); testArgumentSpec(); testAttributes(); - testBlocks(out); + testBlocks(); testCodeTemplate(); testControlFlow(); - testCreateAutodiffSubgraphs(out); + testCreateAutodiffSubgraphs(); testCustomOperators(); - testDifferentiate(out); - testDifferentiateWithRequiresGrad(out); + testDifferentiate(); + testDifferentiateWithRequiresGrad(); testDynamicDAG(); testEvalModeForLoadedModule(); testFromQualString(); @@ -44,10 +44,10 @@ std::string runJITCPPTests() { testWriteTracking(); testWildcards(); testMemoryDAG(); - testNetDefConverter(out); - testIRParser(out); + testNetDefConverter(); + testIRParser(); + testConstantPooling(); testClassParser(); - return out.str(); } } // namespace jit diff --git a/test/cpp/jit/test_constant_pooling.h b/test/cpp/jit/test_constant_pooling.h new file mode 100644 index 0000000..c77df2f --- /dev/null +++ b/test/cpp/jit/test_constant_pooling.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "test/cpp/jit/test_base.h" + +#include +#include + +namespace torch { +namespace jit { + +void testConstantPooling() { + { + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(): + %8 : int = prim::Constant[value=1]() + %10 : int = prim::Constant[value=1]() + return (%8, %10) + )IR", + &*graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count("prim::Constant", 1, /*exactly*/ true) + ->run(*graph); + } + { + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(%cond : Tensor): + %a : string = prim::Constant[value="bcd"]() + %3 : bool = prim::Bool(%cond) + %b : string = prim::If(%3) + block0(): + %b.1 : string = prim::Constant[value="abc"]() + -> (%b.1) + block1(): + %b.2 : string = prim::Constant[value="abc"]() + -> (%b.2) + %7 : (string, string) = prim::TupleConstruct(%a, %b) + return (%7) + )IR", + &*graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true) + ->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true) + ->run(*graph); + } + { + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(): + %2 : int = prim::Constant[value=2]() + %1 : int = prim::Constant[value=1]() + %5 : int? = prim::Constant() + %7 : Device? = prim::Constant() + %10 : int = prim::Constant[value=6]() + %3 : int[] = prim::ListConstruct(%1, %2) + %x : Tensor = aten::tensor(%3, %5, %7) + %y : Tensor = aten::tensor(%3, %10, %7) + %9 : int[] = prim::ListConstruct(%1, %2) + %z : Tensor = aten::tensor(%9, %10, %7) + %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y) + return (%14) + )IR", + &*graph); + // three tensors created - two different devices among the three + // don't have good support for parsing tensor constants + ConstantPropagation(graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count("Float(2) = prim::Constant", 1, /*exactly*/ true) + ->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true) + ->run(*graph); + } +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/jit/test_irparser.h b/test/cpp/jit/test_irparser.h index b12da1f..c8c3eea 100644 --- a/test/cpp/jit/test_irparser.h +++ b/test/cpp/jit/test_irparser.h @@ -39,7 +39,7 @@ static void checkRoundtrip(const std::string& s) { AT_ASSERT(original == parsed); } -void testIRParser(std::ostream& out = std::cout) { +void testIRParser() { { auto graph = std::make_shared(); script::parseIR( diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index bd55e33..ad7fa58 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -2,6 +2,7 @@ #include "test/cpp/jit/test_base.h" +#include #include "ATen/core/interned_strings.h" #include "torch/csrc/autograd/generated/variable_factories.h" #include "torch/csrc/autograd/variable.h" @@ -34,6 +35,7 @@ #include "torch/csrc/autograd/engine.h" #include "torch/csrc/autograd/variable.h" +#include #include "ATen/core/ivalue.h" #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/script/compiler.h" @@ -831,7 +833,7 @@ void testADFormulas() { } } -void testDifferentiate(std::ostream& out = std::cout) { +void testDifferentiate() { auto graph = std::make_shared(); at::ScalarType s = at::ScalarType::Float; auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); @@ -852,13 +854,19 @@ void testDifferentiate(std::ostream& out = std::cout) { ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs); ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); - out << "testDifferentiate\n"; - out << *grad_spec.f; - out << *grad_spec.df; - out << "\n"; + testing::FileCheck() + .check_count("aten::mul", 2) + ->check("aten::size") + ->check("aten::add") + ->run(*grad_spec.f); + testing::FileCheck() + .check("prim::GradOf[name=\"aten::add\"]") + ->check_count("prim::GradOf[name=\"aten::mul\"]", 2) + ->check_count("AutogradAdd", 2) + ->run(*grad_spec.df); } -void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) { +void testDifferentiateWithRequiresGrad() { // Build up a fake graph auto graph = std::make_shared(); auto a = SymbolicVariable::asNewInput(*graph); @@ -884,10 +892,17 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) { ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector({2, 3})); ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); - out << "testDifferentiateWithRequiresGrad\n"; - out << *grad_spec.f; - out << *grad_spec.df; - out << "\n"; + testing::FileCheck() + .check("aten::mul") + ->check_count("aten::add", 2) + ->check("aten::mul") + ->check("aten::size") + ->check("aten::add") + ->run(*grad_spec.f); + + testing::FileCheck() + .check_count("prim::GradOf[name=\"aten::mul\"]", 1, /*exactly*/ true) + ->run(*grad_spec.df); } void testRegisterFusionCachesKernel(std::ostream& out = std::cout) { @@ -937,11 +952,18 @@ void testRegisterFusionCachesKernel(std::ostream& out = std::cout) { ASSERT_EQ(second_key, expected_key); } -void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) { +void testCreateAutodiffSubgraphs() { auto graph = build_lstm(); CreateAutodiffSubgraphs(graph, /*threshold=*/2); - out << "testCreateAutodiffSubgraphs\n"; - out << *graph << "\n"; + // all of the ops are within the DifferentiableGraph + testing::FileCheck() + .check_not("aten::mm") + ->check_not("aten::sigmoid") + ->check_not("aten::tanh") + ->check_not("aten::mul") + ->check("DifferentiableGraph") + ->check_next("return") + ->run(*graph); } void testSubgraphUtils() { @@ -1088,11 +1110,13 @@ void testGraphExecutor() { } void testBlocks(std::ostream& out = std::cout) { - Graph g; - auto a = Var::asNewInput(g, "a"); - auto b = Var::asNewInput(g, "b"); + auto g = std::make_shared(); + // auto g = *graph; + auto a = Var::asNewInput(*g, "a"); + auto b = Var::asNewInput(*g, "b"); auto c = a + b; - auto r = g.appendNode(g.create(prim::If, {Var::asNewInput(g, "c").value()})); + auto r = + g->appendNode(g->create(prim::If, {Var::asNewInput(*g, "c").value()})); auto then_block = r->addBlock(); auto else_block = r->addBlock(); { @@ -1106,15 +1130,32 @@ void testBlocks(std::ostream& out = std::cout) { auto e = d + c; else_block->registerOutput(e.value()); } - g.registerOutput((Var(r->output()) + c).value()); - g.lint(); - out << "testBlocks\n" << g << "\n"; + g->registerOutput((Var(r->output()) + c).value()); + g->lint(); + testing::FileCheck() + .check("add") + ->check("prim::If") + ->check("block0") + ->check("aten::add") + ->check("block1") + ->check_count("aten::add", 3) + ->run(*g); r->eraseBlock(0); - out << g << "\n"; - g.lint(); + testing::FileCheck() + .check("add") + ->check("prim::If") + ->check("block0") + ->check_not("block") + ->run(*g); + g->lint(); // test recursive copy of blocks works - auto g2 = g.copy(); - out << *g2 << "\n"; + auto g2 = g->copy(); + testing::FileCheck() + .check("add") + ->check("prim::If") + ->check("block0") + ->check_not("block") + ->run(*g2); } const auto cf_examples = R"JIT( @@ -1508,9 +1549,7 @@ void testSchemaParser() { Symbol::fromQualString("alias::b"), }; const auto expectedAfter = std::unordered_set{ - Symbol::fromQualString("alias::b"), - Symbol::fromQualString("alias::c") - }; + Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")}; ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore); ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter); ASSERT_FALSE(containedAliasInfo.isWrite()); @@ -1849,7 +1888,6 @@ void testNoneSchemaMatch() { // checking that constant propagation ran wo/failure AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1); } - } // namespace } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_netdef_converter.h b/test/cpp/jit/test_netdef_converter.h index 3dba9b8..0ac7d92 100644 --- a/test/cpp/jit/test_netdef_converter.h +++ b/test/cpp/jit/test_netdef_converter.h @@ -24,7 +24,7 @@ static caffe2::OperatorDef createOperator( return op; } -void testNetDefConverter(std::ostream& out = std::cout) { +void testNetDefConverter() { { // Check a simple net conversion back and forth. diff --git a/test/expect/TestJit.test_cpp_cuda.expect b/test/expect/TestJit.test_cpp_cuda.expect deleted file mode 100644 index e15ab3c..0000000 --- a/test/expect/TestJit.test_cpp_cuda.expect +++ /dev/null @@ -1,169 +0,0 @@ -testBlocks -graph(%a : Tensor, - %b : Tensor, - %c : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : Tensor = aten::add(%a, %b, %2) - %5 : Tensor = prim::If(%c) - block0(): - %6 : int = prim::Constant[value=1]() - %7 : Tensor = aten::add(%3, %3, %6) - -> (%7) - block1(): - %8 : int = prim::Constant[value=1]() - %9 : Tensor = aten::add(%b, %3, %8) - %10 : int = prim::Constant[value=1]() - %11 : Tensor = aten::add(%9, %3, %10) - -> (%11) - %12 : int = prim::Constant[value=1]() - %13 : Tensor = aten::add(%5, %3, %12) - return (%13) - -graph(%a : Tensor, - %b : Tensor, - %c : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : Tensor = aten::add(%a, %b, %2) - %5 : Tensor = prim::If(%c) - block0(): - %8 : int = prim::Constant[value=1]() - %9 : Tensor = aten::add(%b, %3, %8) - %10 : int = prim::Constant[value=1]() - %11 : Tensor = aten::add(%9, %3, %10) - -> (%11) - %12 : int = prim::Constant[value=1]() - %13 : Tensor = aten::add(%5, %3, %12) - return (%13) - -graph(%a : Tensor, - %b : Tensor, - %c : Tensor): - %3 : int = prim::Constant[value=1]() - %4 : Tensor = aten::add(%a, %b, %3) - %5 : Tensor = prim::If(%c) - block0(): - %6 : int = prim::Constant[value=1]() - %7 : Tensor = aten::add(%b, %4, %6) - %8 : int = prim::Constant[value=1]() - %9 : Tensor = aten::add(%7, %4, %8) - -> (%9) - %10 : int = prim::Constant[value=1]() - %11 : Tensor = aten::add(%5, %4, %10) - return (%11) - -testCreateAutodiffSubgraphs -graph(%0 : Tensor, - %1 : Tensor, - %2 : Tensor, - %3 : Tensor, - %4 : Tensor): - %7 : int = prim::Constant[value=1]() - %23 : Tensor, %24 : Tensor = prim::DifferentiableGraph_0(%2, %1, %4, %0, %3) - return (%23, %24) -with prim::DifferentiableGraph_0 = graph(%13 : Tensor, - %32 : Tensor, - %33 : Tensor, - %35 : Tensor, - %36 : Tensor): - %37 : Tensor = aten::mm(%35, %36) - %34 : Tensor = aten::mm(%32, %33) - %30 : int = prim::Constant[value=1]() - %31 : Tensor = aten::add(%37, %34, %30) - %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor = prim::ConstantChunk[chunks=4, dim=1](%31) - %22 : Tensor = aten::sigmoid(%24) - %20 : Tensor = aten::sigmoid(%27) - %18 : Tensor = aten::tanh(%26) - %16 : Tensor = aten::sigmoid(%25) - %14 : Tensor = aten::mul(%16, %13) - %11 : Tensor = aten::mul(%22, %18) - %8 : Tensor = aten::add(%14, %11, %30) - %4 : Tensor = aten::tanh(%8) - %2 : Tensor = aten::mul(%20, %4) - return (%2, %8) - -testDifferentiate -graph(%0 : Float(2, 3, 4), - %1 : Float(2, 3, 4)): - %2 : Float(2, 3, 4) = aten::mul(%0, %1) - %3 : Float(2, 3, 4) = aten::mul(%2, %0) - %4 : int = prim::Constant[value=1]() - %7 : int[] = aten::size(%3) - %5 : Float(2, 3, 4) = aten::add(%3, %1, %4) - return (%5, %2, %7) -graph(%0 : Float(2, 3, 4), - %1 : Float(2, 3, 4), - %2 : Float(2, 3, 4), - %3 : Float(2, 3, 4), - %4 : Float(2, 3, 4), - %5 : int[]): - %7 : int = prim::Constant[value=1]() - %6 : int[] = aten::size(%3) - %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0) - block0(): - %10 : Tensor = aten::_grad_sum_to_size(%0, %5) - %11 : Float(2, 3, 4) = aten::mul(%0, %7) - %12 : Tensor = aten::_grad_sum_to_size(%11, %6) - -> (%10, %12) - %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8) - block0(): - %15 : Tensor = aten::mul(%8, %2) - %16 : int[] = aten::size(%4) - %grad_self.1 : Tensor = aten::_grad_sum_to_size(%15, %16) - %18 : Tensor = aten::mul(%8, %4) - %19 : int[] = aten::size(%2) - %grad_other.1 : Tensor = aten::_grad_sum_to_size(%18, %19) - -> (%grad_self.1, %grad_other.1) - %21 : Tensor = prim::AutogradAdd(%1, %grad_self.2) - %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21) - block0(): - %24 : Tensor = aten::mul(%21, %3) - %25 : int[] = aten::size(%2) - %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %25) - %27 : Tensor = aten::mul(%21, %2) - %28 : int[] = aten::size(%3) - %grad_other.3 : Tensor = aten::_grad_sum_to_size(%27, %28) - -> (%grad_self.3, %grad_other.3) - %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self) - %31 : Tensor = prim::AutogradAdd(%9, %grad_other) - return (%30, %31) - -testDifferentiateWithRequiresGrad -graph(%0 : Float(*), - %1 : Float(*)): - %2 : Float(*) = aten::mul(%1, %1) - %3 : int = prim::Constant[value=1]() - %4 : Float(*) = aten::add(%2, %1, %3) - %6 : Float(*) = aten::add(%4, %0, %3) - %7 : Float(*) = aten::mul(%6, %0) - %11 : int[] = aten::size(%7) - %9 : Float(*) = aten::add(%7, %1, %3) - return (%4, %9, %6, %11) -graph(%0 : Float(*), - %1 : Float(*), - %2 : Float(*), - %3 : Float(*), - %4 : int[]): - %6 : int = prim::Constant[value=1]() - %5 : int[] = aten::size(%2) - %7 : Tensor = prim::GradOf[name="aten::add"](%0) - block0(): - %8 : Tensor = aten::_grad_sum_to_size(%0, %4) - -> (%8) - %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7) - block0(): - %11 : Tensor = aten::mul(%7, %2) - %12 : int[] = aten::size(%3) - %grad_self.1 : Tensor = aten::_grad_sum_to_size(%11, %12) - %14 : Tensor = aten::mul(%7, %3) - %15 : int[] = aten::size(%2) - %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %15) - -> (%grad_self.1, %grad_other.1) - %17 : Tensor = prim::AutogradAdd(%1, %grad_self) - %18 : Tensor = prim::GradOf[name="aten::add"](%17) - block0(): - %19 : Tensor = aten::mul(%17, %6) - %20 : Tensor = aten::_grad_sum_to_size(%19, %5) - -> (%20) - %21 : Tensor = prim::AutogradAdd(%grad_other, %18) - return (%21) - diff --git a/test/expect/TestJit.test_python_ir.expect b/test/expect/TestJit.test_python_ir.expect deleted file mode 100644 index b7fa4c1..0000000 --- a/test/expect/TestJit.test_python_ir.expect +++ /dev/null @@ -1,9 +0,0 @@ -graph(%0 : Tensor, - %1 : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : Double(1) = aten::add(%0, %1, %2) - %4 : Double(1) = aten::mul(%0, %3) - %5 : Double(1) = aten::tanh(%4) - %6 : Double(1) = aten::sigmoid(%5) - %7 : Tensor = prim::TensorTest[a= 1 1 1 1 [ Variable[CPUDoubleType]{2,2} ]]() - return (%6) diff --git a/test/expect/TestScript.test_if_for_in_range.expect b/test/expect/TestScript.test_if_for_in_range.expect deleted file mode 100644 index a32202c..0000000 --- a/test/expect/TestScript.test_if_for_in_range.expect +++ /dev/null @@ -1,22 +0,0 @@ -graph(%a.1 : Dynamic, - %b.1 : Dynamic): - %d.1 : Long() = prim::Constant[value={3}]() - %3 : Long() = prim::Constant[value={20}]() - %4 : Byte() = prim::Constant[value={1}]() - %a : Dynamic, %d : Long(), %b : Dynamic = prim::Loop(%3, %4, %a.1, %d.1, %b.1) - block0(%_ : Dynamic, %6 : Dynamic, %10 : Long(), %14 : Dynamic): - %7 : Long() = prim::Constant[value={10}]() - %8 : Dynamic = aten::gt(%6, %7) - %a.3 : Dynamic, %b.3 : Dynamic, %d.3 : Long() = prim::If(%8) - block0(): - %9 : Long() = prim::Constant[value={3}]() - %a.2 : Dynamic = aten::add[alpha={1}](%9, %10) - -> (%a.2, %14, %10) - block1(): - %12 : Long() = prim::Constant[value={3}]() - %b.2 : Dynamic = aten::add[alpha={1}](%12, %10) - %d.2 : Long() = prim::Constant[value={4}]() - -> (%6, %b.2, %d.2) - %20 : Byte() = prim::Constant[value={1}]() - -> (%20, %a.3, %d.3, %b.3) - return (%d) diff --git a/test/expect/TestScript.test_loop_unroll_unused_counter.expect b/test/expect/TestScript.test_loop_unroll_unused_counter.expect deleted file mode 100644 index 4f66287..0000000 --- a/test/expect/TestScript.test_loop_unroll_unused_counter.expect +++ /dev/null @@ -1,26 +0,0 @@ -graph(%x : Tensor): - %1 : bool = prim::Constant[value=1]() - %y.1 : int = prim::Constant[value=0]() - %3 : int = prim::Constant[value=1]() - %4 : int = prim::Int(%x) - %5 : int = prim::Constant[value=8]() - %6 : int = aten::__round_to_zero_floordiv(%4, %5) - %7 : int = prim::Constant[value=8]() - %8 : int = aten::mul(%6, %7) - %9 : int = aten::sub(%4, %8) - %y.3 : int = prim::Loop(%6, %1, %y.1) - block0(%11 : int, %12 : int): - %y.12 : int = aten::add(%12, %3) - %y.5 : int = aten::add(%y.12, %3) - %y.6 : int = aten::add(%y.5, %3) - %y.7 : int = aten::add(%y.6, %3) - %y.8 : int = aten::add(%y.7, %3) - %y.9 : int = aten::add(%y.8, %3) - %y.10 : int = aten::add(%y.9, %3) - %y.11 : int = aten::add(%y.10, %3) - -> (%1, %y.11) - %y : int = prim::Loop(%9, %1, %y.3) - block0(%22 : int, %23 : int): - %y.4 : int = aten::add(%23, %3) - -> (%1, %y.4) - return (%y) diff --git a/test/expect/TestScript.test_math_tensor_number.expect b/test/expect/TestScript.test_math_tensor_number.expect deleted file mode 100644 index 13555dd..0000000 --- a/test/expect/TestScript.test_math_tensor_number.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%x : Tensor): - %1 : int = prim::Constant[value=1]() - %2 : int = prim::Constant[value=7]() - %3 : Tensor = aten::add(%x, %2, %1) - return (%3) diff --git a/test/expect/TestScript.test_string_cu.expect b/test/expect/TestScript.test_string_cu.expect deleted file mode 100644 index 0604991..0000000 --- a/test/expect/TestScript.test_string_cu.expect +++ /dev/null @@ -1,6 +0,0 @@ -graph(%a : Tensor): - %3 : string = prim::Constant[value="aa"]() - %1 : string = prim::Constant[value="a\n\tb\n"]() - %2 : int = prim::Constant[value=2]() - = prim::Print(%a, %1, %2, %3) - return (%a) diff --git a/test/test_jit.py b/test/test_jit.py index e59c91e..5725201 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1322,9 +1322,7 @@ class TestJit(JitTestCase): def test_cpp_cuda(self): from cpp.jit import tests_setup tests_setup.setup() - # rather than rebuild assertExpected in cpp, - # just glob all the cpp outputs into one file for now - self.assertExpected(torch._C._jit_run_cpp_tests()) + torch._C._jit_run_cpp_tests() tests_setup.shutdown() def test_batchnorm(self): @@ -3294,7 +3292,7 @@ class TestScript(JitTestCase): a") return a ''') - self.assertExpected(str(cu.foo.graph)) + FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph)) def test_string_ops(self): def foo(): @@ -3710,33 +3708,6 @@ a") y2 = torch.sum(x, dim=0) self.assertEqual(y, y2) - def test_constant_pooling(self): - def func(cond): - a = 1 - b = 4 - c = 0 - d = "abc" - e = "bcd" - f = "abc" - x = torch.ones([2]) - y = x * 4 - z = torch.ones([2]) - if bool(cond): - c = b - a - else: - y = torch.rand(0) - if bool(cond): - y = torch.rand(1) - print(d, e, f, x, y, z) - b = b - a - return a, b, c, x, y - - self.checkScript(func, torch.tensor([1])) - graph = torch.jit.script(func).graph - self.run_pass('constant_propagation', graph) - self.run_pass('constant_pooling', graph) - self.assertExpectedGraph(graph) - def test_constant_pooling_none(self): @torch.jit.script def typed_nones(a=None, b=None, c=None): diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index ad0e0d9..2fcaa68 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -80,11 +80,11 @@ bool loadPythonClasses() { } // anonymous namespace #if defined(_WIN32) -std::string runJITCPPTests() { +void runJITCPPTests() { AT_ERROR("JIT tests not yet supported on Windows"); } #else -std::string runJITCPPTests(); +void runJITCPPTests(); #endif void initJITBindings(PyObject* module) { diff --git a/torch/csrc/jit/irparser.cpp b/torch/csrc/jit/irparser.cpp index 0b64bc1..4af5c80 100644 --- a/torch/csrc/jit/irparser.cpp +++ b/torch/csrc/jit/irparser.cpp @@ -79,13 +79,8 @@ void parseIR(const std::string& str, torch::jit::Graph* graph) { } VarWithType IRParser::parseVarWithType() { - L.expect('%'); VarWithType r; - if (L.cur().kind == TK_IDENT) { - r.name = L.expect(TK_IDENT).text(); - } else { - r.name = L.expect(TK_NUMBER).text(); - } + r.name = parseVar(); r.type = TensorType::get(); if (L.nextIf(':')) { auto type_alias = type_parser.parseType(); @@ -98,9 +93,16 @@ VarWithType IRParser::parseVarWithType() { std::string IRParser::parseVar() { L.expect('%'); if (L.cur().kind == TK_IDENT) { - return L.expect(TK_IDENT).text(); + auto name = L.expect(TK_IDENT).text(); + if (L.cur().kind == TK_NUMBER) { + auto suffix = L.expect(TK_NUMBER).text(); + AT_ASSERT(suffix[0] == '.'); + name += suffix; + } + return name; + } else { + return L.expect(TK_NUMBER).text(); } - return L.expect(TK_NUMBER).text(); } void IRParser::parseOperatorOutputs(std::vector* outs) { diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index c9a83f3..79e7d97 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -269,7 +269,6 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { std::shared_ptr module_; }; - // defines how modules/methods behave inside the script subset. // for now this does not have any interaction with python. // in the future, we will add the ability to resolve `self.foo` to python @@ -774,8 +773,7 @@ void initJitScriptBindings(PyObject* module) { auto& p = parameters[i]; py::tuple r(2); result[i] = std::make_tuple( - p.key(), - autograd::as_variable_ref(p->slot()->toTensor())); + p.key(), autograd::as_variable_ref(p->slot()->toTensor())); } return result; }) @@ -789,9 +787,7 @@ void initJitScriptBindings(PyObject* module) { py::tuple r(3); IValue v = *buffer->slot(); result[i] = std::make_tuple( - buffer.key(), - buffer->type, - toPyObject(std::move(v))); + buffer.key(), buffer->type, toPyObject(std::move(v))); } return result; }) @@ -935,8 +931,7 @@ void initJitScriptBindings(PyObject* module) { std::shared_ptr orig) { std::vector member_inputs; for (auto& p : params) { - NamedIValue* np = - std::get<0>(p)->find_parameter(std::get<1>(p)); + NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p)); if (np == nullptr) { np = std::get<0>(p)->find_buffer(std::get<1>(p)); } @@ -945,8 +940,7 @@ void initJitScriptBindings(PyObject* module) { } Method* orig_method = orig->find_method(name); - m->create_method( - name, orig_method->graph()->copy(), member_inputs); + m->create_method(name, orig_method->graph()->copy(), member_inputs); }); py::class_(m, "ScriptMethod", py::dynamic_attr()) @@ -1074,7 +1068,14 @@ void initJitScriptBindings(PyObject* module) { py::arg("str"), py::arg("count"), py::arg("exactly") = false) - .def("run", &testing::FileCheck::run); + .def( + "run", + [](testing::FileCheck& f, const std::string& str) { + return f.run(str); + }) + .def("run", [](testing::FileCheck& f, const Graph& g) { + return f.run(g); + }); } } // namespace script } // namespace jit diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index d428699..3af8c5a 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -18,7 +18,7 @@ #include #include -#include +#include #include namespace torch { @@ -303,6 +303,12 @@ void FileCheck::run(const std::string& test_file) { fcImpl->run(test_file); }; +void FileCheck::run(const Graph& graph) { + std::stringstream graph_str; + graph_str << graph; + fcImpl->run(graph_str.str()); +}; + FileCheck* FileCheck::check(const std::string& str) { fcImpl->addCheck(CHECK, str); return this; diff --git a/torch/csrc/jit/testing/file_check.h b/torch/csrc/jit/testing/file_check.h index 20f987f..cf80575 100644 --- a/torch/csrc/jit/testing/file_check.h +++ b/torch/csrc/jit/testing/file_check.h @@ -5,6 +5,9 @@ namespace torch { namespace jit { + +struct Graph; + namespace testing { struct FileCheckImpl; @@ -17,6 +20,9 @@ struct FileCheck { // Run FileCheck against test string TORCH_API void run(const std::string& test_string); + // Run FileCheck against dump of graph IR + TORCH_API void run(const Graph& graph); + // Checks that the string occurs, starting at the end of the most recent match TORCH_API FileCheck* check(const std::string& str); -- 2.7.4