From dd96c26066fd8e31dc768002e207477c38f86b7a Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Tue, 24 Aug 2021 00:29:22 -0700 Subject: [PATCH] [TensorExpr] More NFC changes like Expr* -> ExprPtr. (#63778) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63778 This is a preparation for a switch from raw pointers to shared pointers as a memory model for TE expressions and statements. Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D30487425 Pulled By: ZolotukhinM fbshipit-source-id: 9cbe817b7d4e5fc2f150b29bb9b3bf578868f20c --- benchmarks/cpp/tensorexpr/bench_approx.cpp | 26 +++---- benchmarks/cpp/tensorexpr/bench_batchnorm.cpp | 4 +- benchmarks/cpp/tensorexpr/bench_compile.cpp | 4 +- benchmarks/cpp/tensorexpr/bench_concat.cpp | 42 +++++++----- benchmarks/cpp/tensorexpr/bench_gemm.cpp | 80 +++++++++++----------- benchmarks/cpp/tensorexpr/bench_parallel.cpp | 4 +- benchmarks/cpp/tensorexpr/bench_reduce.cpp | 32 ++++----- test/cpp/tensorexpr/test_llvm.cpp | 2 +- test/cpp/tensorexpr/test_loopnest.cpp | 70 +++++++++---------- test/cpp/tensorexpr/test_memdependency.cpp | 6 +- test/cpp/tensorexpr/test_reductions.cpp | 24 +++---- test/cpp/tensorexpr/tutorial.cpp | 4 +- torch/csrc/jit/tensorexpr/cuda_codegen.cpp | 29 ++++---- torch/csrc/jit/tensorexpr/half_support.h | 3 +- torch/csrc/jit/tensorexpr/mem_dependency_checker.h | 4 +- torch/csrc/jit/tensorexpr/operators/reduction.h | 6 +- 16 files changed, 172 insertions(+), 168 deletions(-) diff --git a/benchmarks/cpp/tensorexpr/bench_approx.cpp b/benchmarks/cpp/tensorexpr/bench_approx.cpp index 1f09b1d..6e31697 100644 --- a/benchmarks/cpp/tensorexpr/bench_approx.cpp +++ b/benchmarks/cpp/tensorexpr/bench_approx.cpp @@ -12,19 +12,19 @@ using namespace torch::jit::tensorexpr; void vectorize(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target, int width) { auto loops = ln->getLoopStmtsFor(target); - For *inner, *tail; + ForPtr inner, tail; ln->splitWithTail(loops[0], width, &inner, &tail); ln->vectorize(inner); } void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target) { - std::vector loops = ln->getLoopStmtsFor(target); - For *inner, *tail; + std::vector loops = ln->getLoopStmtsFor(target); + ForPtr inner, tail; ln->splitWithTail(loops[0], 16 * 8, &inner, &tail); - For* outer = loops[0]; + ForPtr outer = loops[0]; ln->vectorize(inner); ln->splitWithTail(outer, 8, &inner, &tail); - Stmt* unrolled; + StmtPtr unrolled; LoopNest::unroll(inner, &unrolled); } @@ -44,7 +44,7 @@ static void relu_nnc(benchmark::State& state) { LoopNest ln({B}); optimizePointwise(&ln, B); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -74,7 +74,7 @@ static void log_nnc_sleef(benchmark::State& state) { LoopNest ln({B}); ln.prepareForCodegen(); vectorize(&ln, B, 8); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -104,7 +104,7 @@ static void log_nnc_fast(benchmark::State& state) { LoopNest ln({B}); optimizePointwise(&ln, B); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -134,7 +134,7 @@ static void log_nnc_vml(benchmark::State& state) { LoopNest ln({B}); vectorize(&ln, B, 8); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -181,7 +181,7 @@ static void logit_nnc_sleef(benchmark::State& state) { LoopNest ln({B}); ln.prepareForCodegen(); optimizePointwise(&ln, B); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -218,7 +218,7 @@ static void logit_nnc_fast(benchmark::State& state) { LoopNest ln({B}); ln.prepareForCodegen(); optimizePointwise(&ln, B); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -255,7 +255,7 @@ static void logit_nnc_vml(benchmark::State& state) { LoopNest ln({B}); ln.prepareForCodegen(); vectorize(&ln, B, 16); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -326,7 +326,7 @@ static void tanh_nnc_fast(benchmark::State& state) { LoopNest ln({B}); optimizePointwise(&ln, B); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); diff --git a/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp b/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp index 434cd6b..872594e 100644 --- a/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp +++ b/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp @@ -105,7 +105,7 @@ BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) { loops = nest.getLoopStmtsFor(output); loops[0]->set_parallel(); nest.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); LLVMCodeGen cg(s, {input, weight, bias, mean, var, output, eps}); std::vector args; @@ -163,7 +163,7 @@ BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) { }); LoopNest nest({output}); nest.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); LLVMCodeGen cg(s, {input, weight, bias, mean, var, output, eps}); std::vector args; diff --git a/benchmarks/cpp/tensorexpr/bench_compile.cpp b/benchmarks/cpp/tensorexpr/bench_compile.cpp index cc84e65..245d5d8 100644 --- a/benchmarks/cpp/tensorexpr/bench_compile.cpp +++ b/benchmarks/cpp/tensorexpr/bench_compile.cpp @@ -33,7 +33,7 @@ static void BM_CompileSwish(benchmark::State& state) { nest.computeInline(tensor->buf()); } nest.prepareForCodegen(); - te::Stmt* s = te::IRSimplifier::simplify(nest.root_stmt()); + te::StmtPtr s = te::IRSimplifier::simplify(nest.root_stmt()); te::LLVMCodeGen cg(s, {A, sixth, n}); } } @@ -63,7 +63,7 @@ static void BM_CompileSwishLLVMOnly(benchmark::State& state) { nest.computeInline(tensor->buf()); } nest.prepareForCodegen(); - te::Stmt* s = te::IRSimplifier::simplify(nest.root_stmt()); + te::StmtPtr s = te::IRSimplifier::simplify(nest.root_stmt()); for (auto _ : state) { te::LLVMCodeGen cg(s, {A, sixth, n}); } diff --git a/benchmarks/cpp/tensorexpr/bench_concat.cpp b/benchmarks/cpp/tensorexpr/bench_concat.cpp index a437967..cb9aa84 100644 --- a/benchmarks/cpp/tensorexpr/bench_concat.cpp +++ b/benchmarks/cpp/tensorexpr/bench_concat.cpp @@ -83,7 +83,7 @@ class ConcatBench : public benchmark::Fixture { }); LoopNest nest({output}); nest.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); std::vector buf_args(inputs.begin(), inputs.end()); buf_args.push_back(output); LLVMCodeGen cg(s, buf_args); @@ -108,47 +108,51 @@ class ConcatBench : public benchmark::Fixture { TORCH_INTERNAL_ASSERT(concat_dim_ == 1); - auto output_buf = new Buf( - new Var("aten_cat", kHandle), - {new IntImm(output_size_[0]), new IntImm(output_size_[1])}, + auto output_buf = alloc( + alloc("aten_cat", kHandle), + std::vector( + {alloc(output_size_[0]), alloc(output_size_[1])}), kFloat); std::vector inputs; - std::vector for_stmts(num_inputs); + std::vector for_stmts(num_inputs); int cumulative_input_sizes = 0; for (size_t i = 0; i < num_inputs; ++i) { inputs.emplace_back(Placeholder( "input" + std::to_string(i), kFloat, {input_sizes_[i][0], input_sizes_[i][1]})); - std::vector for_vars(num_inputs); + std::vector for_vars(num_inputs); for (size_t d = 0; d < num_dims; ++d) { for_vars[d] = - new Var("i" + std::to_string(i) + "_" + std::to_string(d), kInt); + alloc("i" + std::to_string(i) + "_" + std::to_string(d), kInt); } - auto store = new Store( + auto store = alloc( output_buf, - {for_vars[0], - new Add(for_vars[1], new IntImm(cumulative_input_sizes))}, - new Load(inputs[i].data(), {for_vars[0], for_vars[1]})); - auto for_st = new For( + std::vector( + {for_vars[0], + alloc(for_vars[1], alloc(cumulative_input_sizes))}), + alloc( + inputs[i].data(), + std::vector({for_vars[0], for_vars[1]}))); + auto for_st = alloc( for_vars[0], - new IntImm(0), - new IntImm(input_sizes_[i][0]), - new For( + alloc(0), + alloc(input_sizes_[i][0]), + alloc( for_vars[1], - new IntImm(0), - new IntImm(input_sizes_[i][1]), + alloc(0), + alloc(input_sizes_[i][1]), store)); for_stmts[i] = for_st; cumulative_input_sizes += input_sizes_[i][1]; } - auto output = new Tensor(output_buf, new Block(for_stmts)); + auto output = new Tensor(output_buf, alloc(for_stmts)); LoopNest nest({output}); nest.prepareForCodegen(); nest.vectorizeInnerLoops(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); std::vector buf_args(inputs.begin(), inputs.end()); buf_args.push_back(output); LLVMCodeGen cg(s, buf_args); diff --git a/benchmarks/cpp/tensorexpr/bench_gemm.cpp b/benchmarks/cpp/tensorexpr/bench_gemm.cpp index 792d457..7ebaa87 100644 --- a/benchmarks/cpp/tensorexpr/bench_gemm.cpp +++ b/benchmarks/cpp/tensorexpr/bench_gemm.cpp @@ -54,7 +54,7 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) { {{K, "K"}}); te::LoopNest loop({CT}); loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); @@ -80,41 +80,41 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* m = loops[0]; + te::ForPtr m = loops[0]; loop.splitWithMask(m, 32); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* n = loops[2]; + te::ForPtr n = loops[2]; loop.splitWithMask(n, 32); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[1]; - te::For* no = loops[2]; + te::ForPtr mi = loops[1]; + te::ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* ni = loops[3]; - te::For* k = loops[4]; + te::ForPtr ni = loops[3]; + te::ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[2]; - te::For* k = loops[3]; + te::ForPtr mi = loops[2]; + te::ForPtr k = loops[3]; loop.reorderAxis(mi, k); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); @@ -140,41 +140,41 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* m = loops[0]; + te::ForPtr m = loops[0]; loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* n = loops[2]; + te::ForPtr n = loops[2]; loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[1]; - te::For* no = loops[2]; + te::ForPtr mi = loops[1]; + te::ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* ni = loops[3]; - te::For* k = loops[4]; + te::ForPtr ni = loops[3]; + te::ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[2]; - te::For* k = loops[3]; + te::ForPtr mi = loops[2]; + te::ForPtr k = loops[3]; loop.reorderAxis(mi, k); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); @@ -200,49 +200,49 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* m = loops[0]; + te::ForPtr m = loops[0]; loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* n = loops[2]; + te::ForPtr n = loops[2]; loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[1]; - te::For* no = loops[2]; + te::ForPtr mi = loops[1]; + te::ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* ni = loops[3]; - te::For* k = loops[4]; + te::ForPtr ni = loops[3]; + te::ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[2]; - te::For* k = loops[3]; + te::ForPtr mi = loops[2]; + te::ForPtr k = loops[3]; loop.reorderAxis(mi, k); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[3]; - te::For* ni = loops[4]; - te::Stmt* unrolled; + te::ForPtr mi = loops[3]; + te::ForPtr ni = loops[4]; + te::StmtPtr unrolled; loop.vectorize(ni); loop.unroll(mi, &unrolled); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); @@ -268,36 +268,36 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* m = loops[0]; + te::ForPtr m = loops[0]; loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* n = loops[2]; + te::ForPtr n = loops[2]; loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[1]; - te::For* no = loops[2]; + te::ForPtr mi = loops[1]; + te::ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* ni = loops[3]; - te::For* k = loops[4]; + te::ForPtr ni = loops[3]; + te::ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[2]; - te::For* k = loops[3]; + te::ForPtr mi = loops[2]; + te::ForPtr k = loops[3]; loop.reorderAxis(mi, k); } { @@ -306,7 +306,7 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) { } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); diff --git a/benchmarks/cpp/tensorexpr/bench_parallel.cpp b/benchmarks/cpp/tensorexpr/bench_parallel.cpp index fee326c..966c9e2 100644 --- a/benchmarks/cpp/tensorexpr/bench_parallel.cpp +++ b/benchmarks/cpp/tensorexpr/bench_parallel.cpp @@ -44,10 +44,10 @@ BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) { }); LoopNest loop_nest({c_tensor}); auto const& loops = loop_nest.getLoopStmtsFor(c_tensor); - For* m = loops[0]; + ForPtr m = loops[0]; m->set_parallel(); loop_nest.prepareForCodegen(); - Stmt* stmt = loop_nest.root_stmt(); + StmtPtr stmt = loop_nest.root_stmt(); LLVMCodeGen cg(stmt, {c_tensor, a_buf, b_buf}); float* a_ptr = A.data_ptr(); diff --git a/benchmarks/cpp/tensorexpr/bench_reduce.cpp b/benchmarks/cpp/tensorexpr/bench_reduce.cpp index acd46ac..be5dcc8 100644 --- a/benchmarks/cpp/tensorexpr/bench_reduce.cpp +++ b/benchmarks/cpp/tensorexpr/bench_reduce.cpp @@ -233,7 +233,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeNaive)(benchmark::State& state) { te::LoopNest loop({BT}); loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BT}); @@ -269,12 +269,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(BT); - te::For* m = loops[1]; + te::ForPtr m = loops[1]; loop.splitWithTail(m, kChunkSize); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BT}); @@ -310,12 +310,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(BT); - te::For* m = loops[1]; + te::ForPtr m = loops[1]; loop.splitWithMask(m, kChunkSize); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BT}); @@ -349,17 +349,17 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) { {{M, "M"}}); te::LoopNest loop({BT}); - te::Buf* rfac_buf; + te::BufPtr rfac_buf; auto loops = loop.getLoopStmtsFor(BT); TORCH_CHECK(loops.size() == 1); - te::For* mi; + te::ForPtr mi; loop.splitWithMask(loops.at(0), kChunkSize, &mi); - te::For* mo = loops.at(0); + te::ForPtr mo = loops.at(0); loop.reorderAxis(mo, mi); loops = loop.getLoopStmtsFor(BT); - auto bt_body = const_cast(loop.getAllWritesToBuf(BT->buf())[1]); + auto bt_body = loop.getAllWritesToBuf(BT->buf())[1]; TORCH_CHECK(loop.rfactor(bt_body, loops.at(0), &rfac_buf)); loop.reorderAxis(loops.at(0), loops.at(1)); @@ -368,7 +368,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) { loop.vectorize(loops.at(1)); loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BT}); @@ -394,8 +394,8 @@ BENCHMARK_DEFINE_F(Reduce1D, Op)(benchmark::State& state) { te::LoopNest nest({b}); auto loops = nest.getLoopStmtsFor(b); - te::For *mi, *mo; - te::Buf *rf; + te::ForPtr mi, mo; + te::BufPtr rf; nest.splitWithMask(loops[0], kChunkSize, &mi); loops = nest.reorder({loops[0], mi}, {1, 0}); nest.rfactor(nest.getLoopBodyFor(b), loops[0], &rf); @@ -566,8 +566,8 @@ BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) { auto sch = state.range(2); if (sch == 1) { auto loops = nest.getLoopStmtsFor(b); - te::For *mi, *mo; - te::Buf *rf; + te::ForPtr mi, mo; + te::BufPtr rf; nest.splitWithMask(loops[1], kChunkSize, &mi); loops = nest.reorder({loops[1], mi}, {1, 0}); TORCH_CHECK(nest.rfactor(nest.getLoopBodyFor(b), loops[0], &rf)); @@ -583,8 +583,8 @@ BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) { nest.reorderAxis(loops[1], loops[2]); } else if (sch == 3) { auto loops = nest.getLoopStmtsFor(b); - te::For *mi, *mo; - te::Buf *rf; + te::ForPtr mi, mo; + te::BufPtr rf; nest.splitWithMask(loops[1], kChunkSize, &mi); loops = nest.reorder({loops[1], mi}, {1, 0}); TORCH_CHECK(nest.rfactor(nest.getLoopBodyFor(b), loops[0], &rf)); diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 3776329..75e6a06 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -1642,7 +1642,7 @@ TEST(LLVM, CompositeParallel) { [=](const VarHandle& m, const VarHandle& n) { return t3->load(m, n) + m + n; }); - LoopNest loop_nest({t4}, {t1, t2, t3, t4}); + LoopNest loop_nest(std::vector({t4}), {t1, t2, t3, t4}); std::vector loop_list; { auto const& loops = loop_nest.getLoopStmtsFor(t1); diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 898ee52..c80dd5f 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -1011,7 +1011,7 @@ TEST(LoopNest, ScheduleFunctionCall01) { return c->load(m, n, k) + 1; }); - LoopNest l({d}, {c, d}); + LoopNest l(std::vector({d}), {c, d}); l.prepareForCodegen(); StmtPtr stmt = l.root_stmt(); std::ostringstream oss; @@ -1071,7 +1071,7 @@ TEST(LoopNest, ScheduleInlineSimple) { return c_buf.load(m, n) * d_buf.load(m, k) + x->load(m, n, k); }); - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); LoopNest l2(l1); l2.computeInline(x->buf()); @@ -1158,7 +1158,7 @@ void InlineFunc01Helper(const std::vector& inline_order) { return x->load(m, n, k) + y->load(m, n, k); }); - LoopNest l({z}, {x, y, z}); + LoopNest l(std::vector({z}), {x, y, z}); for (const std::string& order : inline_order) { if (order == "x") { l.computeInline(x->buf()); @@ -1267,7 +1267,7 @@ TEST(LoopNest, ScheduleInlineRandom) { return x->load(m, n, k) + x->load(m, n, k); }); - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); l1.computeInline(x->buf()); // would normally compare results but Rand isn't implemented in the @@ -1304,7 +1304,7 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) { Intrinsics::make(kRand, kInt); }); - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); l1.computeInline(x->buf()); // would normally compare results but Rand isn't implemented in the @@ -1337,7 +1337,7 @@ TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { return x->load(m) + x->load(m); }); - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); l1.computeInline(x->buf()); // would normally compare results but Rand isn't implemented in the @@ -1389,7 +1389,7 @@ TEST(LoopNest, ScheduleInlineIntrinsics) { } } - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); LoopNest l2(l1); l2.computeInline(x->buf()); @@ -1434,7 +1434,7 @@ TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { return Intrinsics::make(kSqrt, x->load(m, n, k)); }); - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); l1.computeInline(x->buf()); StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); @@ -1457,7 +1457,7 @@ TEST(LoopNest, ScheduleSplitAThenInline) { return a->load(j + ExprHandle(8)); }); - LoopNest l({b}, {a, b}); + LoopNest l(std::vector({b}), {a, b}); std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); LoopNest::splitWithMask(loops[0], 4); ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); @@ -1472,7 +1472,7 @@ TEST(LoopNest, ScheduleSplitBThenInline) { return a->load(j + ExprHandle(8)); }); - LoopNest l({b}, {a, b}); + LoopNest l(std::vector({b}), {a, b}); std::vector loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0); LoopNest::splitWithMask(loops[0], 3); l.computeInline(a->buf()); @@ -1499,7 +1499,7 @@ TEST(LoopNest, ScheduleSplitTwiceThenInline) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr i_inner; - LoopNest l({b}, {a, b}); + LoopNest l(std::vector({b}), {a, b}); std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); LoopNest::splitWithMask(loops[0], 4, &i_inner); LoopNest::splitWithMask(i_inner, 2); @@ -1515,7 +1515,7 @@ TEST(LoopNest, ScheduleInlineThenSplit) { return a->load(j + ExprHandle(8)); }); - LoopNest l({b}, {a, b}); + LoopNest l(std::vector({b}), {a, b}); l.computeInline(a->buf()); std::vector loops = NodeFinder::find(l.root_stmt()); @@ -1540,7 +1540,7 @@ TEST(LoopNest, ScheduleSplitInlineThenSplit) { return a->load(j + ExprHandle(8)); }); - LoopNest l({b}, {a, b}); + LoopNest l(std::vector({b}), {a, b}); auto loops = NodeFinder::find(l.root_stmt()); LoopNest::splitWithMask(loops.back(), 2); l.computeInline(a->buf()); @@ -1568,7 +1568,7 @@ TEST(LoopNest, ScheduleSplitInlineSimplify) { return a->load(j) - ExprHandle(1); }); - LoopNest l({b}, {a, b}); + LoopNest l(std::vector({b}), {a, b}); std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); LoopNest::splitWithMask(loops[0], 4); ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); @@ -1587,7 +1587,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedOnce) { return a->load(k) * b->load(l); }); - LoopNest l({c}, {a, b, c}); + LoopNest l(std::vector({c}), {a, b, c}); std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); l.computeInline(a->buf()); l.prepareForCodegen(); @@ -1617,7 +1617,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedTwice) { return a->load(k) * b->load(l); }); - LoopNest l({c}, {a, b, c}); + LoopNest l(std::vector({c}), {a, b, c}); std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); l.computeInline(a->buf()); l.computeInline(b->buf()); @@ -1648,7 +1648,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedInner) { return a->load(k) * b->load(l); }); - LoopNest l({c}, {a, b, c}); + LoopNest l(std::vector({c}), {a, b, c}); std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); l.computeInline(b->buf()); l.prepareForCodegen(); @@ -1678,7 +1678,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedSplit) { return a->load(k) * b->load(l); }); - LoopNest l({c}, {a, b, c}); + LoopNest l(std::vector({c}), {a, b, c}); std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); LoopNest::splitWithMask(loops[0], 4); loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0); @@ -1782,7 +1782,7 @@ TEST(LoopNest, ScheduleFuserThreeArg) { return f->load(i) + d.load(i); }); - LoopNest l({g}, {e, f, g}); + LoopNest l(std::vector({g}), {e, f, g}); l.computeInline(l.getLoopBodyFor(e)); l.computeInline(l.getLoopBodyFor(f)); l.prepareForCodegen(); @@ -1846,7 +1846,7 @@ TEST(LoopNest, LoopNestComputeAt_1) { "A", {{N, "i_a"}}, [&](const VarHandle& i_a) { return i_a * i_a; }); Tensor* B = Compute( "B", {{N, "i_b"}}, [&](const VarHandle& i_b) { return A->load(i_b); }); - LoopNest l({B}, {A, B}); + LoopNest l(std::vector({B}), {A, B}); std::vector loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); l.prepareForCodegen(); @@ -1909,7 +1909,7 @@ TEST(LoopNest, LoopNestComputeAt_2) { c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); } } - LoopNest orig_loopnest({c}, {p, c}); + LoopNest orig_loopnest(std::vector({c}), {p, c}); { // First let's try to compute P at axis cy (the outer loop) @@ -2009,7 +2009,7 @@ TEST(LoopNest, LoopNestComputeAt_3) { } } - LoopNest orig_loopnest({D}, {A, B, C, D}); + LoopNest orig_loopnest(std::vector({D}), {A, B, C, D}); { // First let's try to compute A at axis dy (the outer loop) LoopNest l(orig_loopnest); @@ -2100,7 +2100,7 @@ TEST(LoopNest, Reduce2dComputeAt) { c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); } } - LoopNest orig_loopnest({c}, {p, c}); + LoopNest orig_loopnest(std::vector({c}), {p, c}); checkIR(orig_loopnest.root_stmt(), R"IR( # CHECK: for (int py = 0; py < H + 1; py++) { # CHECK: for (int px = 0; px < W + 1; px++) { @@ -2771,7 +2771,7 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) { return x->load(m, n, k) + y->load(m, n, k); }); - LoopNest l({z}, {x, y, z}); + LoopNest l(std::vector({z}), {x, y, z}); ForPtr a = nullptr; ForPtr b = nullptr; auto fors = NodeFinder::find(l.root_stmt()); @@ -2983,7 +2983,7 @@ TEST(LoopNest, UnrollMultipleStatements) { Block::make( {Store::make(a_buf, {x}, x * 2), Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); - Block::make({f}); + auto parent_block = Block::make({f}); StmtPtr unrolled = nullptr; LoopNest::unroll(f, &unrolled); checkIR(unrolled, R"IR( @@ -3069,7 +3069,7 @@ TEST(LoopNest, UnrollWithLet) { {Let::make(e, 7), Store::make(a_buf, {x}, e), Store::make(b_buf, {x}, e + 1)})); - Block::make({f}); + auto parent_block = Block::make({f}); StmtPtr unrolled = nullptr; LoopNest::unroll(f, &unrolled); std::ostringstream oss; @@ -3680,7 +3680,7 @@ TEST(LoopNest, DetectInlineRankMismatch) { "reshape", {{kTotalSize / 2, "i"}, {2, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return a->load(i, j); }); - LoopNest l({reshape}, {a, reshape}); + LoopNest l(std::vector({reshape}), {a, reshape}); ASSERT_THROWS_WITH( l.computeInline(l.getLoopBodyFor(a)), "Placeholder indexed access is inconsistent with its rank"); @@ -3702,7 +3702,7 @@ TEST(LoopNest, CacheReadsSimple) { return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); }); - LoopNest l({B, C}, {A, B, C}); + LoopNest l(std::vector({B, C}), {A, B, C}); StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1]; LoopNest::cacheAccesses(A->buf(), "A_local", j_loop); @@ -3770,7 +3770,7 @@ TEST(LoopNest, CacheReadsOuter) { return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); }); - LoopNest l({B, C}, {A, B, C}); + LoopNest l(std::vector({B, C}), {A, B, C}); StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][0]; LoopNest::cacheAccesses(A->buf(), "A_local", i_loop); @@ -3818,7 +3818,7 @@ TEST(LoopNest, CacheReadsInternal) { return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); }); - LoopNest l({B, C}, {A, B, C}); + LoopNest l(std::vector({B, C}), {A, B, C}); StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1]; LoopNest::cacheAccesses(A->buf(), "A_local", j_loop); l.prepareForCodegen(); @@ -3866,7 +3866,7 @@ TEST(LoopNest, CacheReadsInner) { return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); }); - LoopNest l({B, C}, {A, B, C}); + LoopNest l(std::vector({B, C}), {A, B, C}); StmtPtr body = l.getLoopBodyFor(B); LoopNest::cacheAccesses(A->buf(), "A_local", body); l.prepareForCodegen(); @@ -3913,7 +3913,7 @@ TEST(LoopNest, CacheWritesSimple) { return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); }); - LoopNest l({B, C}, {A, B, C}); + LoopNest l(std::vector({B, C}), {A, B, C}); StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A->buf())[0][1]; LoopNest::cacheAccesses(A->buf(), "A_local", a_loop); @@ -4093,7 +4093,7 @@ TEST(LoopNest, InlineConstantIndex) { return y->load(m, n, o); }); - LoopNest l({z}, {y, z}); + LoopNest l(std::vector({z}), {y, z}); l.simplify(); ASSERT_TRUE(l.computeInline(y->buf())); } @@ -4121,7 +4121,7 @@ TEST(LoopNest, CompoundTensorUsed) { return A->load(i, j + 1) + A->load(i, j + 2); }); - LoopNest l({B}, {A, B}); + LoopNest l(std::vector({B}), {A, B}); ASSERT_FALSE(l.computeInline(A->buf())); l.prepareForCodegen(); @@ -4897,7 +4897,7 @@ TEST(LoopNest, VectorizeUse) { "b", {{N, "n"}}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); Tensor* c = Compute( "c", {{N, "n"}}, [&](const VarHandle& n) { return b->load(n) + 2.0f; }); - LoopNest nest({c}, {b, c}); + LoopNest nest(std::vector({c}), {b, c}); auto loops = nest.getAllLoopNestsWritingToBuf(b->buf())[0]; ASSERT_TRUE(LoopNest::vectorize(loops[0])); loops = nest.getAllLoopNestsWritingToBuf(c->buf())[0]; diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp index 7f844c5..9503f9d 100644 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -2739,7 +2739,7 @@ TEST(MemDependency, MemDependencyCheckerComputeAPI) { return c->load(m, n, k) + 1; }); - LoopNest l({d}, {c, d}); + LoopNest l(std::vector({d}), {c, d}); MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d->buf()}); @@ -2786,7 +2786,7 @@ TEST(MemDependency, MemDependencyCheckerComputeInline) { return c->load(m, n, k) + 1; }); - LoopNest l({d}, {c, d}); + LoopNest l(std::vector({d}), {c, d}); l.computeInline(c->buf()); MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d->buf()}); @@ -2935,7 +2935,7 @@ TEST(MemDependency, MemDependencyCheckerComputeReduce) { return b.load(l, n, m) * a.load(l, n, m); }); Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {6, "m1"}}); - LoopNest l({d}, {c, d}); + LoopNest l(std::vector({d}), {c, d}); MemDependencyChecker analyzer({a.data(), b.data()}, {d->buf()}); diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 0d033e0..449edac 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -533,7 +533,7 @@ TEST(Reductions, ReduceAsProducer) { [&](const VarHandle& l, const VarHandle& n) { return c->load(l, n) * a.load(l, n); }); - LoopNest loop({d}, {c, d}); + LoopNest loop(std::vector({d}), {c, d}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); @@ -578,7 +578,7 @@ TEST(Reductions, ReduceAsConsumer) { return b.load(l, n, m) * a.load(l, n, m); }); Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {m, "m1"}}); - LoopNest loop({d}, {c, d}); + LoopNest loop(std::vector({d}), {c, d}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); @@ -1201,7 +1201,7 @@ TEST(Reductions, ReduceInlineReduction) { } } - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); // Cannot inline a reduction computation ASSERT_FALSE(l1.computeInline(x->buf())); } @@ -1235,7 +1235,7 @@ TEST(Reductions, ReduceInlineConsumer) { } } - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); LoopNest l2(l1); l2.computeInline(x->buf()); @@ -1293,7 +1293,7 @@ TEST(Reductions, ReduceInlineReducerInternal) { } } - LoopNest l1({y}, {x, y}); + LoopNest l1(std::vector({y}), {x, y}); LoopNest l2(l1); l2.computeInline(x->buf()); @@ -1340,7 +1340,7 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) { return b.load(0, 0, l) * d->load(l); }); - LoopNest l({e}, {c, d, e}); + LoopNest l(std::vector({e}), {c, d, e}); LoopNest l_before(l); l_before.prepareForCodegen(); SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); @@ -1417,7 +1417,7 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { return b.load(0, 0, l) * d->load(l); }); - LoopNest l({e}, {c, d, e}); + LoopNest l(std::vector({e}), {c, d, e}); LoopNest l_before(l); l_before.prepareForCodegen(); SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); @@ -1492,7 +1492,7 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { return b.load(0, 0, l) * d->load(l); }); - LoopNest l({e}, {c, d, e}); + LoopNest l(std::vector({e}), {c, d, e}); LoopNest l_before(l); l_before.prepareForCodegen(); SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); @@ -1563,7 +1563,7 @@ TEST(Reductions, ReductionCacheBodyAccess) { return b.load(0, 0, l) * d->load(l); }); - LoopNest l({e}, {c, d, e}); + LoopNest l(std::vector({e}), {c, d, e}); StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; l.cacheAccesses(c->buf(), "scale_local", d_loop); @@ -1604,7 +1604,7 @@ TEST(Reductions, ReductionCacheConsumerAccess) { return b.load(0, 0, l) * d->load(l); }); - LoopNest l({e}, {c, d, e}); + LoopNest l(std::vector({e}), {c, d, e}); LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); @@ -1645,7 +1645,7 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) { return b.load(0, 0, l) * d->load(l); }); - LoopNest l({e}, {c, d, e}); + LoopNest l(std::vector({e}), {c, d, e}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner; @@ -1693,7 +1693,7 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { return b.load(0, 0, l) * d->load(l); }); - LoopNest l({e}, {c, d, e}); + LoopNest l(std::vector({e}), {c, d, e}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner; diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp index 9320f47..5a6f257 100644 --- a/test/cpp/tensorexpr/tutorial.cpp +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -256,7 +256,9 @@ int main(int argc, char* argv[]) { // Creating a loop nest is as quite simple, we just need to specify a list // of all and a list of output tensors: // NOLINTNEXTLINE(bugprone-argument-comment) - LoopNest loopnest(/*outputs=*/{Y}, /*all=*/{X, Y}); + std::vector outputs = {Y}; + std::vector all = {X, Y}; + LoopNest loopnest(outputs, all); // An IR used in LoopNest is based on tensor statements, represented by // `Stmt` class. Statements are used to specify the loop nest structure, and diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index 2d00b1e..b342f14 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -389,34 +389,33 @@ class AtomicAddFuser : public IRMutator { StmtPtr mutate(StorePtr v) override { BufPtr buf = v->buf(); - StorePtr orig = const_cast(v); // NOLINT // Thread locals never need to be atomic. if (thread_local_bufs_.count(buf->base_handle()) != 0) { - return orig; + return v; } ScalarType dtype = v->value()->dtype().scalar_type(); if (dtype != ScalarType::Float && dtype != ScalarType::Double) { - return orig; + return v; } AddPtr add_v = to(v->value()); if (!add_v) { - return orig; + return v; } LoadPtr load_v = to(add_v->lhs()); if (!load_v) { - return orig; + return v; } if (v->base_handle() != load_v->base_handle()) { - return orig; + return v; } if (v->indices().empty() && load_v->indices().empty()) { - return orig; + return v; } bool index_equal = CheckEqual(v->flat_index(), load_v->flat_index()); if (!index_equal) { - return orig; + return v; } // TODO: this checks that the metavars occur directly as an index, but this @@ -431,7 +430,7 @@ class AtomicAddFuser : public IRMutator { if (vars_to_find.empty()) { // All metavars accounted for. - return orig; + return v; } return alloc(buf, v->indices(), add_v->rhs()); @@ -609,23 +608,21 @@ class PrioritizeLoad : public IRMutator { } StmtPtr mutate(BlockPtr v) override { - BlockPtr v1 = const_cast(v); // NOLINT - assert(v1); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::list stmts = v1->stmts(); + std::list stmts = v->stmts(); for (StmtPtr stmt : stmts) { PushList(); StmtPtr stmt_new = stmt->accept_mutator(this); - AddMemLoadsFromList(v1, stmt); + AddMemLoadsFromList(v, stmt); PopList(); if (stmt_new == stmt) { continue; } - v1->replace_stmt(stmt, stmt_new); + v->replace_stmt(stmt, stmt_new); } - return v1; + return v; } ExprPtr mutate(IfThenElsePtr v) override { @@ -821,7 +818,7 @@ StmtPtr GPUMetaVarRewriter::mutate(BlockPtr v) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector stmts; for (auto& v : innerSegments) { - for (auto* s : v.stmts()) { + for (auto s : v.stmts()) { stmts.push_back(s); } } diff --git a/torch/csrc/jit/tensorexpr/half_support.h b/torch/csrc/jit/tensorexpr/half_support.h index 15d48cd..eaf74d3 100644 --- a/torch/csrc/jit/tensorexpr/half_support.h +++ b/torch/csrc/jit/tensorexpr/half_support.h @@ -72,7 +72,8 @@ class HalfRewriter : public IRMutator { inserted_half_casts_.insert(new_val); } - return alloc(v->buf(), v->indices(), new_val); + v->set_value(new_val); + return v; } ExprPtr mutate(HalfImmPtr v) override { diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h index 5363d2f..1965b05 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -299,7 +299,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor { DependencySet getAllReadsWithin(StmtOrExprPtr v) { DependencySet reads; auto insertAllReads = [&](const auto& nodes) { - for (auto* l : nodes) { + for (auto l : nodes) { auto bound = exprToAccess_.equal_range(l); for (auto it = bound.first; it != bound.second; ++it) { if (it->second->isRead()) { @@ -324,7 +324,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor { // writes just Store currently. auto stores = NodeFinder::find(v); - for (auto* s : stores) { + for (auto s : stores) { auto bound = stmtToAccess_.equal_range(s); for (auto it = bound.first; it != bound.second; ++it) { if (it->second->isWrite()) { diff --git a/torch/csrc/jit/tensorexpr/operators/reduction.h b/torch/csrc/jit/tensorexpr/operators/reduction.h index 29f051f..4335d7b 100644 --- a/torch/csrc/jit/tensorexpr/operators/reduction.h +++ b/torch/csrc/jit/tensorexpr/operators/reduction.h @@ -6,14 +6,14 @@ namespace torch { namespace jit { namespace tensorexpr { -Tensor* computeSum( +TORCH_API Tensor* computeSum( const std::vector& inputs, const c10::optional& outputType); -Tensor* computeMean( +TORCH_API Tensor* computeMean( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType); -Tensor* computeAdaptiveAvgPool2d( +TORCH_API Tensor* computeAdaptiveAvgPool2d( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType); -- 2.7.4