#include <gtest/gtest.h>
#include <test/cpp/jit/test_alias_analysis.h>
+#include <test/cpp/jit/test_constant_pooling.h>
#include <test/cpp/jit/test_irparser.h>
#include <test/cpp/jit/test_misc.h>
#include <test/cpp/jit/test_netdef_converter.h>
JIT_TEST(Wildcards)
JIT_TEST(MemoryDAG)
JIT_TEST(IRParser)
+JIT_TEST(ConstantPooling)
JIT_TEST(NetDefConverter)
#include <test/cpp/jit/test_alias_analysis.h>
+#include <test/cpp/jit/test_constant_pooling.h>
#include <test/cpp/jit/test_irparser.h>
#include <test/cpp/jit/test_misc.h>
#include <test/cpp/jit/test_netdef_converter.h>
namespace torch {
namespace jit {
-std::string runJITCPPTests() {
- std::stringstream out;
+void runJITCPPTests() {
testNoneSchemaMatch();
testAutogradProfiler();
testADFormulas();
testArgumentSpec();
testAttributes();
- testBlocks(out);
+ testBlocks();
testCodeTemplate();
testControlFlow();
- testCreateAutodiffSubgraphs(out);
+ testCreateAutodiffSubgraphs();
testCustomOperators();
- testDifferentiate(out);
- testDifferentiateWithRequiresGrad(out);
+ testDifferentiate();
+ testDifferentiateWithRequiresGrad();
testDynamicDAG();
testEvalModeForLoadedModule();
testFromQualString();
testWriteTracking();
testWildcards();
testMemoryDAG();
- testNetDefConverter(out);
- testIRParser(out);
- return out.str();
+ testNetDefConverter();
+ testIRParser();
+ testConstantPooling();
}
} // namespace jit
--- /dev/null
+#pragma once
+
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/irparser.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
+#include <torch/csrc/jit/testing/file_check.h>
+#include "test/cpp/jit/test_base.h"
+
+#include <sstream>
+#include <string>
+
+namespace torch {
+namespace jit {
+
+void testConstantPooling() {
+ {
+ auto graph = std::make_shared<Graph>();
+ script::parseIR(
+ R"IR(
+graph():
+ %8 : int = prim::Constant[value=1]()
+ %10 : int = prim::Constant[value=1]()
+ return (%8, %10)
+ )IR",
+ &*graph);
+ ConstantPooling(graph);
+ testing::FileCheck()
+ .check_count("prim::Constant", 1, /*exactly*/ true)
+ ->run(*graph);
+ }
+ {
+ auto graph = std::make_shared<Graph>();
+ script::parseIR(
+ R"IR(
+graph(%cond : Tensor):
+ %a : string = prim::Constant[value="bcd"]()
+ %3 : bool = prim::Bool(%cond)
+ %b : string = prim::If(%3)
+ block0():
+ %b.1 : string = prim::Constant[value="abc"]()
+ -> (%b.1)
+ block1():
+ %b.2 : string = prim::Constant[value="abc"]()
+ -> (%b.2)
+ %7 : (string, string) = prim::TupleConstruct(%a, %b)
+ return (%7)
+ )IR",
+ &*graph);
+ ConstantPooling(graph);
+ testing::FileCheck()
+ .check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
+ ->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
+ ->run(*graph);
+ }
+ {
+ auto graph = std::make_shared<Graph>();
+ script::parseIR(
+ R"IR(
+graph():
+ %2 : int = prim::Constant[value=2]()
+ %1 : int = prim::Constant[value=1]()
+ %5 : int? = prim::Constant()
+ %7 : Device? = prim::Constant()
+ %10 : int = prim::Constant[value=6]()
+ %3 : int[] = prim::ListConstruct(%1, %2)
+ %x : Tensor = aten::tensor(%3, %5, %7)
+ %y : Tensor = aten::tensor(%3, %10, %7)
+ %9 : int[] = prim::ListConstruct(%1, %2)
+ %z : Tensor = aten::tensor(%9, %10, %7)
+ %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
+ return (%14)
+ )IR",
+ &*graph);
+ // three tensors created - two different devices among the three
+ // don't have good support for parsing tensor constants
+ ConstantPropagation(graph);
+ ConstantPooling(graph);
+ testing::FileCheck()
+ .check_count("Float(2) = prim::Constant", 1, /*exactly*/ true)
+ ->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true)
+ ->run(*graph);
+ }
+}
+
+} // namespace jit
+} // namespace torch
AT_ASSERT(original == parsed);
}
-void testIRParser(std::ostream& out = std::cout) {
+void testIRParser() {
{
auto graph = std::make_shared<Graph>();
script::parseIR(
#include "test/cpp/jit/test_base.h"
+#include <torch/csrc/jit/passes/canonicalize.h>
#include "ATen/core/interned_strings.h"
#include "torch/csrc/autograd/generated/variable_factories.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/engine.h"
#include "torch/csrc/autograd/variable.h"
+#include <torch/csrc/jit/testing/file_check.h>
#include "ATen/core/ivalue.h"
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/script/compiler.h"
}
}
-void testDifferentiate(std::ostream& out = std::cout) {
+void testDifferentiate() {
auto graph = std::make_shared<Graph>();
at::ScalarType s = at::ScalarType::Float;
auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs);
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
- out << "testDifferentiate\n";
- out << *grad_spec.f;
- out << *grad_spec.df;
- out << "\n";
+ testing::FileCheck()
+ .check_count("aten::mul", 2)
+ ->check("aten::size")
+ ->check("aten::add")
+ ->run(*grad_spec.f);
+ testing::FileCheck()
+ .check("prim::GradOf[name=\"aten::add\"]")
+ ->check_count("prim::GradOf[name=\"aten::mul\"]", 2)
+ ->check_count("AutogradAdd", 2)
+ ->run(*grad_spec.df);
}
-void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
+void testDifferentiateWithRequiresGrad() {
// Build up a fake graph
auto graph = std::make_shared<Graph>();
auto a = SymbolicVariable::asNewInput(*graph);
ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
- out << "testDifferentiateWithRequiresGrad\n";
- out << *grad_spec.f;
- out << *grad_spec.df;
- out << "\n";
+ testing::FileCheck()
+ .check("aten::mul")
+ ->check_count("aten::add", 2)
+ ->check("aten::mul")
+ ->check("aten::size")
+ ->check("aten::add")
+ ->run(*grad_spec.f);
+
+ testing::FileCheck()
+ .check_count("prim::GradOf[name=\"aten::mul\"]", 1, /*exactly*/ true)
+ ->run(*grad_spec.df);
}
void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
ASSERT_EQ(second_key, expected_key);
}
-void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) {
+void testCreateAutodiffSubgraphs() {
auto graph = build_lstm();
CreateAutodiffSubgraphs(graph, /*threshold=*/2);
- out << "testCreateAutodiffSubgraphs\n";
- out << *graph << "\n";
+ // all of the ops are within the DifferentiableGraph
+ testing::FileCheck()
+ .check_not("aten::mm")
+ ->check_not("aten::sigmoid")
+ ->check_not("aten::tanh")
+ ->check_not("aten::mul")
+ ->check("DifferentiableGraph")
+ ->check_next("return")
+ ->run(*graph);
}
void testSubgraphUtils() {
}
void testBlocks(std::ostream& out = std::cout) {
- Graph g;
- auto a = Var::asNewInput(g, "a");
- auto b = Var::asNewInput(g, "b");
+ auto g = std::make_shared<Graph>();
+ // auto g = *graph;
+ auto a = Var::asNewInput(*g, "a");
+ auto b = Var::asNewInput(*g, "b");
auto c = a + b;
- auto r = g.appendNode(g.create(prim::If, {Var::asNewInput(g, "c").value()}));
+ auto r =
+ g->appendNode(g->create(prim::If, {Var::asNewInput(*g, "c").value()}));
auto then_block = r->addBlock();
auto else_block = r->addBlock();
{
auto e = d + c;
else_block->registerOutput(e.value());
}
- g.registerOutput((Var(r->output()) + c).value());
- g.lint();
- out << "testBlocks\n" << g << "\n";
+ g->registerOutput((Var(r->output()) + c).value());
+ g->lint();
+ testing::FileCheck()
+ .check("add")
+ ->check("prim::If")
+ ->check("block0")
+ ->check("aten::add")
+ ->check("block1")
+ ->check_count("aten::add", 3)
+ ->run(*g);
r->eraseBlock(0);
- out << g << "\n";
- g.lint();
+ testing::FileCheck()
+ .check("add")
+ ->check("prim::If")
+ ->check("block0")
+ ->check_not("block")
+ ->run(*g);
+ g->lint();
// test recursive copy of blocks works
- auto g2 = g.copy();
- out << *g2 << "\n";
+ auto g2 = g->copy();
+ testing::FileCheck()
+ .check("add")
+ ->check("prim::If")
+ ->check("block0")
+ ->check_not("block")
+ ->run(*g2);
}
const auto cf_examples = R"JIT(
Symbol::fromQualString("alias::b"),
};
const auto expectedAfter = std::unordered_set<Symbol>{
- Symbol::fromQualString("alias::b"),
- Symbol::fromQualString("alias::c")
- };
+ Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
ASSERT_FALSE(containedAliasInfo.isWrite());
// checking that constant propagation ran wo/failure
AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
}
-
} // namespace
} // namespace jit
} // namespace torch
return op;
}
-void testNetDefConverter(std::ostream& out = std::cout) {
+void testNetDefConverter() {
{
// Check a simple net conversion back and forth.
+++ /dev/null
-testBlocks
-graph(%a : Tensor,
- %b : Tensor,
- %c : Tensor):
- %2 : int = prim::Constant[value=1]()
- %3 : Tensor = aten::add(%a, %b, %2)
- %5 : Tensor = prim::If(%c)
- block0():
- %6 : int = prim::Constant[value=1]()
- %7 : Tensor = aten::add(%3, %3, %6)
- -> (%7)
- block1():
- %8 : int = prim::Constant[value=1]()
- %9 : Tensor = aten::add(%b, %3, %8)
- %10 : int = prim::Constant[value=1]()
- %11 : Tensor = aten::add(%9, %3, %10)
- -> (%11)
- %12 : int = prim::Constant[value=1]()
- %13 : Tensor = aten::add(%5, %3, %12)
- return (%13)
-
-graph(%a : Tensor,
- %b : Tensor,
- %c : Tensor):
- %2 : int = prim::Constant[value=1]()
- %3 : Tensor = aten::add(%a, %b, %2)
- %5 : Tensor = prim::If(%c)
- block0():
- %8 : int = prim::Constant[value=1]()
- %9 : Tensor = aten::add(%b, %3, %8)
- %10 : int = prim::Constant[value=1]()
- %11 : Tensor = aten::add(%9, %3, %10)
- -> (%11)
- %12 : int = prim::Constant[value=1]()
- %13 : Tensor = aten::add(%5, %3, %12)
- return (%13)
-
-graph(%a : Tensor,
- %b : Tensor,
- %c : Tensor):
- %3 : int = prim::Constant[value=1]()
- %4 : Tensor = aten::add(%a, %b, %3)
- %5 : Tensor = prim::If(%c)
- block0():
- %6 : int = prim::Constant[value=1]()
- %7 : Tensor = aten::add(%b, %4, %6)
- %8 : int = prim::Constant[value=1]()
- %9 : Tensor = aten::add(%7, %4, %8)
- -> (%9)
- %10 : int = prim::Constant[value=1]()
- %11 : Tensor = aten::add(%5, %4, %10)
- return (%11)
-
-testCreateAutodiffSubgraphs
-graph(%0 : Tensor,
- %1 : Tensor,
- %2 : Tensor,
- %3 : Tensor,
- %4 : Tensor):
- %7 : int = prim::Constant[value=1]()
- %23 : Tensor, %24 : Tensor = prim::DifferentiableGraph_0(%2, %1, %4, %0, %3)
- return (%23, %24)
-with prim::DifferentiableGraph_0 = graph(%13 : Tensor,
- %32 : Tensor,
- %33 : Tensor,
- %35 : Tensor,
- %36 : Tensor):
- %37 : Tensor = aten::mm(%35, %36)
- %34 : Tensor = aten::mm(%32, %33)
- %30 : int = prim::Constant[value=1]()
- %31 : Tensor = aten::add(%37, %34, %30)
- %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor = prim::ConstantChunk[chunks=4, dim=1](%31)
- %22 : Tensor = aten::sigmoid(%24)
- %20 : Tensor = aten::sigmoid(%27)
- %18 : Tensor = aten::tanh(%26)
- %16 : Tensor = aten::sigmoid(%25)
- %14 : Tensor = aten::mul(%16, %13)
- %11 : Tensor = aten::mul(%22, %18)
- %8 : Tensor = aten::add(%14, %11, %30)
- %4 : Tensor = aten::tanh(%8)
- %2 : Tensor = aten::mul(%20, %4)
- return (%2, %8)
-
-testDifferentiate
-graph(%0 : Float(2, 3, 4),
- %1 : Float(2, 3, 4)):
- %2 : Float(2, 3, 4) = aten::mul(%0, %1)
- %3 : Float(2, 3, 4) = aten::mul(%2, %0)
- %4 : int = prim::Constant[value=1]()
- %7 : int[] = aten::size(%3)
- %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
- return (%5, %2, %7)
-graph(%0 : Float(2, 3, 4),
- %1 : Float(2, 3, 4),
- %2 : Float(2, 3, 4),
- %3 : Float(2, 3, 4),
- %4 : Float(2, 3, 4),
- %5 : int[]):
- %7 : int = prim::Constant[value=1]()
- %6 : int[] = aten::size(%3)
- %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0)
- block0():
- %10 : Tensor = aten::_grad_sum_to_size(%0, %5)
- %11 : Float(2, 3, 4) = aten::mul(%0, %7)
- %12 : Tensor = aten::_grad_sum_to_size(%11, %6)
- -> (%10, %12)
- %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8)
- block0():
- %15 : Tensor = aten::mul(%8, %2)
- %16 : int[] = aten::size(%4)
- %grad_self.1 : Tensor = aten::_grad_sum_to_size(%15, %16)
- %18 : Tensor = aten::mul(%8, %4)
- %19 : int[] = aten::size(%2)
- %grad_other.1 : Tensor = aten::_grad_sum_to_size(%18, %19)
- -> (%grad_self.1, %grad_other.1)
- %21 : Tensor = prim::AutogradAdd(%1, %grad_self.2)
- %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21)
- block0():
- %24 : Tensor = aten::mul(%21, %3)
- %25 : int[] = aten::size(%2)
- %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %25)
- %27 : Tensor = aten::mul(%21, %2)
- %28 : int[] = aten::size(%3)
- %grad_other.3 : Tensor = aten::_grad_sum_to_size(%27, %28)
- -> (%grad_self.3, %grad_other.3)
- %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
- %31 : Tensor = prim::AutogradAdd(%9, %grad_other)
- return (%30, %31)
-
-testDifferentiateWithRequiresGrad
-graph(%0 : Float(*),
- %1 : Float(*)):
- %2 : Float(*) = aten::mul(%1, %1)
- %3 : int = prim::Constant[value=1]()
- %4 : Float(*) = aten::add(%2, %1, %3)
- %6 : Float(*) = aten::add(%4, %0, %3)
- %7 : Float(*) = aten::mul(%6, %0)
- %11 : int[] = aten::size(%7)
- %9 : Float(*) = aten::add(%7, %1, %3)
- return (%4, %9, %6, %11)
-graph(%0 : Float(*),
- %1 : Float(*),
- %2 : Float(*),
- %3 : Float(*),
- %4 : int[]):
- %6 : int = prim::Constant[value=1]()
- %5 : int[] = aten::size(%2)
- %7 : Tensor = prim::GradOf[name="aten::add"](%0)
- block0():
- %8 : Tensor = aten::_grad_sum_to_size(%0, %4)
- -> (%8)
- %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7)
- block0():
- %11 : Tensor = aten::mul(%7, %2)
- %12 : int[] = aten::size(%3)
- %grad_self.1 : Tensor = aten::_grad_sum_to_size(%11, %12)
- %14 : Tensor = aten::mul(%7, %3)
- %15 : int[] = aten::size(%2)
- %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %15)
- -> (%grad_self.1, %grad_other.1)
- %17 : Tensor = prim::AutogradAdd(%1, %grad_self)
- %18 : Tensor = prim::GradOf[name="aten::add"](%17)
- block0():
- %19 : Tensor = aten::mul(%17, %6)
- %20 : Tensor = aten::_grad_sum_to_size(%19, %5)
- -> (%20)
- %21 : Tensor = prim::AutogradAdd(%grad_other, %18)
- return (%21)
-
+++ /dev/null
-graph(%0 : Tensor,
- %1 : Tensor):
- %2 : int = prim::Constant[value=1]()
- %3 : Double(1) = aten::add(%0, %1, %2)
- %4 : Double(1) = aten::mul(%0, %3)
- %5 : Double(1) = aten::tanh(%4)
- %6 : Double(1) = aten::sigmoid(%5)
- %7 : Tensor = prim::TensorTest[a= 1 1 1 1 [ Variable[CPUDoubleType]{2,2} ]]()
- return (%6)
+++ /dev/null
-graph(%cond : Tensor):
- %1 : int[] = prim::Constant[value=[1]]()
- %2 : int[] = prim::Constant[value=[0]]()
- %3 : int = prim::Constant[value=3]()
- %4 : Float(2) = prim::Constant[value= 4 4 [ Variable[CPUFloatType]{2} ]]()
- %5 : Float(2) = prim::Constant[value= 1 1 [ Variable[CPUFloatType]{2} ]]()
- %c.1 : int = prim::Constant[value=0]()
- %a : int = prim::Constant[value=1]()
- %d : string = prim::Constant[value="abc"]()
- %e : string = prim::Constant[value="bcd"]()
- %10 : int = prim::Constant[value=6]()
- %11 : Device = prim::Constant[value="cpu"]()
- %12 : bool = prim::Bool(%cond)
- %c : int, %y : Tensor = prim::If(%12)
- block0():
- -> (%3, %4)
- block1():
- %y.2 : Tensor = aten::rand(%2, %10, %c.1, %11)
- %16 : bool = prim::Bool(%cond)
- %y.4 : Tensor = prim::If(%16)
- block0():
- %y.3 : Tensor = aten::rand(%1, %10, %c.1, %11)
- -> (%y.3)
- block1():
- -> (%y.2)
- = prim::Print(%d, %e, %d, %5, %y.4, %5)
- -> (%c.1, %y.4)
- %19 : (int, int, int, Tensor, Tensor) = prim::TupleConstruct(%a, %3, %c, %5, %y)
- return (%19)
+++ /dev/null
-graph(%a.1 : Dynamic,
- %b.1 : Dynamic):
- %d.1 : Long() = prim::Constant[value={3}]()
- %3 : Long() = prim::Constant[value={20}]()
- %4 : Byte() = prim::Constant[value={1}]()
- %a : Dynamic, %d : Long(), %b : Dynamic = prim::Loop(%3, %4, %a.1, %d.1, %b.1)
- block0(%_ : Dynamic, %6 : Dynamic, %10 : Long(), %14 : Dynamic):
- %7 : Long() = prim::Constant[value={10}]()
- %8 : Dynamic = aten::gt(%6, %7)
- %a.3 : Dynamic, %b.3 : Dynamic, %d.3 : Long() = prim::If(%8)
- block0():
- %9 : Long() = prim::Constant[value={3}]()
- %a.2 : Dynamic = aten::add[alpha={1}](%9, %10)
- -> (%a.2, %14, %10)
- block1():
- %12 : Long() = prim::Constant[value={3}]()
- %b.2 : Dynamic = aten::add[alpha={1}](%12, %10)
- %d.2 : Long() = prim::Constant[value={4}]()
- -> (%6, %b.2, %d.2)
- %20 : Byte() = prim::Constant[value={1}]()
- -> (%20, %a.3, %d.3, %b.3)
- return (%d)
+++ /dev/null
-graph(%x : Tensor):
- %1 : bool = prim::Constant[value=1]()
- %y.1 : int = prim::Constant[value=0]()
- %3 : int = prim::Constant[value=1]()
- %4 : int = prim::Int(%x)
- %5 : int = prim::Constant[value=8]()
- %6 : int = aten::__round_to_zero_floordiv(%4, %5)
- %7 : int = prim::Constant[value=8]()
- %8 : int = aten::mul(%6, %7)
- %9 : int = aten::sub(%4, %8)
- %y.3 : int = prim::Loop(%6, %1, %y.1)
- block0(%11 : int, %12 : int):
- %y.12 : int = aten::add(%12, %3)
- %y.5 : int = aten::add(%y.12, %3)
- %y.6 : int = aten::add(%y.5, %3)
- %y.7 : int = aten::add(%y.6, %3)
- %y.8 : int = aten::add(%y.7, %3)
- %y.9 : int = aten::add(%y.8, %3)
- %y.10 : int = aten::add(%y.9, %3)
- %y.11 : int = aten::add(%y.10, %3)
- -> (%1, %y.11)
- %y : int = prim::Loop(%9, %1, %y.3)
- block0(%22 : int, %23 : int):
- %y.4 : int = aten::add(%23, %3)
- -> (%1, %y.4)
- return (%y)
+++ /dev/null
-graph(%x : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : int = prim::Constant[value=7]()
- %3 : Tensor = aten::add(%x, %2, %1)
- return (%3)
+++ /dev/null
-graph(%a : Tensor):
- %3 : string = prim::Constant[value="aa"]()
- %1 : string = prim::Constant[value="a\n\tb\n"]()
- %2 : int = prim::Constant[value=2]()
- = prim::Print(%a, %1, %2, %3)
- return (%a)
def test_cpp_cuda(self):
from cpp.jit import tests_setup
tests_setup.setup()
- # rather than rebuild assertExpected in cpp,
- # just glob all the cpp outputs into one file for now
- self.assertExpected(torch._C._jit_run_cpp_tests())
+ torch._C._jit_run_cpp_tests()
tests_setup.shutdown()
def test_batchnorm(self):
a")
return a
''')
- self.assertExpected(str(cu.foo.graph))
+ FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph))
def test_string_ops(self):
def foo():
y2 = torch.sum(x, dim=0)
self.assertEqual(y, y2)
- def test_constant_pooling(self):
- def func(cond):
- a = 1
- b = 4
- c = 0
- d = "abc"
- e = "bcd"
- f = "abc"
- x = torch.ones([2])
- y = x * 4
- z = torch.ones([2])
- if bool(cond):
- c = b - a
- else:
- y = torch.rand(0)
- if bool(cond):
- y = torch.rand(1)
- print(d, e, f, x, y, z)
- b = b - a
- return a, b, c, x, y
-
- self.checkScript(func, torch.tensor([1]))
- graph = torch.jit.script(func).graph
- self.run_pass('constant_propagation', graph)
- self.run_pass('constant_pooling', graph)
- self.assertExpectedGraph(graph)
-
def test_constant_pooling_none(self):
@torch.jit.script
def typed_nones(a=None, b=None, c=None):
} // anonymous namespace
#if defined(_WIN32)
-std::string runJITCPPTests() {
+void runJITCPPTests() {
AT_ERROR("JIT tests not yet supported on Windows");
}
#else
-std::string runJITCPPTests();
+void runJITCPPTests();
#endif
void initJITBindings(PyObject* module) {
}
VarWithType IRParser::parseVarWithType() {
- L.expect('%');
VarWithType r;
- if (L.cur().kind == TK_IDENT) {
- r.name = L.expect(TK_IDENT).text();
- } else {
- r.name = L.expect(TK_NUMBER).text();
- }
+ r.name = parseVar();
r.type = TensorType::get();
if (L.nextIf(':')) {
auto type_alias = type_parser.parseType();
std::string IRParser::parseVar() {
L.expect('%');
if (L.cur().kind == TK_IDENT) {
- return L.expect(TK_IDENT).text();
+ auto name = L.expect(TK_IDENT).text();
+ if (L.cur().kind == TK_NUMBER) {
+ auto suffix = L.expect(TK_NUMBER).text();
+ AT_ASSERT(suffix[0] == '.');
+ name += suffix;
+ }
+ return name;
+ } else {
+ return L.expect(TK_NUMBER).text();
}
- return L.expect(TK_NUMBER).text();
}
void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
std::shared_ptr<Module> module_;
};
-
// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
auto& p = parameters[i];
py::tuple r(2);
result[i] = std::make_tuple(
- p.key(),
- autograd::as_variable_ref(p->slot()->toTensor()));
+ p.key(), autograd::as_variable_ref(p->slot()->toTensor()));
}
return result;
})
py::tuple r(3);
IValue v = *buffer->slot();
result[i] = std::make_tuple(
- buffer.key(),
- buffer->type,
- toPyObject(std::move(v)));
+ buffer.key(), buffer->type, toPyObject(std::move(v)));
}
return result;
})
std::shared_ptr<Module> orig) {
std::vector<IValue*> member_inputs;
for (auto& p : params) {
- NamedIValue* np =
- std::get<0>(p)->find_parameter(std::get<1>(p));
+ NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p));
if (np == nullptr) {
np = std::get<0>(p)->find_buffer(std::get<1>(p));
}
}
Method* orig_method = orig->find_method(name);
- m->create_method(
- name, orig_method->graph()->copy(), member_inputs);
+ m->create_method(name, orig_method->graph()->copy(), member_inputs);
});
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
py::arg("str"),
py::arg("count"),
py::arg("exactly") = false)
- .def("run", &testing::FileCheck::run);
+ .def(
+ "run",
+ [](testing::FileCheck& f, const std::string& str) {
+ return f.run(str);
+ })
+ .def("run", [](testing::FileCheck& f, const Graph& g) {
+ return f.run(g);
+ });
}
} // namespace script
} // namespace jit
#include <sstream>
#include <string>
-#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
fcImpl->run(test_file);
};
+void FileCheck::run(const Graph& graph) {
+ std::stringstream graph_str;
+ graph_str << graph;
+ fcImpl->run(graph_str.str());
+};
+
FileCheck* FileCheck::check(const std::string& str) {
fcImpl->addCheck(CHECK, str);
return this;
namespace torch {
namespace jit {
+
+struct Graph;
+
namespace testing {
struct FileCheckImpl;
// Run FileCheck against test string
TORCH_API void run(const std::string& test_string);
+ // Run FileCheck against dump of graph IR
+ TORCH_API void run(const Graph& graph);
+
// Checks that the string occurs, starting at the end of the most recent match
TORCH_API FileCheck* check(const std::string& str);