From 1dc2b52764a81288cb0e31bb7a72fe8c4cc826a1 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Tue, 17 Aug 2021 13:39:36 -0700 Subject: [PATCH] [TensorExpr] Add a wrapper for all expr and stmt pointers. (#63195) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63195 This helps us to later switch from using KernelArena with raw pointers to shared pointers without having to change all our source files at once. The changes are mechanical and should not affect any functionality. With this PR, we're changing the following: * `Add*` --> `AddPtr` * `new Add(...)` --> `alloc(...)` * `dynamic_cast` --> `to` * `static_cast` --> `static_to` Due to some complications with args forwarding, some places became more verbose, e.g.: * `new Block({})` --> `new Block(std::vector())` Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D30292779 Pulled By: ZolotukhinM fbshipit-source-id: 150301c7d2df56b608b035827b6a9a87f5e2d9e9 --- test/cpp/tensorexpr/test_approx.cpp | 4 +- test/cpp/tensorexpr/test_aten.cpp | 120 +- test/cpp/tensorexpr/test_boundsinference.cpp | 54 +- test/cpp/tensorexpr/test_conv.cpp | 2 +- test/cpp/tensorexpr/test_cpp_codegen.cpp | 21 +- test/cpp/tensorexpr/test_cuda.cpp | 219 ++-- test/cpp/tensorexpr/test_expr.cpp | 50 +- test/cpp/tensorexpr/test_external_calls.cpp | 4 +- test/cpp/tensorexpr/test_ir_printer.cpp | 2 +- test/cpp/tensorexpr/test_ir_verifier.cpp | 98 +- test/cpp/tensorexpr/test_kernel.cpp | 36 +- test/cpp/tensorexpr/test_llvm.cpp | 86 +- test/cpp/tensorexpr/test_loopnest.cpp | 579 +++++----- test/cpp/tensorexpr/test_memdependency.cpp | 280 ++--- test/cpp/tensorexpr/test_reductions.cpp | 180 ++- test/cpp/tensorexpr/test_registerizer.cpp | 138 +-- test/cpp/tensorexpr/test_simplify.cpp | 474 ++++---- test/cpp/tensorexpr/test_utils.h | 56 +- test/cpp/tensorexpr/tutorial.cpp | 27 +- torch/csrc/jit/runtime/static/ops.cpp | 6 +- torch/csrc/jit/tensorexpr/analysis.h | 102 +- torch/csrc/jit/tensorexpr/block_codegen.cpp | 54 +- torch/csrc/jit/tensorexpr/block_codegen.h | 62 +- torch/csrc/jit/tensorexpr/bounds_inference.cpp | 65 +- torch/csrc/jit/tensorexpr/bounds_inference.h | 32 +- torch/csrc/jit/tensorexpr/bounds_overlap.cpp | 33 +- torch/csrc/jit/tensorexpr/bounds_overlap.h | 8 +- torch/csrc/jit/tensorexpr/codegen.cpp | 4 +- torch/csrc/jit/tensorexpr/codegen.h | 26 +- torch/csrc/jit/tensorexpr/cpp_codegen.cpp | 8 +- torch/csrc/jit/tensorexpr/cpp_codegen.h | 6 +- torch/csrc/jit/tensorexpr/cuda_codegen.cpp | 232 ++-- torch/csrc/jit/tensorexpr/cuda_codegen.h | 121 +- torch/csrc/jit/tensorexpr/eval.cpp | 119 +- torch/csrc/jit/tensorexpr/eval.h | 22 +- torch/csrc/jit/tensorexpr/exceptions.h | 25 +- torch/csrc/jit/tensorexpr/expr.cpp | 2 +- torch/csrc/jit/tensorexpr/expr.h | 81 +- torch/csrc/jit/tensorexpr/fwd_decls.h | 120 ++ torch/csrc/jit/tensorexpr/half_support.h | 52 +- torch/csrc/jit/tensorexpr/hash_provider.cpp | 88 +- torch/csrc/jit/tensorexpr/hash_provider.h | 70 +- torch/csrc/jit/tensorexpr/ir.cpp | 66 +- torch/csrc/jit/tensorexpr/ir.h | 253 ++-- torch/csrc/jit/tensorexpr/ir_cloner.cpp | 272 ++--- torch/csrc/jit/tensorexpr/ir_cloner.h | 125 +- torch/csrc/jit/tensorexpr/ir_mutator.cpp | 310 +++-- torch/csrc/jit/tensorexpr/ir_mutator.h | 134 +-- torch/csrc/jit/tensorexpr/ir_printer.cpp | 122 +- torch/csrc/jit/tensorexpr/ir_printer.h | 87 +- torch/csrc/jit/tensorexpr/ir_simplifier.cpp | 1207 ++++++++++---------- torch/csrc/jit/tensorexpr/ir_simplifier.h | 233 ++-- torch/csrc/jit/tensorexpr/ir_verifier.cpp | 36 +- torch/csrc/jit/tensorexpr/ir_verifier.h | 37 +- torch/csrc/jit/tensorexpr/ir_visitor.cpp | 110 +- torch/csrc/jit/tensorexpr/ir_visitor.h | 126 +- torch/csrc/jit/tensorexpr/kernel.cpp | 98 +- torch/csrc/jit/tensorexpr/kernel.h | 14 +- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 196 ++-- torch/csrc/jit/tensorexpr/llvm_codegen.h | 8 +- torch/csrc/jit/tensorexpr/loopnest.cpp | 1181 +++++++++---------- torch/csrc/jit/tensorexpr/loopnest.h | 137 +-- .../csrc/jit/tensorexpr/mem_dependency_checker.cpp | 221 ++-- torch/csrc/jit/tensorexpr/mem_dependency_checker.h | 130 +-- torch/csrc/jit/tensorexpr/operators/conv2d.cpp | 4 +- torch/csrc/jit/tensorexpr/operators/matmul.cpp | 10 +- torch/csrc/jit/tensorexpr/operators/softmax.cpp | 8 +- torch/csrc/jit/tensorexpr/reduction.cpp | 22 +- torch/csrc/jit/tensorexpr/reduction.h | 49 +- torch/csrc/jit/tensorexpr/registerizer.cpp | 121 +- torch/csrc/jit/tensorexpr/registerizer.h | 137 +-- torch/csrc/jit/tensorexpr/stmt.h | 327 +++--- torch/csrc/jit/tensorexpr/tensor.cpp | 76 +- torch/csrc/jit/tensorexpr/tensor.h | 83 +- torch/csrc/jit/tensorexpr/tensorexpr_init.cpp | 94 +- torch/csrc/jit/tensorexpr/unique_name_manager.cpp | 2 +- torch/csrc/jit/tensorexpr/unique_name_manager.h | 5 +- torch/csrc/jit/tensorexpr/var_substitutor.h | 22 +- 78 files changed, 4972 insertions(+), 4859 deletions(-) create mode 100644 torch/csrc/jit/tensorexpr/fwd_decls.h diff --git a/test/cpp/tensorexpr/test_approx.cpp b/test/cpp/tensorexpr/test_approx.cpp index 2005f1e..d761645 100644 --- a/test/cpp/tensorexpr/test_approx.cpp +++ b/test/cpp/tensorexpr/test_approx.cpp @@ -13,7 +13,7 @@ namespace te = torch::jit::tensorexpr; static void vectorize(te::LoopNest* ln, te::Tensor* target, int width) { auto loops = ln->getLoopStmtsFor(target); - te::For *inner, *tail; + te::ForPtr inner, tail; ln->splitWithTail(loops[0], width, &inner, &tail); ASSERT_TRUE(te::LoopNest::vectorize(inner)); } @@ -39,7 +39,7 @@ TEST(Approx, log_vml) { te::LoopNest ln({B}); ln.prepareForCodegen(); vectorize(&ln, B, 8); - te::Stmt* s = ln.root_stmt(); + te::StmtPtr s = ln.root_stmt(); s = te::IRSimplifier::simplify(s); te::LLVMCodeGen cg(s, {A, B, N}); diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp index b61b84d..9eb1412 100644 --- a/test/cpp/tensorexpr/test_aten.cpp +++ b/test/cpp/tensorexpr/test_aten.cpp @@ -23,8 +23,8 @@ TEST(ATen, _cast_Float) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle to_float = Cast::make(kFloat, load_a); - Stmt* store_b = b_buf.store({index}, to_float); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -51,8 +51,8 @@ TEST(ATen, negInt) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle to_float = Sub::make(0, load_a); - Stmt* store_b = b_buf.store({index}, to_float); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -79,8 +79,8 @@ TEST(ATen, negFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle to_float = Sub::make(0, load_a); - Stmt* store_b = b_buf.store({index}, to_float); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -110,8 +110,8 @@ TEST(ATen, addInt) { ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); ExprHandle load_c = c_buf.load(index); - Stmt* store_d = d_buf.store({index}, load_a + load_b * load_c); - Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -147,8 +147,8 @@ TEST(ATen, addFloat) { ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); ExprHandle load_c = c_buf.load(index); - Stmt* store_d = d_buf.store({index}, load_a + load_b * load_c); - Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -184,8 +184,8 @@ TEST(ATen, subInt) { ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); ExprHandle load_c = c_buf.load(index); - Stmt* store_d = d_buf.store({index}, load_a - load_b * load_c); - Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -221,8 +221,8 @@ TEST(ATen, subFloat) { ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); ExprHandle load_c = c_buf.load(index); - Stmt* store_d = d_buf.store({index}, load_a - load_b * load_c); - Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -258,8 +258,8 @@ TEST(ATen, lerp) { ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); ExprHandle load_c = c_buf.load(index); - Stmt* store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -297,8 +297,8 @@ TEST(ATen, addcmulInt) { ExprHandle load_b = b_buf.load(index); ExprHandle load_c = c_buf.load(index); ExprHandle load_d = d_buf.load(index); - Stmt* store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); - Stmt* stmt = For::make(index, 0, kTotalSize, store_e); + StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -339,8 +339,8 @@ TEST(ATen, addcmulFloat) { ExprHandle load_b = b_buf.load(index); ExprHandle load_c = c_buf.load(index); ExprHandle load_d = d_buf.load(index); - Stmt* store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); - Stmt* stmt = For::make(index, 0, kTotalSize, store_e); + StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -377,8 +377,8 @@ TEST(ATen, mulInt) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); - Stmt* store_c = c_buf.store({index}, load_a * load_b); - Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + StmtPtr store_c = c_buf.store({index}, load_a * load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -409,8 +409,8 @@ TEST(ATen, mulFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); - Stmt* store_c = c_buf.store({index}, load_a * load_b); - Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + StmtPtr store_c = c_buf.store({index}, load_a * load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -441,8 +441,8 @@ TEST(ATen, divInt) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); - Stmt* store_c = c_buf.store({index}, load_a / load_b); - Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + StmtPtr store_c = c_buf.store({index}, load_a / load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -473,8 +473,8 @@ TEST(ATen, divFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); - Stmt* store_c = c_buf.store({index}, load_a / load_b); - Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + StmtPtr store_c = c_buf.store({index}, load_a / load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -505,8 +505,8 @@ TEST(ATen, maxInt) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); - Stmt* store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -537,8 +537,8 @@ TEST(ATen, maxFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); - Stmt* store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -569,8 +569,8 @@ TEST(ATen, minInt) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); - Stmt* store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -601,8 +601,8 @@ TEST(ATen, minFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); ExprHandle load_b = b_buf.load(index); - Stmt* store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -631,8 +631,8 @@ void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -658,8 +658,8 @@ TEST(ATen, reluInt) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, Max::make(load_a, 0, false)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -685,10 +685,10 @@ TEST(ATen, reluFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store( + StmtPtr store_b = b_buf.store( {index}, Max::make(load_a, 0, false) // relu does not propagate nans ); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -714,8 +714,8 @@ TEST(ATen, logFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, log(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -741,8 +741,8 @@ TEST(ATen, fastLogFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, fast_log(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -773,8 +773,8 @@ TEST(ATen, fastTanhFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, fast_tanh(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -805,8 +805,8 @@ TEST(ATen, fastSigmoidFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, fast_sigmoid(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -838,8 +838,8 @@ TEST(ATen, log10Float) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, log10(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, log10(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -865,8 +865,8 @@ TEST(ATen, log2Float) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, log2(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, log2(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -892,8 +892,8 @@ TEST(ATen, expFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, exp(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, exp(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -920,8 +920,8 @@ TEST(ATen, erfFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, erf(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, erf(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -948,8 +948,8 @@ TEST(ATen, cosFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, cos(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, cos(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp index d246bd7..fcfa8ce 100644 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ b/test/cpp/tensorexpr/test_boundsinference.cpp @@ -138,8 +138,8 @@ TEST(BoundsInference, _4) { return a.load(y, x) * b->load(y, x); }); LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - Stmt* body = l.getLoopBodyFor(c); + std::vector loops = l.getLoopStmtsFor(c); + StmtPtr body = l.getLoopBodyFor(c); { // Infer bounds on the top-level loop scope auto bounds_info = inferBounds(loops[0]); @@ -213,12 +213,12 @@ TEST(BoundsInference, _5) { LoopNest l({b}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner; + ForPtr inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getLoopStmtsFor(b); + ForPtr tail; + std::vector loops = l.getLoopStmtsFor(b); LoopNest::splitWithTail(loops[0], 16, &inner, &tail); - For* outer = loops[0]; + ForPtr outer = loops[0]; { // Verify inferred bounds for the outer loop @@ -272,8 +272,8 @@ TEST(BoundsInference, _6) { return a.load(y + 100, x + 100) * b->load(y * 2, x * 5); }); LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - Stmt* body = l.getLoopBodyFor(c); + std::vector loops = l.getLoopStmtsFor(c); + StmtPtr body = l.getLoopBodyFor(c); { // Infer bounds on the top-level loop scope auto bounds_info = inferBounds(loops[0]); @@ -336,7 +336,7 @@ TEST(BoundsInference, Adjacent) { Tensor* c = Compute( "c", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x + H); }); LoopNest l({b, c}); - std::vector loops = NodeFinder::find(l.root_stmt()); + std::vector loops = NodeFinder::find(l.root_stmt()); { // Infer bounds on the top-level loop scope @@ -453,7 +453,7 @@ TEST(BoundsInference, MultipleTopLoopStore) { // Same as above but the offsets are on the Store now. // Can't do this through ComputeAPI without transforms we don't have yet. - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))), For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))), For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))}); @@ -522,7 +522,7 @@ TEST(BoundsInference, CacheReads) { LoopNest l({B, C}); auto bounds_info_before = inferBounds(l.root_stmt()); - Stmt* j_loop = l.getLoopStmtsFor(B)[1]; + StmtPtr j_loop = l.getLoopStmtsFor(B)[1]; LoopNest::cacheAccesses(A->buf(), "A_local", j_loop); auto bounds_info_after = inferBounds(l.root_stmt()); @@ -592,8 +592,8 @@ TEST(BoundsInference, Flattened) { ASSERT_EQ(TABI.stop.size(), 1); // Bounds should be 0 -> (3*4*5)-1 - ASSERT_TRUE(exprEquals(TABI.start[0], new IntImm(0))); - ASSERT_TRUE(exprEquals(TABI.stop[0], new IntImm(3 * 4 * 5 - 1))); + ASSERT_TRUE(exprEquals(TABI.start[0], alloc(0))); + ASSERT_TRUE(exprEquals(TABI.stop[0], alloc(3 * 4 * 5 - 1))); } TEST(BoundsInference, GetPotentialHazards) { @@ -614,11 +614,11 @@ TEST(BoundsInference, GetPotentialHazards) { * C[0] = 5; */ - Store* store1 = Store::make(a, {0}, Load::make(b, {0})); - Store* store2 = Store::make(b, {0}, 3); - Store* store3 = Store::make(a, {0}, Load::make(b, {0})); - Store* store4 = Store::make(c, {0}, 5); - Stmt* stmt = Block::make({store1, store2, store3, store4}); + StorePtr store1 = Store::make(a, {0}, Load::make(b, {0})); + StorePtr store2 = Store::make(b, {0}, 3); + StorePtr store3 = Store::make(a, {0}, Load::make(b, {0})); + StorePtr store4 = Store::make(c, {0}, 5); + StmtPtr stmt = Block::make({store1, store2, store3, store4}); MemDependencyChecker analyzer; stmt->accept(&analyzer); @@ -667,8 +667,8 @@ TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { MemDependencyChecker analyzer; l.root_stmt()->accept(&analyzer); - For* loopRootA = l.getLoopStmtsFor(A)[0]; - For* loopRootB = l.getLoopStmtsFor(B)[0]; + ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; + ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; // No dependencies between loops. ASSERT_EQ( @@ -695,8 +695,8 @@ TEST(BoundsInference, GetPotentialHazardsLoopCall) { MemDependencyChecker analyzer; l.root_stmt()->accept(&analyzer); - For* loopRootA = l.getLoopStmtsFor(A)[0]; - For* loopRootB = l.getLoopStmtsFor(B)[0]; + ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; + ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; ASSERT_EQ( HazardKind::ReadAfterWrite, @@ -713,11 +713,11 @@ TEST(BoundsInference, GetPotentialHazardsLoopSplit) { LoopNest l({A}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *inner, *tail; + ForPtr inner, tail; // Splitting with tail by something offset creates a tail which also writes to // A. - For* outer = l.getLoopStmtsFor(A)[0]; + ForPtr outer = l.getLoopStmtsFor(A)[0]; // `outer` loop get transformed to the outer loop after splitting. LoopNest::splitWithTail(outer, 5, &inner, &tail); @@ -1066,10 +1066,8 @@ TEST(BoundsInference, IsOverlapping) { i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3})); tensorexpr::analysis::MemDependencyChecker analyzer; forI->accept(&analyzer); - ASSERT_TRUE( - isOverlapping(analyzer, storeA1, dynamic_cast(loadA1.node()))); - ASSERT_FALSE( - isOverlapping(analyzer, storeA1, dynamic_cast(loadA2.node()))); + ASSERT_TRUE(isOverlapping(analyzer, storeA1, to(loadA1.node()))); + ASSERT_FALSE(isOverlapping(analyzer, storeA1, to(loadA2.node()))); ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2)); ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3)); } diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp index 83dab36..63881d0 100644 --- a/test/cpp/tensorexpr/test_conv.cpp +++ b/test/cpp/tensorexpr/test_conv.cpp @@ -223,7 +223,7 @@ TEST(Conv, Conv2D) { // LoopNest, IRSimplifier, etc. te::LoopNest loop({conv}); loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); at::Tensor result = at::empty_like(ref); diff --git a/test/cpp/tensorexpr/test_cpp_codegen.cpp b/test/cpp/tensorexpr/test_cpp_codegen.cpp index 1a2a214..82ea40d 100644 --- a/test/cpp/tensorexpr/test_cpp_codegen.cpp +++ b/test/cpp/tensorexpr/test_cpp_codegen.cpp @@ -14,11 +14,11 @@ using namespace torch::jit::tensorexpr; TEST(CppPrinter, AllocateOnStackThenFree) { KernelScope kernel_scope; - std::vector dims = {new IntImm(2), new IntImm(3)}; - Buf* buf = new Buf("x", dims, kInt); - Allocate* alloc = new Allocate(buf); - Free* free = new Free(buf); - Block* block = Block::make({alloc, free}); + std::vector dims = {alloc(2), alloc(3)}; + BufPtr buf = alloc("x", dims, kInt); + AllocatePtr alloc_ = alloc(buf); + FreePtr free_ = alloc(buf); + BlockPtr block = Block::make({alloc_, free_}); std::stringstream ss; CppPrinter printer(&ss); @@ -33,11 +33,12 @@ TEST(CppPrinter, AllocateOnStackThenFree) { TEST(CppPrinter, AllocateOnHeapThenFree) { KernelScope kernel_scope; - std::vector dims = {new IntImm(20), new IntImm(50), new IntImm(3)}; - Buf* buf = new Buf("y", dims, kLong); - Allocate* alloc = new Allocate(buf); - Free* free = new Free(buf); - Block* block = Block::make({alloc, free}); + std::vector dims = { + alloc(20), alloc(50), alloc(3)}; + BufPtr buf = alloc("y", dims, kLong); + AllocatePtr alloc_ = alloc(buf); + FreePtr free_ = alloc(buf); + BlockPtr block = Block::make({alloc_, free_}); std::stringstream ss; CppPrinter printer(&ss); diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp index 1c80e36..3ca6e0d 100644 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -45,11 +45,11 @@ static void testCudaTestVectorAdd01_impl() { return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id); }); LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[1]->set_gpu_block_index(0); loops[2]->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); const int N = block_count * block_size * num_iter; PaddedBuffer a_v(N); @@ -110,11 +110,11 @@ TEST(Cuda, Sigmoid_CUDA) { return sigmoid(sigmoid(a_buf.load(n, b_id, t_id))); }); LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[1]->set_gpu_block_index(0); loops[2]->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, a_buf); const int N = block_count * block_size * num_iter; PaddedBuffer a_v(N); @@ -172,13 +172,13 @@ static void testCudaTestVectorAdd02_impl(int N, int block_size) { }, [&](const VarHandle& n) { return a_buf.load(n) + b_buf.load(n); }); LoopNest l({c}); - For* n_inner; - std::vector loops = l.getLoopStmtsFor(c); + ForPtr n_inner; + std::vector loops = l.getLoopStmtsFor(c); l.splitWithMask(loops[0], block_size, &n_inner); loops[0]->set_gpu_block_index(0); n_inner->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); PaddedBuffer a_v(N); PaddedBuffer b_v(N); @@ -231,7 +231,7 @@ TEST(Cuda, HalfCast_CUDA) { LoopNest l({b}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); CudaCodeGen cg(s, {a, b}); std::vector aData(4, 2.0f); @@ -273,7 +273,7 @@ TEST(Cuda, DynamicShape2D_CUDA) { }); LoopNest l({c}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); CudaCodeGen cg(s, {a, b, c, m, n}); std::vector aData(M * N, 1.0f); @@ -339,11 +339,11 @@ TEST(Cuda, TestRand01_CUDA) { return Intrinsics::make(IntrinsicsOp::kRand, kFloat); }); LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[1]->set_gpu_block_index(0); loops[2]->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c); const int N = block_count * block_size * num_iter; PaddedBuffer c_v(N); @@ -390,12 +390,12 @@ TEST(Cuda, DynamicShapeSplit_CUDA) { Tensor* b = Compute( "b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); LoopNest l({b}); - For* inner; - std::vector loops = l.getLoopStmtsFor(b); + ForPtr inner; + std::vector loops = l.getLoopStmtsFor(b); l.splitWithMask(loops[0], 1024, &inner); loops[0]->set_gpu_block_index(0); inner->set_gpu_thread_index(0); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); CudaCodeGen cg(s, {a, b, n}); std::vector aData(N, 1.0f); @@ -448,23 +448,23 @@ TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { // } // } - Store* init_store = output_buf.store({0}, 0.f); + StorePtr init_store = output_buf.store({0}, 0.f); VarHandle i1("i1", kInt); ExprHandle load_data = Load::make(BufHandle(data_buf.data()), {i1}); ExprHandle load_output = Load::make(BufHandle(output_buf.data()), {0}); ExprHandle add_value = load_output + load_data; - Store* store_output = output_buf.store({0}, add_value); - For* for_output = For::make(i1, 0, N, store_output); - Stmt* reduce_block = Block::make({init_store, for_output}); + StorePtr store_output = output_buf.store({0}, add_value); + ForPtr for_output = For::make(i1, 0, N, store_output); + StmtPtr reduce_block = Block::make({init_store, for_output}); VarHandle thread_idx("tidx", kInt); LoopOptions thread_idx_options; thread_idx_options.set_gpu_thread_index(0); - For* thread_idx_loop = + ForPtr thread_idx_loop = For::make(thread_idx, 0, 1, reduce_block, thread_idx_options); VarHandle block_idx("bidx", kInt); LoopOptions block_idx_options; block_idx_options.set_gpu_block_index(0); - For* block_idx_loop = + ForPtr block_idx_loop = For::make(block_idx, 0, 1, thread_idx_loop, block_idx_options); CudaCodeGen cuda_cg(block_idx_loop, data_buf, output_buf); @@ -517,7 +517,7 @@ TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { Placeholder a_buf("a", kFloat, {N}); Placeholder b_buf("b", kFloat, {1}); - Store* init_store = b_buf.store({0}, 0.f); + StorePtr init_store = b_buf.store({0}, 0.f); VarHandle t("t", kInt); VarHandle b("b", kInt); @@ -526,25 +526,25 @@ TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { // b[0] = 0 ExprHandle cond_t_lt_1 = CompareSelect::make(t, 1, CompareSelectOperation::kLT); - Cond* masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr); + CondPtr masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr); LoopOptions thread_idx_options; thread_idx_options.set_gpu_thread_index(0); - For* for_init = For::make(t, 0, N, masked_init_b, thread_idx_options); + ForPtr for_init = For::make(t, 0, N, masked_init_b, thread_idx_options); // for t in 0..1024: // thread-idx // b[0] = b[0] + a[t] // implied atomic ExprHandle load_a = Load::make(BufHandle(a_buf.data()), {t}); ExprHandle load_b = Load::make(BufHandle(b_buf.data()), {0}); ExprHandle add_value = load_b + load_a; - Store* store_b = b_buf.store({0}, add_value); - For* for_b = For::make(t, 0, N, store_b, thread_idx_options); + StorePtr store_b = b_buf.store({0}, add_value); + ForPtr for_b = For::make(t, 0, N, store_b, thread_idx_options); - Stmt* reduce_block = Block::make({for_init, for_b}); + StmtPtr reduce_block = Block::make({for_init, for_b}); VarHandle block_idx("bidx", kInt); LoopOptions block_idx_options; block_idx_options.set_gpu_block_index(0); - For* block_idx_loop = + ForPtr block_idx_loop = For::make(block_idx, 0, 1, reduce_block, block_idx_options); CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); @@ -607,35 +607,35 @@ TEST(Cuda, NoThreadIdxWrite_1_CUDA) { // a[0] = 0 // for n in 0..2: // a[0] = a[0] + n - Store* store_a0_0 = a_buf.store({0}, 0.f); + StorePtr store_a0_0 = a_buf.store({0}, 0.f); ExprHandle load_a0 = Load::make(BufHandle(a_buf.data()), {0}); ExprHandle v1 = load_a0 + n; - Store* store_a0_v1 = a_buf.store({0}, v1); - For* loop_a_0 = For::make(n, 0, 2, store_a0_v1); + StorePtr store_a0_v1 = a_buf.store({0}, v1); + ForPtr loop_a_0 = For::make(n, 0, 2, store_a0_v1); // for m in 0..1024: // thread-idx // b[m] = m - Store* store_bm_m = b_buf.store({m}, m + 0.f); + StorePtr store_bm_m = b_buf.store({m}, m + 0.f); LoopOptions thread_idx_options; thread_idx_options.set_gpu_thread_index(0); - For* loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); + ForPtr loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); // a[1] = 1 // for l in 0..2: // a[1] = a[1] + l - Store* store_a1_1 = a_buf.store({1}, 1.f); + StorePtr store_a1_1 = a_buf.store({1}, 1.f); ExprHandle load_a1 = a_buf.load(1); ExprHandle v2 = load_a1 + l; - Store* store_a1_v2 = a_buf.store({1}, v2); - For* loop_a_1 = For::make(l, 0, 2, store_a1_v2); + StorePtr store_a1_v2 = a_buf.store({1}, v2); + ForPtr loop_a_1 = For::make(l, 0, 2, store_a1_v2); - Stmt* reduce_block = + StmtPtr reduce_block = Block::make({store_a0_0, loop_a_0, loop_b_1, store_a1_1, loop_a_1}); VarHandle block_idx("bidx", kInt); LoopOptions block_idx_options; block_idx_options.set_gpu_block_index(0); - For* block_idx_loop = + ForPtr block_idx_loop = For::make(block_idx, 0, 1, reduce_block, block_idx_options); CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); @@ -704,21 +704,21 @@ TEST(Cuda, SharedMemReduce_1_CUDA) { VarHandle m("m", kInt); VarHandle n("n", kInt); - std::vector block; - std::vector dims; + std::vector block; + std::vector dims; dims.push_back(ExprHandle(N).node()); - BufHandle c{new Buf("c", dims, kFloat)}; + BufHandle c{alloc("c", dims, kFloat)}; { // alloc(c, 64); - Allocate* alloc = Allocate::make(c); + AllocatePtr alloc = Allocate::make(c); block.push_back(alloc); } { // for n in 0..64: // thread-idx // c(n) = 0 - Store* store_cn_0 = Store::make(c, {n}, 0.f); - For* loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt); + StorePtr store_cn_0 = Store::make(c, {n}, 0.f); + ForPtr loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt); block.push_back(loop_n1); } @@ -730,9 +730,9 @@ TEST(Cuda, SharedMemReduce_1_CUDA) { ExprHandle a_kmn = Load::make(BufHandle(a.data()), {k * (M * N) + m * N + n}); ExprHandle v_add = load_cn + a_kmn; - Store* store_cn_v = Store::make(c, {n}, v_add); - For* loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); - For* loop_m1 = For::make(m, 0, M, loop_n2); + StorePtr store_cn_v = Store::make(c, {n}, v_add); + ForPtr loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); + ForPtr loop_m1 = For::make(m, 0, M, loop_n2); block.push_back(loop_m1); } @@ -740,24 +740,24 @@ TEST(Cuda, SharedMemReduce_1_CUDA) { // b(k) = 0 // for n in 0..64: // thread_idx // b(k) = b(k) + c(n) - Store* store_bk_0 = b.store({k}, 0.f); + StorePtr store_bk_0 = b.store({k}, 0.f); block.push_back(store_bk_0); ExprHandle load_bk = b.load(k); ExprHandle load_cn = Load::make(kFloat, c, {n}); ExprHandle v_add = load_bk + load_cn; - Store* store_bk = b.store({k}, v_add); - For* loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); + StorePtr store_bk = b.store({k}, v_add); + ForPtr loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); block.push_back(loop_n3); } { // free(c) - Free* free_stmt = Free::make(c); + FreePtr free_stmt = Free::make(c); block.push_back(free_stmt); } - Block* reduce_body = Block::make(block); - For* loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt); + BlockPtr reduce_body = Block::make(block); + ForPtr loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt); // TODO: check the generated code for correctness. CudaCodeGen cuda_cg(loop_k1, a, b); @@ -840,22 +840,23 @@ TEST(Cuda, LocalMemReduce_1_CUDA) { VarHandle m("m", kInt); VarHandle n("n", kInt); - BufHandle c{new Buf("c", {new IntImm(1)}, kFloat)}; - std::vector block_k; + BufHandle c{ + alloc("c", std::vector({alloc(1)}), kFloat)}; + std::vector block_k; { // b(k) = 0 - Store* store_bk_0 = b.store({k}, 0.f); + StorePtr store_bk_0 = b.store({k}, 0.f); block_k.push_back(store_bk_0); } - std::vector block_n; + std::vector block_n; { // alloc(c, 1); - Allocate* alloc = Allocate::make(c); + AllocatePtr alloc = Allocate::make(c); block_n.push_back(alloc); } { // c(0) = 0 - Store* store_c0_0 = Store::make(c, {0}, 0.f); + StorePtr store_c0_0 = Store::make(c, {0}, 0.f); block_n.push_back(store_c0_0); } { @@ -864,8 +865,8 @@ TEST(Cuda, LocalMemReduce_1_CUDA) { ExprHandle load_c0 = Load::make(kFloat, c, {0}); ExprHandle a_kmn = a.load(k * (M * N) + m * N + n); ExprHandle v_add = load_c0 + a_kmn; - Store* store_c0_v = Store::make(c, {0}, v_add); - For* loop_m = For::make(m, 0, M, store_c0_v); + StorePtr store_c0_v = Store::make(c, {0}, v_add); + ForPtr loop_m = For::make(m, 0, M, store_c0_v); block_n.push_back(loop_m); } { @@ -873,21 +874,21 @@ TEST(Cuda, LocalMemReduce_1_CUDA) { ExprHandle load_bk = b.load(k); ExprHandle load_c0 = Load::make(kFloat, c, {0}); ExprHandle v_add = load_bk + load_c0; - Store* store_bk = b.store({k}, v_add); + StorePtr store_bk = b.store({k}, v_add); block_n.push_back(store_bk); } { // free(c) - Free* free_stmt = Free::make(c); + FreePtr free_stmt = Free::make(c); block_n.push_back(free_stmt); } { - Block* block_n_stmt = Block::make(block_n); - For* for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt); + BlockPtr block_n_stmt = Block::make(block_n); + ForPtr for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt); block_k.push_back(for_n); } - Block* block_k_stmt = Block::make(block_k); - For* loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt); + BlockPtr block_k_stmt = Block::make(block_k); + ForPtr loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt); CudaCodeGen cuda_cg(loop_k, a, b); PaddedBuffer a_v(1, M, N, "a_v"); @@ -941,7 +942,7 @@ TEST(Cuda, HalfSupport_CUDA) { LoopNest l({b, c, d}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); CudaCodeGen cg(s, {a, b, c, d}); std::vector aData(4, 2.0f); @@ -986,12 +987,12 @@ TEST(Cuda, HalfPropagation_CUDA) { auto half = ToDtype(); Placeholder a("a", half, {4}); Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) { - return Max::make(a.load(i), ExprHandle(new HalfImm(0)), true); + return Max::make(a.load(i), ExprHandle(alloc(0)), true); }); LoopNest l({relu}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); CudaCodeGen cg(s, {a, relu}); std::ostringstream oss; @@ -1036,12 +1037,12 @@ TEST(Cuda, UnusedHalfArgument_CUDA) { auto half = ToDtype(); Placeholder b("b", half, {4}); Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) { - return Max::make(a.load(i), ExprHandle(new FloatImm(0)), true); + return Max::make(a.load(i), ExprHandle(alloc(0)), true); }); LoopNest l({relu}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); CudaCodeGen cg(s, {a, b, relu}); std::ostringstream oss; @@ -1109,7 +1110,7 @@ TEST(Cuda, PrioritizeDependents_CUDA) { ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT); ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b); - For* loop = + ForPtr loop = For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt); CudaCodeGen cuda_cg(loop, a, b, c); @@ -1175,13 +1176,13 @@ TEST(Cuda, MaskBlockDim_CUDA) { }); LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[0]->set_gpu_block_index(0); loops = l.getLoopStmtsFor(d); loops[0]->set_gpu_block_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); std::ostringstream oss; @@ -1199,8 +1200,8 @@ TEST(Cuda, MaskBlockDim_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(1))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(1))); // Sanity check that the kernel works. PaddedBuffer a_v(A_SIZE); @@ -1268,13 +1269,13 @@ TEST(Cuda, MaskThreadDim_CUDA) { }); LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[0]->set_gpu_thread_index(0); loops = l.getLoopStmtsFor(d); loops[0]->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); std::ostringstream oss; @@ -1293,8 +1294,8 @@ TEST(Cuda, MaskThreadDim_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(1))); - ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(B_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); PaddedBuffer a_v(A_SIZE); PaddedBuffer b_v(B_SIZE); @@ -1363,13 +1364,13 @@ TEST(Cuda, MaskMultiBlockDim_CUDA) { }); LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[0]->set_gpu_block_index(0); loops = l.getLoopStmtsFor(d); loops[0]->set_gpu_block_index(1); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); std::ostringstream oss; @@ -1387,8 +1388,8 @@ TEST(Cuda, MaskMultiBlockDim_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(A_SIZE))); - ASSERT_TRUE(exprEquals(blockExtents[1], new IntImm(B_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[1], alloc(B_SIZE))); PaddedBuffer a_v(A_SIZE); PaddedBuffer b_v(B_SIZE); @@ -1457,13 +1458,13 @@ TEST(Cuda, MaskBlockAndThreadDim_CUDA) { }); LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[0]->set_gpu_block_index(0); loops = l.getLoopStmtsFor(d); loops[0]->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); std::ostringstream oss; @@ -1481,8 +1482,8 @@ TEST(Cuda, MaskBlockAndThreadDim_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(B_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); PaddedBuffer a_v(A_SIZE); PaddedBuffer b_v(B_SIZE); @@ -1556,7 +1557,7 @@ TEST(Cuda, MaskMultiDim_CUDA) { }); LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[0]->set_gpu_block_index(0); loops[1]->set_gpu_thread_index(0); loops = l.getLoopStmtsFor(d); @@ -1564,7 +1565,7 @@ TEST(Cuda, MaskMultiDim_CUDA) { loops[1]->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); std::ostringstream oss; @@ -1583,8 +1584,8 @@ TEST(Cuda, MaskMultiDim_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); PaddedBuffer a_v(OUTER_SIZE, A_SIZE); PaddedBuffer b_v(OUTER_SIZE, B_SIZE); @@ -1686,7 +1687,7 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) { }); LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[0]->set_gpu_block_index(0); loops[1]->set_gpu_thread_index(0); loops = l.getLoopStmtsFor(d); @@ -1694,7 +1695,7 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) { loops[1]->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, d, OUTER_SIZE, A_SIZE, B_SIZE, a_buf, b_buf); std::ostringstream oss; @@ -1715,7 +1716,7 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) { auto threadExtents = cuda_cg.gpu_thread_extents(); ASSERT_TRUE(exprEquals(blockExtents[0], OUTER_SIZE.node())); ASSERT_TRUE(exprEquals( - threadExtents[0], new Max(A_SIZE.node(), B_SIZE.node(), true))); + threadExtents[0], alloc(A_SIZE.node(), B_SIZE.node(), true))); int OUTER_EXTENT = 10; int A_EXTENT = 100; @@ -1820,7 +1821,7 @@ TEST(Cuda, MaskCompoundInnerLoop_CUDA) { VarHandle j("j", kInt); VarHandle k("k", kInt); - Stmt* stmt = For::make( + StmtPtr stmt = For::make( i, 0, OUTER_SIZE, @@ -1860,8 +1861,8 @@ TEST(Cuda, MaskCompoundInnerLoop_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); PaddedBuffer a_v(OUTER_SIZE, A_SIZE); PaddedBuffer b_v(OUTER_SIZE, B_SIZE); @@ -1959,7 +1960,7 @@ TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { VarHandle j("j", kInt); VarHandle k("k", kInt); - Stmt* stmt = For::make( + StmtPtr stmt = For::make( i, 0, OUTER_SIZE, @@ -1999,8 +2000,8 @@ TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(1))); - ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); PaddedBuffer a_v(OUTER_SIZE, A_SIZE); PaddedBuffer b_v(OUTER_SIZE, B_SIZE); @@ -2100,7 +2101,7 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { }); LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[0]->set_gpu_block_index(0); loops[1]->set_gpu_thread_index(0); loops = l.getLoopStmtsFor(d); @@ -2108,7 +2109,7 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { loops[1]->set_gpu_thread_index(1); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); std::ostringstream oss; @@ -2127,8 +2128,8 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); PaddedBuffer a_v(OUTER_SIZE, A_SIZE); PaddedBuffer b_v(OUTER_SIZE, B_SIZE); @@ -2231,7 +2232,7 @@ TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { }); LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); + std::vector loops = l.getLoopStmtsFor(c); loops[0]->set_gpu_block_index(0); loops[1]->set_gpu_thread_index(0); loops = l.getLoopStmtsFor(d); @@ -2239,7 +2240,7 @@ TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { loops[1]->set_gpu_thread_index(0); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); std::ostringstream oss; @@ -2259,8 +2260,8 @@ TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { auto blockExtents = cuda_cg.gpu_block_extents(); auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(OUTER_A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); PaddedBuffer a_v(OUTER_A_SIZE, A_SIZE); PaddedBuffer b_v(OUTER_B_SIZE, B_SIZE); diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp index 88aa0ee..7c234fb 100644 --- a/test/cpp/tensorexpr/test_expr.cpp +++ b/test/cpp/tensorexpr/test_expr.cpp @@ -70,9 +70,9 @@ TEST(Expr, LetStmtTest01) { ExprHandle load_a = a_buf.load(0); VarHandle var = VarHandle("v", kFloat); - Stmt* let_store = Let::make(var, load_a); - Stmt* store_b = b_buf.store({0}, var); - Block* block = Block::make({let_store, store_b}); + StmtPtr let_store = Let::make(var, load_a); + StmtPtr store_b = b_buf.store({0}, var); + BlockPtr block = Block::make({let_store, store_b}); SimpleIREvaluator eval(block, {a_buf, b_buf}); @@ -189,9 +189,9 @@ TEST(Expr, VectorAdd01) { ExprHandle load_b = b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); ExprHandle value = load_a + load_b; - Stmt* store_c = + StmtPtr store_c = c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value); - Stmt* stmt = For::make(index, 0, kVectorCount, store_c); + StmtPtr stmt = For::make(index, 0, kVectorCount, store_c); ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize)); ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize)); @@ -313,15 +313,16 @@ TEST(Expr, IntrinsicsDtypes) { TEST(Expr, Substitute01) { KernelScope kernel_scope; - Var* x = new Var("x", kFloat); - Var* y = new Var("y", kFloat); - Expr* e = new Mul(new Sub(x, new FloatImm(1.0f)), new Add(x, y)); - - Var* z = new Var("z", kFloat); - Expr* e2 = Substitute(e, {{x, new Add(z, new FloatImm(5.0f))}}); - Expr* e2_ref = new Mul( - new Sub(new Add(z, new FloatImm(5.0f)), new FloatImm(1.0f)), - new Add(new Add(z, new FloatImm(5.0f)), y)); + VarPtr x = alloc("x", kFloat); + VarPtr y = alloc("y", kFloat); + ExprPtr e = + alloc(alloc(x, alloc(1.0f)), alloc(x, y)); + + VarPtr z = alloc("z", kFloat); + ExprPtr e2 = Substitute(e, {{x, alloc(z, alloc(5.0f))}}); + ExprPtr e2_ref = alloc( + alloc(alloc(z, alloc(5.0f)), alloc(1.0f)), + alloc(alloc(z, alloc(5.0f)), y)); std::ostringstream oss; oss << *e2; std::string e2_str = oss.str(); @@ -568,7 +569,7 @@ TEST(Expr, DynamicShapeAdd) { Placeholder b(BufHandle("b", {n}, kFloat)); Placeholder c(BufHandle("c", {n}, kFloat)); VarHandle i("i", kInt); - Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); std::vector aData(size, 1.0f); std::vector bData(size, 2.0f); std::vector cData(size, 0.0f); @@ -586,11 +587,11 @@ void testCond01() { PaddedBuffer a_v(N); Placeholder a_buf("a", kFloat, {N}); VarHandle index = VarHandle("index", kInt); - Stmt* assign_x2 = a_buf.store({index}, cast(index) * 2); - Stmt* assign_x3 = a_buf.store({index}, cast(index) * 3); + StmtPtr assign_x2 = a_buf.store({index}, cast(index) * 2); + StmtPtr assign_x3 = a_buf.store({index}, cast(index) * 3); ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); - Stmt* assign = Cond::make(even_cond, assign_x2, assign_x3); - Stmt* for_stmt = For::make(index, 0, N, assign); + StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3); + StmtPtr for_stmt = For::make(index, 0, N, assign); SimpleIREvaluator(for_stmt, {a_buf})(a_v); PaddedBuffer a_ref(N); @@ -647,10 +648,10 @@ void testStmtClone() { Placeholder a_buf("a", kInt, {N}); VarHandle index = VarHandle("index", kInt); - Stmt* body = a_buf.store({index}, 5); - Stmt* loop = For::make(index, 0, N, body); + StmtPtr body = a_buf.store({index}, 5); + StmtPtr loop = For::make(index, 0, N, body); - Stmt* cloned_loop = Stmt::clone(loop); + StmtPtr cloned_loop = Stmt::clone(loop); std::vector orig_loop_results(N); std::vector cloned_loop_results(N); SimpleIREvaluator(loop, {a_buf})(orig_loop_results); @@ -661,9 +662,8 @@ void testStmtClone() { // Let's add another assign to the body in the cloned loop and verify that the // original statement hasn't changed while the cloned one has. - Stmt* body_addition = a_buf.store({index}, 33); - Block* cloned_body = - static_cast(static_cast(cloned_loop)->body()); + StmtPtr body_addition = a_buf.store({index}, 33); + BlockPtr cloned_body = static_to(static_to(cloned_loop)->body()); cloned_body->append_stmt(body_addition); std::vector orig_loop_results_after_mutation(N); diff --git a/test/cpp/tensorexpr/test_external_calls.cpp b/test/cpp/tensorexpr/test_external_calls.cpp index bae54b8..9ae99ca 100644 --- a/test/cpp/tensorexpr/test_external_calls.cpp +++ b/test/cpp/tensorexpr/test_external_calls.cpp @@ -680,8 +680,8 @@ TEST(ExternalCall, Inlining) { return MatmulResult->load(i, j) + FloatImm::make(3.0f); }); - Stmt* root_stmt = - new Block({A->stmt(), B->stmt(), MatmulResult->stmt(), Result->stmt()}); + StmtPtr root_stmt = alloc(std::vector( + {A->stmt(), B->stmt(), MatmulResult->stmt(), Result->stmt()})); LoopNest l(root_stmt, {Result->buf()}); // Inlining should not inline anything here since all Bufs are either diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp index 1be5f7e..76d9247 100644 --- a/test/cpp/tensorexpr/test_ir_printer.cpp +++ b/test/cpp/tensorexpr/test_ir_printer.cpp @@ -83,7 +83,7 @@ TEST(IRPrinter, FunctionName) { }); LoopNest l({chunk_0, chunk_1, consumer}); - auto* body = l.root_stmt(); + auto body = l.root_stmt(); std::stringstream ss; ss << *body; diff --git a/test/cpp/tensorexpr/test_ir_verifier.cpp b/test/cpp/tensorexpr/test_ir_verifier.cpp index 17763d3..2c91d8b 100644 --- a/test/cpp/tensorexpr/test_ir_verifier.cpp +++ b/test/cpp/tensorexpr/test_ir_verifier.cpp @@ -18,30 +18,30 @@ using namespace torch::jit::tensorexpr; TEST(IRVerifier, BitwiseOps) { KernelScope kernel_scope; - Var* X = new Var("x", kInt); - Var* Y = new Var("y", kFloat); + VarPtr X = alloc("x", kInt); + VarPtr Y = alloc("y", kFloat); { - auto a = new And(X, Y); + auto a = alloc(X, Y); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { - auto a = new Or(X, Y); + auto a = alloc(X, Y); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { - auto a = new Xor(X, Y); + auto a = alloc(X, Y); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { - auto a = new Lshift(X, Y); + auto a = alloc(X, Y); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { - auto a = new Rshift(X, Y); + auto a = alloc(X, Y); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } @@ -49,15 +49,15 @@ TEST(IRVerifier, BitwiseOps) { TEST(IRVerifier, CompareSelect) { KernelScope kernel_scope; - Expr* X = new IntImm(1); - Expr* Y = new FloatImm(3.14f); + ExprPtr X = alloc(1); + ExprPtr Y = alloc(3.14f); { - auto a = new CompareSelect(X, X, X, Y, kEQ); + auto a = alloc(X, X, X, Y, kEQ); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { - auto a = new CompareSelect(X, Y, X, X, kEQ); + auto a = alloc(X, Y, X, X, kEQ); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } @@ -65,10 +65,10 @@ TEST(IRVerifier, CompareSelect) { TEST(IRVerifier, Ramp) { KernelScope kernel_scope; - Var* I = new Var("i", kInt); - Var* J = new Var("j", kFloat); + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kFloat); { - auto a = new Ramp(I, J, 4); + auto a = alloc(I, J, 4); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } @@ -76,26 +76,29 @@ TEST(IRVerifier, Ramp) { TEST(IRVerifier, Load) { KernelScope kernel_scope; - Var* I = new Var("i", kInt); - Var* J = new Var("j", kLong); - Var* K = new Var("k", kFloat); - Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat); + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); + BufPtr B = alloc( + "b", + std::vector({alloc(10), alloc(20)}), + kFloat); { // Indices with different int dtypes (kInt, kLong) are ok - auto a = new Load(B, {I, J}); + auto a = alloc(B, std::vector({I, J})); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_NO_THROW(verify(a)); } { // Float index - auto a = new Load(B, {K, K}); + auto a = alloc(B, std::vector({K, K})); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { // Multilanes are only allowed in flattened indices - auto multilane_index = new Ramp(I, new IntImm(1), 4); - auto a = new Load(B, {I, multilane_index}); + auto multilane_index = alloc(I, alloc(1), 4); + auto a = alloc(B, std::vector({I, multilane_index})); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } @@ -103,24 +106,24 @@ TEST(IRVerifier, Load) { TEST(IRVerifier, IfThenElse) { KernelScope kernel_scope; - Var* I = new Var("i", kInt); - Var* J = new Var("j", kLong); - Var* K = new Var("k", kFloat); + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); { // Condition must be integral - auto a = new IfThenElse(K, I, I); + auto a = alloc(K, I, I); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { // Dtypes of true and false exprs must match - auto a = new IfThenElse(I, I, J); + auto a = alloc(I, I, J); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { // Can't have multiple lanes in condition expr - auto a = new IfThenElse(new Broadcast(I, 4), I, I); + auto a = alloc(alloc(I, 4), I, I); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } @@ -128,12 +131,12 @@ TEST(IRVerifier, IfThenElse) { TEST(IRVerifier, For) { KernelScope kernel_scope; - Var* I = new Var("i", kInt); - Var* J = new Var("j", kInt); - Stmt* body = new Block({}); + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kInt); + StmtPtr body = alloc(std::vector({})); { // Can't have nullptr as a Var - auto a = new For(nullptr, I, J, body); + auto a = alloc(nullptr, I, J, body); // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) EXPECT_ANY_THROW(verify(a)); } @@ -141,14 +144,14 @@ TEST(IRVerifier, For) { TEST(IRVerifier, Block) { KernelScope kernel_scope; - Var* I = new Var("i", kInt); - Buf* B = new Buf("B", {new IntImm(10)}, kInt); + VarPtr I = alloc("i", kInt); + BufPtr B = alloc("B", std::vector({alloc(10)}), kInt); { - Stmt* store = new Store(B, {I}, I); + StmtPtr store = alloc(B, std::vector({I}), I); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - Stmt* block1 = new Block({store}); + StmtPtr block1 = alloc(std::vector({store})); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Stmt* block2 = new Block({store}); + StmtPtr block2 = alloc(std::vector({store})); // Stmt can't have multiple parrents, thus inserting it into several blocks // is illegal // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) @@ -158,32 +161,35 @@ TEST(IRVerifier, Block) { TEST(IRVerifier, Store) { KernelScope kernel_scope; - Var* I = new Var("i", kInt); - Var* J = new Var("j", kLong); - Var* K = new Var("k", kFloat); - Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat); + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); + BufPtr B = alloc( + "b", + std::vector({alloc(10), alloc(20)}), + kFloat); { // Indices with different int dtypes (kInt, kLong) are ok - auto a = new Store(B, {I, J}, K); + auto a = alloc(B, std::vector({I, J}), K); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_NO_THROW(verify(a)); } { // Float index - auto a = new Store(B, {K, K}, K); + auto a = alloc(B, std::vector({K, K}), K); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { // Multilanes are only allowed in flattened indices - auto multilane_index = new Ramp(I, new IntImm(1), 4); - auto a = new Store(B, {I, multilane_index}, K); + auto multilane_index = alloc(I, alloc(1), 4); + auto a = alloc(B, std::vector({I, multilane_index}), K); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } { // Value and buf dtypes mismatch - auto a = new Store(B, {I}, I); + auto a = alloc(B, std::vector({I}), I); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) EXPECT_ANY_THROW(verify(a)); } diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 67641c1..8f36f54 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -105,7 +105,7 @@ TEST_F(Kernel, _1) { auto ref = a * (a * b); TensorExprKernel k(graph); std::vector inputs = {a, b}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -145,7 +145,7 @@ TEST_F(Kernel, _2) { auto ref = a * (a * b); TensorExprKernel k(graph); std::vector inputs = {a, b}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -185,7 +185,7 @@ TEST_F(Kernel, _3) { auto ref = a * (a * b); TensorExprKernel k(graph); std::vector inputs = {a, b}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -230,7 +230,7 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { auto ref = a * (a * b); TensorExprKernel k(graph); std::vector inputs = {a, b}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -270,7 +270,7 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { auto ref = t[0] * t[1]; TensorExprKernel k(graph); std::vector inputs = {a, b}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -321,7 +321,7 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { TensorExprKernel k(graph); std::vector inputs = {a, b, c}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -376,7 +376,7 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { TensorExprKernel k(graph); std::vector inputs = {a, b, c}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -478,7 +478,7 @@ TEST_F(Kernel, CatInputTypesPromotion) { TensorExprKernel k(graph); std::vector inputs = {a, b, c}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -527,7 +527,7 @@ TEST_F(Kernel, CatWoConditionals) { parseIR(graph_string, &*graph); TensorExprKernel k(graph); - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -592,7 +592,7 @@ TEST_F(Kernel, OptimizeConditionals) { parseIR(graph_string, &*graph); TensorExprKernel k(graph); - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -697,7 +697,7 @@ TEST_F(Kernel, SumAllAxes) { auto ref = a.sum(/*dtype=*/dtype); TensorExprKernel k(graph); std::vector inputs = {a}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -769,7 +769,7 @@ TEST_F(Kernel, SumOneAxis) { auto o = at::empty({}, TensorOptions(kCPU)); TensorExprKernel k(graph); std::vector inputs = {a}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -831,7 +831,7 @@ TEST_F(Kernel, SumMultipleAxes) { TensorExprKernel k(graph); std::vector inputs = {a}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -902,7 +902,7 @@ TEST_F(Kernel, Softmax2D) { TensorExprKernel k(graph); std::vector inputs = {a}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -978,7 +978,7 @@ TEST_F(Kernel, Softmax3D) { TensorExprKernel k(graph); std::vector inputs = {a}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -1060,7 +1060,7 @@ TEST_F(Kernel, Softmax4D) { TensorExprKernel k(graph); std::vector inputs = {a}; - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -1104,7 +1104,7 @@ TEST_F(Kernel, InlineProducerIntoReduction) { parseIR(graph_string, &*graph); TensorExprKernel k(graph); - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; @@ -1145,7 +1145,7 @@ TEST_F(Kernel, InlineReductionIntoConsumer) { parseIR(graph_string, &*graph); TensorExprKernel k(graph); - Stmt* s = k.getCodeGenStmt(); + StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; oss << *s; diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index eb83505..3776329 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -225,8 +225,8 @@ TEST(LLVM, fastLogFloat) { VarHandle index = VarHandle("index", kInt); ExprHandle load_a = a_buf.load(index); - Stmt* store_b = b_buf.store({index}, fast_log(load_a)); - Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); PaddedBuffer b_v(kTotalSize); @@ -478,7 +478,7 @@ TEST(LLVM, DirectVectorization) { BufHandle c("c", {M, N}, kFloat); VarHandle m("m", kInt); VarHandle n("n", kInt); - Stmt* s = For::make( + StmtPtr s = For::make( m, 0, M, @@ -598,11 +598,10 @@ TEST(LLVM, VectorizerLoadStoreTest) { Placeholder c_buf(BufHandle(c->buf())); LoopNest l({c}); - Stmt* s = l.root_stmt(); - ASSERT_TRUE(LoopNest::vectorize( - dynamic_cast(dynamic_cast(s)->front()))); + StmtPtr s = l.root_stmt(); + ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); - ASSERT_TRUE(dynamic_cast(dynamic_cast(s)->front()) == nullptr); + ASSERT_TRUE(to(to(s)->front()) == nullptr); LLVMCodeGen cg(s, {a, c_buf}); @@ -623,10 +622,9 @@ TEST(LLVM, VectorizeBitCast) { Placeholder c_buf(BufHandle(c->buf())); LoopNest l({c}); - Stmt* s = l.root_stmt(); - ASSERT_TRUE(LoopNest::vectorize( - dynamic_cast(dynamic_cast(s)->front()))); - ASSERT_TRUE(dynamic_cast(dynamic_cast(s)->front()) == nullptr); + StmtPtr s = l.root_stmt(); + ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); + ASSERT_TRUE(to(to(s)->front()) == nullptr); LLVMCodeGen cg(s, {a, c_buf}); @@ -1223,7 +1221,7 @@ TEST(LLVM, SimpleMath01) { return cast(i * i + 1); }); LoopNest l({tensor}); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); Placeholder f_buf(BufHandle(tensor->buf())); LLVMCodeGen cg(stmt, {f_buf}); @@ -1249,7 +1247,7 @@ TEST(LLVM, ComputeMul) { Placeholder c_buf(BufHandle(c->buf())); LoopNest l({c}); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); LLVMCodeGen cg(s, {a, b, c_buf}); @@ -1275,7 +1273,7 @@ TEST(LLVM, BroadcastAdd) { Placeholder c_buf(BufHandle(c->buf())); LoopNest l({c}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); LLVMCodeGen cg(s, {a, b, c_buf}); @@ -1333,7 +1331,7 @@ TEST(LLVM, DynamicShapeAdd) { Placeholder b(BufHandle("b", {n}, kFloat)); Placeholder c(BufHandle("c", {n}, kFloat)); VarHandle i("i", kInt); - Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); std::vector aData(size, 1.0f); std::vector bData(size, 2.0f); std::vector cData(size, 0.0f); @@ -1355,7 +1353,7 @@ TEST(LLVM, BindDynamicShapeAdd) { Placeholder b(BufHandle("b", {n}, kFloat)); Placeholder c(BufHandle("c", {n}, kFloat)); VarHandle i("i", kInt); - Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); std::vector aData(size, 1.0f); std::vector bData(size, 2.0f); std::vector cData(size, 0.0f); @@ -1378,7 +1376,7 @@ TEST(LLVM, TensorDynamicShapeAdd) { return a.load(i) + b.load(i); }); LoopNest l({c}); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); LLVMCodeGen cg(s, {a, b, c, n}); std::vector aData(size, 1.0f); std::vector bData(size, 2.0f); @@ -1404,7 +1402,7 @@ TEST(LLVM, DynamicShape2D) { }); LoopNest l({c}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); LLVMCodeGen cg(s, {a, b, c, m, n}); std::vector aData(M * N, 1.0f); std::vector bData(M * N, 2.0f); @@ -1419,7 +1417,7 @@ TEST(LLVM, DynamicShape2D) { TEST(LLVM, EmptyStmt) { KernelScope kernel_scope; - Stmt* s = new Block({}); + StmtPtr s = alloc(std::vector({})); LLVMCodeGen cg(s, {}); cg.call({}); @@ -1434,7 +1432,7 @@ TEST(LLVM, EliminatedStmt) { LoopNest l({c}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); s = IRSimplifier::simplify(s); LLVMCodeGen cg(s, {a, c}); std::vector aData(1, 1.0f); @@ -1458,7 +1456,7 @@ TEST(LLVM, SimpleReduction) { LoopNest loop({b}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); LLVMCodeGen cg(s, {a, b}); @@ -1496,19 +1494,19 @@ TEST(LLVM, RFactorReduction) { Tensor* b = Reduce("sum", axis, Sum(), a, reduce_axis); LoopNest loop({b}); - std::vector loops = loop.getLoopStmtsFor(b); - For* loop_m = loops.at(1); - For* loop_n = loops.at(2); + std::vector loops = loop.getLoopStmtsFor(b); + ForPtr loop_m = loops.at(1); + ForPtr loop_n = loops.at(2); loop.reorderAxis(loop_m, loop_n); loops = loop.getLoopStmtsFor(b); loop_m = loops.at(2); loop_n = loops.at(1); - auto b_body = const_cast(loop.getAllWritesToBuf(b->buf())[1]); + auto b_body = loop.getAllWritesToBuf(b->buf())[1]; ASSERT_TRUE(loop.rfactor(b_body, loop_n)); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); LLVMCodeGen cg(s, {a, b}); @@ -1542,10 +1540,10 @@ TEST(LLVM, RFactorVectorizedReduction) { Tensor* b = Reduce("sum", {{1, "K"}}, Sum(), a, {{M, "M"}, {N, "N"}}); LoopNest loopnest({b}); - std::vector loops = loopnest.getLoopStmtsFor(b); + std::vector loops = loopnest.getLoopStmtsFor(b); // Reorder n and m loops loopnest.reorderAxis(loops.at(1), loops.at(2)); - auto b_body = const_cast(loopnest.getAllWritesToBuf(b->buf()).at(1)); + auto b_body = loopnest.getAllWritesToBuf(b->buf()).at(1); auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b->buf()); ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3); ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1])); @@ -1559,7 +1557,7 @@ TEST(LLVM, RFactorVectorizedReduction) { loopnest.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(loopnest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt()); LLVMCodeGen cg(s, {a, b}); PaddedBuffer a_v(1, M, N, "a_v"); @@ -1593,8 +1591,8 @@ TEST(LLVM, SimpleParallel) { }); LoopNest loop_nest({f}); auto const& loops = loop_nest.getLoopStmtsFor(f); - For* m = loops[0]; - For* n = loops[1]; + ForPtr m = loops[0]; + ForPtr n = loops[1]; if (test_cfg & 0x1) { m->set_parallel(); } @@ -1602,7 +1600,7 @@ TEST(LLVM, SimpleParallel) { n->set_parallel(); } loop_nest.prepareForCodegen(); - Stmt* stmt = loop_nest.root_stmt(); + StmtPtr stmt = loop_nest.root_stmt(); LLVMCodeGen cg(stmt, {f}); PaddedBuffer f_v(M, N, "f_v"); @@ -1645,7 +1643,7 @@ TEST(LLVM, CompositeParallel) { return t3->load(m, n) + m + n; }); LoopNest loop_nest({t4}, {t1, t2, t3, t4}); - std::vector loop_list; + std::vector loop_list; { auto const& loops = loop_nest.getLoopStmtsFor(t1); loop_list.push_back(loops[0]); @@ -1671,7 +1669,7 @@ TEST(LLVM, CompositeParallel) { } } loop_nest.prepareForCodegen(); - Stmt* stmt = loop_nest.root_stmt(); + StmtPtr stmt = loop_nest.root_stmt(); LLVMCodeGen cg(stmt, {t4}); PaddedBuffer t4_v(M, N, "t4_v"); @@ -1709,36 +1707,36 @@ TEST(LLVM, VectorizedGEMM) { { auto const& loops = loop.getLoopStmtsFor(CT); - For* m = loops[0]; + ForPtr m = loops[0]; loop.splitWithMask(m, 16); } { auto const& loops = loop.getLoopStmtsFor(CT); - For* n = loops[2]; + ForPtr n = loops[2]; loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - For* mi = loops[1]; - For* no = loops[2]; + ForPtr mi = loops[1]; + ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - For* ni = loops[3]; - For* k = loops[4]; + ForPtr ni = loops[3]; + ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - For* mi = loops[2]; - For* k = loops[3]; + ForPtr mi = loops[2]; + ForPtr k = loops[3]; loop.reorderAxis(mi, k); } { @@ -1749,7 +1747,7 @@ TEST(LLVM, VectorizedGEMM) { loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); LLVMCodeGen cg(s, {AP, BP, CT}); @@ -1785,7 +1783,7 @@ TEST(LLVM, CallRaw) { LoopNest l({c}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); int32_t N_value = 1024; std::vector av(M * N_value); diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index e46aaee..440b169 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -23,7 +23,7 @@ namespace jit { using namespace torch::jit::tensorexpr; -void checkIR(Stmt* s, const std::string& pattern) { +void checkIR(StmtPtr s, const std::string& pattern) { std::ostringstream oss; oss << *s; torch::jit::testing::FileCheck().run(pattern, oss.str()); @@ -36,7 +36,8 @@ TEST(LoopNest, ExprSimple01) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::splitWithTail(loops[0], 2); LoopNest::splitWithTail(loops[0], 2); @@ -49,7 +50,7 @@ TEST(LoopNest, ExprLower01) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); LoopNest l({tensor}); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); std::ostringstream oss; oss << *stmt; ASSERT_GT(oss.str().size(), 20); @@ -63,11 +64,12 @@ TEST(LoopNest, ExprSimple02) { }; Tensor* tensor = Compute("f", {{26, "x"}, {5, "y"}}, func); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::splitWithTail(loops[0], 4); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); std::ostringstream oss; oss << *stmt; ASSERT_GT(oss.str().size(), 200); @@ -82,7 +84,7 @@ TEST(LoopNest, ExprSimple02) { BufHandle f("f", {26, 5}, kFloat); ExprHandle x_1 = x_outer * 4 + x_inner; ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; - For* stmt1 = For::make( + ForPtr stmt1 = For::make( x_outer, 0, x_outer_end, @@ -92,12 +94,12 @@ TEST(LoopNest, ExprSimple02) { 4, For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y))))); ExprHandle x_2 = x_tail + x_outer_end * 4; - For* stmt2 = For::make( + ForPtr stmt2 = For::make( x_tail, 0, (ExprHandle(26) - 0) % 4, For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y)))); - Stmt* stmt = Block::make({stmt1, stmt2}); + StmtPtr stmt = Block::make({stmt1, stmt2}); std::ostringstream oss_ref; oss_ref << *stmt; @@ -122,30 +124,30 @@ TEST(LoopNest, ExprSimple02) { } } -Block* getSimplifiedBody(const LoopNest& l) { - Stmt* stmt = l.root_stmt(); - Stmt* simplified = IRSimplifier::simplify(stmt); - return dynamic_cast(simplified); +BlockPtr getSimplifiedBody(const LoopNest& l) { + StmtPtr stmt = l.root_stmt(); + StmtPtr simplified = IRSimplifier::simplify(stmt); + return to(simplified); } -void assertForRange(For* f, int expected_start, int expected_stop) { +void assertForRange(ForPtr f, int expected_start, int expected_stop) { ASSERT_NE(f, nullptr); - const IntImm* start = dynamic_cast(f->start()); + IntImmPtr start = to(f->start()); ASSERT_NE(start, nullptr); ASSERT_EQ(start->value(), expected_start); - const IntImm* stop = dynamic_cast(f->stop()); + IntImmPtr stop = to(f->stop()); ASSERT_NE(stop, nullptr); ASSERT_EQ(stop->value(), expected_stop); } void assertForRanges( - Block* body, + BlockPtr body, const std::vector>& start_stops) { ASSERT_EQ(body->nstmts(), start_stops.size()); auto it = body->begin(); for (size_t i = 0; i < start_stops.size(); i++, it++) { - For* loop = dynamic_cast(*it); + ForPtr loop = to(*it); assertForRange(loop, start_stops[i].first, start_stops[i].second); } } @@ -158,14 +160,15 @@ TEST(LoopNest, ExprSliceHeadWithLoopOptions) { Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); LoopNest::sliceHead(loops[0], 2, &head, &tail); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 2}, {0, 8}}); ASSERT_TRUE(tail->loop_options().is_gpu_block_index()); @@ -182,20 +185,21 @@ TEST(LoopNest, ExprSliceTailWithLoopOptions) { Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::sliceTail(loops[0], 4, &head, &tail); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail_head; + ForPtr tail_head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail_tail; + ForPtr tail_tail; tail->set_gpu_block_index(LoopOptions::IDX_Y); LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}}); ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index()); @@ -215,16 +219,17 @@ TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::sliceHead(loops[0], 10, &head, &tail); ASSERT_EQ(head, loops[0]); ASSERT_EQ(tail, nullptr); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 10}}); } @@ -236,16 +241,17 @@ TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::sliceHead(loops[0], 100, &head, &tail); ASSERT_EQ(head, loops[0]); ASSERT_EQ(tail, nullptr); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 10}}); } @@ -257,10 +263,11 @@ TEST(LoopNest, ExprSliceHead) { Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::sliceHead(loops[0], 4, &head, &tail); ASSERT_NE(head, nullptr); @@ -268,7 +275,7 @@ TEST(LoopNest, ExprSliceHead) { ASSERT_NE(tail, nullptr); ASSERT_NE(tail, loops[0]); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 4}, {4, 10}}); } @@ -279,12 +286,13 @@ TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { }; Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; + ForPtr tail; LoopNest::sliceTail(loops[0], 4, &head, &tail); // head: [0, 6) // tail: [6, 10) @@ -293,7 +301,7 @@ TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { // tail_head: [6, 8) // tail_tail: [8, 10) - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); } @@ -307,16 +315,17 @@ TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::sliceTail(loops[0], 10, &head, &tail); ASSERT_EQ(head, nullptr); ASSERT_EQ(tail, loops[0]); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 10}}); } @@ -330,16 +339,17 @@ TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::sliceTail(loops[0], 100, &head, &tail); ASSERT_EQ(head, nullptr); ASSERT_EQ(tail, loops[0]); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 10}}); } @@ -351,10 +361,11 @@ TEST(LoopNest, ExprSliceTail) { Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::sliceTail(loops[0], 4, &head, &tail); ASSERT_NE(head, nullptr); @@ -362,7 +373,7 @@ TEST(LoopNest, ExprSliceTail) { ASSERT_NE(tail, nullptr); ASSERT_NE(tail, loops[0]); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 6}, {6, 10}}); } @@ -378,10 +389,11 @@ TEST(LoopNest, ExprSplitAndSlice) { LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner; + ForPtr inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + ForPtr tail; + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); // outer: [0, 4) // inner: [0, 21) // tail: [84, 100) @@ -408,15 +420,15 @@ TEST(LoopNest, ExprSplitAndSlice) { // for (int x_tail = 0; x_tail < 16; x_tail++) { // f[x_tail + 84] = 1.f + float(x_tail + 84); // } - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}}); auto biter = body->begin(); - For* loop = dynamic_cast(*biter++); + ForPtr loop = to(*biter++); assertForRanges(loop->body(), {{0, 19}, {19, 21}}); - loop = dynamic_cast(*biter); + loop = to(*biter); assertForRanges(loop->body(), {{0, 19}, {19, 21}}); } @@ -429,12 +441,13 @@ TEST(LoopNest, ExprSliceAndNormalize) { }; Tensor* tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; + ForPtr tail; LoopNest::sliceHead(loops[0], 2, &head, &tail); // head: [0, 2) // tail: [2, 10) @@ -442,7 +455,7 @@ TEST(LoopNest, ExprSliceAndNormalize) { LoopNest::normalize(tail); // normalized_tail: [0, 8) - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 2}, {0, 8}}); } @@ -461,22 +474,22 @@ TEST(LoopNest, ExprSliceWithVariableDimension) { Tensor* tensor = Compute("f", {{dim, "x"}}, [](const ExprHandle& x) { return x; }); LoopNest l({tensor}); - std::vector loops = + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* head; + ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail; + ForPtr tail; LoopNest::sliceHead(loops[0], 2, &head, &tail); LoopNest::sliceTail(tail, 2); - Block* body = getSimplifiedBody(l); + BlockPtr body = getSimplifiedBody(l); ASSERT_EQ(expected_for_ranges.size(), 3); auto it = body->begin(); for (auto& start_stop : expected_for_ranges) { - For* loop = dynamic_cast(*it++); + ForPtr loop = to(*it++); int start = evalExpr(ExprHandle(loop->start()), dim, dimension); int stop = evalExpr(ExprHandle(loop->stop()), dim, dimension); ASSERT_EQ(start, start_stop.first); @@ -499,26 +512,27 @@ TEST(LoopNest, ExprSplitWithTail) { }; Tensor* tensor = Compute("f", {{199, "x"}}, func); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) LoopNest::splitWithTail(loops[0], 17); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) LoopNest::splitWithTail(loops[0], 7); - Stmt* stmt = l.root_stmt(); - Stmt* simplified = IRSimplifier::simplify(stmt); - Block* body = dynamic_cast(simplified); + StmtPtr stmt = l.root_stmt(); + StmtPtr simplified = IRSimplifier::simplify(stmt); + BlockPtr body = to(simplified); ASSERT_EQ(body->nstmts(), 3); auto biter = body->begin(); // Verify that the split loops are ordered correctly. - For* loop = dynamic_cast(*biter++); + ForPtr loop = to(*biter++); assertForRange(loop, 0, 7); - loop = dynamic_cast(*biter++); + loop = to(*biter++); assertForRange(loop, 0, 4); - loop = dynamic_cast(*biter); + loop = to(*biter); assertForRange(loop, 0, 12); } @@ -529,10 +543,11 @@ TEST(LoopNest, ExprSplitWithTailNone) { }; Tensor* tensor = Compute("f", {{24, "x"}, {5, "y"}}, func); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::splitWithTail(loops[0], 4); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); std::ostringstream oss; oss << *stmt; ASSERT_GT(oss.str().size(), 200); @@ -548,7 +563,7 @@ TEST(LoopNest, ExprSplitWithTailNone) { BufHandle f("f", {24, 5}, kFloat); ExprHandle x_1 = x_outer * 4 + x_inner; ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; - Stmt* stmt = new Block({For::make( + StmtPtr stmt = alloc(std::vector({For::make( x_outer, 0, x_outer_end, @@ -556,7 +571,7 @@ TEST(LoopNest, ExprSplitWithTailNone) { x_inner, 0, 4, - For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))}); + For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))})); std::ostringstream oss_ref; oss_ref << *stmt; @@ -592,10 +607,11 @@ TEST(LoopNest, ExprSplitWithMask01) { }); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::splitWithMask(loops[1], 4); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); PaddedBuffer a_v(M, N, "a"); PaddedBuffer b_v(M, N, "b"); @@ -626,11 +642,12 @@ TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { }); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::splitWithMask(loops[0], 4); LoopNest::splitWithMask(loops[0], 4); - Stmt* stmt1 = IRSimplifier::simplify(l.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt()); // Two splits mean 3 loops, but should need no masks in this case. checkIR(stmt1, R"IR( @@ -658,11 +675,20 @@ TEST(LoopNest, getLoopAt) { // } // } // } - Buf* A = new Buf("A", {new IntImm(100), new IntImm(100)}, kInt); - Buf* B = - new Buf("B", {new IntImm(100), new IntImm(100), new IntImm(200)}, kInt); - Buf* C = - new Buf("C", {new IntImm(100), new IntImm(100), new IntImm(300)}, kInt); + BufPtr A = alloc( + "A", + std::vector({alloc(100), alloc(100)}), + kInt); + BufPtr B = alloc( + "B", + std::vector( + {alloc(100), alloc(100), alloc(200)}), + kInt); + BufPtr C = alloc( + "C", + std::vector( + {alloc(100), alloc(100), alloc(300)}), + kInt); BufHandle a_buf(A); BufHandle b_buf(B); BufHandle c_buf(C); @@ -705,12 +731,13 @@ TEST(LoopNest, TileSimple) { }); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) l.tile(loops[0], loops[1], 4, 8); // IR check - Stmt* stmt = IRSimplifier::simplify(l.root_stmt()); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); checkIR(stmt, R"IR( # CHECK: for (int m_outer # CHECK: for (int n_outer @@ -751,12 +778,13 @@ TEST(LoopNest, TileWithTails) { }); LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) l.tile(loops[0], loops[1], 5, 9); // IR check - Stmt* stmt = IRSimplifier::simplify(l.root_stmt()); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); checkIR(stmt, R"IR( # CHECK: for (int m_outer # CHECK: for (int n_outer @@ -803,13 +831,13 @@ TEST(LoopNest, TileInMiddle) { }); LoopNest nest({tensor}); - std::vector loops = + std::vector loops = nest.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) nest.tile(loops[1], loops[2], 3, 3); // IR check - Stmt* stmt = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt()); checkIR(stmt, R"IR( # CHECK: for (int m # CHECK: for (int n_outer @@ -856,7 +884,7 @@ TEST(LoopNest, SplitWithTailWithLoopOptions) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *inner, *tail; + ForPtr inner, tail; LoopNest l({tensor}); auto loops = NodeFinder::find(l.root_stmt()); @@ -865,7 +893,7 @@ TEST(LoopNest, SplitWithTailWithLoopOptions) { LoopNest::splitWithTail(loops[0], 4, &inner, &tail); ASSERT_NE(inner, nullptr); ASSERT_NE(tail, nullptr); - For* outer = loops[0]; + ForPtr outer = loops[0]; // Outer loop carries loop axis bindings. ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); @@ -887,13 +915,13 @@ TEST(LoopNest, SplitWithMaskWithLoopOptions) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner; + ForPtr inner; LoopNest l({tensor}); auto loops = NodeFinder::find(l.root_stmt()); loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); LoopNest::splitWithMask(loops[0], 4, &inner); - For* outer = loops[0]; + ForPtr outer = loops[0]; // Outer loop carries loop axis bindings. ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); @@ -917,7 +945,7 @@ TEST(LoopNest, ScheduleBroadcastAddBuffer) { return a_buf.load(m, n) + b_buf.load(n, k); }); LoopNest l({c}); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); PaddedBuffer a_v(M, N, "a_v"); for (int m = 0; m < M; m++) { @@ -974,7 +1002,7 @@ TEST(LoopNest, ScheduleFunctionCall01) { LoopNest l({d}, {c, d}); l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); std::ostringstream oss; oss << *stmt; ASSERT_GT(oss.str().size(), 100); @@ -1039,8 +1067,8 @@ TEST(LoopNest, ScheduleInlineSimple) { l1.prepareForCodegen(); l2.prepareForCodegen(); - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); - Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); @@ -1130,7 +1158,7 @@ void InlineFunc01Helper(const std::vector& inline_order) { } } l.prepareForCodegen(); - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); std::ostringstream oss; oss << *stmt; @@ -1189,7 +1217,7 @@ void InlineFunc01Helper(const std::vector& inline_order) { }); LoopNest l2({z2}); l2.prepareForCodegen(); - Stmt* stmt2 = l2.root_stmt(); + StmtPtr stmt2 = l2.root_stmt(); std::ostringstream oss2; oss2 << *stmt2; @@ -1233,7 +1261,7 @@ TEST(LoopNest, ScheduleInlineRandom) { // would normally compare results but Rand isn't implemented in the // SimpleIREvaluator, even if we could seed it. - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); // Check the IR we produced checkIR(stmt1, R"IR( @@ -1270,7 +1298,7 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) { // would normally compare results but Rand isn't implemented in the // SimpleIREvaluator, even if we could seed it. - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); // Check the IR we produced checkIR(stmt1, R"IR( @@ -1303,7 +1331,7 @@ TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { // would normally compare results but Rand isn't implemented in the // SimpleIREvaluator, even if we could seed it. - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); // Check the IR we produced checkIR(stmt1, R"IR( @@ -1357,8 +1385,8 @@ TEST(LoopNest, ScheduleInlineIntrinsics) { l1.prepareForCodegen(); l2.prepareForCodegen(); - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); - Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); @@ -1398,7 +1426,7 @@ TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { LoopNest l1({y}, {x, y}); l1.computeInline(x->buf()); - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); // Check the IR we produced checkIR(stmt1, R"IR( @@ -1419,7 +1447,7 @@ TEST(LoopNest, ScheduleSplitAThenInline) { }); LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); LoopNest::splitWithMask(loops[0], 4); ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); } @@ -1434,11 +1462,11 @@ TEST(LoopNest, ScheduleSplitBThenInline) { }); LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0); LoopNest::splitWithMask(loops[0], 3); l.computeInline(a->buf()); l.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(l.root_stmt()); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); std::vector output(6, 0); SimpleIREvaluator eval(s, {b}); @@ -1458,10 +1486,10 @@ TEST(LoopNest, ScheduleSplitTwiceThenInline) { return a->load(j + ExprHandle(8)); }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* i_inner; + ForPtr i_inner; LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); LoopNest::splitWithMask(loops[0], 4, &i_inner); LoopNest::splitWithMask(i_inner, 2); ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); @@ -1479,10 +1507,10 @@ TEST(LoopNest, ScheduleInlineThenSplit) { LoopNest l({b}, {a, b}); l.computeInline(a->buf()); - std::vector loops = NodeFinder::find(l.root_stmt()); + std::vector loops = NodeFinder::find(l.root_stmt()); LoopNest::splitWithMask(loops.back(), 3); l.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(l.root_stmt()); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); std::vector output(6, 0); SimpleIREvaluator eval(s, {b}); eval(output); @@ -1509,7 +1537,7 @@ TEST(LoopNest, ScheduleSplitInlineThenSplit) { loops = NodeFinder::find(l.root_stmt()); LoopNest::splitWithMask(loops.front(), 2); l.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(l.root_stmt()); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); std::vector output(16, 0); SimpleIREvaluator eval(s, {b}); eval(output); @@ -1530,7 +1558,7 @@ TEST(LoopNest, ScheduleSplitInlineSimplify) { }); LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); LoopNest::splitWithMask(loops[0], 4); ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); } @@ -1549,11 +1577,11 @@ TEST(LoopNest, ScheduleInlineThreeMixedOnce) { }); LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); l.computeInline(a->buf()); l.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(l.root_stmt()); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); std::vector output(4 * 3, 0); SimpleIREvaluator eval(s, {c}); eval(output); @@ -1579,12 +1607,12 @@ TEST(LoopNest, ScheduleInlineThreeMixedTwice) { }); LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); l.computeInline(a->buf()); l.computeInline(b->buf()); l.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(l.root_stmt()); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); std::vector output(4 * 3, 0); SimpleIREvaluator eval(s, {c}); eval(output); @@ -1610,11 +1638,11 @@ TEST(LoopNest, ScheduleInlineThreeMixedInner) { }); LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); l.computeInline(b->buf()); l.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(l.root_stmt()); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); std::vector output(4 * 3, 0); SimpleIREvaluator eval(s, {c}); eval(output); @@ -1640,7 +1668,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedSplit) { }); LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); LoopNest::splitWithMask(loops[0], 4); loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0); LoopNest::splitWithMask(loops[0], 3); @@ -1675,7 +1703,7 @@ TEST(LoopNest, ScheduleInlineOutputTensors) { // would normally compare results but Rand isn't implemented in the // SimpleIREvaluator, even if we could seed it. - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); // Check the IR we produced checkIR(stmt1, R"IR( @@ -1709,7 +1737,7 @@ TEST(LoopNest, ScheduleFuserStyle) { LoopNest l({b, c}); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); std::vector a_data(kTotalSize, 7.0f); std::vector b_data(kTotalSize, 0.0f); @@ -1747,7 +1775,7 @@ TEST(LoopNest, ScheduleFuserThreeArg) { l.computeInline(l.getLoopBodyFor(e)); l.computeInline(l.getLoopBodyFor(f)); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); std::vector a_data(kTotalSize, 1.0f); std::vector b_data(kTotalSize, 2.0f); @@ -1773,7 +1801,7 @@ TEST(LoopNest, ScheduleDynamicShape2D) { return a.load(i, j) + b.load(i, j); }); LoopNest l({c}); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); SimpleIREvaluator cg(s, {a, b, c, m, n}); std::vector aData(M * N, 1.0f); std::vector bData(M * N, 2.0f); @@ -1808,10 +1836,10 @@ TEST(LoopNest, LoopNestComputeAt_1) { Tensor* B = Compute( "B", {{N, "i_b"}}, [&](const VarHandle& i_b) { return A->load(i_b); }); LoopNest l({B}, {A, B}); - std::vector loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); checkIR(s, R"IR( # CHECK: Allocate(temp); // dtype=int, dims=[1] @@ -1875,10 +1903,10 @@ TEST(LoopNest, LoopNestComputeAt_2) { { // First let's try to compute P at axis cy (the outer loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); // Check the IR we produced checkIR(s, R"IR( @@ -1901,10 +1929,10 @@ TEST(LoopNest, LoopNestComputeAt_2) { { // Now let's try to compute P at axis cx (the inner loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); // Check the IR we produced checkIR(s, R"IR( @@ -1974,10 +2002,10 @@ TEST(LoopNest, LoopNestComputeAt_3) { { // First let's try to compute A at axis dy (the outer loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); // Check the IR we produced checkIR(s, R"IR( @@ -2005,10 +2033,10 @@ TEST(LoopNest, LoopNestComputeAt_3) { { // Now let's try to compute A at axis dx (the inner loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); // Check the IR we produced checkIR(s, R"IR( @@ -2109,7 +2137,7 @@ TEST(LoopNest, Reduce2dComputeAt) { # CHECK: } # CHECK: Free(temp); )IR"); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); // Now check that the loop still produces the correct result. std::vector c_data(kW * kH, 0); @@ -2120,7 +2148,7 @@ TEST(LoopNest, Reduce2dComputeAt) { { // Now let's try to compute P at axis cx (the inner loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); l.simplify(); l.eliminateDeadStores(); @@ -2144,7 +2172,7 @@ TEST(LoopNest, Reduce2dComputeAt) { # CHECK: } # CHECK: Free(temp); )IR"); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); // Now check that the loop still produces the correct result. std::vector c_data(kW * kH, 0); @@ -2194,7 +2222,7 @@ TEST(LoopNest, DISABLED_Conv1d_NH) { # CHECK: } # CHECK: } )IR"); - std::vector loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); // FIXME: The current IR is totally broken. The body of the inlined loop is: @@ -2222,7 +2250,7 @@ TEST(LoopNest, DISABLED_Conv1d_NH) { l.simplify(); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); SimpleIREvaluator cg(s, {IP, B}); // auto At = at::ones({N, H}, at::kFloat); @@ -2238,14 +2266,14 @@ class LoopOrderHelper : public IRVisitor { std::stringstream ordering; public: - std::string getOrder(Stmt* s) { + std::string getOrder(StmtPtr s) { ordering.str(""); s->accept(this); return ordering.str(); } // NOLINTNEXTLINE(cppcoreguidelines-explicit--functions,modernize-use-override) - void visit(For* v) { + void visit(ForPtr v) { ordering << v->var()->name_hint() << ","; IRVisitor::visit(v); } @@ -2258,7 +2286,7 @@ TEST(LoopNest, LoopNestReorderAxis1) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); LoopNest l({tensor}); - Stmt* stmt1 = Stmt::clone(l.root_stmt()); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); std::vector stmt1_output(6, 0); SimpleIREvaluator cg(stmt1, {tensor}); @@ -2266,7 +2294,7 @@ TEST(LoopNest, LoopNestReorderAxis1) { auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::reorderAxis(loops[0], loops[1]); - Stmt* stmt2 = Stmt::clone(l.root_stmt()); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); ASSERT_NE(stmt1, stmt2); LoopOrderHelper loopOrderHelper; @@ -2287,7 +2315,7 @@ TEST(LoopNest, LoopNestReorderAxis1) { // Reorder them back. loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::reorderAxis(loops[0], loops[1]); - Stmt* stmt3 = l.root_stmt(); + StmtPtr stmt3 = l.root_stmt(); std::string order3 = loopOrderHelper.getOrder(stmt3); ASSERT_EQ(order3, order1); @@ -2312,7 +2340,7 @@ TEST(LoopNest, LoopNestReorderPartialAxes) { LoopNest l({tensor}); LoopOrderHelper loopOrderHelper; - Stmt* stmt1 = Stmt::clone(l.root_stmt()); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "x,y,z,"); std::vector stmt1_output(24, 0); @@ -2323,7 +2351,7 @@ TEST(LoopNest, LoopNestReorderPartialAxes) { LoopNest::reorderAxis(loops[0], loops[1]); ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "y,x,z,"); - Stmt* stmt2 = Stmt::clone(l.root_stmt()); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); std::vector stmt2_output(24, 0); SimpleIREvaluator cg2(stmt2, {tensor}); @@ -2337,7 +2365,7 @@ TEST(LoopNest, LoopNestReorderPartialAxes) { LoopNest::reorderAxis(loops[1], loops[2]); ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "y,z,x,"); - Stmt* stmt3 = Stmt::clone(l.root_stmt()); + StmtPtr stmt3 = Stmt::clone(l.root_stmt()); std::vector stmt3_output(24, 0); SimpleIREvaluator cg3(stmt3, {tensor}); @@ -2363,7 +2391,7 @@ TEST(LoopNest, LoopNestReorderInternalAxis) { LoopNest l({tensor}); LoopOrderHelper loopOrderHelper; - Stmt* stmt1 = Stmt::clone(l.root_stmt()); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "w,x,y,z,"); std::vector stmt1_output(24, 0); @@ -2374,7 +2402,7 @@ TEST(LoopNest, LoopNestReorderInternalAxis) { LoopNest::reorderAxis(loops[2], loops[1]); ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "w,y,x,z,"); - Stmt* stmt2 = l.root_stmt(); + StmtPtr stmt2 = l.root_stmt(); std::vector stmt2_output(24, 0); SimpleIREvaluator cg2(stmt2, {tensor}); @@ -2400,7 +2428,7 @@ TEST(LoopNest, LoopNestReorderEnclosingAxis) { LoopNest l({tensor}); LoopOrderHelper loopOrderHelper; - Stmt* stmt1 = Stmt::clone(l.root_stmt()); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); std::vector stmt1_output(24, 0); SimpleIREvaluator cg(stmt1, {tensor}); @@ -2410,7 +2438,7 @@ TEST(LoopNest, LoopNestReorderEnclosingAxis) { LoopNest::reorderAxis(loops[0], loops[3]); ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "z,x,y,w,"); - Stmt* stmt2 = l.root_stmt(); + StmtPtr stmt2 = l.root_stmt(); std::vector stmt2_output(24, 0); SimpleIREvaluator cg2(stmt2, {tensor}); @@ -2428,11 +2456,11 @@ TEST(LoopNest, LoopNestReorderSameAxis) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); LoopNest l({tensor}); - Stmt* stmt1 = Stmt::clone(l.root_stmt()); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::reorderAxis(loops[1], loops[1]); - Stmt* stmt2 = Stmt::clone(l.root_stmt()); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); std::ostringstream oss, oss2; oss << *stmt1; @@ -2468,15 +2496,18 @@ TEST(LoopNest, LoopNestReorderExtraStatements) { VarHandle i = VarHandle(loops[0]->var()); - Stmt* store_1 = Store::make(BufHandle(extra.data()), {i, 0}, ExprHandle(1.f)); - Stmt* store_2 = Store::make(BufHandle(extra.data()), {i, 1}, ExprHandle(2.f)); + StmtPtr store_1 = + Store::make(BufHandle(extra.data()), {i, 0}, ExprHandle(1.f)); + StmtPtr store_2 = + Store::make(BufHandle(extra.data()), {i, 1}, ExprHandle(2.f)); // stmt 3 is the Function body. - Stmt* store_3 = Store::make(BufHandle(extra.data()), {i, 2}, ExprHandle(4.f)); + StmtPtr store_3 = + Store::make(BufHandle(extra.data()), {i, 2}, ExprHandle(4.f)); loops[0]->body()->prepend_stmt(store_1); loops[1]->body()->prepend_stmt(store_2); loops[1]->body()->append_stmt(store_3); - Stmt* stmt1 = Stmt::clone(l.root_stmt()); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); std::vector extra1(6, 0); std::vector res1(24, 0); @@ -2501,7 +2532,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) { */ LoopNest::reorderAxis(loops[1], loops[2]); - Stmt* stmt2 = Stmt::clone(l.root_stmt()); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); // Check the IR we produced checkIR(stmt2, R"IR( @@ -2549,7 +2580,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) { */ loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); LoopNest::reorderAxis(loops[0], loops[2]); - Stmt* stmt3 = Stmt::clone(l.root_stmt()); + StmtPtr stmt3 = Stmt::clone(l.root_stmt()); // Check the IR we produced checkIR(stmt3, R"IR( @@ -2596,12 +2627,14 @@ void LoopNestReorderTestHelper( auto loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); int j = 0; - for (auto* l : loops) { + for (auto l : loops) { // Add an increment at each layer of the loop which counts the number of // times the loop executes. - Load* load = new Load(extra.data(), {new IntImm(j)}); - Add* add = new Add(load, new IntImm(1)); - Stmt* store = new Store(extra.data(), {new IntImm(j)}, add); + LoadPtr load = + alloc(extra.data(), std::vector({alloc(j)})); + AddPtr add = alloc(load, alloc(1)); + StmtPtr store = alloc( + extra.data(), std::vector({alloc(j)}), add); if (prepend) { l->body()->prepend_stmt(store); } @@ -2612,7 +2645,7 @@ void LoopNestReorderTestHelper( j++; } - Stmt* stmt1 = Stmt::clone(l.root_stmt()); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); std::vector extra1(5, 0); std::vector res1(2 * 3 * 2 * 3 * 2, 0); @@ -2635,7 +2668,7 @@ void LoopNestReorderTestHelper( loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); LoopNest::reorderAxis(loops[index1], loops[index2]); - Stmt* stmt2 = Stmt::clone(l.root_stmt()); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); std::ostringstream oss, oss2; oss << *stmt1; @@ -2728,10 +2761,10 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) { }); LoopNest l({z}, {x, y, z}); - For* a = nullptr; - For* b = nullptr; + ForPtr a = nullptr; + ForPtr b = nullptr; auto fors = NodeFinder::find(l.root_stmt()); - for (auto* f : fors) { + for (auto f : fors) { if (f->var()->name_hint() == "m2") { a = f; } else if (f->var()->name_hint() == "k2") { @@ -2741,7 +2774,7 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) { LoopNest::reorderAxis(a, b); l.prepareForCodegen(); - Stmt* stmt = IRSimplifier::simplify(l.root_stmt()); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); // Check the IR we produced has the 3 nests in the right order, but k and m // swapped in the middle. @@ -2810,21 +2843,21 @@ TEST(LoopNest, OuterLoopVectorization) { ASSERT_TRUE( LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor->buf())[0][0])); - Stmt* root_stmt = l.root_stmt(); - Block* outer_block = dynamic_cast(root_stmt); + StmtPtr root_stmt = l.root_stmt(); + BlockPtr outer_block = to(root_stmt); ASSERT_NE(outer_block, nullptr); - while (Block* inner_block = dynamic_cast(outer_block->front())) { + while (BlockPtr inner_block = to(outer_block->front())) { outer_block = inner_block; } // Verify that we have only a single loop level remaining after // vectorization. ASSERT_EQ(outer_block->nstmts(), 1); - For* for_loop = dynamic_cast(outer_block->front()); + ForPtr for_loop = to(outer_block->front()); ASSERT_NE(for_loop, nullptr); - Block* for_body = for_loop->body(); + BlockPtr for_body = for_loop->body(); ASSERT_EQ(for_body->nstmts(), 1); - ASSERT_EQ(dynamic_cast(for_body->front()), nullptr); + ASSERT_EQ(to(for_body->front()), nullptr); } TEST(LoopNest, VectorizeLoopNotNormalized) { @@ -2847,7 +2880,7 @@ TEST(LoopNest, VectorizeLoopNotNormalized) { ASSERT_TRUE(LoopNest::vectorize(inner_for)); ASSERT_EQ(outer_for->body()->nstmts(), 1); - ASSERT_EQ(dynamic_cast(outer_for->body()->front()), nullptr); + ASSERT_EQ(to(outer_for->body()->front()), nullptr); } namespace { @@ -2858,8 +2891,8 @@ std::string constantUpperBoundLoopIR(int upper_bound_val) { Tensor* A = Compute( "A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; }); LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; - Stmt* unrolled = nullptr; + std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; + StmtPtr unrolled = nullptr; LoopNest::unroll(loops[0], &unrolled); std::ostringstream oss; oss << *unrolled; @@ -2888,8 +2921,8 @@ TEST(LoopNest, UnrollOuter) { {{outer_bound, "x"}, {inner_bound, "y"}}, [&](const VarHandle& x, const VarHandle& y) { return x + y; }); LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; - Stmt* unrolled = nullptr; + std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; + StmtPtr unrolled = nullptr; LoopNest::unroll(loops[0], &unrolled); checkIR(unrolled, R"IR( # CHECK: for (int y = 0; y < 4; y++) { @@ -2912,10 +2945,10 @@ TEST(LoopNest, UnrollInner) { {{outer_bound, "x"}, {inner_bound, "y"}}, [&](const VarHandle& x, const VarHandle& y) { return x + y; }); LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; - Stmt* unrolled = nullptr; + std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; + StmtPtr unrolled = nullptr; LoopNest::unroll( - static_cast(loops[0]->body()->stmts().front()), &unrolled); + static_to(loops[0]->body()->stmts().front()), &unrolled); checkIR(loops[0], R"IR( # CHECK: for (int x = 0; x < 3; x++) { # CHECK: A[x, 0] = x; @@ -2940,7 +2973,7 @@ TEST(LoopNest, UnrollMultipleStatements) { {Store::make(a_buf, {x}, x * 2), Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); Block::make({f}); - Stmt* unrolled = nullptr; + StmtPtr unrolled = nullptr; LoopNest::unroll(f, &unrolled); checkIR(unrolled, R"IR( # CHECK: A[0] = 0; @@ -2973,8 +3006,8 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto b = Block::make({outer_for}); - std::vector loops = {outer_for, inner_for}; - Stmt* unrolled = nullptr; + std::vector loops = {outer_for, inner_for}; + StmtPtr unrolled = nullptr; LoopNest::unroll(loops[0], &unrolled); checkIR(unrolled, R"IR( # CHECK: for (int j = 0; j < 4; j++) { @@ -3003,8 +3036,8 @@ TEST(LoopNest, NoUnroll) { Tensor* A = Compute( "A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; }); LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; - Stmt* unrolled = nullptr; + std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; + StmtPtr unrolled = nullptr; ASSERT_THROWS_WITH( LoopNest::unroll(loops[0], &unrolled), "non-constant loop"); } @@ -3026,7 +3059,7 @@ TEST(LoopNest, UnrollWithLet) { Store::make(a_buf, {x}, e), Store::make(b_buf, {x}, e + 1)})); Block::make({f}); - Stmt* unrolled = nullptr; + StmtPtr unrolled = nullptr; LoopNest::unroll(f, &unrolled); std::ostringstream oss; oss << *unrolled; @@ -3068,7 +3101,7 @@ TEST(LoopNest, IsNormalized) { Block::make({for_stmt}); ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); - for_stmt->set_start(new IntImm(0)); + for_stmt->set_start(alloc(0)); ASSERT_TRUE(LoopNest::isNormalized(for_stmt)); VarHandle N("N", kInt); @@ -3299,9 +3332,9 @@ TEST(LoopNest, NormalizeAndSplitWithTail) { LoopNest::normalize(for_stmt); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_inner; + ForPtr x_inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_tail; + ForPtr x_tail; LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); auto x_outer_result = IRSimplifier::simplify(for_stmt); @@ -3342,8 +3375,8 @@ TEST(LoopNest, FlattenSimpleLoopNest2D) { auto outer_for = For::make(i, 0, 10, inner_for); Block::make({outer_for}); - std::vector loops = {outer_for, inner_for}; - For* flattened = nullptr; + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, loops.front()); @@ -3389,8 +3422,8 @@ TEST(LoopNest, FlattenSimpleLoopNest3D) { auto for3 = For::make(i, 0, 10, for2); Block::make({for3}); - std::vector loops = {for3, for2, for1}; - For* flattened = nullptr; + std::vector loops = {for3, for2, for1}; + ForPtr flattened = nullptr; ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, loops.front()); @@ -3432,8 +3465,8 @@ TEST(LoopNest, FlattenLoopNestAfterNormalize) { auto outer_for = For::make(i, 2, 10, inner_for); Block::make({outer_for}); - std::vector loops = {outer_for, inner_for}; - For* flattened = nullptr; + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, loops.front()); @@ -3478,8 +3511,8 @@ TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto b = Block::make({outer_for}); - std::vector loops = {outer_for, inner_for}; - For* flattened = nullptr; + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, loops.front()); @@ -3523,8 +3556,8 @@ TEST(LoopNest, FlattenImperfectLoopNest) { HashProvider hasher; auto hash_before = hasher.hash(par); - std::vector loops = {outer_for, inner_for}; - For* flattened = nullptr; + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, nullptr); auto hash_after = hasher.hash(par); @@ -3556,8 +3589,8 @@ TEST(LoopNest, FlattenReductionLoopNest) { HashProvider hasher; auto hash_before = hasher.hash(par); - std::vector loops = {outer_for, inner_for}; - For* flattened = nullptr; + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, nullptr); auto hash_after = hasher.hash(par); @@ -3577,7 +3610,7 @@ TEST(LoopNest, FlattenReductionLoopNestFromTensor) { auto hash_before = hasher.hash(loop.root_stmt()); auto loops = loop.getAllLoopNestsWritingToBuf(c->buf())[1]; - For* flattened = nullptr; + ForPtr flattened = nullptr; ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, nullptr); auto hash_after = hasher.hash(loop.root_stmt()); @@ -3616,8 +3649,8 @@ TEST(LoopNest, FlattenIncorrectLoopsAsInput) { HashProvider hasher; auto hash_before = hasher.hash(par); - std::vector loops = {outer_for1, inner_for2}; - For* flattened = nullptr; + std::vector loops = {outer_for1, inner_for2}; + ForPtr flattened = nullptr; ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, nullptr); auto hash_after = hasher.hash(par); @@ -3659,11 +3692,11 @@ TEST(LoopNest, CacheReadsSimple) { }); LoopNest l({B, C}, {A, B, C}); - Stmt* j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1]; + StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1]; LoopNest::cacheAccesses(A->buf(), "A_local", j_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); // just this once: verify the whole thing. checkIR(result, R"IR( @@ -3727,11 +3760,11 @@ TEST(LoopNest, CacheReadsOuter) { }); LoopNest l({B, C}, {A, B, C}); - Stmt* i_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][0]; + StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][0]; LoopNest::cacheAccesses(A->buf(), "A_local", i_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); checkIR(result, R"IR( #CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] @@ -3775,10 +3808,10 @@ TEST(LoopNest, CacheReadsInternal) { }); LoopNest l({B, C}, {A, B, C}); - Stmt* j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1]; + StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1]; LoopNest::cacheAccesses(A->buf(), "A_local", j_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); checkIR(result, R"IR( #CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] @@ -3823,10 +3856,10 @@ TEST(LoopNest, CacheReadsInner) { }); LoopNest l({B, C}, {A, B, C}); - Stmt* body = l.getLoopBodyFor(B); + StmtPtr body = l.getLoopBodyFor(B); LoopNest::cacheAccesses(A->buf(), "A_local", body); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); checkIR(result, R"IR( #CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] @@ -3870,11 +3903,11 @@ TEST(LoopNest, CacheWritesSimple) { }); LoopNest l({B, C}, {A, B, C}); - Stmt* a_loop = l.getAllLoopNestsWritingToBuf(A->buf())[0][1]; + StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A->buf())[0][1]; LoopNest::cacheAccesses(A->buf(), "A_local", a_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); checkIR(result, R"IR( #CHECK: Allocate(A_local); // dtype=int, dims=[1, 64] @@ -3913,7 +3946,7 @@ TEST(LoopNest, DeadStoreElimination) { BufHandle g("g", {26, 5}, kInt); ExprHandle x_outer_end = 5; ExprHandle x_2 = x + x_outer_end * 4; - For* stmt1 = For::make( + ForPtr stmt1 = For::make( x, 0, 5, @@ -3925,7 +3958,7 @@ TEST(LoopNest, DeadStoreElimination) { Store::make(f, {x_2, y}, (x_2 + y)), Store::make(g, {x_2, y}, (x_2 * y)), }))); - Stmt* stmt = Block::make({stmt1}); + StmtPtr stmt = Block::make({stmt1}); // Will eliminate if not used by an output. LoopNest loop(Stmt::clone(stmt), {f.node()}); @@ -3956,9 +3989,9 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) { BufHandle h("h", {26, 5}, kInt); ExprHandle x_outer_end = 5; ExprHandle x_2 = x + x_outer_end * 4; - For* stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); - For* stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); - For* stmt3 = For::make( + ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); + ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); + ForPtr stmt3 = For::make( x, 0, 5, @@ -3969,7 +4002,7 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) { Block::make({ Store::make(h, {x, y}, Load::make(f, {x * y})), }))); - Stmt* stmt = Block::make({stmt1, stmt2, stmt3}); + StmtPtr stmt = Block::make({stmt1, stmt2, stmt3}); // Will eliminate the write to g, but not f since it used by the producer of // h. @@ -4008,7 +4041,7 @@ TEST(LoopNest, CompoundTensorSimple) { {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); auto inner_for2 = For::make(y, 0, 5, for_body2); auto outer_for2 = For::make(x, 0, 10, inner_for2); - Block* body = Block::make({outer_for1, outer_for2}); + BlockPtr body = Block::make({outer_for1, outer_for2}); Tensor* A = new Tensor(a_buf.node(), body); @@ -4017,7 +4050,7 @@ TEST(LoopNest, CompoundTensorSimple) { std::vector a_data(50, 0); - Stmt* s = IRSimplifier::simplify(l.root_stmt()); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); SimpleIREvaluator cg(s, {A}); std::vector a_ref(50, 0); @@ -4069,7 +4102,7 @@ TEST(LoopNest, CompoundTensorUsed) { {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); auto inner_for2 = For::make(y, 0, 5, for_body2); auto outer_for2 = For::make(x, 0, 10, inner_for2); - Block* body = Block::make({outer_for1, outer_for2}); + BlockPtr body = Block::make({outer_for1, outer_for2}); Tensor* A = new Tensor(a_buf.node(), body); Tensor* B = Compute( @@ -4084,7 +4117,7 @@ TEST(LoopNest, CompoundTensorUsed) { std::vector a_data(50, 0); std::vector b_data(50, 0); - Stmt* s = IRSimplifier::simplify(l.root_stmt()); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); SimpleIREvaluator cg(s, {B}); std::vector b_ref(50, 0); @@ -4677,7 +4710,7 @@ static std::pair, Tensor*> colReduce( return {std::move(a), t}; } -static Stmt* splitTailReorder(Tensor* b) { +static StmtPtr splitTailReorder(Tensor* b) { constexpr int kVectorWidth = 8; LoopNest nest({b}); auto loops = nest.getAllLoopNestsWritingToBuf(b->buf())[0]; @@ -4707,7 +4740,7 @@ static Stmt* splitTailReorder(Tensor* b) { return nest.root_stmt(); } -static Stmt* splitMaskReorder(Tensor* b) { +static StmtPtr splitMaskReorder(Tensor* b) { constexpr int kVectorWidth = 8; LoopNest nest({b}); auto loops = nest.getAllLoopNestsWritingToBuf(b->buf())[1]; @@ -4718,7 +4751,7 @@ static Stmt* splitMaskReorder(Tensor* b) { return nest.root_stmt(); } -static void checkColReduce(Stmt* s, Placeholder& p, Tensor* t) { +static void checkColReduce(StmtPtr s, Placeholder& p, Tensor* t) { int M = immediateAs(p.dim(0)); int N = immediateAs(p.dim(1)); PaddedBuffer a(M, N); @@ -4743,7 +4776,7 @@ TEST(LoopNest, ColReduceSplitTailEvenReorder) { KernelScope kernel_scope; constexpr int M = 76, N = 128; auto p = colReduce(M, N); - Stmt* s = splitTailReorder(p.second); + StmtPtr s = splitTailReorder(p.second); std::ostringstream oss; oss << *s; @@ -4766,7 +4799,7 @@ TEST(LoopNest, ColReduceSplitTailUnevenReorder) { KernelScope kernel_scope; constexpr int M = 76, N = 100; auto p = colReduce(M, N); - Stmt* s = splitTailReorder(p.second); + StmtPtr s = splitTailReorder(p.second); std::ostringstream oss; oss << *s; @@ -4792,7 +4825,7 @@ TEST(LoopNest, ColReduceSplitMaskEvenReorder) { KernelScope kernel_scope; constexpr int M = 76, N = 128; auto p = colReduce(M, N); - Stmt* s = splitMaskReorder(p.second); + StmtPtr s = splitMaskReorder(p.second); checkColReduce(s, *p.first, p.second); } @@ -4800,7 +4833,7 @@ TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { KernelScope kernel_scope; constexpr int M = 76, N = 100; auto p = colReduce(M, N); - Stmt* s = splitMaskReorder(p.second); + StmtPtr s = splitMaskReorder(p.second); checkColReduce(s, *p.first, p.second); } @@ -4825,7 +4858,7 @@ TEST(LoopNest, ReorderAxisWithMultipleConds) { auto outer_cond = Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr); auto forI = For::make(i, 0, 20, outer_cond); - Stmt* par = Block::make({forI}); + StmtPtr par = Block::make({forI}); LoopNest l(par, {a_buf.node()}); LoopNest::reorderAxis(forI, forJ); ASSERT_EQ(par, l.root_stmt()); @@ -4860,7 +4893,7 @@ TEST(LoopNest, VectorizeUse) { ASSERT_TRUE(LoopNest::vectorize(loops[0])); nest.prepareForCodegen(); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - Stmt* s = nest.root_stmt(); + StmtPtr s = nest.root_stmt(); std::ostringstream oss; oss << *nest.root_stmt(); torch::jit::testing::FileCheck().run( @@ -4883,7 +4916,7 @@ TEST(LoopNest, Int64Direct) { Placeholder a("a", kLong, {N}); Placeholder b("b", kLong, {N}); VarHandle n("n", kLong); - Stmt* s = For::make(n, 0, N, b.store({n}, a.load({n}) + LongImm::make(1l))); + StmtPtr s = For::make(n, 0, N, b.store({n}, a.load({n}) + LongImm::make(1l))); s = IRSimplifier::simplify(s); std::ostringstream oss; oss << *s; @@ -5212,7 +5245,7 @@ TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { # CHECK-NOT: for ( )IR"; - auto newForI = dynamic_cast(Stmt::clone(forI)); + auto newForI = to(Stmt::clone(forI)); auto forM = For::make(m, 0, 50, newForI); auto par = Block::make({forM}); LoopNest nest(par, {a_buf.node(), b_buf.node()}); @@ -5244,7 +5277,7 @@ TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { # CHECK-NOT: for ( )IR"; - auto newForI = dynamic_cast(Stmt::clone(forI)); + auto newForI = to(Stmt::clone(forI)); auto forM = For::make(m, 0, 50, newForI); auto par = Block::make({forM}); LoopNest nest(par, {a_buf.node(), b_buf.node()}); @@ -5277,7 +5310,7 @@ TEST(LoopNest, fuseLoopsSimple) { auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); std::ostringstream oss; @@ -5319,7 +5352,7 @@ TEST(LoopNest, fuseLoopsMultiple) { auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); auto par = Block::make({forI, forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); std::ostringstream oss; @@ -5378,7 +5411,7 @@ TEST(LoopNest, fuseLoopsNested) { auto forN = For::make(n, 0, 20, Block::make({initB, forK})); auto par = Block::make({forM, forN}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); std::ostringstream oss; @@ -5440,7 +5473,7 @@ TEST(LoopNest, fuseLoopsNested2D) { Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); auto par = Block::make({forI, forM}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); std::ostringstream oss; @@ -5483,7 +5516,7 @@ TEST(LoopNest, fuseLoopsNested2DInner) { n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); std::ostringstream oss; @@ -5521,7 +5554,7 @@ TEST(LoopNest, fuseLoopsDifferentStopBounds) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -5544,7 +5577,7 @@ TEST(LoopNest, fuseLoopsDifferentStartBounds) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -5569,7 +5602,7 @@ TEST(LoopNest, fuseLoopsNotContiguous) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, initB, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -5598,7 +5631,7 @@ TEST(LoopNest, fuseLoopsWithDifferentParents) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forI, initB, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -5622,7 +5655,7 @@ TEST(LoopNest, fuseLoopsWithVariableBounds) { auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); std::ostringstream oss; @@ -5660,7 +5693,7 @@ TEST(LoopNest, fuseLoopsWithExprBounds) { auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); std::ostringstream oss; @@ -5699,7 +5732,7 @@ TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); std::ostringstream oss; @@ -5736,7 +5769,7 @@ TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); std::ostringstream oss; @@ -5784,7 +5817,7 @@ TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { auto par = Block::make({forI, forM}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); std::ostringstream oss; @@ -5832,7 +5865,7 @@ TEST(LoopNest, fuseLoopsWithReductions) { For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); auto par = Block::make({forI, forM}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); std::ostringstream oss; @@ -5890,7 +5923,7 @@ TEST(LoopNest, fuseLoopsWith2DReductions) { auto par = Block::make({forI, forM}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); std::ostringstream oss; @@ -5940,7 +5973,7 @@ TEST(LoopNest, fuseLoopsWithComplexIndices) { auto par = Block::make({forI, forM}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); std::ostringstream oss; @@ -5987,7 +6020,7 @@ TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { auto par = Block::make({forI, forM}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); } @@ -6018,7 +6051,7 @@ TEST(LoopNest, fuseLoopsWithTranspose) { auto par = Block::make({forI, forM}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); } @@ -6041,7 +6074,7 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies1) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -6064,7 +6097,7 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies2) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -6109,7 +6142,7 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies3) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forM, forN}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); } @@ -6153,7 +6186,7 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies4) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forI, forM}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); } @@ -6183,7 +6216,7 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies5) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); } @@ -6211,7 +6244,7 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies6) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -6239,7 +6272,7 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies7) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forK, forJ}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_loop; + ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); } diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp index 08e93cd..7f844c5 100644 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -23,7 +23,9 @@ TEST(MemDependency, BoundOverlap) { using namespace analysis; - auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; // Sanity check 3 overlap cases. ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0))); @@ -118,7 +120,9 @@ TEST(MemDependency, BoundOverlapMultiDim) { using namespace analysis; - auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; // Sanity check one dimensional cases. ASSERT_EQ(ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)})); @@ -189,7 +193,9 @@ TEST(MemDependency, BoundSubtract) { using namespace analysis; - auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; auto EQ = [](const IndexBounds& x, const IndexBounds& y) { return indexBoundsEquals(x, y); }; @@ -271,7 +277,9 @@ TEST(MemDependency, BoundSubtractMultiDim) { using namespace analysis; - auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; auto EQ = [](std::vector x, std::vector y) { if (x.size() != y.size()) { return false; @@ -406,10 +414,10 @@ TEST(MemDependency, MemDependencyCheckerSimple) { * B[0] = A[0] + 1; */ - Store* aStore = Store::make(a, {0}, 3); - Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); - Stmt* stmt = Block::make({aStore, bStore}); + StmtPtr stmt = Block::make({aStore, bStore}); stmt->accept(&analyzer); @@ -434,11 +442,11 @@ TEST(MemDependency, MemDependencyCheckerMultiStmt) { * C[0] = B[0] + 1; */ - Store* aStore = Store::make(a, {0}, 3); - Store* bStore = Store::make(b, {0}, Load::make(a, {0})); - Store* cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)); + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); + StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)); - Stmt* stmt = Block::make({aStore, bStore, cStore}); + StmtPtr stmt = Block::make({aStore, bStore, cStore}); stmt->accept(&analyzer); @@ -470,11 +478,11 @@ TEST(MemDependency, MemDependencyCheckerOverlap) { * B[0] = A[0] + 1; */ - Store* aStore = Store::make(a, {0}, 3); - Store* a2Store = Store::make(a, {0}, 6); - Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr a2Store = Store::make(a, {0}, 6); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); - Stmt* stmt = Block::make({aStore, a2Store, bStore}); + StmtPtr stmt = Block::make({aStore, a2Store, bStore}); stmt->accept(&analyzer); @@ -507,11 +515,11 @@ TEST(MemDependency, MemDependencyCheckerLoop) { * B[0] = A[0] + 1; */ - Store* aStore = Store::make(a, {x}, x); - Stmt* loop = For::make(x, 0, 10, aStore); - Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1)); + StorePtr aStore = Store::make(a, {x}, x); + StmtPtr loop = For::make(x, 0, 10, aStore); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1)); - Stmt* stmt = Block::make({loop, bStore}); + StmtPtr stmt = Block::make({loop, bStore}); stmt->accept(&analyzer); @@ -528,7 +536,7 @@ TEST(MemDependency, MemDependencyCheckerLoop) { // It should have bounds covering the range of x: 0 <= x < 10. ASSERT_TRUE(indexBoundsEquals( - aStoreAccess->bounds(), {Bound(new IntImm(0), new IntImm(9))})); + aStoreAccess->bounds(), {Bound(alloc(0), alloc(9))})); } // Reductions should promote dependencies as well. @@ -550,14 +558,14 @@ TEST(MemDependency, MemDependencyCheckerLoopReduce) { * B[0] = A[0]; */ - Store* aInit = Store::make(a, {0}, 0); + StorePtr aInit = Store::make(a, {0}, 0); ExprHandle reduce = ExprHandle(Sum()(a.node(), ExprHandle(1), {x.node()}, {x.node()})); - Store* aReduce = Store::make(a, {0}, reduce); - Stmt* loop = For::make(x, 0, 10, aReduce); - Store* bStore = Store::make(b, {0}, Load::make(a, {0})); + StorePtr aReduce = Store::make(a, {0}, reduce); + StmtPtr loop = For::make(x, 0, 10, aReduce); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - Stmt* stmt = Block::make({aInit, loop, bStore}); + StmtPtr stmt = Block::make({aInit, loop, bStore}); stmt->accept(&analyzer); @@ -582,11 +590,11 @@ TEST(MemDependency, MemDependencyCheckerLoopReduce) { // Find loads within the reduction: auto reduceLoads = NodeFinder::find(reduce.node()); // Pull out the access for the load inside the loop. - for (auto* load : reduceLoads) { + for (auto load : reduceLoads) { auto loopLoad = analyzer.accessFor(load); // It should have 10 element long bounds. ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(new IntImm(0), new IntImm(9))})); + loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); } } @@ -609,13 +617,13 @@ TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { * B[0] = A[0]; */ - Store* aInit = Store::make(a, {0}, 0); + StorePtr aInit = Store::make(a, {0}, 0); ExprHandle aLoad = Load::make(a, {x}); - Store* aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); - Stmt* loop = For::make(x, 0, 10, aReduce); - Store* bStore = Store::make(b, {0}, Load::make(a, {0})); + StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); + StmtPtr loop = For::make(x, 0, 10, aReduce); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - Stmt* stmt = Block::make({aInit, loop, bStore}); + StmtPtr stmt = Block::make({aInit, loop, bStore}); stmt->accept(&analyzer); @@ -641,7 +649,7 @@ TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { auto loopLoad = analyzer.accessFor(aLoad.node()); // It should have 10 element long bounds. ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(new IntImm(0), new IntImm(9))})); + loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); } // Can determine dependencies of outputs, through to inputs. @@ -662,10 +670,10 @@ TEST(MemDependency, MemDependencyCheckerInputsOutputs) { */ ExprHandle aLoad = Load::make(a, {x}); - Store* bStore = Store::make(b, {x}, Max::make(aLoad, 0, true)); - Stmt* loop = For::make(x, 0, 10, bStore); + StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true)); + StmtPtr loop = For::make(x, 0, 10, bStore); - Stmt* stmt = Block::make({loop}); + StmtPtr stmt = Block::make({loop}); stmt->accept(&analyzer); @@ -714,10 +722,10 @@ TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { * } */ - Store* bStore = Store::make(b, {x}, Max::make(x, 0, true)); - Stmt* loop = For::make(x, 0, 10, bStore); + StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true)); + StmtPtr loop = For::make(x, 0, 10, bStore); - Stmt* stmt = Block::make({loop}); + StmtPtr stmt = Block::make({loop}); stmt->accept(&analyzer); @@ -766,14 +774,14 @@ TEST(MemDependency, MemDependencyCheckerLoopBounds) { * } */ - std::vector stmts( + std::vector stmts( {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))), For::make( x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))), For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))), For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))}); - Stmt* stmt = Block::make(stmts); + StmtPtr stmt = Block::make(stmts); stmt->accept(&analyzer); @@ -793,7 +801,9 @@ TEST(MemDependency, MemDependencyCheckerLoopBounds) { // The last write to C does not depend on the other write to C. ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2])); - auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; auto EQ = [](const IndexBounds& x, const IndexBounds& y) { return indexBoundsEquals(x, y); }; @@ -815,9 +825,9 @@ TEST(MemDependency, MemDependencyCheckerLoopBounds) { // much. auto history = analyzer.getHistory(); ASSERT_EQ(history.size(), 10); - Var* aVar = a.node()->base_handle(); - Var* bVar = b.node()->base_handle(); - Var* cVar = c.node()->base_handle(); + VarPtr aVar = a.node()->base_handle(); + VarPtr bVar = b.node()->base_handle(); + VarPtr cVar = c.node()->base_handle(); // The first access is the input A. ASSERT_EQ(history[0]->type(), AccessType::Input); @@ -949,7 +959,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { * } */ - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))), For::make( @@ -967,7 +977,9 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { // Sanity check output depends on Input. ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; auto EQ = [](const IndexBounds& x, const IndexBounds& y) { return indexBoundsEquals(x, y); }; @@ -989,8 +1001,8 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { // Now let's look at the bounds of each access. auto history = analyzer.getHistory(); ASSERT_EQ(history.size(), 12); - Var* aVar = a.node()->base_handle(); - Var* bVar = b.node()->base_handle(); + VarPtr aVar = a.node()->base_handle(); + VarPtr bVar = b.node()->base_handle(); // The first access is the input A. ASSERT_EQ(history[0]->type(), AccessType::Input); @@ -1123,7 +1135,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Not self dependent since all loop iterations use a different y. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( y, 0, 10, @@ -1143,7 +1155,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Not self dependent due to different y (with offset). MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( y, 0, 10, @@ -1164,7 +1176,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Is self dependent since all loops use a common constant element of A. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, @@ -1184,7 +1196,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // read. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, @@ -1203,7 +1215,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Is self dependent since all loops use a common symbolic element of A. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, @@ -1223,7 +1235,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { MemDependencyChecker analyzer; - Stmt* stmt = + StmtPtr stmt = For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); stmt->accept(&analyzer); @@ -1241,7 +1253,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { MemDependencyChecker analyzer; analyzer.allowLoopExecutionOrderAnalysis(); - Stmt* stmt = + StmtPtr stmt = For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); stmt->accept(&analyzer); @@ -1258,7 +1270,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { MemDependencyChecker analyzer; - Stmt* stmt = + StmtPtr stmt = For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); stmt->accept(&analyzer); @@ -1274,7 +1286,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { MemDependencyChecker analyzer; analyzer.allowLoopExecutionOrderAnalysis(); - Stmt* stmt = + StmtPtr stmt = For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); stmt->accept(&analyzer); @@ -1295,7 +1307,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { MemDependencyChecker analyzer; analyzer.allowLoopExecutionOrderAnalysis(); - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 3, 10, @@ -1319,7 +1331,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { MemDependencyChecker analyzer; analyzer.allowLoopExecutionOrderAnalysis(); - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 3, 10, @@ -1340,7 +1352,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 3, 10, @@ -1362,7 +1374,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { MemDependencyChecker analyzer; analyzer.allowLoopExecutionOrderAnalysis(); - Stmt* stmt = + StmtPtr stmt = For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1}))); stmt->accept(&analyzer); @@ -1383,7 +1395,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Execution order doesn't matter since the read and the write are totally // distinct. - Stmt* stmt = + StmtPtr stmt = For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2}))); stmt->accept(&analyzer); @@ -1405,7 +1417,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Execution order doesn't matter since the read and the write are totally // distinct. - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1}))); stmt->accept(&analyzer); @@ -1421,7 +1433,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // same if the read is behind the write so long as they are distinct. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1}))); stmt->accept(&analyzer); @@ -1437,7 +1449,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // But not if the offset is in the stride. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2}))); stmt->accept(&analyzer); @@ -1453,7 +1465,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Works with negative offsets too. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2}))); stmt->accept(&analyzer); @@ -1469,7 +1481,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Detects accesses are distinct when offset is large but not a multiple // of stride. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7}))); stmt->accept(&analyzer); @@ -1484,7 +1496,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Works with offsets which are multiples of the stride. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4}))); stmt->accept(&analyzer); @@ -1501,7 +1513,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // within. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5}))); stmt->accept(&analyzer); @@ -1518,7 +1530,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // multiple. MemDependencyChecker analyzer; - Stmt* stmt = + StmtPtr stmt = For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6}))); stmt->accept(&analyzer); @@ -1534,7 +1546,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // still works when the read axis is the smaller stride. MemDependencyChecker analyzer; - Stmt* stmt = + StmtPtr stmt = For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2}))); stmt->accept(&analyzer); @@ -1551,7 +1563,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // and there is an offset. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1}))); stmt->accept(&analyzer); @@ -1567,7 +1579,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // The smaller stride determines whether there is overlap. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4}))); stmt->accept(&analyzer); @@ -1583,7 +1595,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // The smaller stride determines whether there is overlap, not the larger. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6}))); stmt->accept(&analyzer); @@ -1598,7 +1610,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // If they have strides with no common muliple > 1, they overlap. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1}))); stmt->accept(&analyzer); @@ -1614,7 +1626,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // If the offset is greater than the size of the loop, they can't overlap. MemDependencyChecker analyzer; - Stmt* stmt = + StmtPtr stmt = For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10}))); stmt->accept(&analyzer); @@ -1629,7 +1641,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // If they have different execution orders they may overlap. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))); stmt->accept(&analyzer); @@ -1644,7 +1656,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Or they may not, depending on their start offset and strides. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, @@ -1663,7 +1675,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // If the stride is not monotonic, they overlap. MemDependencyChecker analyzer; - Stmt* stmt = + StmtPtr stmt = For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2}))); stmt->accept(&analyzer); @@ -1678,7 +1690,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // If the stride is not monotonic, they overlap - even with an offset. MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1}))); stmt->accept(&analyzer); @@ -1694,7 +1706,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // Mod too... analysis::MemDependencyChecker analyzer; - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, @@ -1714,7 +1726,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { { MemDependencyChecker analyzer; - Stmt* stmt = + StmtPtr stmt = For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); stmt->accept(&analyzer); @@ -1724,7 +1736,7 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { { MemDependencyChecker analyzer; analyzer.allowLoopExecutionOrderAnalysis(); - Stmt* stmt = + StmtPtr stmt = For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); stmt->accept(&analyzer); @@ -1745,7 +1757,7 @@ TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { using namespace analysis; MemDependencyChecker analyzer({a.node()}, {b.node()}); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make( x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))) @@ -1772,7 +1784,7 @@ TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { { analysis::MemDependencyChecker analyzer({a.node()}, {c.node()}); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make( x, 0, @@ -1817,7 +1829,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { // Future usages may depend on accesses in both branches of a condition. MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), Cond::make( CompareSelect::make(y, 5, CompareSelectOperation::kLT), @@ -1854,7 +1866,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { // Future usages may depend on accesses in both branches of a condition. MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), Cond::make( CompareSelect::make(y, 5, CompareSelectOperation::kLT), @@ -1895,7 +1907,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { // Only has true branch. MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), Cond::make( CompareSelect::make(y, 5, CompareSelectOperation::kLT), @@ -1933,7 +1945,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { // Only has false branch. MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), Cond::make( CompareSelect::make(y, 5, CompareSelectOperation::kLT), @@ -1968,9 +1980,9 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { // Cond's Condition depends on a previous access. MemDependencyChecker analyzer({a}, {c}); - Store* initStore = Store::make(c, {x}, Load::make(a, {x})); + StorePtr initStore = Store::make(c, {x}, Load::make(a, {x})); ExprHandle conditionalLoad = Load::make(c, {0}); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 10, initStore), Cond::make( CompareSelect::make( @@ -2009,14 +2021,14 @@ TEST(MemDependency, MemDependencyCheckerIfThenElse) { // Future usages may depend on accesses in both branches of a condition. MemDependencyChecker analyzer({a, b}, {c}); - Store* ifStore = Store::make( + StorePtr ifStore = Store::make( c, {0}, IfThenElse::make( CompareSelect::make(y, 5, CompareSelectOperation::kLT), Add::make(Load::make(b, {0}), 1), Add::make(Load::make(b, {1}), 1))); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), ifStore}); @@ -2049,14 +2061,14 @@ TEST(MemDependency, MemDependencyCheckerIfThenElse) { // dependent on it. MemDependencyChecker analyzer({a, b}, {c}); - Store* ifStore = Store::make( + StorePtr ifStore = Store::make( c, {0}, IfThenElse::make( CompareSelect::make(y, 5, CompareSelectOperation::kLT), Add::make(Load::make(b, {0}), 1), 42)); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), ifStore}); @@ -2081,14 +2093,14 @@ TEST(MemDependency, MemDependencyCheckerIfThenElse) { // uncertain if this would be helpful. MemDependencyChecker analyzer({a, b}, {c}); - Store* ifStore = Store::make( + StorePtr ifStore = Store::make( c, {0}, IfThenElse::make( CompareSelect::make(y, 5, CompareSelectOperation::kLT), Load::make(b, {x}), Load::make(a, {x}))); - Stmt* stmt = Block::make({For::make(x, 0, 10, ifStore)}); + StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)}); stmt->accept(&analyzer); @@ -2117,7 +2129,7 @@ TEST(MemDependency, MemDependencyCheckerCutLoop) { // Cutting a loop with single element writes. MemDependencyChecker analyzer({a}, {b}); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))), Store::make(b, {5}, 100)}); @@ -2148,12 +2160,13 @@ TEST(MemDependency, MemDependencyCheckerCutLoop) { // loop with one element writes. MemDependencyChecker analyzer({a}, {b}); - For* firstLoop = + ForPtr firstLoop = For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))); - Store* secondStore = Store::make(b, {x}, Add::make(Load::make(b, {x}), 1)); - For* secondLoop = For::make(x, 4, 7, secondStore); + StorePtr secondStore = + Store::make(b, {x}, Add::make(Load::make(b, {x}), 1)); + ForPtr secondLoop = For::make(x, 4, 7, secondStore); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {firstLoop, secondLoop, Store::make(b, {4}, 100), @@ -2203,7 +2216,7 @@ TEST(MemDependency, MemDependencyCheckerDynamicShapes) { * } */ MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))}); stmt->accept(&analyzer); @@ -2241,7 +2254,7 @@ TEST(MemDependency, MemDependencyCheckerDynamicShapes) { * } */ MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, Load::make(b, {0}), Load::make(b, {1}), @@ -2286,7 +2299,7 @@ TEST(MemDependency, MemDependencyCheckerDynamicShapes) { * } */ MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))}); stmt->accept(&analyzer); @@ -2331,7 +2344,7 @@ TEST(MemDependency, MemDependencyCheckerDynamicShapes) { * } */ MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))}); stmt->accept(&analyzer); @@ -2375,7 +2388,7 @@ TEST(MemDependency, MemDependencyCheckerDynamicShapes) { * } */ MemDependencyChecker analyzer({a, b}, {c}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))}); stmt->accept(&analyzer); @@ -2454,7 +2467,7 @@ TEST(MemDependency, MemDependencyCheckerMultiDim) { // Full range. MemDependencyChecker analyzer({a}, {b}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, M, @@ -2500,7 +2513,7 @@ TEST(MemDependency, MemDependencyCheckerMultiDim) { // Partial range. MemDependencyChecker analyzer({a}, {b}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 5, @@ -2543,7 +2556,7 @@ TEST(MemDependency, MemDependencyCheckerMultiDim) { // Partial loops. MemDependencyChecker analyzer({a}, {b}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, N, @@ -2584,7 +2597,7 @@ TEST(MemDependency, MemDependencyCheckerMultiDim) { // dimensionality. MemDependencyChecker analyzer({a, c}, {b}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, M, @@ -2643,7 +2656,7 @@ TEST(MemDependency, MemDependencyCheckerMultiDim) { // Multi-dim reductions. MemDependencyChecker analyzer({a}, {b}); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, M, @@ -2737,8 +2750,8 @@ TEST(MemDependency, MemDependencyCheckerComputeAPI) { ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b_buf.data())); // Second loop depends on first loop. - auto* c_loop = l.getLoopStmtsFor(c)[0]; - auto* d_loop = l.getLoopStmtsFor(d)[0]; + auto c_loop = l.getLoopStmtsFor(c)[0]; + auto d_loop = l.getLoopStmtsFor(d)[0]; ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); } @@ -2813,7 +2826,7 @@ TEST(MemDependency, MemDependencyCheckerComputeSplit) { l.splitWithTail(l.getLoopStmtsFor(c)[0], 2); MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c->buf()}); - Stmt* stmt = IRSimplifier::simplify(l.root_stmt()); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); stmt->accept(&analyzer_after); // Splitting should not change accesses at all. @@ -2863,7 +2876,7 @@ TEST(MemDependency, MemDependencyCheckerComputeReorder) { l.reorderAxis(loops[0], loops[1]); MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c->buf()}); - Stmt* stmt = IRSimplifier::simplify(l.root_stmt()); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); stmt->accept(&analyzer_after); // Reordering should not change accesses at all. @@ -2933,8 +2946,8 @@ TEST(MemDependency, MemDependencyCheckerComputeReduce) { ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b.data())); // Second loop depends on first loop. - auto* c_loop = l.getLoopStmtsFor(c)[0]; - auto* d_loop = l.getLoopStmtsFor(d)[0]; + auto c_loop = l.getLoopStmtsFor(c)[0]; + auto d_loop = l.getLoopStmtsFor(d)[0]; ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); // Reduction depends on both inputs. @@ -2964,36 +2977,36 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) { { auto const& loops = loop.getLoopStmtsFor(CT); - For* m = loops[0]; + ForPtr m = loops[0]; loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); - For* n = loops[2]; + ForPtr n = loops[2]; loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - For* mi = loops[1]; - For* no = loops[2]; + ForPtr mi = loops[1]; + ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - For* ni = loops[3]; - For* k = loops[4]; + ForPtr ni = loops[3]; + ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - For* mi = loops[2]; - For* k = loops[3]; + ForPtr mi = loops[2]; + ForPtr k = loops[3]; loop.reorderAxis(mi, k); } { @@ -3009,7 +3022,7 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) { // Test both unlowered and lowered form. { - Stmt* stmt = IRSimplifier::simplify(loop.root_stmt()); + StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt()); stmt->accept(&analyzer_unlowered); // Outputs depend on inputs. @@ -3053,7 +3066,7 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) { // now check lowered dependency graph. { - Stmt* stmt = IRSimplifier::simplify(loop.root_stmt()); + StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt()); stmt->accept(&analyzer_lowered); // Lowering will change the dimensionality of all bounds due to index @@ -3119,18 +3132,19 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) { history_before[i]->bounds(), history_after[i]->bounds())); } else { ASSERT_EQ(history_after[i]->bounds().size(), 1); - Expr* flat_bounds = new IntImm(1); + ExprPtr flat_bounds = alloc(1); for (auto& b : history_before[i]->bounds()) { - flat_bounds = new Mul(flat_bounds, new Add(b.end, new IntImm(1))); + flat_bounds = + alloc(flat_bounds, alloc(b.end, alloc(1))); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start)); } flat_bounds = IRSimplifier::simplify(flat_bounds); - Expr* after_bounds = IRSimplifier::simplify( - new Add(history_after[i]->bounds()[0].end, new IntImm(1))); + ExprPtr after_bounds = IRSimplifier::simplify( + alloc(history_after[i]->bounds()[0].end, alloc(1))); ASSERT_TRUE(exprEquals(flat_bounds, after_bounds)); } } diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 9e221da..bd71a4f 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -38,7 +38,7 @@ TEST(Reductions, ReduceSum0D_1) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {}); LoopNest loop({c}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -62,7 +62,7 @@ TEST(Reductions, ReduceSum0D_2) { Tensor* c = Reduce("sum", {}, Sum(), b, {}); LoopNest loop({c}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -86,7 +86,7 @@ TEST(Reductions, ReduceSum1D) { Tensor* c = Reduce("sum", {}, Sum(), b, {{10, "m"}}); LoopNest loop({c}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -117,7 +117,7 @@ TEST(Reductions, ReduceSum2D) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}}); LoopNest loop({c}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c, n, m}); @@ -148,7 +148,7 @@ TEST(Reductions, ReduceSum3D) { Tensor* c = Reduce("sum", {{2, "l"}, {3, "n"}}, Sum(), b, {{m, "m"}}); LoopNest loop({c}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c, m}); @@ -178,7 +178,7 @@ TEST(Reductions, ReduceSum3D) { Tensor* d = Reduce("sum2", {{2, "l"}}, Sum(), b, {{3, "n"}, {m, "m"}}); LoopNest loop2({d}); loop2.prepareForCodegen(); - Stmt* s2 = loop2.root_stmt(); + StmtPtr s2 = loop2.root_stmt(); s2 = IRSimplifier::simplify(s2); SimpleIREvaluator cg2(s2, {b, d, m}); @@ -196,7 +196,7 @@ TEST(Reductions, ReduceSum3D) { Tensor* e = Reduce("sum3", {{2, "l"}}, Sum(), c_buf, {{3, "m"}}); LoopNest loop3({e}); loop3.prepareForCodegen(); - Stmt* s3 = loop3.root_stmt(); + StmtPtr s3 = loop3.root_stmt(); s3 = IRSimplifier::simplify(s3); SimpleIREvaluator cg3(s3, {c, e}); @@ -227,7 +227,7 @@ TEST(Reductions, ReduceSum10D) { {{3, "f"}, {2, "g"}, {3, "h"}, {2, "i"}, {3, "j"}}); LoopNest loop({c}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {in_, c}); @@ -264,7 +264,7 @@ TEST(Reductions, ReduceProduct) { Tensor* c = Reduce("product", {{M, "m"}}, product, b, {{N, "n"}}); LoopNest loop({c}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -298,7 +298,7 @@ TEST(Reductions, ReduceMax) { LoopNest loop({dm1}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {in_, dm1}); @@ -345,7 +345,7 @@ TEST(Reductions, ReduceMinCustomInitializer) { LoopNest loop({min}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {in_, min, minInit}); @@ -383,7 +383,7 @@ TEST(Reductions, ReduceAnyAll) { LoopNest loop({any}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, any, searchValue}); @@ -476,7 +476,7 @@ TEST(Reductions, ReduceMatmul2D) { LoopNest loop({mm}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {tA, tB, mm}); @@ -508,7 +508,7 @@ TEST(Reductions, ReduceRfactorLike) { LoopNest loop({l1, l2}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {in, l1, l2}); @@ -535,7 +535,7 @@ TEST(Reductions, ReduceAsProducer) { }); LoopNest loop({d}, {c, d}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {a, b, d, m}); @@ -580,7 +580,7 @@ TEST(Reductions, ReduceAsConsumer) { Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {m, "m1"}}); LoopNest loop({d}, {c, d}); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {a, b, d, m}); @@ -628,12 +628,12 @@ TEST(Reductions, SplitReduceAxis) { Tensor* tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}}); LoopNest l({tensor}); - std::vector loops = l.getLoopStmtsFor(tensor); + std::vector loops = l.getLoopStmtsFor(tensor); LoopNest::splitWithTail(loops[1], 2); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {in, tensor}); @@ -658,13 +658,13 @@ TEST(Reductions, SplitNonReduceAxis) { std::vector out(16, -1.f); Tensor* tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}}); LoopNest l({tensor}); - std::vector loops = l.getLoopStmtsFor(tensor); + std::vector loops = l.getLoopStmtsFor(tensor); LoopNest::splitWithTail(loops[0], 2); LoopNest::splitWithTail(loops[0], 2); l.prepareForCodegen(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {in, tensor}); @@ -691,7 +691,7 @@ TEST(Reductions, ReorderedReductionInitializer) { LoopNest l_({tensor_}); l_.prepareForCodegen(); - Stmt* s_ = Stmt::clone(l_.root_stmt()); + StmtPtr s_ = Stmt::clone(l_.root_stmt()); s_ = IRSimplifier::simplify(s_); Tensor* tensor = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}}); @@ -703,7 +703,7 @@ TEST(Reductions, ReorderedReductionInitializer) { LoopNest::reorderAxis(loops[1], loops[2]); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) s = IRSimplifier::simplify(s); @@ -743,14 +743,13 @@ TEST(Reductions, ReduceRfactor) { Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto c_body = const_cast(loop.getAllWritesToBuf(c->buf())[1]); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c->buf())[1]; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c, m, n}); @@ -779,14 +778,13 @@ TEST(Reductions, Reduce3DRfactorInner) { Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto c_body = const_cast(loop.getAllWritesToBuf(c->buf())[1]); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c->buf())[1]; ASSERT_FALSE(loop.rfactor(c_body, loops.at(2))); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 1); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c, m, n, k}); @@ -815,14 +813,13 @@ TEST(Reductions, Reduce3DRfactorOuter) { Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto c_body = const_cast(loop.getAllWritesToBuf(c->buf())[1]); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c->buf())[1]; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c, m, n, k}); @@ -857,18 +854,16 @@ TEST(Reductions, ReduceRepeatedInternalRfactor) { IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); ref_cg.call({in, ref}); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - Buf* tmp_buf = const_cast(c->buf()); + BufPtr tmp_buf = c->buf(); for (int idx = 0; idx < rfac_number; idx++) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto reduce = const_cast(loop.getAllWritesToBuf(tmp_buf)[1]); + auto reduce = loop.getAllWritesToBuf(tmp_buf)[1]; ASSERT_TRUE(loop.rfactor( reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf)); } loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {in_, c}); @@ -897,11 +892,11 @@ TEST(Reductions, ReduceSplitTail) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithTail(loops[i], 8); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -929,11 +924,11 @@ TEST(Reductions, ReduceSplitNoTail) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithTail(loops[i], 5); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -963,11 +958,11 @@ TEST(Reductions, ReduceOverSplitTail) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithTail(loops[i], 16); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -996,11 +991,11 @@ TEST(Reductions, ReduceSplitMask) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithMask(loops[i], 8); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -1028,11 +1023,11 @@ TEST(Reductions, ReduceSplitNoMask) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithMask(loops[i], 5); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -1061,11 +1056,11 @@ TEST(Reductions, ReduceOverSplitMask) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithMask(loops[i], 16); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg(s, {b, c}); @@ -1097,11 +1092,10 @@ TEST(Reductions, ReduceSplitRfactor) { Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithTail(loops[2], SPLIT_FACTOR); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto c_body = const_cast(loop.getAllWritesToBuf(c->buf())[2]); + auto c_body = loop.getAllWritesToBuf(c->buf())[2]; auto all_loops = loop.getAllLoopNestsWritingToBuf(c->buf()); ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]); @@ -1110,7 +1104,7 @@ TEST(Reductions, ReduceSplitRfactor) { ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1])); loop.prepareForCodegen(); loop.simplify(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); SimpleIREvaluator cg(s, {b, c}); @@ -1139,22 +1133,21 @@ TEST(Reductions, ReduceOverSplitRfactor) { Tensor* c = Reduce("sum", {}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *i, *t; + ForPtr i, t; LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); LoopNest::reorderAxis(loops[0], i); auto all_loops = loop.getAllLoopNestsWritingToBuf(c->buf()); ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto c_body = const_cast(loop.getAllWritesToBuf(c->buf())[1]); + auto c_body = loop.getAllWritesToBuf(c->buf())[1]; ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0])); LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]); loop.prepareForCodegen(); loop.simplify(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); SimpleIREvaluator cg(s, {b, c}); @@ -1249,8 +1242,8 @@ TEST(Reductions, ReduceInlineConsumer) { l1.prepareForCodegen(); l2.prepareForCodegen(); - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); - Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); @@ -1307,8 +1300,8 @@ TEST(Reductions, ReduceInlineReducerInternal) { l1.prepareForCodegen(); l2.prepareForCodegen(); - Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); - Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt()); + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); @@ -1352,11 +1345,11 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) { l_before.prepareForCodegen(); SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - Stmt* d_loop = l.getLoopStmtsFor(d)[0]; + StmtPtr d_loop = l.getLoopStmtsFor(d)[0]; l.cacheAccesses(d->buf(), "d_local", d_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); SimpleIREvaluator cg_after(result, {a, b, e}); std::ostringstream oss; @@ -1429,11 +1422,11 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { l_before.prepareForCodegen(); SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - Stmt* d_loop = l.getLoopStmtsFor(d)[1]; + StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; l.cacheAccesses(d->buf(), "d_local", d_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); SimpleIREvaluator cg_after(result, {a, b, e}); std::ostringstream oss; @@ -1504,11 +1497,11 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { l_before.prepareForCodegen(); SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - Stmt* d_loop = l.getLoopStmtsFor(d)[2]; + StmtPtr d_loop = l.getLoopStmtsFor(d)[2]; l.cacheAccesses(d->buf(), "d_local", d_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); SimpleIREvaluator cg_after(result, {a, b, e}); std::ostringstream oss; @@ -1572,11 +1565,11 @@ TEST(Reductions, ReductionCacheBodyAccess) { LoopNest l({e}, {c, d, e}); - Stmt* d_loop = l.getLoopStmtsFor(d)[1]; + StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; l.cacheAccesses(c->buf(), "scale_local", d_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); std::ostringstream oss; oss << *result; @@ -1615,11 +1608,11 @@ TEST(Reductions, ReductionCacheConsumerAccess) { LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); - Stmt* e_loop = l.getLoopStmtsFor(e)[1]; + StmtPtr e_loop = l.getLoopStmtsFor(e)[1]; l.cacheAccesses(d->buf(), "sum_local", e_loop); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); std::ostringstream oss; oss << *result; @@ -1655,7 +1648,7 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) { LoopNest l({e}, {c, d, e}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner; + ForPtr inner; // Split outer reduction axis. LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner); @@ -1666,7 +1659,7 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) { l.cacheAccesses(d->buf(), "sum_local", inner); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); // reduction changes but cache does not. std::ostringstream oss; @@ -1703,7 +1696,7 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { LoopNest l({e}, {c, d, e}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner; + ForPtr inner; // reorder outer reduction axes. auto loops = l.getLoopStmtsFor(d); @@ -1715,7 +1708,7 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { l.cacheAccesses(d->buf(), "sum_local", inner); l.prepareForCodegen(); - Stmt* result = IRSimplifier::simplify(l.root_stmt()); + StmtPtr result = IRSimplifier::simplify(l.root_stmt()); // neither reduction body not cache changes. std::ostringstream oss; @@ -1752,13 +1745,12 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) { Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); + std::vector loops = loop.getLoopStmtsFor(c); LoopNest::reorderAxis(loops.at(0), loops.at(1)); loops = loop.getLoopStmtsFor(c); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto c_body = const_cast(loop.getAllWritesToBuf(c->buf())[1]); + auto c_body = loop.getAllWritesToBuf(c->buf())[1]; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Buf* rfac_buf; + BufPtr rfac_buf; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); loop.distributeLoop(loops.at(0)); @@ -1770,7 +1762,7 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) { LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]); loop.simplify(); loop.prepareForCodegen(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); std::ostringstream oss; oss << *s; @@ -1821,14 +1813,13 @@ TEST(Reductions, ReductionRfactorCacheTempInner) { Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto c_body = const_cast(loop.getAllWritesToBuf(c->buf())[1]); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c->buf())[1]; LoopNest::reorderAxis(loops.at(0), loops.at(1)); loops = loop.getLoopStmtsFor(c); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Buf* rfac_buf; + BufPtr rfac_buf; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); loop.distributeLoop(loops.at(0)); auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); @@ -1840,7 +1831,7 @@ TEST(Reductions, ReductionRfactorCacheTempInner) { LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]); loop.prepareForCodegen(); loop.simplify(); - Stmt* s = loop.root_stmt(); + StmtPtr s = loop.root_stmt(); std::ostringstream oss; oss << *s; @@ -1889,7 +1880,7 @@ TEST(Reductions, ReductionVectorize) { ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0])); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); s = IRSimplifier::simplify(s); std::ostringstream oss; @@ -1950,12 +1941,11 @@ TEST(Reductions, ReductionVectorizeRfactor) { // But if we rfactor this so it's not a reduce axis we can vectorize that // loop. - std::vector loops = l.getLoopStmtsFor(tensor); + std::vector loops = l.getLoopStmtsFor(tensor); LoopNest::reorderAxis(loops[0], loops[1]); loops = l.getLoopStmtsFor(tensor); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto tensor_body = const_cast(l.getAllWritesToBuf(tensor->buf())[1]); - Buf* rfac_buf = nullptr; + auto tensor_body = l.getAllWritesToBuf(tensor->buf())[1]; + BufPtr rfac_buf = nullptr; ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf)); LoopNest::distributeLoop(loops.at(0)); @@ -1964,7 +1954,7 @@ TEST(Reductions, ReductionVectorizeRfactor) { ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0])); l.simplify(); - Stmt* s = l.root_stmt(); + StmtPtr s = l.root_stmt(); std::ostringstream oss; oss << *s; @@ -2007,7 +1997,7 @@ TEST(Reductions, InitFunction) { {{M, "m"}}); LoopNest nest({C}); nest.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); std::ostringstream oss; oss << *s << "\n"; const std::string& expected_ir = diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp index 2c771c7..a0ac095 100644 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ b/test/cpp/tensorexpr/test_registerizer.cpp @@ -16,7 +16,7 @@ TEST(Registerizer, RegisterizerSimple) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make( x, @@ -61,7 +61,7 @@ TEST(Registerizer, RegisterizerLoop) { KernelScope kernel_scope; BufHandle a("A", {10}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make( x, @@ -108,7 +108,7 @@ TEST(Registerizer, RegisterizerLoopFixedLoad) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make( x, @@ -155,7 +155,7 @@ TEST(Registerizer, RegisterizerLoopInternal) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, @@ -209,7 +209,7 @@ TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, @@ -238,7 +238,7 @@ TEST(Registerizer, RegisterizerLoopInternalRepeated) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make( x, 0, @@ -310,7 +310,7 @@ TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make( x, 0, @@ -357,7 +357,7 @@ TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make( x, 0, @@ -403,7 +403,7 @@ TEST(Registerizer, RegisterizerMultiVar) { KernelScope kernel_scope; BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({ + StmtPtr stmt = Block::make({ Store::make(a, {0}, 0), Store::make(a, {1}, 0), For::make( @@ -461,7 +461,7 @@ TEST(Registerizer, RegisterizerVariableLoad) { BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); VarHandle x2("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make(x, 0, 10, Store::make(b, {x}, x)), For::make( @@ -517,7 +517,7 @@ TEST(Registerizer, RegisterizerSymbolicIndices) { VarHandle N("N", kInt); BufHandle a("A", {N}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {i}, 0), For::make( x, @@ -563,7 +563,7 @@ TEST(Registerizer, RegisterizerMultiLoop) { BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make( x, @@ -619,7 +619,7 @@ TEST(Registerizer, RegisterizerRepeated) { KernelScope kernel_scope; BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({ + StmtPtr stmt = Block::make({ Store::make(a, {0}, 0), Store::make(a, {1}, 0), For::make( @@ -676,7 +676,7 @@ TEST(Registerizer, RegisterizerNoLoads) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make( x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))}); @@ -718,7 +718,7 @@ TEST(Registerizer, RegisterizerNoRepeatedStores) { BufHandle a("A", {1}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make( x, @@ -766,7 +766,7 @@ TEST(Registerizer, RegisterizerMultiVarOverlap) { KernelScope kernel_scope; BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({ + StmtPtr stmt = Block::make({ Store::make(a, {0}, 0), Store::make(a, {1}, 0), For::make( @@ -800,7 +800,7 @@ TEST(Registerizer, RegisterizerAllocs) { BufHandle b("B", {Load::make(c, {0})}, kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Allocate::make(b), Store::make(a, {0}, Load::make(c, {0})), Store::make(b, {0}, 0), @@ -863,7 +863,7 @@ TEST(Registerizer, RegisterizerNoInitializer) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, @@ -903,7 +903,7 @@ TEST(Registerizer, RegisterizerNoInitializerLoopVar) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, @@ -933,7 +933,7 @@ TEST(Registerizer, RegisterizerLoadThenStore) { BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, @@ -985,7 +985,7 @@ TEST(Registerizer, RegisterizerParallelized) { VarHandle x("x", kInt); LoopOptions loopOpts; loopOpts.set_gpu_block_index(0); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make( x, @@ -1015,7 +1015,7 @@ TEST(Registerizer, RegisterizerConditionAfter) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {x}, Load::make(b, {x})), Store::make(c, {x}, Load::make(a, {x})), Cond::make( @@ -1065,7 +1065,7 @@ TEST(Registerizer, RegisterizerConditionBefore) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), @@ -1117,7 +1117,7 @@ TEST(Registerizer, RegisterizerConditionInside) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {x}, Load::make(b, {x})), Store::make(c, {x}, Load::make(a, {x})), Cond::make( @@ -1178,7 +1178,7 @@ TEST(Registerizer, RegisterizerConditionInsideOverlap1) { VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) {Store::make(a, {x}, Load::make(b, {x})), Store::make(c, {x}, Load::make(a, {x})), @@ -1238,7 +1238,7 @@ TEST(Registerizer, RegisterizerConditionInsideOverlap2) { VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) {Store::make(a, {x}, Load::make(b, {x})), Store::make(a, {x}, Load::make(b, {x + 1})), @@ -1323,7 +1323,7 @@ TEST(Registerizer, RegisterizerConditionHidden) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), @@ -1365,7 +1365,7 @@ TEST(Registerizer, RegisterizerConditionUnhidden) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), @@ -1426,7 +1426,7 @@ TEST(Registerizer, RegisterizerCondCondition) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {x}, Load::make(b, {x})), Store::make(c, {x}, Load::make(a, {x})), Cond::make( @@ -1477,7 +1477,7 @@ TEST(Registerizer, RegisterizerCondConditionUnhidden) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({Cond::make( + StmtPtr stmt = Block::make({Cond::make( CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT), Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))}); @@ -1527,7 +1527,7 @@ TEST(Registerizer, RegisterizerIfThenElseHidden) { VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make( b, {y}, @@ -1569,7 +1569,7 @@ TEST(Registerizer, RegisterizerIfThenElseUnhidden) { VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make({ + StmtPtr stmt = Block::make({ Store::make(a, {x}, 0), Store::make( b, @@ -1624,7 +1624,7 @@ TEST(Registerizer, RegisterizerIfThenElseNested) { BufHandle d("D", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({Store::make( + StmtPtr stmt = Block::make({Store::make( a, {x}, IfThenElse::make( @@ -1667,7 +1667,7 @@ TEST(Registerizer, RegisterizerIfThenElseInternal) { BufHandle b("B", {5}, kFloat); VarHandle x("x", kInt); - Stmt* stmt = Block::make({Store::make( + StmtPtr stmt = Block::make({Store::make( a, {x}, IfThenElse::make( @@ -1746,7 +1746,7 @@ TEST(Registerizer, RegisterizerIfThenElseCondition) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {x}, Load::make(a, {x})), Store::make( a, @@ -1792,7 +1792,7 @@ TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({Store::make( + StmtPtr stmt = Block::make({Store::make( b, {x}, IfThenElse::make( @@ -1829,7 +1829,7 @@ TEST(Registerizer, RegisterizerConditionBranchOnly) { KernelScope kernel_scope; BufHandle a("A", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, @@ -1883,7 +1883,7 @@ TEST(Registerizer, RegisterizerCondIfThenElse) { BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({Cond::make( + StmtPtr stmt = Block::make({Cond::make( CompareSelect::make( IfThenElse::make( CompareSelect::make( @@ -1933,7 +1933,7 @@ TEST(Registerizer, RegisterizerIfThenElseLoop) { VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = For::make( + StmtPtr stmt = For::make( y, 0, 10, @@ -1983,7 +1983,7 @@ TEST(Registerizer, RegisterizerIfThenElseLoopCut) { VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( y, 0, 10, @@ -2019,7 +2019,7 @@ TEST(Registerizer, RegisterizerPartialAfter) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 0), For::make( x, @@ -2076,7 +2076,7 @@ TEST(Registerizer, RegisterizerPartialBefore) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), Store::make(a, {0}, 0), For::make( @@ -2135,7 +2135,7 @@ TEST(Registerizer, RegisterizerPartialInside) { VarHandle x1("x1", kInt); VarHandle x2("x2", kInt); VarHandle x3("x3", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 2), For::make( x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))), @@ -2203,7 +2203,7 @@ TEST(Registerizer, RegisterizerPartialCondition) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 2), For::make( x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))), @@ -2273,7 +2273,7 @@ TEST(Registerizer, RegisterizerPartialConditionInternalCut) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 1), Store::make(a, {0}, 3), Cond::make( @@ -2336,7 +2336,7 @@ TEST(Registerizer, RegisterizerPartialConditionInternalStart) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, 1), Store::make(a, {0}, 3), Cond::make( @@ -2400,7 +2400,7 @@ TEST(Registerizer, RegisterizerPartialOverlapsTwo) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {1}, Load::make(a, {0})), Store::make(a, {0}, Load::make(a, {1})), Store::make(a, {0}, Load::make(a, {1})), @@ -2471,7 +2471,7 @@ TEST(Registerizer, RegisterizerNestedBlocks) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}), @@ -2525,7 +2525,7 @@ TEST(Registerizer, RegisterizerNestedConditions) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({Cond::make( + StmtPtr stmt = Block::make({Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), Block::make( {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), @@ -2581,7 +2581,7 @@ TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), @@ -2637,7 +2637,7 @@ TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Cond::make( CompareSelect::make(x, 2, CompareSelectOperation::kEQ), Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), @@ -2680,7 +2680,7 @@ TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), Block::make({Cond::make( @@ -2725,7 +2725,7 @@ TEST(Registerizer, RegisterizerNestedConditionsCut) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), @@ -2765,7 +2765,7 @@ TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Cond::make( CompareSelect::make(x, 2, CompareSelectOperation::kEQ), Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), @@ -2812,7 +2812,7 @@ TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {4}, 0), Cond::make( CompareSelect::make(x, 2, CompareSelectOperation::kGT), @@ -2912,7 +2912,7 @@ TEST(Registerizer, RegisterizerNestedLoopSimple) { BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( y, 0, 10, @@ -2968,7 +2968,7 @@ TEST(Registerizer, RegisterizerHiddenAccessYes) { BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make({Cond::make( + StmtPtr stmt = Block::make({Cond::make( CompareSelect::make(x, 2, CompareSelectOperation::kEQ), Block::make( {Store::make(a, {0}, 0), @@ -3051,7 +3051,7 @@ TEST(Registerizer, RegisterizerHiddenAccessNo) { BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make({Cond::make( + StmtPtr stmt = Block::make({Cond::make( CompareSelect::make(x, 2, CompareSelectOperation::kEQ), Block::make({For::make( x, @@ -3131,7 +3131,7 @@ TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make({Cond::make( + StmtPtr stmt = Block::make({Cond::make( CompareSelect::make(x, 2, CompareSelectOperation::kEQ), Block::make( {Store::make(a, {0}, 0), @@ -3211,7 +3211,7 @@ TEST(Registerizer, RegisterizerTwoConditionalLoops) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), For::make( @@ -3283,7 +3283,7 @@ TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Cond::make( CompareSelect::make(x, 5, CompareSelectOperation::kLT), For::make( @@ -3366,7 +3366,7 @@ TEST(Registerizer, RegisterizerLoopLetVar) { BufHandle a("A", {10}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make({For::make( + StmtPtr stmt = Block::make({For::make( x, 0, 10, @@ -3400,7 +3400,7 @@ TEST(Registerizer, RegisterizerLoopLetVarOuter) { BufHandle a("A", {10}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Let::make(y, 30), For::make( x, @@ -3447,7 +3447,7 @@ TEST(Registerizer, RegisterizerMultiDim) { KernelScope kernel_scope; BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0, 1, 2}, 0), For::make( x, @@ -3493,7 +3493,7 @@ TEST(Registerizer, RegisterizerMultiDimPartial) { KernelScope kernel_scope; BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0, 1, 2}, 0), For::make( x, @@ -3542,7 +3542,7 @@ TEST(Registerizer, RegisterizerMultiDimOverlap) { BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0, 1, 2}, 0), For::make( x, @@ -3577,7 +3577,7 @@ TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - Stmt* stmt = Block::make( + StmtPtr stmt = Block::make( {Store::make(a, {0, 1, 2}, 0), For::make( x, @@ -3626,7 +3626,7 @@ TEST(Registerizer, RegisterizerMultiDim3DReduction1) { VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, @@ -3698,7 +3698,7 @@ TEST(Registerizer, RegisterizerMultiDim3DReduction2) { VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); - Stmt* stmt = For::make( + StmtPtr stmt = For::make( x, 0, 10, diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index 83da919..a08d4ca 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -149,9 +149,9 @@ TEST(Simplify, ConstantFoldWithVar) { ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); ExprHandle newF = IRSimplifier::simplify(body); - Mul* root = newF.AsNode(); + MulPtr root = newF.AsNode(); ASSERT_NE(root, nullptr); - ASSERT_NE(dynamic_cast(root->lhs()), nullptr); + ASSERT_NE(to(root->lhs()), nullptr); SimpleIRExprEval eval(newF); eval.bindVar(x, ExprHandle(3)); @@ -163,9 +163,9 @@ TEST(Simplify, ConstantFoldWithVar) { ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); ExprHandle newF = IRSimplifier::simplify(body); - Mul* root = newF.AsNode(); + MulPtr root = newF.AsNode(); ASSERT_NE(root, nullptr); - ASSERT_NE(dynamic_cast(root->rhs()), nullptr); + ASSERT_NE(to(root->rhs()), nullptr); SimpleIRExprEval eval(newF); eval.bindVar(x, ExprHandle(3.f)); @@ -274,7 +274,7 @@ TEST(Simplify, ConditionalSelectFoldWithVar) { ExprHandle f = x < 4.f; ExprHandle newF = IRSimplifier::simplify(f); - const IntImm* folded = newF.AsNode(); + IntImmPtr folded = newF.AsNode(); ASSERT_EQ(folded, nullptr); { @@ -296,10 +296,10 @@ TEST(Simplify, UnFoldableExpr) { ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); ExprHandle newF = IRSimplifier::simplify(body); - Add* root = newF.AsNode(); + AddPtr root = newF.AsNode(); ASSERT_NE(root, nullptr); - ASSERT_EQ(dynamic_cast(root->lhs()), nullptr); - ASSERT_EQ(dynamic_cast(root->rhs()), nullptr); + ASSERT_EQ(to(root->lhs()), nullptr); + ASSERT_EQ(to(root->rhs()), nullptr); SimpleIRExprEval eval(newF); eval.bindVar(x, ExprHandle(3.f)); @@ -334,7 +334,7 @@ TEST(Simplify, HashEquivalence) { VarHandle y("y", kFloat); ExprHandle f = (x * y) + (x * y); - Add* root = f.AsNode(); + AddPtr root = f.AsNode(); ASSERT_NE(root, nullptr); HashProvider hasher; @@ -370,7 +370,7 @@ TEST(Simplify, HashEquivalenceRand) { ExprHandle f = Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); - Add* root = f.AsNode(); + AddPtr root = f.AsNode(); ASSERT_NE(root, nullptr); HashProvider hasher; @@ -415,18 +415,18 @@ TEST(Simplify, HashDifferenceTypes) { KernelScope kernel_scope; HashProvider hasher; - std::vector immediates; + std::vector immediates; - immediates.push_back(new DoubleImm(1)); - immediates.push_back(new FloatImm(1)); - immediates.push_back(new HalfImm(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); // NOLINTNEXTLINE(modernize-use-bool-literals) - immediates.push_back(new BoolImm(1)); - immediates.push_back(new CharImm(1)); - immediates.push_back(new ByteImm(1)); - immediates.push_back(new ShortImm(1)); - immediates.push_back(new IntImm(1)); - immediates.push_back(new LongImm(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); // Immediates of different types are not equal. for (unsigned int i = 0; i < immediates.size(); ++i) { @@ -546,12 +546,12 @@ TEST(Simplify, SimplifyAdd) { ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); ExprHandle simplified = IRSimplifier::simplify(body); - Add* root = simplified.AsNode(); + AddPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); - Var* lhs = dynamic_cast(root->lhs()); + VarPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); ASSERT_EQ(lhs->name_hint(), "x"); - const IntImm* rhs = dynamic_cast(root->rhs()); + IntImmPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); ASSERT_EQ(rhs->value(), 6.f); } @@ -563,12 +563,12 @@ TEST(Simplify, SimplifySub) { ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); ExprHandle simplified = IRSimplifier::simplify(body); - Sub* root = simplified.AsNode(); + SubPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); - const IntImm* lhs = dynamic_cast(root->lhs()); + IntImmPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); ASSERT_EQ(lhs->value(), -2.f); - Var* rhs = dynamic_cast(root->rhs()); + VarPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); ASSERT_EQ(rhs->name_hint(), "x"); } @@ -594,12 +594,12 @@ TEST(Simplify, SimplifyMultiTerm) { (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); ExprHandle simplified = IRSimplifier::simplify(body); - Mul* root = simplified.AsNode(); + MulPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); - const IntImm* lhs = dynamic_cast(root->lhs()); + IntImmPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); ASSERT_EQ(lhs->value(), 2); - Var* rhs = dynamic_cast(root->rhs()); + VarPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); ASSERT_EQ(rhs->name_hint(), "x"); } @@ -612,12 +612,12 @@ TEST(Simplify, SimplifyCasts) { (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); ExprHandle simplified = IRSimplifier::simplify(body); - Mul* root = simplified.AsNode(); + MulPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); - const LongImm* lhs = dynamic_cast(root->lhs()); + LongImmPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); ASSERT_EQ(lhs->value(), 2); - Var* rhs = dynamic_cast(root->rhs()); + VarPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); ASSERT_EQ(rhs->name_hint(), "x"); } @@ -629,7 +629,7 @@ TEST(Simplify, SimplifyEliminatesNoOps) { ExprHandle body = (x + ExprHandle(0)) * 1; ExprHandle simplified = IRSimplifier::simplify(body); - Var* root = simplified.AsNode(); + VarPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); ASSERT_EQ(root->name_hint(), "x"); } @@ -643,16 +643,16 @@ TEST(Simplify, SimplifyMultiVar) { ExprHandle simplified = IRSimplifier::simplify(body); - Add* root = simplified.AsNode(); + AddPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); - Mul* lhs = dynamic_cast(root->lhs()); + MulPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); - Var* varX = dynamic_cast(lhs->rhs()); + VarPtr varX = to(lhs->rhs()); ASSERT_NE(varX, nullptr); ASSERT_EQ(varX->name_hint(), "y"); - Mul* rhs = dynamic_cast(root->rhs()); + MulPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); - Var* varY = dynamic_cast(rhs->rhs()); + VarPtr varY = to(rhs->rhs()); ASSERT_NE(varY, nullptr); ASSERT_EQ(varY->name_hint(), "x"); } @@ -665,7 +665,7 @@ TEST(Simplify, DISABLED_SimplifyReorderings) { ExprHandle body = x + 2 + y; ExprHandle simplified = IRSimplifier::simplify(body); - Add* root = simplified.AsNode(); + AddPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); IS_NODE_WITH_NAME(Add, root->lhs(), rhs); @@ -1153,7 +1153,7 @@ TEST(Simplify, SimplifyDivWithLoopContext1) { BufHandle a_buf("A", {6}, kInt); auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6)); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1175,7 +1175,7 @@ TEST(Simplify, SimplifyDivWithLoopContext2) { BufHandle a_buf("A", {5}, kInt); auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6)); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1197,7 +1197,7 @@ TEST(Simplify, SimplifyDivWithLoopContext3) { BufHandle a_buf("A", {6}, kInt); auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6))); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1219,7 +1219,7 @@ TEST(Simplify, SimplifyDivWithLoopContext4) { BufHandle a_buf("A", {5}, kInt); auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6)); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1245,7 +1245,7 @@ TEST(Simplify, SimplifyDivWithLoopContext5) { auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6)); auto for_i = For::make(i, 0, 6, for_j); - const Stmt* simplified = IRSimplifier::simplify(for_i); + const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); @@ -1273,7 +1273,7 @@ TEST(Simplify, SimplifyDivWithLoopContext6) { For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6)); auto for_i = For::make(i, 0, 6, for_j); - const Stmt* simplified = IRSimplifier::simplify(for_i); + const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); @@ -1301,7 +1301,7 @@ TEST(Simplify, SimplifyDivWithLoopContext7) { For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6))); auto for_i = For::make(i, 0, 6, for_j); - const Stmt* simplified = IRSimplifier::simplify(for_i); + const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); @@ -1324,7 +1324,7 @@ TEST(Simplify, SimplifyModWithLoopContext0) { BufHandle a_buf("A", {100}, kInt); auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100))); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1346,7 +1346,7 @@ TEST(Simplify, SimplifyModWithLoopContext1) { BufHandle a_buf("A", {6}, kInt); auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6)); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1368,7 +1368,7 @@ TEST(Simplify, SimplifyModWithLoopContext2) { BufHandle a_buf("A", {5}, kInt); auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6)); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1390,7 +1390,7 @@ TEST(Simplify, SimplifyModWithLoopContext3) { BufHandle a_buf("A", {6}, kInt); auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6))); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1412,7 +1412,7 @@ TEST(Simplify, SimplifyModWithLoopContext4) { BufHandle a_buf("A", {5}, kInt); auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6)); - const Stmt* simplified = IRSimplifier::simplify(for_stmt); + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); @@ -1438,7 +1438,7 @@ TEST(Simplify, SimplifyModWithLoopContext5) { auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6)); auto for_i = For::make(i, 0, 6, for_j); - const Stmt* simplified = IRSimplifier::simplify(for_i); + const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); @@ -1466,7 +1466,7 @@ TEST(Simplify, SimplifyModWithLoopContext6) { For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6)); auto for_i = For::make(i, 0, 6, for_j); - const Stmt* simplified = IRSimplifier::simplify(for_i); + const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); @@ -1494,7 +1494,7 @@ TEST(Simplify, SimplifyModWithLoopContext7) { For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6))); auto for_i = For::make(i, 0, 6, for_j); - const Stmt* simplified = IRSimplifier::simplify(for_i); + const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); @@ -3512,12 +3512,12 @@ TEST(Simplify, SimplifyConstantCond) { BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); ExprHandle condition(1); - Stmt* true_val = Store::make(a, {0}, 1); - Stmt* false_val = Store::make(b, {0}, 1); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); - Cond* body = new Cond(condition.node(), true_val, false_val); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + CondPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "A"); } @@ -3528,12 +3528,12 @@ TEST(Simplify, SimplifyConstantCond) { BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); ExprHandle condition(0); - Stmt* true_val = Store::make(a, {0}, 1); - Stmt* false_val = Store::make(b, {0}, 1); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); - Stmt* body = new Cond(condition.node(), true_val, false_val); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "B"); } @@ -3545,12 +3545,12 @@ TEST(Simplify, SimplifyConstantCond) { BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); ExprHandle condition(x - x); - Stmt* true_val = Store::make(a, {0}, 1); - Stmt* false_val = Store::make(b, {0}, 1); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); - Stmt* body = new Cond(condition.node(), true_val, false_val); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "B"); } @@ -3561,12 +3561,12 @@ TEST(Simplify, SimplifyConstantCond) { VarHandle x("x", kInt); BufHandle a("A", {1}, kInt); ExprHandle condition(x - x); - Stmt* true_val = Store::make(a, {0}, x); - Stmt* false_val = Store::make(a, {0}, x); + StmtPtr true_val = Store::make(a, {0}, x); + StmtPtr false_val = Store::make(a, {0}, x); - Stmt* body = new Cond(condition.node(), true_val, false_val); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "A"); } @@ -3577,12 +3577,12 @@ TEST(Simplify, SimplifyConstantCond) { VarHandle x("x", kInt); BufHandle a("A", {1}, kInt); ExprHandle condition(x - x); - Stmt* true_val = Store::make(a, {0}, ExprHandle(2) * x); - Stmt* false_val = Store::make(a, {0}, x + x); + StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x); + StmtPtr false_val = Store::make(a, {0}, x + x); - Stmt* body = new Cond(condition.node(), true_val, false_val); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "A"); } @@ -3593,24 +3593,30 @@ TEST(Simplify, SimplifyConstantCond) { VarHandle x("x", kInt); BufHandle a("A", {1}, kInt); ExprHandle condition(x); - Stmt* true_val = Store::make(a, {0}, x); - Stmt* false_val = Store::make(a, {0}, ExprHandle(2) * x); + StmtPtr true_val = Store::make(a, {0}, x); + StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x); - Stmt* body = new Cond(condition.node(), true_val, false_val); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); ASSERT_EQ(block, nullptr); } { - Stmt* cond = new Cond(ExprHandle(false).node(), new Block({}), nullptr); - Stmt* simplified = IRSimplifier::simplify(cond); + StmtPtr cond = alloc( + ExprHandle(false).node(), + alloc(std::vector({})), + nullptr); + StmtPtr simplified = IRSimplifier::simplify(cond); ASSERT_EQ(simplified, nullptr); } { - Stmt* cond = new Cond(ExprHandle(true).node(), nullptr, new Block({})); - Stmt* simplified = IRSimplifier::simplify(cond); + StmtPtr cond = alloc( + ExprHandle(true).node(), + nullptr, + alloc(std::vector({}))); + StmtPtr simplified = IRSimplifier::simplify(cond); ASSERT_EQ(simplified, nullptr); } } @@ -3621,11 +3627,11 @@ TEST(Simplify, SimplifyEliminateEmptyCond) { { VarHandle x("x", kInt); ExprHandle condition(x); - Stmt* true_val = new Block({}); + StmtPtr true_val = alloc(std::vector({})); - Stmt* body = new Cond(condition.node(), true_val, nullptr); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr body = alloc(condition.node(), true_val, nullptr); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); ASSERT_NE(block, nullptr); ASSERT_EQ(block->nstmts(), 0); } @@ -3633,11 +3639,11 @@ TEST(Simplify, SimplifyEliminateEmptyCond) { { VarHandle x("x", kInt); ExprHandle condition(x); - Stmt* false_val = new Block({}); + StmtPtr false_val = alloc(std::vector({})); - Stmt* body = new Cond(condition.node(), nullptr, false_val); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr body = alloc(condition.node(), nullptr, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); ASSERT_NE(block, nullptr); ASSERT_EQ(block->nstmts(), 0); } @@ -3836,8 +3842,8 @@ TEST(Simplify, SimplifyEliminateZeroLengthFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); ASSERT_EQ(block->nstmts(), 0); } @@ -3847,8 +3853,8 @@ TEST(Simplify, SimplifyEliminateZeroLengthFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); ASSERT_EQ(block->nstmts(), 0); } @@ -3859,8 +3865,8 @@ TEST(Simplify, SimplifyEliminateZeroLengthFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); ASSERT_EQ(block->nstmts(), 0); } @@ -3871,8 +3877,8 @@ TEST(Simplify, SimplifyEliminateZeroLengthFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); ASSERT_EQ(block->nstmts(), 0); } @@ -3882,7 +3888,7 @@ TEST(Simplify, SimplifyEliminateZeroLengthFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE(For, simplified); } } @@ -3896,8 +3902,8 @@ TEST(Simplify, SimplifyOneLoopFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 0); @@ -3909,8 +3915,8 @@ TEST(Simplify, SimplifyOneLoopFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 2); @@ -3923,8 +3929,8 @@ TEST(Simplify, SimplifyOneLoopFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_VAR_WITH_NAME(store->flat_index(), "x"); @@ -3938,8 +3944,8 @@ TEST(Simplify, SimplifyOneLoopFor) { VarHandle i("i", kInt); auto body = For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 0); @@ -3951,7 +3957,7 @@ TEST(Simplify, SimplifyOneLoopFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE(For, simplified); } } @@ -3968,7 +3974,7 @@ TEST(Simplify, SimplifyForWontLoseLoopOptions) { options.set_gpu_block_index(12); auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, for_); LoopOptions options2 = for_->loop_options(); ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index()); @@ -3984,10 +3990,10 @@ TEST(Simplify, SimplifyMultilevelFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); - auto* body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); + auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); auto outer = For::make(j, 0, 1, body); - Stmt* simplified = IRSimplifier::simplify(outer); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(outer); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 0); @@ -3999,15 +4005,15 @@ TEST(Simplify, SimplifyMultilevelFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); - auto* body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); + auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); auto outer = For::make(j, 0, 2, body); - Stmt* simplified = IRSimplifier::simplify(outer); - For* for__ = static_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(outer); + ForPtr for__ = static_to(simplified); IS_NODE_WITH_NAME(For, for__, for_); IS_VAR_WITH_NAME(for_->var(), "j"); IS_IMM_WITH_VAL(Int, for_->start(), 0); IS_IMM_WITH_VAL(Int, for_->stop(), 2); - Block* block = dynamic_cast(for_->body()); + BlockPtr block = to(for_->body()); ASSERT_NE(block, nullptr); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); @@ -4020,10 +4026,10 @@ TEST(Simplify, SimplifyMultilevelFor) { BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); - auto* body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); + auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); auto outer = For::make(j, 0, 1, body); - Stmt* simplified = IRSimplifier::simplify(outer); - Block* block = dynamic_cast(simplified); + StmtPtr simplified = IRSimplifier::simplify(outer); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(For, block->front(), for_); IS_VAR_WITH_NAME(for_->var(), "i"); IS_IMM_WITH_VAL(Int, for_->start(), 0); @@ -4050,10 +4056,10 @@ TEST(Simplify, SimplifyForCleansUp) { LoopNest l({b}); l.prepareForCodegen(); - Stmt* body = l.root_stmt(); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr body = l.root_stmt(); + StmtPtr simplified = IRSimplifier::simplify(body); - Block* block = dynamic_cast(simplified); + BlockPtr block = to(simplified); IS_NODE_WITH_NAME(For, block->front(), for_); // for is over "m". IS_VAR_WITH_NAME(for_->var(), "m"); @@ -4069,13 +4075,13 @@ TEST(Simplify, SimplifyEliminateEmptyFor) { { // Flatten many layers around an empty block to an empty block. - Stmt* last = new Block({}); + StmtPtr last = alloc(std::vector({})); for (int i = 0; i < 11; ++i) { VarHandle loopVar("loopVar", kInt); last = For::make(loopVar, 0, 10, last); } - Stmt* simplified = IRSimplifier::simplify(last); + StmtPtr simplified = IRSimplifier::simplify(last); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 0); } @@ -4088,14 +4094,14 @@ TEST(Simplify, SimplifyFlattenBlock) { // Flatten multiple blocks down to one. // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } BufHandle a("A", {1}, kInt); - Store* store1 = Store::make(a, {0}, 1); - Store* store2 = Store::make(a, {0}, 0); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); - Block* block1 = new Block({store1, store2}); - Block* block2 = new Block({block1}); + BlockPtr block1 = alloc(std::vector({store1, store2})); + BlockPtr block2 = alloc(std::vector({block1})); - Block* enclosing = new Block({block2}); - Stmt* simplified = IRSimplifier::simplify(enclosing); + BlockPtr enclosing = alloc(std::vector({block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); @@ -4111,14 +4117,14 @@ TEST(Simplify, SimplifyFlattenBlock) { // Flatten multiple sub blocks containing statements. // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } BufHandle a("A", {1}, kInt); - Store* store1 = Store::make(a, {0}, 1); - Store* store2 = Store::make(a, {0}, 0); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); - Block* block1 = new Block({store1}); - Block* block2 = new Block({store2}); + BlockPtr block1 = alloc(std::vector({store1})); + BlockPtr block2 = alloc(std::vector({store2})); - Block* enclosing = new Block({block1, block2}); - Stmt* simplified = IRSimplifier::simplify(enclosing); + BlockPtr enclosing = alloc(std::vector({block1, block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); @@ -4134,14 +4140,14 @@ TEST(Simplify, SimplifyFlattenBlock) { // Flatten sub blocks with different depths. // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } BufHandle a("A", {1}, kInt); - Store* store1 = Store::make(a, {0}, 1); - Store* store2 = Store::make(a, {0}, 0); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); - Block* block1 = new Block({store2}); - Block* block2 = new Block({block1}); + BlockPtr block1 = alloc(std::vector({store2})); + BlockPtr block2 = alloc(std::vector({block1})); - Block* enclosing = new Block({store1, block2}); - Stmt* simplified = IRSimplifier::simplify(enclosing); + BlockPtr enclosing = alloc(std::vector({store1, block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); @@ -4155,12 +4161,12 @@ TEST(Simplify, SimplifyFlattenBlock) { { // Flatten many layers around an empty block to an empty block. - Stmt* last = new Block({}); + StmtPtr last = alloc(std::vector({})); for (int i = 0; i < 11; ++i) { - last = new Block({last}); + last = alloc(std::vector({last})); } - Stmt* simplified = IRSimplifier::simplify(last); + StmtPtr simplified = IRSimplifier::simplify(last); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 0); } @@ -4173,13 +4179,13 @@ TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { // Simple positive case. BufHandle b("x", {0}, kInt); - Allocate* alloc = Allocate::make(b); - Free* free_ = Free::make(b); + AllocatePtr alloc_ = Allocate::make(b); + FreePtr free_ = Free::make(b); - Block* block1 = new Block({alloc, free_}); + BlockPtr block1 = alloc(std::vector({alloc_, free_})); ASSERT_EQ(block1->nstmts(), 2); - Stmt* simplified = IRSimplifier::simplify(block1); + StmtPtr simplified = IRSimplifier::simplify(block1); IS_NODE_WITH_NAME(Block, simplified, block2); ASSERT_EQ(block2->nstmts(), 0); } @@ -4188,13 +4194,13 @@ TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { // Simple negative case. BufHandle b("x", {2}, kInt); - Allocate* alloc = Allocate::make(b); - Free* free_ = Free::make(b); + AllocatePtr alloc_ = Allocate::make(b); + FreePtr free_ = Free::make(b); - Block* block1 = new Block({alloc, free_}); + BlockPtr block1 = alloc(std::vector({alloc_, free_})); ASSERT_EQ(block1->nstmts(), 2); - Stmt* simplified = IRSimplifier::simplify(block1); + StmtPtr simplified = IRSimplifier::simplify(block1); IS_NODE_WITH_NAME(Block, simplified, block2); ASSERT_EQ(block2->nstmts(), 2); } @@ -4204,15 +4210,16 @@ TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { BufHandle b1("x", {0}, kInt); BufHandle b2("y", {2}, kInt); - Allocate* alloc1 = Allocate::make(b1); - Allocate* alloc2 = Allocate::make(b2); - Free* free2_ = Free::make(b2); - Free* free1_ = Free::make(b1); + AllocatePtr alloc1 = Allocate::make(b1); + AllocatePtr alloc2 = Allocate::make(b2); + FreePtr free2_ = Free::make(b2); + FreePtr free1_ = Free::make(b1); - Block* block1 = new Block({alloc1, alloc2, free2_, free1_}); + BlockPtr block1 = + alloc(std::vector({alloc1, alloc2, free2_, free1_})); ASSERT_EQ(block1->nstmts(), 4); - Stmt* simplified = IRSimplifier::simplify(block1); + StmtPtr simplified = IRSimplifier::simplify(block1); IS_NODE_WITH_NAME(Block, simplified, block2); ASSERT_EQ(block2->nstmts(), 2); IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc); @@ -4227,14 +4234,15 @@ TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { BufHandle b1("x", {0}, kInt); BufHandle b2("y", {z}, kInt); - Allocate* alloc1 = Allocate::make(b1); - Allocate* alloc2 = Allocate::make(b2); - Free* free2_ = Free::make(b2); - Free* free1_ = Free::make(b1); + AllocatePtr alloc1 = Allocate::make(b1); + AllocatePtr alloc2 = Allocate::make(b2); + FreePtr free2_ = Free::make(b2); + FreePtr free1_ = Free::make(b1); - Block* block1 = new Block({alloc1, alloc2, free2_, free1_}); + BlockPtr block1 = + alloc(std::vector({alloc1, alloc2, free2_, free1_})); ASSERT_EQ(block1->nstmts(), 4); - Stmt* simplified = IRSimplifier::simplify(block1); + StmtPtr simplified = IRSimplifier::simplify(block1); IS_NODE_WITH_NAME(Block, simplified, block2); ASSERT_EQ(block2->nstmts(), 2); } @@ -4293,7 +4301,7 @@ TEST(Simplify, SimplifyReorderForCond) { Store::make(c, {i}, Load::make(a, {i})), nullptr)); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); @@ -4310,7 +4318,7 @@ TEST(Simplify, SimplifyReorderForCond) { Store::make(c, {i}, Load::make(a, {i})), nullptr)); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); } @@ -4328,7 +4336,7 @@ TEST(Simplify, SimplifyReorderForCond) { Store::make(c, {0}, Load::make(a, {i})), nullptr)); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); } @@ -4345,7 +4353,7 @@ TEST(Simplify, SimplifyReorderForCond) { Store::make(c, {0}, Load::make(a, {i})), nullptr)); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); @@ -4363,7 +4371,7 @@ TEST(Simplify, SimplifyReorderForCond) { Store::make(c, {0}, Load::make(a, {i})), nullptr)); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); @@ -4382,7 +4390,7 @@ TEST(Simplify, SimplifyReorderForCond) { Store::make(c, {0}, Load::make(a, {i})), nullptr)})); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Let, loop->body()->front(), let); IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond); @@ -4404,7 +4412,7 @@ TEST(Simplify, SimplifyReorderForCond) { nullptr), nullptr)); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(Cond, true_block->front(), cond2); @@ -4428,7 +4436,7 @@ TEST(Simplify, SimplifyReorderForCond) { nullptr), nullptr)); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); @@ -4448,7 +4456,7 @@ TEST(Simplify, SimplifyReorderForCond) { Store::make(c, {0}, Load::make(a, {i})), Store::make(c, {0}, 0))); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); } @@ -4467,7 +4475,7 @@ TEST(Simplify, SimplifyReorderForCond) { Store::make(c, {1}, Load::make(a, {i})), nullptr)); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); } @@ -4493,7 +4501,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); @@ -4515,7 +4523,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); @@ -4541,7 +4549,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); @@ -4568,7 +4576,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); @@ -4589,25 +4597,15 @@ TEST(Simplify, SimplifyFuseConditions) { // TODO for later. auto body = Block::make( {Cond::make( - CompareSelect::make( - i, - 10, - new IntImm(1), - new IntImm(0), - CompareSelectOperation::kLT), + CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( - CompareSelect::make( - j, - 10, - new IntImm(2), - new IntImm(0), - CompareSelectOperation::kLT), + CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); @@ -4634,7 +4632,7 @@ TEST(Simplify, SimplifyFuseConditions) { nullptr, Store::make(a, {1}, i))}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); @@ -4655,7 +4653,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {1}, i), Store::make(b, {1}, i))}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); @@ -4677,7 +4675,7 @@ TEST(Simplify, SimplifyFuseConditions) { nullptr, Store::make(b, {1}, i))}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); @@ -4723,7 +4721,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {1}, j), nullptr), }); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); auto it = block->begin(); @@ -4754,7 +4752,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {1}, j), nullptr), }); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); @@ -4784,7 +4782,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {1}, j), nullptr), }); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); IS_NODE_WITH_NAME(Cond, block->front(), cond); @@ -4819,7 +4817,7 @@ TEST(Simplify, SimplifyFuseConditions) { CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); @@ -4835,7 +4833,7 @@ TEST(Simplify, SimplifyFuseConditions) { {Cond::make(i, Store::make(a, {0}, i), nullptr), Cond::make(i, Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); @@ -4850,7 +4848,7 @@ TEST(Simplify, SimplifyFuseConditions) { {Cond::make(i, Store::make(a, {0}, i), nullptr), Cond::make(j, Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); @@ -4863,7 +4861,7 @@ TEST(Simplify, SimplifyFuseConditions) { auto body = Block::make( {Cond::make(1, Store::make(a, {0}, i), nullptr), Cond::make(1, Store::make(a, {1}, i), nullptr)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Store, block->front(), store1); @@ -4886,7 +4884,7 @@ TEST(Simplify, SimplifyFuseConditions) { Store::make(a, {2}, Load::make(b, {0})), nullptr)})); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); @@ -4903,10 +4901,10 @@ TEST(Simplify, SimplifySyncThreads) { auto body = Block::make( // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) {Store::make(a, {0}, 1), - new SyncThreads(), - new SyncThreads(), + alloc(), + alloc(), Store::make(a, {1}, 0)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); auto it = block->begin(); @@ -4918,9 +4916,9 @@ TEST(Simplify, SimplifySyncThreads) { { // Eliminate outer SyncThreads. auto body = Block::make( - {new SyncThreads(), Store::make(a, {1}, 0), new SyncThreads()}); + {alloc(), Store::make(a, {1}, 0), alloc()}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); auto it = block->begin(); @@ -4931,14 +4929,14 @@ TEST(Simplify, SimplifySyncThreads) { // Merge many inner SyncThreads. auto body = Block::make( {Store::make(a, {0}, 1), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), + alloc(), + alloc(), + alloc(), + alloc(), + alloc(), Store::make(a, {1}, 0)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); auto it = block->begin(); @@ -4950,15 +4948,15 @@ TEST(Simplify, SimplifySyncThreads) { { // Merge multiple outer SyncThreads. auto body = Block::make( - {new SyncThreads(), - new SyncThreads(), + {alloc(), + alloc(), Store::make(a, {1}, 0), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), - new SyncThreads()}); + alloc(), + alloc(), + alloc(), + alloc()}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); auto it = block->begin(); @@ -4969,16 +4967,16 @@ TEST(Simplify, SimplifySyncThreads) { // Merge multiple sections; auto body = Block::make( {Store::make(a, {0}, 1), - new SyncThreads(), - new SyncThreads(), + alloc(), + alloc(), Store::make(a, {1}, 0), Store::make(a, {2}, 0), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), + alloc(), + alloc(), + alloc(), Store::make(a, {3}, 0)}); - Stmt* simplified = IRSimplifier::simplify(body); + StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 6); auto it = block->begin(); @@ -4997,7 +4995,7 @@ TEST(Simplify, SimplifyRampSubBroadcast) { ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast); - Ramp* newRamp = simplified.AsNode(); + RampPtr newRamp = simplified.AsNode(); IS_NODE_WITH_NAME(IntImm, newRamp->base(), base); ASSERT_EQ(base->value(), 5); IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride); @@ -5040,7 +5038,7 @@ TEST(Simplify, DISABLED_CompareSelectCondAlwaysInLoopBounds) { constexpr int N = 8; Placeholder b("b", kFloat, {N}); VarHandle n("n", kInt); - Stmt* s = For::make( + StmtPtr s = For::make( n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT))); s = IRSimplifier::simplify(s); std::ostringstream oss; @@ -5065,7 +5063,7 @@ TEST(Simplify, DISABLED_IfThenCondAlwaysInLoopBounds) { constexpr int N = 8; Placeholder b("b", kFloat, {N}); VarHandle n("n", kInt); - Stmt* s = + StmtPtr s = For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f))); s = IRSimplifier::simplify(s); std::ostringstream oss; @@ -5099,7 +5097,7 @@ TEST(Simplify, DISABLED_MultiClauseCondAlwaysInLoopBounds) { csel = CompareSelect::make(j, 1, 1, csel, kLT); csel = CompareSelect::make(i, N - 1, 1, csel, kGE); csel = CompareSelect::make(j, N - 1, 1, csel, kGE); - Stmt* s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); + StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); s = For::make(j, 1, N - 1, s); s = For::make(i, 1, N - 1, s); s = IRSimplifier::simplify(s); @@ -5137,7 +5135,7 @@ TEST(Simplify, DISABLED_SimplifyLoopBounds) { csel = CompareSelect::make(j, 1, 1, csel, kLT); csel = CompareSelect::make(i, N - 1, 1, csel, kGE); csel = CompareSelect::make(j, N - 1, 1, csel, kGE); - Stmt* s = b.store( + StmtPtr s = b.store( {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j}))); s = For::make(j, 0, K, s); s = For::make(i, 0, K, s); diff --git a/test/cpp/tensorexpr/test_utils.h b/test/cpp/tensorexpr/test_utils.h index de95df2..01b92a7 100644 --- a/test/cpp/tensorexpr/test_utils.h +++ b/test/cpp/tensorexpr/test_utils.h @@ -10,63 +10,63 @@ namespace torch { namespace jit { using namespace torch::jit::tensorexpr; -#define IS_NODE(T, node) \ - { \ - auto* node_ = dynamic_cast(node); \ - ASSERT_NE(nullptr, node_); \ +#define IS_NODE(T, node) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ } -#define IS_NODE_WITH_NAME(T, node, name) \ - auto* name = dynamic_cast(node); \ +#define IS_NODE_WITH_NAME(T, node, name) \ + auto name = to(node); \ ASSERT_NE(nullptr, name); #define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ - const T* name = nullptr; \ + NodePtr name = nullptr; \ { \ - auto* node_ = dynamic_cast(node); \ + auto node_ = to(node); \ ASSERT_NE(nullptr, node_); \ ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ - name = dynamic_cast(node_->src_value()); \ + name = to(node_->src_value()); \ } \ ASSERT_NE(nullptr, name); -#define IS_IMM_WITH_VAL(T, node, val) \ - { \ - auto* node_ = dynamic_cast(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->value(), val); \ +#define IS_IMM_WITH_VAL(T, node, val) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->value(), val); \ } -#define IS_VAR_WITH_NAME(node, name) \ - { \ - auto* node_ = dynamic_cast(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->name_hint(), name); \ +#define IS_VAR_WITH_NAME(node, name) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->name_hint(), name); \ } #define IS_BINOP_W_VARS(T, node, name, v1, v2) \ - const T* name = nullptr; \ + NodePtr name = nullptr; \ { \ - name = dynamic_cast(node); \ + name = to(node); \ ASSERT_NE(nullptr, name); \ IS_VAR_WITH_NAME(name->lhs(), v1); \ IS_VAR_WITH_NAME(name->rhs(), v2); \ } #define IS_BINOP_W_CONST(T, node, name, v, c) \ - const T* name = nullptr; \ + NodePtr name = nullptr; \ { \ - name = dynamic_cast(node); \ + name = to(node); \ ASSERT_NE(nullptr, name); \ IS_VAR_WITH_NAME(name->lhs(), v); \ IS_IMM_WITH_VAL(Int, name->rhs(), c); \ } -#define IS_RAND(node) \ - { \ - auto* node_ = dynamic_cast(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->op_type(), kRand); \ +#define IS_RAND(node) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->op_type(), kRand); \ } } // namespace jit diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp index 0d0be27..9320f47 100644 --- a/test/cpp/tensorexpr/tutorial.cpp +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -71,9 +71,9 @@ int main(int argc, char* argv[]) { // also be a 'Mul' or some other expression. // // Let's construct a simple TE: - Expr* lhs = new IntImm(5); - Expr* rhs = new Var("x", kInt); - Expr* mul = new Mul(lhs, rhs); + ExprPtr lhs = alloc(5); + ExprPtr rhs = alloc("x", kInt); + ExprPtr mul = alloc(lhs, rhs); std::cout << "Tensor expression: " << *mul << std::endl; // Prints: Tensor expression: 5 * x @@ -127,13 +127,14 @@ int main(int argc, char* argv[]) { // Let's start with defining a domain. We do this by creating a Buf object. // First, let's specify the sizes: - std::vector dims = { - new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate + std::vector dims = { + alloc(64), + alloc(32)}; // IntImm stands for Integer Immediate // and represents an integer constant // Now we can create a Buf object by providing a name, dimensions, and a // data type of the elements: - Buf* buf = new Buf("X", dims, kInt); + BufPtr buf = alloc("X", dims, kInt); // Next we need to spefify the computation. We can do that by either // constructing a complete tensor statement for it (statements are @@ -144,14 +145,14 @@ int main(int argc, char* argv[]) { // Let's define two variables, i and j - they will be axis in our // computation. - Var* i = new Var("i", kInt); - Var* j = new Var("j", kInt); - std::vector args = {i, j}; + VarPtr i = alloc("i", kInt); + VarPtr j = alloc("j", kInt); + std::vector args = {i, j}; // Now we can define the body of the tensor computation using these // variables. What this means is that values in our tensor are: // X[i, j] = i * j - Expr* body = new Mul(i, j); + ExprPtr body = alloc(i, j); // Finally, we pass all these pieces together to Tensor constructor: Tensor* X = new Tensor(buf, args, body); @@ -311,11 +312,11 @@ int main(int argc, char* argv[]) { // Loop transformations can be composed, so we can do something else with // our loop nest now. Let's split the inner loop with a factor of 9, for // instance. - std::vector loops = loopnest.getLoopStmtsFor(Y); + std::vector loops = loopnest.getLoopStmtsFor(Y); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* j_inner; + ForPtr j_inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* j_tail; + ForPtr j_tail; int split_factor = 9; loopnest.splitWithTail( loops[1], // loops[0] is the outer loop, loops[1] is inner diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index bd74666..eef5595 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -486,8 +486,8 @@ void optimizePointwise( tensorexpr::Tensor* target, int width) { using namespace torch::jit::tensorexpr; - std::vector loops = ln->getLoopStmtsFor(target); - For *inner, *tail; + std::vector loops = ln->getLoopStmtsFor(target); + ForPtr inner, tail; TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op"); ln->splitWithTail(loops[0], width, &inner, &tail); ln->vectorize(inner); @@ -503,7 +503,7 @@ std::shared_ptr wrapTECompute( LoopNest ln({out}); optimizePointwise(&ln, out, width); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(out); diff --git a/torch/csrc/jit/tensorexpr/analysis.h b/torch/csrc/jit/tensorexpr/analysis.h index 76b0e24..351eb87 100644 --- a/torch/csrc/jit/tensorexpr/analysis.h +++ b/torch/csrc/jit/tensorexpr/analysis.h @@ -10,7 +10,7 @@ namespace jit { namespace tensorexpr { class HasRand : public IRVisitor { public: - HasRand(Stmt* stmt) : stmt_(stmt) { + HasRand(StmtPtr stmt) : stmt_(stmt) { stmt_->accept(this); } @@ -19,146 +19,146 @@ class HasRand : public IRVisitor { } private: - void visit(Intrinsics* v) override { + void visit(IntrinsicsPtr v) override { if (v->op_type() == IntrinsicsOp::kRand) { has_rand_ = true; } else { IRVisitor::visit(v); } } - Stmt* stmt_; + StmtPtr stmt_; bool has_rand_ = false; }; -template +template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class NodeFinder : public IRVisitor { public: - void visit(Node* v) override { - nodes.push_back((Node*)v); + void visit(NodePtr v) override { + nodes.push_back((NodePtr)v); IRVisitor::visit(v); } - static std::vector find(Stmt* s) { - NodeFinder nf; + static std::vector> find(StmtPtr s) { + NodeFinder nf; s->accept(&nf); return nf.nodes; } - static std::vector find(Expr* e) { - NodeFinder nf; + static std::vector> find(ExprPtr e) { + NodeFinder nf; e->accept(&nf); return nf.nodes; } - std::vector nodes; + std::vector> nodes; }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class VarFinder : public IRVisitor { public: - void visit(Var* v) override { + void visit(VarPtr v) override { vars_.insert(v); IRVisitor::visit(v); } - static std::unordered_set find(Stmt* s) { + static std::unordered_set find(StmtPtr s) { VarFinder nf; s->accept(&nf); return nf.vars(); } - static std::unordered_set find(Expr* e) { + static std::unordered_set find(ExprPtr e) { VarFinder nf; e->accept(&nf); return nf.vars(); } - const std::unordered_set& vars() { + const std::unordered_set& vars() { return vars_; } private: - std::unordered_set vars_; + std::unordered_set vars_; }; class BufFinder : public IRVisitor { public: - void visit(Buf* v) override { + void visit(BufPtr v) override { bufs_.insert(v); IRVisitor::visit(v); } - static std::unordered_set find(Stmt* s) { + static std::unordered_set find(StmtPtr s) { BufFinder nf; s->accept(&nf); return nf.bufs(); } - static std::unordered_set find(Expr* e) { + static std::unordered_set find(ExprPtr e) { BufFinder nf; e->accept(&nf); return nf.bufs(); } - const std::unordered_set& bufs() { + const std::unordered_set& bufs() { return bufs_; } private: - std::unordered_set bufs_; + std::unordered_set bufs_; }; // Finds all kinds of write operations to the provided Buf. class WritesToBuf : public IRVisitor { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - WritesToBuf(Buf* target) : target_(target) {} + WritesToBuf(BufPtr target) : target_(target) {} - std::vector writes() { + std::vector writes() { return writes_; } - static std::vector find(Stmt* s, Buf* b) { + static std::vector find(StmtPtr s, BufPtr b) { WritesToBuf finder(b); s->accept(&finder); return finder.writes(); } private: - void visit(Store* v) override { + void visit(StorePtr v) override { if (v->buf() == target_) { writes_.push_back(v); } } - void visit(AtomicAdd* v) override { + void visit(AtomicAddPtr v) override { if (v->buf() == target_) { writes_.push_back(v); } } - Buf* target_; - std::vector writes_; + BufPtr target_; + std::vector writes_; }; class StmtsReadingBuf : public IRVisitor { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - StmtsReadingBuf(Buf* target) : target_(target) {} + StmtsReadingBuf(BufPtr target) : target_(target) {} - std::vector reads() { + std::vector reads() { return reads_; } - static std::vector find(Stmt* s, Buf* b) { + static std::vector find(StmtPtr s, BufPtr b) { StmtsReadingBuf finder(b); s->accept(&finder); return finder.reads(); } private: - bool readsBuffer(Stmt* s) { + bool readsBuffer(StmtPtr s) { auto loads = NodeFinder::find(s); for (auto l : loads) { if (l->buf() == target_) { @@ -168,40 +168,40 @@ class StmtsReadingBuf : public IRVisitor { return false; } - void visit(Store* v) override { + void visit(StorePtr v) override { if (readsBuffer(v)) { reads_.push_back(v); } } - void visit(Let* v) override { + void visit(LetPtr v) override { if (readsBuffer(v)) { reads_.push_back(v); } } - void visit(Cond* v) override { + void visit(CondPtr v) override { if (readsBuffer(v)) { reads_.push_back(v); } } - void visit(AtomicAdd* v) override { + void visit(AtomicAddPtr v) override { if (readsBuffer(v)) { reads_.push_back(v); } } - Buf* target_; - std::vector reads_; + BufPtr target_; + std::vector reads_; }; // Traverses the IR to determine if a particular Var is modified within it. class ModifiesVarChecker : public IRVisitor { public: - ModifiesVarChecker(Var* v) : var_(v) {} + ModifiesVarChecker(VarPtr v) : var_(v) {} - static bool check(Stmt* s, Var* v) { + static bool check(StmtPtr s, VarPtr v) { ModifiesVarChecker checker(v); s->accept(&checker); return checker.found(); @@ -212,7 +212,7 @@ class ModifiesVarChecker : public IRVisitor { } private: - void visit(Store* v) override { + void visit(StorePtr v) override { if (v->buf()->base_handle() == var_) { found_ = true; return; @@ -220,7 +220,7 @@ class ModifiesVarChecker : public IRVisitor { IRVisitor::visit(v); } - void visit(AtomicAdd* v) override { + void visit(AtomicAddPtr v) override { if (v->buf()->base_handle() == var_) { found_ = true; return; @@ -228,7 +228,7 @@ class ModifiesVarChecker : public IRVisitor { IRVisitor::visit(v); } - void visit(Let* v) override { + void visit(LetPtr v) override { if (v->var() == var_) { found_ = true; return; @@ -236,7 +236,7 @@ class ModifiesVarChecker : public IRVisitor { IRVisitor::visit(v); } - void visit(For* v) override { + void visit(ForPtr v) override { if (v->var() == var_) { found_ = true; return; @@ -244,7 +244,7 @@ class ModifiesVarChecker : public IRVisitor { IRVisitor::visit(v); } - Var* var_; + VarPtr var_; bool found_{false}; }; @@ -252,26 +252,26 @@ class ModifiesVarChecker : public IRVisitor { // It creates a map of multi dim buffers and their flat verions class CreateBufferMap : public IRVisitor { public: - const std::unordered_map& getBufferMap() const { + const std::unordered_map& getBufferMap() const { return map_input_to_tensor_bufs_; } private: - void visit(Store* v) override { - auto load_node = dynamic_cast(v->value()); + void visit(StorePtr v) override { + auto load_node = to(v->value()); if (load_node) { auto t_buf = load_node->buf(); map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf()); } else { - auto add_node = dynamic_cast(v->value()); - auto mul_node = dynamic_cast(v->value()); + auto add_node = to(v->value()); + auto mul_node = to(v->value()); // This means for now, v->value() can be Add or Mul TORCH_INTERNAL_ASSERT((add_node || mul_node)); map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf()); } v->value()->accept(this); } - std::unordered_map map_input_to_tensor_bufs_; + std::unordered_map map_input_to_tensor_bufs_; }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp index 4e5cccd..1ae3330 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -32,7 +32,7 @@ std::string blockDtypeCppString(const Dtype& dtype) { } } -bool BlockAnalysis::areBufsInMap(const std::unordered_set& bufs) const { +bool BlockAnalysis::areBufsInMap(const std::unordered_set& bufs) const { for (auto const& arg : bufs) { auto got = map_input_to_tensor_bufs_.find(arg->name_hint()); if (got == map_input_to_tensor_bufs_.end()) { @@ -42,7 +42,7 @@ bool BlockAnalysis::areBufsInMap(const std::unordered_set& bufs) const { return true; } -Buf* BlockAnalysis::getMultiDimBuf(Buf* buf) const { +BufPtr BlockAnalysis::getMultiDimBuf(BufPtr buf) const { auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint()); if (input_ != map_input_to_tensor_bufs_.end()) { return input_->second; @@ -51,7 +51,7 @@ Buf* BlockAnalysis::getMultiDimBuf(Buf* buf) const { } } -std::string BlockAnalysis::getInputName(Buf* buf) const { +std::string BlockAnalysis::getInputName(BufPtr buf) const { auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint()); if (input_ != map_input_to_tensor_bufs_.end()) { return input_->second->name_hint(); @@ -60,23 +60,23 @@ std::string BlockAnalysis::getInputName(Buf* buf) const { } } -void BlockAnalysis::visit(Store* v) { +void BlockAnalysis::visit(StorePtr v) { store_targets_.insert(v->buf()); v->value()->accept(this); } -void BlockAnalysis::visit(Load* v) { +void BlockAnalysis::visit(LoadPtr v) { loads_.insert(v->buf()); } -void BlockAnalysis::visit(For* v) { +void BlockAnalysis::visit(ForPtr v) { const LoopOptions& loop_options = v->loop_options(); if (loop_options.is_gpu_block_index()) { map_input_to_tensor_bufs_ = loop_options.get_buffer_mapping(); v->body()->accept(this); } else if (loop_options.is_gpu_thread_index()) { auto block_size = v->stop(); - block_size_ = dynamic_cast(block_size)->value(); + block_size_ = to(block_size)->value(); v->body()->accept(this); } else { IRVisitor::visit(v); @@ -90,26 +90,26 @@ void BlockAnalysis::visit(For* v) { // TODO: When handling fused ops d = a + b + c, the correct // way would be to mutate the expression to Block version and print. -void BlockPrinter::visit(Add* v) { +void BlockPrinter::visit(AddPtr v) { emitIndent(); os() << "add("; v->lhs()->accept(this); v->rhs()->accept(this); } -void BlockPrinter::visit(Mul* v) { +void BlockPrinter::visit(MulPtr v) { emitIndent(); os() << "mul("; v->lhs()->accept(this); v->rhs()->accept(this); } -void BlockPrinter::visit(For* v) { +void BlockPrinter::visit(ForPtr v) { const LoopOptions& loop_options = v->loop_options(); auto buf_reads = block_analysis_->loads(); auto buf_writes = block_analysis_->stores(); - std::unordered_set bufs(buf_reads.begin(), buf_reads.end()); + std::unordered_set bufs(buf_reads.begin(), buf_reads.end()); bufs.insert(buf_writes.begin(), buf_writes.end()); if (loop_options.is_gpu_block_index()) { @@ -145,7 +145,7 @@ void BlockPrinter::visit(For* v) { } } -void BlockPrinter::PrintTensorInfo(const std::unordered_set& bufs) { +void BlockPrinter::PrintTensorInfo(const std::unordered_set& bufs) { os() << "tensors {"; for (auto& buf : bufs) { os() << std::endl; @@ -178,19 +178,19 @@ void BlockPrinter::PrintTensorInfo(const std::unordered_set& bufs) { os() << "}" << std::endl << std::endl; } -void BlockPrinter::PrintArguments(const std::unordered_set& bufs) { +void BlockPrinter::PrintArguments(const std::unordered_set& bufs) { for (auto& buf : bufs) { auto multidimbuf = block_analysis_->getMultiDimBuf(buf); auto num_dims = multidimbuf->dims().size(); // The dims for the multi-dim tensors for (unsigned long d = 0; d < num_dims; d++) { - auto dim_val = dynamic_cast(multidimbuf->dim(d)); + auto dim_val = to(multidimbuf->dim(d)); this->dim_values_map.emplace(this->dim_names[d], dim_val->value()); } // The dimensions for the flattened tensors - auto val = dynamic_cast(buf->dim(0)); + auto val = to(buf->dim(0)); if (block_analysis_->is_buf_store_target(buf)) { this->dim_values_map.emplace( this->flat_dim_names[num_dims - 1], val->value()); @@ -216,7 +216,7 @@ void BlockPrinter::PrintArguments(const std::unordered_set& bufs) { os() << "}" << std::endl << std::endl; } -void BlockPrinter::PrintBufferInfo(const std::unordered_set& bufs) { +void BlockPrinter::PrintBufferInfo(const std::unordered_set& bufs) { emitIndent(); os() << "buffers {"; for (auto& read : bufs) { @@ -233,7 +233,7 @@ void BlockPrinter::PrintBufferInfo(const std::unordered_set& bufs) { os() << "}" << std::endl << std::endl; } -void BlockPrinter::PrintDistribution(const std::unordered_set& bufs) { +void BlockPrinter::PrintDistribution(const std::unordered_set& bufs) { emitIndent(); os() << "distribution {" << std::endl; for (auto& buf : bufs) { @@ -247,7 +247,7 @@ void BlockPrinter::PrintDistribution(const std::unordered_set& bufs) { } void BlockPrinter::PrintLoop( - const std::unordered_set& bufs, + const std::unordered_set& bufs, bool block_idx) { emitIndent(); os() << "loop ("; @@ -265,7 +265,7 @@ void BlockPrinter::PrintLoop( } void BlockPrinter::PrintReshapeInfo( - const std::unordered_set& bufs, + const std::unordered_set& bufs, bool reverse) { for (auto& buf : bufs) { emitIndent(); @@ -279,7 +279,7 @@ void BlockPrinter::PrintReshapeInfo( } } -void BlockPrinter::PrintDMAs(const std::unordered_set& bufs) { +void BlockPrinter::PrintDMAs(const std::unordered_set& bufs) { for (auto& read : bufs) { emitIndent(); os() << "dma_in("; @@ -287,7 +287,7 @@ void BlockPrinter::PrintDMAs(const std::unordered_set& bufs) { os() << ")" << std::endl; } } -void BlockPrinter::PrintAdjustBuffers(const std::unordered_set& bufs) { +void BlockPrinter::PrintAdjustBuffers(const std::unordered_set& bufs) { for (auto& read : bufs) { emitIndent(); os() << "adjust_buffer("; @@ -296,19 +296,19 @@ void BlockPrinter::PrintAdjustBuffers(const std::unordered_set& bufs) { } } -void BlockPrinter::visit(Load* v) { +void BlockPrinter::visit(LoadPtr v) { os() << block_analysis_->getFlatInputName(v->buf()) << ".buffer, "; } -void BlockPrinter::visit(Store* v) { +void BlockPrinter::visit(StorePtr v) { emitIndent(); os() << *v->value() << block_analysis_->getFlatInputName(v->buf()) << ".tensor)" << std::endl; } -void BlockPrinter::visit(Block* v) { +void BlockPrinter::visit(BlockPtr v) { os() << "{" << std::endl; indent_++; - for (Stmt* s : v->stmts()) { + for (StmtPtr s : v->stmts()) { s->accept(this); } indent_--; @@ -329,13 +329,13 @@ void BlockCodeGen::Initialize() { block_analysis_ = std::make_unique(); printer_ = std::make_unique(&oss_, block_analysis_.get()); - Stmt* stmt_v = stmt(); + StmtPtr stmt_v = stmt(); stmt_v->accept(block_analysis_.get()); auto buf_reads = block_analysis_->loads(); auto buf_writes = block_analysis_->stores(); // Ensure all Bufs in reads/writes are in the map - std::unordered_set bufs(buf_reads.begin(), buf_reads.end()); + std::unordered_set bufs(buf_reads.begin(), buf_reads.end()); bufs.insert(buf_writes.begin(), buf_writes.end()); if (!block_analysis_->areBufsInMap(bufs)) { throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map"); diff --git a/torch/csrc/jit/tensorexpr/block_codegen.h b/torch/csrc/jit/tensorexpr/block_codegen.h index 5c8f0f6..963c93c 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.h +++ b/torch/csrc/jit/tensorexpr/block_codegen.h @@ -20,15 +20,15 @@ namespace tensorexpr { // A class that analyzes the given program relevant for Block backend. class BlockAnalysis : public IRVisitor { public: - bool is_buf_store_target(Buf* buf) const { + bool is_buf_store_target(BufPtr buf) const { return store_targets_.count(buf) > 0; } - const std::unordered_set& loads() const { + const std::unordered_set& loads() const { return loads_; } - const std::unordered_set& stores() const { + const std::unordered_set& stores() const { return store_targets_; } @@ -36,28 +36,28 @@ class BlockAnalysis : public IRVisitor { return block_size_; } - bool areBufsInMap(const std::unordered_set& bufs) const; + bool areBufsInMap(const std::unordered_set& bufs) const; - Buf* getMultiDimBuf(Buf* buf) const; + BufPtr getMultiDimBuf(BufPtr buf) const; - std::string getInputName(Buf* buf) const; + std::string getInputName(BufPtr buf) const; - std::string getFlatInputName(Buf* buf) const { + std::string getFlatInputName(BufPtr buf) const { return getInputName(buf) + "_flat"; } - std::unordered_map getBufferMap() const { + std::unordered_map getBufferMap() const { return map_input_to_tensor_bufs_; } private: - void visit(Store* v) override; - void visit(Load* v) override; - void visit(For* v) override; + void visit(StorePtr v) override; + void visit(LoadPtr v) override; + void visit(ForPtr v) override; - std::unordered_map map_input_to_tensor_bufs_; - std::unordered_set store_targets_; - std::unordered_set loads_; + std::unordered_map map_input_to_tensor_bufs_; + std::unordered_set store_targets_; + std::unordered_set loads_; int block_size_ = 32; }; @@ -75,30 +75,30 @@ class BlockPrinter : public IRPrinter { std::unordered_map dim_values_map; std::vector dim_names = {"N", "H", "W", "C"}; std::vector flat_dim_names = {"N", "NH", "NHW", "NHWC"}; - void PrintTensorInfo(const std::unordered_set& bufs); - void PrintArguments(const std::unordered_set& bufs); - void PrintBufferInfo(const std::unordered_set& bufs); - void PrintDistribution(const std::unordered_set& bufs); - void PrintLoop(const std::unordered_set& bufs, bool block_idx = true); + void PrintTensorInfo(const std::unordered_set& bufs); + void PrintArguments(const std::unordered_set& bufs); + void PrintBufferInfo(const std::unordered_set& bufs); + void PrintDistribution(const std::unordered_set& bufs); + void PrintLoop(const std::unordered_set& bufs, bool block_idx = true); void PrintReshapeInfo( - const std::unordered_set& bufs, + const std::unordered_set& bufs, bool reverse = false); - void PrintDMAs(const std::unordered_set& bufs); - void PrintAdjustBuffers(const std::unordered_set& bufs); - - void visit(For* v) override; - void visit(Load* v) override; - void visit(Store* v) override; - void visit(Block* v) override; - void visit(Add* v) override; - void visit(Mul* v) override; + void PrintDMAs(const std::unordered_set& bufs); + void PrintAdjustBuffers(const std::unordered_set& bufs); + + void visit(ForPtr v) override; + void visit(LoadPtr v) override; + void visit(StorePtr v) override; + void visit(BlockPtr v) override; + void visit(AddPtr v) override; + void visit(MulPtr v) override; }; class TORCH_API BlockCodeGen : public CodeGen { public: template /* implicit */ - BlockCodeGen(Stmt* stmt, Ts... ts) + BlockCodeGen(StmtPtr stmt, Ts... ts) : CodeGen( stmt, std::vector({BufferArg(ts)...}), @@ -107,7 +107,7 @@ class TORCH_API BlockCodeGen : public CodeGen { } BlockCodeGen( - Stmt* stmt, + StmtPtr stmt, const std::vector& buffer_args, at::Device device = at::Device(at::kCPU), const std::string& kernel_func_name = "func") diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.cpp b/torch/csrc/jit/tensorexpr/bounds_inference.cpp index cf66f23..55dbacf 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_inference.cpp @@ -19,7 +19,7 @@ using namespace analysis; template BoundsInfo mergeTensorAccesses( const Container& accesses, - const std::unordered_map& varToBuf, + const std::unordered_map& varToBuf, bool distinctAccessKinds) { BoundsInfo ret; for (auto& access : accesses) { @@ -30,7 +30,7 @@ BoundsInfo mergeTensorAccesses( auto vtbIt = varToBuf.find(access->var()); TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end()); - Buf* buf = vtbIt->second; + BufPtr buf = vtbIt->second; std::vector& infos = ret[buf]; bool added = false; @@ -42,9 +42,9 @@ BoundsInfo mergeTensorAccesses( TORCH_INTERNAL_ASSERT(TABI.stop.size() == access->bounds().size()); for (size_t i = 0; i < TABI.start.size(); ++i) { TABI.start[i] = IRSimplifier::simplify( - new Min(TABI.start[i], access->bounds()[i].start, true)); + alloc(TABI.start[i], access->bounds()[i].start, true)); TABI.stop[i] = IRSimplifier::simplify( - new Max(TABI.stop[i], access->bounds()[i].end, true)); + alloc(TABI.stop[i], access->bounds()[i].end, true)); added = true; if (kind != TABI.kind) { @@ -70,27 +70,27 @@ BoundsInfo mergeTensorAccesses( return ret; } -std::unordered_map getAllBufs(Stmt* s) { - std::unordered_map varToBuf; +std::unordered_map getAllBufs(StmtPtr s) { + std::unordered_map varToBuf; auto bufs = NodeFinder::find(s); - for (auto* b : bufs) { + for (auto b : bufs) { varToBuf[b->base_handle()] = b; } return varToBuf; } -std::unordered_map getAllBufs(Expr* e) { - std::unordered_map varToBuf; +std::unordered_map getAllBufs(ExprPtr e) { + std::unordered_map varToBuf; auto bufs = NodeFinder::find(e); - for (auto* b : bufs) { + for (auto b : bufs) { varToBuf[b->base_handle()] = b; } return varToBuf; } -BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds) { +BoundsInfo inferBounds(StmtPtr s, bool distinctAccessKinds) { auto varToBuf = getAllBufs(s); MemDependencyChecker checker; @@ -102,7 +102,7 @@ BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds) { BoundsInfo getInferredBounds( MemDependencyChecker& analyzer, - Stmt* s, + StmtPtr s, bool distinctAccessKinds) { return mergeTensorAccesses( analyzer.accessesWithin(s), getAllBufs(s), distinctAccessKinds); @@ -110,7 +110,7 @@ BoundsInfo getInferredBounds( BoundsInfo getInferredBounds( MemDependencyChecker& analyzer, - Expr* e, + ExprPtr e, bool distinctAccessKinds) { return mergeTensorAccesses( analyzer.accessesWithin(e), getAllBufs(e), distinctAccessKinds); @@ -157,10 +157,10 @@ void printBoundsInfo(const BoundsInfo& v) { std::cerr << "}\n"; } -std::vector getBoundExtents( +std::vector getBoundExtents( const std::vector& infos) { - std::vector starts; - std::vector stops; + std::vector starts; + std::vector stops; // Find the safe size of the temprorary buffer by determining the outer // extents of a union of all bounds. @@ -170,21 +170,22 @@ std::vector getBoundExtents( starts.push_back(p.start[i]); } else { starts[i] = - IRSimplifier::simplify(new Min(starts[i], p.start[i], true)); + IRSimplifier::simplify(alloc(starts[i], p.start[i], true)); } if (stops.size() <= i) { stops.push_back(p.stop[i]); } else { - stops[i] = IRSimplifier::simplify(new Max(stops[i], p.stop[i], true)); + stops[i] = + IRSimplifier::simplify(alloc(stops[i], p.stop[i], true)); } } } - std::vector extents; + std::vector extents; for (size_t i = 0; i < starts.size(); ++i) { - Expr* dim = IRSimplifier::simplify( - new Add(new Sub(stops[i], starts[i]), new IntImm(1))); + ExprPtr dim = IRSimplifier::simplify( + alloc(alloc(stops[i], starts[i]), alloc(1))); extents.push_back(dim); } @@ -210,7 +211,7 @@ BoundSet convertBounds( BoundSet convertBounds( BoundsInfo& bounds, - Buf* buf, + BufPtr buf, TensorAccessKind filter = kMutate) { auto it = bounds.find(buf); if (it == bounds.end()) { @@ -222,8 +223,8 @@ BoundSet convertBounds( HazardKind getPotentialHazards( MemDependencyChecker& analyzer, - Stmt* A, - Stmt* B) { + StmtPtr A, + StmtPtr B) { BoundsInfo aBounds = getInferredBounds(analyzer, A, true); BoundsInfo bBounds = getInferredBounds(analyzer, B, true); @@ -231,7 +232,7 @@ HazardKind getPotentialHazards( BoundSet aReads; for (auto& pair : bBounds) { - Buf* buf = pair.first; + BufPtr buf = pair.first; if (aBounds.find(buf) == aBounds.end()) { continue; } @@ -302,7 +303,7 @@ bool hasConflictingOverlap( const BoundsInfo& bBounds, TensorAccessKind aFilter = kMutate, TensorAccessKind bFilter = kMutate) { - using IndexBoundsInfo = std::unordered_map>; + using IndexBoundsInfo = std::unordered_map>; IndexBoundsInfo aIndexBoundsInfo; for (auto& aBound : aBounds) { aIndexBoundsInfo[aBound.first] = getIndexBounds(aBound.second, aFilter); @@ -340,8 +341,8 @@ bool hasConflictingOverlap( bool hasConflictingOverlap( analysis::MemDependencyChecker& analyzer, - Stmt* A, - Stmt* B) { + StmtPtr A, + StmtPtr B) { BoundsInfo aBounds = getInferredBounds(analyzer, A, true); BoundsInfo bBounds = getInferredBounds(analyzer, B, true); return hasConflictingOverlap(aBounds, bBounds); @@ -349,8 +350,8 @@ bool hasConflictingOverlap( bool isOverlapping( analysis::MemDependencyChecker& analyzer, - Store* S1, - Store* S2) { + StorePtr S1, + StorePtr S2) { BoundsInfo s1Bounds = getInferredBounds(analyzer, S1, true); BoundsInfo s2Bounds = getInferredBounds(analyzer, S2, true); return hasConflictingOverlap(s1Bounds, s2Bounds, kStore, kStore); @@ -358,8 +359,8 @@ bool isOverlapping( bool isOverlapping( analysis::MemDependencyChecker& analyzer, - Store* S, - Load* L) { + StorePtr S, + LoadPtr L) { BoundsInfo sBounds = getInferredBounds(analyzer, S, true); BoundsInfo lBounds = getInferredBounds(analyzer, L, true); return hasConflictingOverlap(sBounds, lBounds, kStore, kLoad); diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.h b/torch/csrc/jit/tensorexpr/bounds_inference.h index 26821e2..8defed2 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.h +++ b/torch/csrc/jit/tensorexpr/bounds_inference.h @@ -20,29 +20,29 @@ enum C10_API_ENUM TensorAccessKind { kLoad, kStore, kMutate }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct TORCH_API TensorAccessBoundsInfo { TensorAccessKind kind; - std::vector start; - std::vector stop; + std::vector start; + std::vector stop; }; using BoundsInfo = - std::unordered_map>; + std::unordered_map>; -TORCH_API BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds = true); +TORCH_API BoundsInfo inferBounds(StmtPtr s, bool distinctAccessKinds = true); // Bounds inference caching the analysis. The MemDependencyChecker must already // have been run. TORCH_API BoundsInfo getInferredBounds( analysis::MemDependencyChecker& analyzer, - Stmt* s, + StmtPtr s, bool distinctAccessKinds = true); TORCH_API BoundsInfo getInferredBounds( analysis::MemDependencyChecker& analyzer, - Expr* e, + ExprPtr e, bool distinctAccessKinds = true); TORCH_API void printBoundsInfo(const BoundsInfo& v); -TORCH_API std::vector getBoundExtents( +TORCH_API std::vector getBoundExtents( const std::vector& infos); // The kind of dependency found, in increasing order of exclusivity. @@ -52,26 +52,28 @@ enum class HazardKind { WriteAfterWrite, NoDependency, }; -TORCH_API HazardKind -getPotentialHazards(analysis::MemDependencyChecker& analyzer, Stmt* A, Stmt* B); +TORCH_API HazardKind getPotentialHazards( + analysis::MemDependencyChecker& analyzer, + StmtPtr A, + StmtPtr B); // Returns true if there is a conflicting overlap between accesses in // statements A and B. A conflicting overlap is an overlap in buffer accesses // where at least one of the accesses is a Store. TORCH_API bool hasConflictingOverlap( analysis::MemDependencyChecker& analyzer, - Stmt* A, - Stmt* B); + StmtPtr A, + StmtPtr B); // Same as above, between accesses in stores S1 and S2. TORCH_API bool isOverlapping( analysis::MemDependencyChecker& analyzer, - Store* S1, - Store* S2); + StorePtr S1, + StorePtr S2); // Same as above, between accesses in store S and load L. TORCH_API bool isOverlapping( analysis::MemDependencyChecker& analyzer, - Store* S, - Load* L); + StorePtr S, + LoadPtr L); } // namespace tensorexpr } // namespace jit diff --git a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp index e980134..4ac5c6b 100644 --- a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp @@ -14,8 +14,8 @@ OverlapKind boundOverlap(Bound a, Bound b) { return ContainedOrEqual; } - Expr* lowDiff = IRSimplifier::simplify(new Sub(a.start, b.end)); - Expr* highDiff = IRSimplifier::simplify(new Sub(b.start, a.end)); + ExprPtr lowDiff = IRSimplifier::simplify(alloc(a.start, b.end)); + ExprPtr highDiff = IRSimplifier::simplify(alloc(b.start, a.end)); if (lowDiff->isConstant() && highDiff->isConstant()) { int low = immediateAs(lowDiff); @@ -26,8 +26,8 @@ OverlapKind boundOverlap(Bound a, Bound b) { } } - Expr* diff_start = IRSimplifier::simplify(new Sub(b.start, a.start)); - Expr* diff_end = IRSimplifier::simplify(new Sub(b.end, a.end)); + ExprPtr diff_start = IRSimplifier::simplify(alloc(b.start, a.start)); + ExprPtr diff_end = IRSimplifier::simplify(alloc(b.end, a.end)); // If one side fully encloses the other, they're adjacent. if (diff_start->isConstant() && diff_end->isConstant()) { @@ -68,8 +68,8 @@ Bound flattenBounds(const IndexBounds& a) { Bound ret = a[0]; for (size_t i = 1; i < a.size(); ++i) { - ret.start = new Mul(ret.start, a[i].start); - ret.end = new Mul(ret.end, a[i].end); + ret.start = alloc(ret.start, a[i].start); + ret.end = alloc(ret.end, a[i].end); } ret.start = IRSimplifier::simplify(ret.start); @@ -122,25 +122,25 @@ std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { return {a}; } - Expr* lowDiff = IRSimplifier::simplify(new Sub(b.start, a.start)); - Expr* highDiff = IRSimplifier::simplify(new Sub(b.end, a.end)); + ExprPtr lowDiff = IRSimplifier::simplify(alloc(b.start, a.start)); + ExprPtr highDiff = IRSimplifier::simplify(alloc(b.end, a.end)); // If the diff has only a single var, we can try to guess sign. if (!lowDiff->isConstant()) { auto vars = VarFinder::find(lowDiff); if (vars.size() == 1) { - lowDiff = IRSimplifier::simplify(new Sub( - SubstituteInClone(b.start, {{*vars.begin(), new IntImm(1)}}), - SubstituteInClone(a.start, {{*vars.begin(), new IntImm(1)}}))); + lowDiff = IRSimplifier::simplify(alloc( + SubstituteInClone(b.start, {{*vars.begin(), alloc(1)}}), + SubstituteInClone(a.start, {{*vars.begin(), alloc(1)}}))); } } if (!highDiff->isConstant()) { auto vars = VarFinder::find(highDiff); if (vars.size() == 1) { - highDiff = IRSimplifier::simplify(new Sub( - SubstituteInClone(b.end, {{*vars.begin(), new IntImm(1)}}), - SubstituteInClone(a.end, {{*vars.begin(), new IntImm(1)}}))); + highDiff = IRSimplifier::simplify(alloc( + SubstituteInClone(b.end, {{*vars.begin(), alloc(1)}}), + SubstituteInClone(a.end, {{*vars.begin(), alloc(1)}}))); } } @@ -157,11 +157,12 @@ std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { if (hasHead) { res.emplace_back( - a.start, IRSimplifier::simplify(new Sub(b.start, new IntImm(1)))); + a.start, IRSimplifier::simplify(alloc(b.start, alloc(1)))); } if (hasTail) { - Expr* tailStart = IRSimplifier::simplify(new Add(b.end, new IntImm(1))); + ExprPtr tailStart = + IRSimplifier::simplify(alloc(b.end, alloc(1))); res.emplace_back(tailStart, a.end); } diff --git a/torch/csrc/jit/tensorexpr/bounds_overlap.h b/torch/csrc/jit/tensorexpr/bounds_overlap.h index dda6c20..482b786 100644 --- a/torch/csrc/jit/tensorexpr/bounds_overlap.h +++ b/torch/csrc/jit/tensorexpr/bounds_overlap.h @@ -13,8 +13,8 @@ namespace analysis { // A simple class containing the start and end of a range in a single dimension. struct TORCH_API Bound { - Expr* start{nullptr}; - Expr* end{nullptr}; + ExprPtr start{nullptr}; + ExprPtr end{nullptr}; // This stores whether or not the start and end of this Bound have previously // been swapped. This occurs when the bound is in a loop with a negative @@ -22,7 +22,7 @@ struct TORCH_API Bound { bool swapped{false}; Bound() = default; - Bound(Expr* s, Expr* e) : start(s), end(e) {} + Bound(ExprPtr s, ExprPtr e) : start(s), end(e) {} void print() const { std::cout << "(" << *start << ", " << *end << ")"; @@ -44,7 +44,7 @@ struct TORCH_API Bound { struct BoundHash { size_t operator()(const Bound& b) const { - return std::hash()(b.start) ^ std::hash()(b.end); + return std::hash()(b.start) ^ std::hash()(b.end); } }; diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index e312f81..0bbc337 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -35,7 +35,7 @@ void RegisterCodeGenList::AddStmtFactoryMethod( std::unique_ptr CreateCodeGen( const std::string& name, - Stmt* stmt, + StmtPtr stmt, const std::vector& params, at::Device device, const std::string& kernel_func_name) { @@ -44,7 +44,7 @@ std::unique_ptr CreateCodeGen( return method(stmt, params, device, kernel_func_name); } -Expr* GenericIntrinsicsExpander::mutate(Intrinsics* v) { +ExprPtr GenericIntrinsicsExpander::mutate(IntrinsicsPtr v) { if (v->op_type() == kSigmoid) { auto x = v->param(0)->accept_mutator(this); auto one = expr_to_vec( diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index d303ad7..77ba8e1 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -18,12 +18,12 @@ class TORCH_API CodeGen { template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - CodeGen(Stmt* stmt, Ts... ts) + CodeGen(StmtPtr stmt, Ts... ts) : stmt_(stmt), buffer_args_({BufferArg(ts)...}) {} // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CodeGen( - Stmt* stmt, + StmtPtr stmt, std::vector buffer_args, at::Device device = at::kCPU, std::string kernel_func_name = "func") @@ -34,11 +34,11 @@ class TORCH_API CodeGen { virtual ~CodeGen() = default; - Stmt* stmt() const { + StmtPtr stmt() const { return stmt_; } - void set_stmt(Stmt* s) { + void set_stmt(StmtPtr s) { stmt_ = s; } @@ -95,7 +95,7 @@ class TORCH_API CodeGen { static void* argToPtr(const BufferArg& bufferArg, const CallArg& callArg); private: - Stmt* stmt_; + StmtPtr stmt_; std::vector buffer_args_; at::Device device_ = at::kCPU; std::string kernel_func_name_ = "func"; @@ -108,11 +108,11 @@ class CodeGen::BufferArg { BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {} BufferArg(const BufHandle& buf) : buf_(buf.node()) {} - Var* var() const { + VarPtr var() const { return isVar_ ? var_ : buf_->base_handle(); } - Buf* buf() const { + BufPtr buf() const { return buf_; } @@ -125,8 +125,8 @@ class CodeGen::BufferArg { } private: - Var* var_ = nullptr; - Buf* buf_ = nullptr; + VarPtr var_ = nullptr; + BufPtr buf_ = nullptr; bool isVar_ = false; }; @@ -177,7 +177,7 @@ class RegisterCodeGenList { } using StmtFactoryMethod = std::function( - Stmt* stmt, + StmtPtr stmt, const std::vector&, at::Device device, const std::string& kernel_func_name)>; @@ -205,7 +205,7 @@ class RegisterCodeGen { RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance(); codegen_list.AddStmtFactoryMethod( name, - [](Stmt* stmt, + [](StmtPtr stmt, const std::vector& params, at::Device device, const std::string& kernel_func_name) { @@ -219,14 +219,14 @@ class RegisterCodeGen { TORCH_API std::unique_ptr CreateCodeGen( const std::string& name, - Stmt* stmt, + StmtPtr stmt, const std::vector& params, at::Device device = at::kCPU, const std::string& kernel_func_name = "func"); class TORCH_API GenericIntrinsicsExpander : public IRMutator { protected: - Expr* mutate(Intrinsics* v) override; + ExprPtr mutate(IntrinsicsPtr v) override; }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/cpp_codegen.cpp b/torch/csrc/jit/tensorexpr/cpp_codegen.cpp index 562ce69..39a5615 100644 --- a/torch/csrc/jit/tensorexpr/cpp_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cpp_codegen.cpp @@ -4,12 +4,12 @@ namespace torch { namespace jit { namespace tensorexpr { -void CppPrinter::visit(Allocate* alloc) { +void CppPrinter::visit(AllocatePtr alloc) { constexpr size_t kAllocOnStackThresholdSize = 512; size_t size = 1; for (auto dim : alloc->dims()) { - IntImm* v = dynamic_cast(dim); + IntImmPtr v = to(dim); if (v) { size *= v->value(); } else { @@ -30,8 +30,8 @@ void CppPrinter::visit(Allocate* alloc) { } } -void CppPrinter::visit(Free* free) { - Var* var = free->buffer_var(); +void CppPrinter::visit(FreePtr free) { + VarPtr var = free->buffer_var(); if (allocated_on_heap_.count(var)) { emitIndent(); os() << "free(" << name_manager()->get_unique_name(var) << ");" diff --git a/torch/csrc/jit/tensorexpr/cpp_codegen.h b/torch/csrc/jit/tensorexpr/cpp_codegen.h index 45ccabd..1cf1565 100644 --- a/torch/csrc/jit/tensorexpr/cpp_codegen.h +++ b/torch/csrc/jit/tensorexpr/cpp_codegen.h @@ -14,11 +14,11 @@ class TORCH_API CppPrinter : public IRPrinter { explicit CppPrinter(std::ostream* os) : IRPrinter(*os) {} using IRPrinter::visit; - void visit(Allocate*) override; - void visit(Free*) override; + void visit(AllocatePtr) override; + void visit(FreePtr) override; private: - std::unordered_set allocated_on_heap_; + std::unordered_set allocated_on_heap_; }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index b630a40..2d00b1e 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -21,7 +21,7 @@ namespace tensorexpr { // TODO: move this to a more shared place. class ScopedVarName { public: - ScopedVarName(VarNameMap* mapping, Var* var, const std::string& name) + ScopedVarName(VarNameMap* mapping, VarPtr var, const std::string& name) : mapping_(mapping), var_(var) { auto iter = mapping->find(var); if (iter != mapping->end()) { @@ -30,7 +30,7 @@ class ScopedVarName { mapping->insert(std::make_pair(var, name)); } - ScopedVarName(UniqueNameManager* manager, Var* var, const std::string& name) + ScopedVarName(UniqueNameManager* manager, VarPtr var, const std::string& name) : ScopedVarName(&manager->unique_name_mapping_, var, name) {} ScopedVarName(const ScopedVarName&) = delete; @@ -42,11 +42,11 @@ class ScopedVarName { private: VarNameMap* mapping_ = nullptr; - Var* var_ = nullptr; + VarPtr var_ = nullptr; }; -static int as_int(Expr* expr) { - auto v = dynamic_cast(expr); +static int as_int(ExprPtr expr) { + auto v = to(expr); if (!v) { throw malformed_input( "cuda_codegen: non Int expr interpreted as int", expr); @@ -55,7 +55,7 @@ static int as_int(Expr* expr) { return v->value(); } -static bool is_zero(Expr* expr) { +static bool is_zero(ExprPtr expr) { return as_int(expr) == 0; } @@ -120,17 +120,17 @@ std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) { } } -void CudaAnalysis::visit(Free* v) { +void CudaAnalysis::visit(FreePtr v) { if (thread_local_bufs_.count(v->buffer_var()) == 0 && cross_block_bufs_.count(v->buffer_var()) == 0) { throw std::runtime_error("Global free not supported yet"); } } -void CudaAnalysis::visit(Allocate* v) { - Stmt* p = v->get_parent(); +void CudaAnalysis::visit(AllocatePtr v) { + StmtPtr p = v->get_parent(); while (p) { - For* for_v = dynamic_cast(p); + ForPtr for_v = to(p); if (for_v) { // NOLINTNEXTLINE(bugprone-branch-clone) if (for_v->loop_options().is_gpu_block_index()) { @@ -148,7 +148,7 @@ void CudaAnalysis::visit(Allocate* v) { throw std::runtime_error("Global alloc not supported yet"); } -void CudaAnalysis::visit(For* v) { +void CudaAnalysis::visit(ForPtr v) { // Recurse first. v->body()->accept(this); @@ -158,7 +158,7 @@ void CudaAnalysis::visit(For* v) { if (gpu_block_index >= 3) { throw std::runtime_error("support only 3D gpu_block_index"); } - Expr* prev = nullptr; + ExprPtr prev = nullptr; // NOLINTNEXTLINE(clang-diagnostic-sign-compare) // NOLINTNEXTLINE(bugprone-branch-clone) if (gpu_block_extents_.size() <= gpu_block_index) { @@ -181,14 +181,14 @@ void CudaAnalysis::visit(For* v) { gpu_block_extents_[gpu_block_index] = v->stop(); } else { gpu_block_extents_[gpu_block_index] = - IRSimplifier::simplify(new Max(prev, v->stop(), true)); + IRSimplifier::simplify(alloc(prev, v->stop(), true)); } } else if (loop_options.is_gpu_thread_index()) { int gpu_thread_index = loop_options.gpu_thread_index(); if (gpu_thread_index >= 3) { throw std::runtime_error("support only 3D gpu_thread_index"); } - Expr* prev = nullptr; + ExprPtr prev = nullptr; // NOLINTNEXTLINE(clang-diagnostic-sign-compare) // NOLINTNEXTLINE(bugprone-branch-clone) if (gpu_thread_extents_.size() <= gpu_thread_index) { @@ -211,18 +211,18 @@ void CudaAnalysis::visit(For* v) { gpu_thread_extents_[gpu_thread_index] = v->stop(); } else { gpu_thread_extents_[gpu_thread_index] = - IRSimplifier::simplify(new Max(prev, v->stop(), true)); + IRSimplifier::simplify(alloc(prev, v->stop(), true)); } } } -void CudaPrinter::print_flat_alloc(Allocate* alloc) { +void CudaPrinter::print_flat_alloc(AllocatePtr alloc) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector dims = alloc->dims(); + std::vector dims = alloc->dims(); // TODO: this should be merged with the storage flattener. int64_t flat_size = 1; for (auto dim : dims) { - IntImm* dim_i = dynamic_cast(dim); + IntImmPtr dim_i = to(dim); if (dim_i) { flat_size *= dim_i->value(); } else { @@ -233,7 +233,7 @@ void CudaPrinter::print_flat_alloc(Allocate* alloc) { << "[" << flat_size << "];" << std::endl; } -void CudaPrinter::visit(Allocate* v) { +void CudaPrinter::visit(AllocatePtr v) { // TODO: handle dynamic shapes here. if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) { emitIndent(); @@ -251,15 +251,15 @@ void CudaPrinter::visit(Allocate* v) { throw std::runtime_error("Encountered Alloc not local to block or thread"); } -void CudaPrinter::visit(Free* v) { +void CudaPrinter::visit(FreePtr v) { // do nothing } -void CudaPrinter::visit(For* v) { +void CudaPrinter::visit(ForPtr v) { IRPrinter::visit(v); } -void CudaPrinter::visit(Cast* v) { +void CudaPrinter::visit(CastPtr v) { if (v->dtype().scalar_type() == ScalarType::Half) { os() << "__float2half("; v->src_value()->accept(this); @@ -278,7 +278,7 @@ void CudaPrinter::visit(Cast* v) { os() << ")"; } -void CudaPrinter::visit(Intrinsics* v) { +void CudaPrinter::visit(IntrinsicsPtr v) { if (v->op_type() == IntrinsicsOp::kRand) { os() << "Uint32ToFloat(" << *rand_func_ << "())"; return; @@ -314,11 +314,11 @@ void CudaPrinter::visit(Intrinsics* v) { os() << ")"; } -void CudaPrinter::visit(ExternalCall* v) { +void CudaPrinter::visit(ExternalCallPtr v) { throw unimplemented_lowering(v); } -void CudaPrinter::visit(Load* v) { +void CudaPrinter::visit(LoadPtr v) { // TODO: find a better metric in using ldg or not. Support different dtypes. // Detects whether the load target is also a store target. // TODO: this is currently too wide. It detects whether a store-target @@ -343,9 +343,9 @@ void CudaPrinter::visit(Load* v) { } // TODO: maybe this should be a more shared location? -// TODO: investigate how "Expr*" can be implicitly converted to "ExprHandle" as -// a bool. -static bool CheckEqual(Expr* lhs, Expr* rhs) { +// TODO: investigate how "ExprPtr" can be implicitly converted to "ExprHandle" +// as a bool. +static bool CheckEqual(ExprPtr lhs, ExprPtr rhs) { // The fast path. Checks if the pointers are the same. if (lhs == rhs) { return true; @@ -359,11 +359,11 @@ class AtomicAddFuser : public IRMutator { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AtomicAddFuser( - const std::unordered_set& thread_local_bufs, + const std::unordered_set& thread_local_bufs, const GPUMetaVarRewriter& metavars) : thread_local_bufs_(thread_local_bufs) { - const std::vector& block_extents = metavars.gpu_block_extents(); - const std::vector& block_vars = metavars.gpu_block_vars(); + const std::vector& block_extents = metavars.gpu_block_extents(); + const std::vector& block_vars = metavars.gpu_block_vars(); for (size_t i = 0; i < block_extents.size(); ++i) { MetaVarExtent extent{block_extents[i], false}; if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) { @@ -374,8 +374,8 @@ class AtomicAddFuser : public IRMutator { metavars_[block_vars[i]] = extent; } - const std::vector& thread_extents = metavars.gpu_thread_extents(); - const std::vector& thread_vars = metavars.gpu_thread_vars(); + const std::vector& thread_extents = metavars.gpu_thread_extents(); + const std::vector& thread_vars = metavars.gpu_thread_vars(); for (size_t i = 0; i < thread_extents.size(); ++i) { MetaVarExtent extent{thread_extents[i], false}; if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) { @@ -387,9 +387,9 @@ class AtomicAddFuser : public IRMutator { } } - Stmt* mutate(Store* v) override { - Buf* buf = v->buf(); - Store* orig = const_cast(v); // NOLINT + StmtPtr mutate(StorePtr v) override { + BufPtr buf = v->buf(); + StorePtr orig = const_cast(v); // NOLINT // Thread locals never need to be atomic. if (thread_local_bufs_.count(buf->base_handle()) != 0) { @@ -400,11 +400,11 @@ class AtomicAddFuser : public IRMutator { if (dtype != ScalarType::Float && dtype != ScalarType::Double) { return orig; } - Add* add_v = dynamic_cast(v->value()); + AddPtr add_v = to(v->value()); if (!add_v) { return orig; } - Load* load_v = dynamic_cast(add_v->lhs()); + LoadPtr load_v = to(add_v->lhs()); if (!load_v) { return orig; } @@ -422,9 +422,9 @@ class AtomicAddFuser : public IRMutator { // TODO: this checks that the metavars occur directly as an index, but this // is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping. // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::unordered_set vars_to_find = nontrivial_metavars_; - for (Expr* e : v->indices()) { - if (Var* v = dynamic_cast(e)) { + std::unordered_set vars_to_find = nontrivial_metavars_; + for (ExprPtr e : v->indices()) { + if (VarPtr v = to(e)) { vars_to_find.erase(v); } } @@ -434,20 +434,20 @@ class AtomicAddFuser : public IRMutator { return orig; } - return new AtomicAdd(buf, v->indices(), add_v->rhs()); + return alloc(buf, v->indices(), add_v->rhs()); } private: - const std::unordered_set& thread_local_bufs_; + const std::unordered_set& thread_local_bufs_; struct MetaVarExtent { - Expr* expr{nullptr}; + ExprPtr expr{nullptr}; bool trivial{false}; }; - std::unordered_map metavars_; - std::unordered_set nontrivial_metavars_; + std::unordered_map metavars_; + std::unordered_set nontrivial_metavars_; }; -void CudaPrinter::visit(Store* v) { +void CudaPrinter::visit(StorePtr v) { emitIndent(); if (v->indices().empty()) { os() << *v->base_handle() << " = "; @@ -458,7 +458,7 @@ void CudaPrinter::visit(Store* v) { os() << std::endl; } -void CudaPrinter::visit(AtomicAdd* v) { +void CudaPrinter::visit(AtomicAddPtr v) { emitIndent(); if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) { // atomicAdd only works on global and shared memory @@ -471,7 +471,7 @@ void CudaPrinter::visit(AtomicAdd* v) { os() << std::endl; } -void CudaPrinter::visit(Max* v) { +void CudaPrinter::visit(MaxPtr v) { if (v->dtype().is_integral()) { os() << "max("; } else { @@ -483,7 +483,7 @@ void CudaPrinter::visit(Max* v) { os() << ")"; } -void CudaPrinter::visit(Min* v) { +void CudaPrinter::visit(MinPtr v) { if (v->dtype().is_integral()) { os() << "min("; } else { @@ -495,7 +495,7 @@ void CudaPrinter::visit(Min* v) { os() << ")"; } -void CudaPrinter::visit(IfThenElse* v) { +void CudaPrinter::visit(IfThenElsePtr v) { os() << "(("; v->condition()->accept(this); os() << ") ? "; @@ -505,11 +505,11 @@ void CudaPrinter::visit(IfThenElse* v) { os() << ")"; } -void CudaPrinter::visit(Block* v) { +void CudaPrinter::visit(BlockPtr v) { os() << "{" << std::endl; indent_++; - for (Stmt* s : v->stmts()) { + for (StmtPtr s : v->stmts()) { s->accept(this); } @@ -518,7 +518,7 @@ void CudaPrinter::visit(Block* v) { os() << "}"; } -void CudaPrinter::visit(Let* v) { +void CudaPrinter::visit(LetPtr v) { emitIndent(); os() << dtypeToCppString(v->dtype()); os() << " " << *v->var() << " = "; @@ -529,7 +529,7 @@ void CudaPrinter::visit(Let* v) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class PrioritizeLoad : public IRMutator { public: - Expr* mutate(Load* v) override { + ExprPtr mutate(LoadPtr v) override { // Look at the declaration of this variable for more details. if (nested_if_then_else_ > 0) { return IRMutator::mutate(v); @@ -564,19 +564,19 @@ class PrioritizeLoad : public IRMutator { } MemLoadList& load_list = load_stack_.back(); - Var* load_new_var = new Var("v", v->dtype()); - Expr* new_value = IRMutator::mutate(v); + VarPtr load_new_var = alloc("v", v->dtype()); + ExprPtr new_value = IRMutator::mutate(v); load_list.push_back(std::make_pair(load_new_var, new_value)); return load_new_var; } - Expr* mutate(Cast* v) override { - Load* src_load = dynamic_cast(v->src_value()); - Expr* new_src = v->src_value()->accept_mutator(this); - Var* new_var = dynamic_cast(new_src); + ExprPtr mutate(CastPtr v) override { + LoadPtr src_load = to(v->src_value()); + ExprPtr new_src = v->src_value()->accept_mutator(this); + VarPtr new_var = to(new_src); if (!src_load || !new_var) { - return new Cast(v->dtype(), new_src); + return alloc(v->dtype(), new_src); } // We just did the prioritize load, let's fold in the Cast. @@ -586,36 +586,36 @@ class PrioritizeLoad : public IRMutator { assert(pair.first == new_var); load_list.pop_back(); - new_var = new Var("v", v->dtype()); + new_var = alloc("v", v->dtype()); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* new_value = new Cast(v->dtype(), pair.second); + ExprPtr new_value = alloc(v->dtype(), pair.second); load_list.push_back(std::make_pair(new_var, new_value)); return new_var; } - Stmt* mutate(Store* v) override { - Store* last = nested_store_; + StmtPtr mutate(StorePtr v) override { + StorePtr last = nested_store_; nested_store_ = v; - Stmt* s = IRMutator::mutate(v); + StmtPtr s = IRMutator::mutate(v); nested_store_ = last; return s; } - Stmt* mutate(Let* v) override { + StmtPtr mutate(LetPtr v) override { nested_let_ = true; - Stmt* s = IRMutator::mutate(v); + StmtPtr s = IRMutator::mutate(v); nested_let_ = false; return s; } - Stmt* mutate(Block* v) override { - Block* v1 = const_cast(v); // NOLINT + StmtPtr mutate(BlockPtr v) override { + BlockPtr v1 = const_cast(v); // NOLINT assert(v1); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::list stmts = v1->stmts(); - for (Stmt* stmt : stmts) { + std::list stmts = v1->stmts(); + for (StmtPtr stmt : stmts) { PushList(); - Stmt* stmt_new = stmt->accept_mutator(this); + StmtPtr stmt_new = stmt->accept_mutator(this); AddMemLoadsFromList(v1, stmt); PopList(); @@ -628,15 +628,15 @@ class PrioritizeLoad : public IRMutator { return v1; } - Expr* mutate(IfThenElse* v) override { + ExprPtr mutate(IfThenElsePtr v) override { nested_if_then_else_++; - Expr* new_v = IRMutator::mutate(v); + ExprPtr new_v = IRMutator::mutate(v); nested_if_then_else_--; return new_v; } private: - using MemLoadEntry = std::pair; + using MemLoadEntry = std::pair; using MemLoadList = std::vector; using MemoryLoadStack = std::vector; @@ -648,14 +648,14 @@ class PrioritizeLoad : public IRMutator { load_stack_.pop_back(); } - void AddMemLoadsFromList(Block* block, Stmt* last) { + void AddMemLoadsFromList(BlockPtr block, StmtPtr last) { MemLoadList& load_list = load_stack_.back(); if (load_list.empty()) { return; } for (auto& pair : load_list) { - Stmt* news = new Let(pair.first, pair.second); + StmtPtr news = alloc(pair.first, pair.second); block->insert_stmt_before(news, last); } } @@ -673,9 +673,9 @@ class PrioritizeLoad : public IRMutator { // } // int v2 = v + 2; int nested_if_then_else_{0}; - Store* nested_store_{nullptr}; + StorePtr nested_store_{nullptr}; bool nested_let_{false}; - std::unordered_set thread_local_bufs_; + std::unordered_set thread_local_bufs_; }; std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) { @@ -711,9 +711,9 @@ bool GPUMetaVarRewriter::isFullExtent() { return true; } -Stmt* GPUMetaVarRewriter::mutate(For* v) { - Stmt* body = v->body(); - Expr* old_reach = nullptr; +StmtPtr GPUMetaVarRewriter::mutate(ForPtr v) { + StmtPtr body = v->body(); + ExprPtr old_reach = nullptr; const LoopOptions& loop_options = v->loop_options(); if (loop_options.is_gpu_block_index()) { int gpu_block_index = loop_options.gpu_block_index(); @@ -728,11 +728,11 @@ Stmt* GPUMetaVarRewriter::mutate(For* v) { current_block_reach_[gpu_block_index] = v->stop(); } else { current_block_reach_[gpu_block_index] = - IRSimplifier::simplify(new Max(old_reach, v->stop(), true)); + IRSimplifier::simplify(alloc(old_reach, v->stop(), true)); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Var* metaVar = gpu_block_vars_[gpu_block_index]; + VarPtr metaVar = gpu_block_vars_[gpu_block_index]; body = Substitute(Stmt::clone(body), {{v->var(), metaVar}}); } else if (loop_options.is_gpu_thread_index()) { int gpu_thread_index = loop_options.gpu_thread_index(); @@ -747,11 +747,11 @@ Stmt* GPUMetaVarRewriter::mutate(For* v) { current_thread_reach_[gpu_thread_index] = v->stop(); } else { current_thread_reach_[gpu_thread_index] = - IRSimplifier::simplify(new Max(old_reach, v->stop(), true)); + IRSimplifier::simplify(alloc(old_reach, v->stop(), true)); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Var* metaVar = gpu_thread_vars_[gpu_thread_index]; + VarPtr metaVar = gpu_thread_vars_[gpu_thread_index]; body = Substitute(Stmt::clone(body), {{v->var(), metaVar}}); } @@ -771,7 +771,7 @@ Stmt* GPUMetaVarRewriter::mutate(For* v) { return v->cloneWithNewBody(body); } -Stmt* GPUMetaVarRewriter::mutate(Block* v) { +StmtPtr GPUMetaVarRewriter::mutate(BlockPtr v) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector innerSegments; Segment current; @@ -787,19 +787,19 @@ Stmt* GPUMetaVarRewriter::mutate(Block* v) { // the same launch reach. Segments are comprised of all statements that aren't // loops - which are their own segments. Some operations, such as threading // and memory ops should never be masked and so also get their own segment. - for (Stmt* stmt : *v) { - Stmt* stmt_new = stmt->accept_mutator(this); + for (StmtPtr stmt : *v) { + StmtPtr stmt_new = stmt->accept_mutator(this); if (stmt == stmt_new) { stmt_new = Stmt::clone(stmt_new); } // Likewise, Allocate and Free should never be masked. - if (dynamic_cast(stmt) || dynamic_cast(stmt)) { + if (to(stmt) || to(stmt)) { pushAndReset(false); } // If the current stmt *was* a loop, it's a segment boundary. - if (For* f = dynamic_cast(stmt)) { + if (ForPtr f = to(stmt)) { pushAndReset(false); } @@ -819,18 +819,18 @@ Stmt* GPUMetaVarRewriter::mutate(Block* v) { if (isFullExtent()) { // flatten inner segments. // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector stmts; + std::vector stmts; for (auto& v : innerSegments) { for (auto* s : v.stmts()) { stmts.push_back(s); } } - return new Block(stmts); + return alloc(stmts); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector stmts; + std::vector stmts; for (auto& segment : innerSegments) { bool need_sync = false; // We never mask loops, they'll mask their contents. @@ -842,15 +842,15 @@ Stmt* GPUMetaVarRewriter::mutate(Block* v) { // If we get here, we must mask since we're not full reach and our direct // child isn't a For. - Stmt* inner = new Block(segment.stmts()); + StmtPtr inner = alloc(segment.stmts()); // threads inside blocks. auto& thread_extents = cuda_analysis_->gpu_thread_extents(); for (size_t i = 0; i < gpu_thread_vars_.size(); ++i) { if (!exprEquals(current_thread_reach_[i], thread_extents[i])) { need_sync = true; // Mask it against the current dimensions. - inner = new Cond( - new CompareSelect( + inner = alloc( + alloc( gpu_thread_vars_[i], current_thread_reach_[i], CompareSelectOperation::kLT), @@ -862,8 +862,8 @@ Stmt* GPUMetaVarRewriter::mutate(Block* v) { for (size_t i = 0; i < gpu_block_vars_.size(); ++i) { if (!exprEquals(current_block_reach_[i], block_extents[i])) { // Mask it against the current dimensions. - inner = new Cond( - new CompareSelect( + inner = alloc( + alloc( gpu_block_vars_[i], current_block_reach_[i], CompareSelectOperation::kLT), @@ -873,20 +873,20 @@ Stmt* GPUMetaVarRewriter::mutate(Block* v) { } if (need_sync) { - stmts.push_back(new SyncThreads()); + stmts.push_back(alloc()); } stmts.push_back(inner); if (need_sync) { - stmts.push_back(new SyncThreads()); + stmts.push_back(alloc()); } } - return new Block(stmts); + return alloc(stmts); } static std::ostream& operator<<( std::ostream& out, - const std::vector& exprs) { + const std::vector& exprs) { size_t i = 0; for (auto expr : exprs) { if (i++ > 0) { @@ -935,7 +935,7 @@ void CudaCodeGen::Initialize() { // Check whether the statement uses the Half type, if so add the // half_support_literal. - Stmt* stmt_v = stmt(); + StmtPtr stmt_v = stmt(); HalfChecker halfChecker(buffer_args()); stmt_v->accept(&halfChecker); @@ -978,7 +978,7 @@ void CudaCodeGen::Initialize() { os() << ", "; } const BufferArg& buffer_arg = buffer_args[i]; - Var* var = buffer_arg.var(); + VarPtr var = buffer_arg.var(); Dtype dtype = buffer_arg.dtype(); os() << printer_->dtypeToCppString(dtype) @@ -986,13 +986,13 @@ void CudaCodeGen::Initialize() { << name_manager()->get_unique_name(var); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Var* rand_seed; + VarPtr rand_seed; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Var* rand_offset; + VarPtr rand_offset; if (has_random_) { // TODO: switch to kUint64 when it is available. - rand_seed = new Var("rand_seed", kInt); - rand_offset = new Var("rand_offset", kInt); + rand_seed = alloc("rand_seed", kInt); + rand_offset = alloc("rand_offset", kInt); std::string uint64_str = "unsigned long long"; os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " " << *rand_offset; @@ -1001,11 +1001,11 @@ void CudaCodeGen::Initialize() { os() << std::endl; if (has_random_) { - Var* idx = new Var("idx", kInt); + VarPtr idx = alloc("idx", kInt); os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;" << std::endl; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Var* rand_func = printer_->rand_func(); + VarPtr rand_func = printer_->rand_func(); os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", " << *rand_offset << ");" << std::endl; os() << std::endl; @@ -1036,7 +1036,7 @@ void CudaCodeGen::Initialize() { os() << "}"; // Check that all block extents had been set. - const std::vector& gpu_block_extents = + const std::vector& gpu_block_extents = metavar_rewriter_->gpu_block_extents(); for (size_t i = 0; i < gpu_block_extents.size(); i++) { if (!gpu_block_extents[i]) { @@ -1062,9 +1062,9 @@ void CudaCodeGen::call_raw(const std::vector& raw_args) { auto const& buffer_args = this->buffer_args(); // TODO: move as much of this into the constructors. - const std::vector& gpu_block_extents = + const std::vector& gpu_block_extents = metavar_rewriter_->gpu_block_extents(); - const std::vector& gpu_thread_extents = + const std::vector& gpu_thread_extents = metavar_rewriter_->gpu_thread_extents(); if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) { throw malformed_input( diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.h b/torch/csrc/jit/tensorexpr/cuda_codegen.h index 1ed8f23..6fdf35b 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.h +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.h @@ -23,44 +23,45 @@ namespace tensorexpr { class CudaAnalysis : public IRVisitor { public: CudaAnalysis() { - gpu_block_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)}; - gpu_thread_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)}; + gpu_block_extents_ = {alloc(1), alloc(1), alloc(1)}; + gpu_thread_extents_ = { + alloc(1), alloc(1), alloc(1)}; } - bool is_buf_store_target(Buf* buf) const { + bool is_buf_store_target(BufPtr buf) const { return store_targets_.count(buf) > 0; } - const std::unordered_set& thread_local_bufs() const { + const std::unordered_set& thread_local_bufs() const { return thread_local_bufs_; } - const std::unordered_set& cross_block_bufs() const { + const std::unordered_set& cross_block_bufs() const { return cross_block_bufs_; } - const std::vector& gpu_block_extents() const { + const std::vector& gpu_block_extents() const { return gpu_block_extents_; } - const std::vector& gpu_thread_extents() const { + const std::vector& gpu_thread_extents() const { return gpu_thread_extents_; } private: - void visit(Store* v) override { + void visit(StorePtr v) override { store_targets_.insert(v->buf()); } - void visit(Allocate* v) override; - void visit(Free* v) override; - void visit(For* v) override; + void visit(AllocatePtr v) override; + void visit(FreePtr v) override; + void visit(ForPtr v) override; - std::unordered_set store_targets_; - std::unordered_set thread_local_bufs_; - std::unordered_set cross_block_bufs_; + std::unordered_set store_targets_; + std::unordered_set thread_local_bufs_; + std::unordered_set cross_block_bufs_; - std::vector gpu_block_extents_; - std::vector gpu_thread_extents_; + std::vector gpu_block_extents_; + std::vector gpu_thread_extents_; }; // An IRMutator that replaces binding loop options with Cuda metavars, and masks @@ -75,34 +76,36 @@ class GPUMetaVarRewriter : public IRMutator { explicit GPUMetaVarRewriter(const CudaAnalysis* cuda_analysis) : cuda_analysis_(cuda_analysis) { gpu_block_vars_ = { - new Var("blockIdx.x", kInt), - new Var("blockIdx.y", kInt), - new Var("blockIdx.z", kInt)}; + alloc("blockIdx.x", kInt), + alloc("blockIdx.y", kInt), + alloc("blockIdx.z", kInt)}; gpu_thread_vars_ = { - new Var("threadIdx.x", kInt), - new Var("threadIdx.y", kInt), - new Var("threadIdx.z", kInt)}; - - current_block_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)}; - current_thread_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)}; + alloc("threadIdx.x", kInt), + alloc("threadIdx.y", kInt), + alloc("threadIdx.z", kInt)}; + + current_block_reach_ = { + alloc(1), alloc(1), alloc(1)}; + current_thread_reach_ = { + alloc(1), alloc(1), alloc(1)}; } - Stmt* mutate(For* v) override; - Stmt* mutate(Block* v) override; + StmtPtr mutate(ForPtr v) override; + StmtPtr mutate(BlockPtr v) override; - const std::vector& gpu_block_vars() const { + const std::vector& gpu_block_vars() const { return gpu_block_vars_; } - const std::vector& gpu_thread_vars() const { + const std::vector& gpu_thread_vars() const { return gpu_thread_vars_; } - const std::vector& gpu_block_extents() const { + const std::vector& gpu_block_extents() const { return cuda_analysis_->gpu_block_extents(); } - const std::vector& gpu_thread_extents() const { + const std::vector& gpu_thread_extents() const { return cuda_analysis_->gpu_thread_extents(); } @@ -120,7 +123,7 @@ class GPUMetaVarRewriter : public IRMutator { return stmts_.empty(); } - std::vector& stmts() { + std::vector& stmts() { return stmts_; } bool mask() { @@ -128,7 +131,7 @@ class GPUMetaVarRewriter : public IRMutator { } private: - std::vector stmts_; + std::vector stmts_; bool mask_{true}; }; @@ -136,11 +139,11 @@ class GPUMetaVarRewriter : public IRMutator { // parameters. bool isFullExtent(); - std::vector gpu_block_vars_; - std::vector gpu_thread_vars_; + std::vector gpu_block_vars_; + std::vector gpu_thread_vars_; - std::vector current_block_reach_; - std::vector current_thread_reach_; + std::vector current_block_reach_; + std::vector current_thread_reach_; const CudaAnalysis* cuda_analysis_; }; @@ -154,28 +157,28 @@ class CudaPrinter : public IRPrinter { bool has_random) : IRPrinter(*os), cuda_analysis_(cuda_analysis) { if (has_random) { - rand_func_ = new Var("rand", kHandle); + rand_func_ = alloc("rand", kHandle); } } - void visit(Cast* v) override; - void visit(Intrinsics* v) override; - void visit(For* v) override; + void visit(CastPtr v) override; + void visit(IntrinsicsPtr v) override; + void visit(ForPtr v) override; - void visit(Load* v) override; - void visit(Store* v) override; - void visit(AtomicAdd* v) override; - void visit(Max* v) override; - void visit(Min* v) override; - void visit(IfThenElse* v) override; - void visit(Block* v) override; - void visit(Allocate* v) override; - void visit(Free* v) override; - void visit(Let* v) override; + void visit(LoadPtr v) override; + void visit(StorePtr v) override; + void visit(AtomicAddPtr v) override; + void visit(MaxPtr v) override; + void visit(MinPtr v) override; + void visit(IfThenElsePtr v) override; + void visit(BlockPtr v) override; + void visit(AllocatePtr v) override; + void visit(FreePtr v) override; + void visit(LetPtr v) override; - void visit(ExternalCall* v) override; + void visit(ExternalCallPtr v) override; - Var* rand_func() const { + VarPtr rand_func() const { return rand_func_; } @@ -185,10 +188,10 @@ class CudaPrinter : public IRPrinter { using IRPrinter::visit; private: - Var* rand_func_; + VarPtr rand_func_; const CudaAnalysis* cuda_analysis_; - void print_flat_alloc(Allocate* alloc); + void print_flat_alloc(AllocatePtr alloc); }; // Construct Cuda C from the buffer and tensor input, and invoke the kernel @@ -197,7 +200,7 @@ class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen { public: template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - CudaCodeGen(Stmt* stmt, Ts... ts) + CudaCodeGen(StmtPtr stmt, Ts... ts) : CodeGen( stmt, std::vector({BufferArg(ts)...}), @@ -207,7 +210,7 @@ class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CudaCodeGen( - Stmt* stmt, + StmtPtr stmt, const std::vector& buffer_args, at::Device device = at::Device(at::kCUDA, at::cuda::current_device()), const std::string& kernel_func_name = "func") @@ -233,11 +236,11 @@ class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen { c10::optional device_opt, c10::optional pin_memory_opt) override; - const std::vector& gpu_block_extents() const { + const std::vector& gpu_block_extents() const { return cuda_analysis_->gpu_block_extents(); } - const std::vector& gpu_thread_extents() const { + const std::vector& gpu_thread_extents() const { return cuda_analysis_->gpu_thread_extents(); } diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index 6dbf10c..c7a28bd 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -59,14 +59,14 @@ class SimpleIREvaluatorImpl : public IRVisitor { ~SimpleIREvaluatorImpl() override = default; - void bindBuf(Buf* buf, void* ptr) { + void bindBuf(BufPtr buf, void* ptr) { buffer_mapping_[buf] = ptr; } - void bindVar(Var* var, const Value& val) { + void bindVar(VarPtr var, const Value& val) { eval_context_[var] = val; } - Value evaluateExpr(Expr* e) { + Value evaluateExpr(ExprPtr e) { e->accept(this); return value_; } @@ -81,45 +81,45 @@ class SimpleIREvaluatorImpl : public IRVisitor { internal_buffers_.clear(); } - TORCH_API void visit(Add* v) override { + TORCH_API void visit(AddPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Sub* v) override { + TORCH_API void visit(SubPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Mul* v) override { + TORCH_API void visit(MulPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Div* v) override { + TORCH_API void visit(DivPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Mod* v) override { + TORCH_API void visit(ModPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Max* v) override { + TORCH_API void visit(MaxPtr v) override { visit_binary_op(v, v->propagate_nans()); } - TORCH_API void visit(Min* v) override { + TORCH_API void visit(MinPtr v) override { visit_binary_op(v, v->propagate_nans()); } - TORCH_API void visit(And* v) override { + TORCH_API void visit(AndPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Or* v) override { + TORCH_API void visit(OrPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Xor* v) override { + TORCH_API void visit(XorPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Lshift* v) override { + TORCH_API void visit(LshiftPtr v) override { visit_binary_op(v); } - TORCH_API void visit(Rshift* v) override { + TORCH_API void visit(RshiftPtr v) override { visit_binary_op(v); } - void visit(CompareSelect* v) override { + void visit(CompareSelectPtr v) override { visit_compare_select_op(v, v->compare_select_op()); } @@ -365,7 +365,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { } void visit_compare_select_op( - CompareSelect* v, + CompareSelectPtr v, CompareSelectOperation cmp_op) { v->lhs()->accept(this); Value lhs_v = value_; @@ -394,23 +394,23 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } -#define IMM_VISIT(Type, Name) \ - TORCH_API void visit(Name##Imm* v) override { \ - value_ = Value(v->value()); \ +#define IMM_VISIT(Type, Name) \ + TORCH_API void visit(Name##ImmPtr v) override { \ + value_ = Value(v->value()); \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); #undef IMM_VISIT - TORCH_API void visit(Block* v) override { - Block* last = scope_; + TORCH_API void visit(BlockPtr v) override { + BlockPtr last = scope_; scope_ = v; - for (Stmt* s : v->stmts()) { + for (StmtPtr s : v->stmts()) { s->accept(this); } auto it = var_by_scope_.find(v); if (it != var_by_scope_.end()) { - for (Expr* v : it->second) { + for (ExprPtr v : it->second) { eval_context_.erase(v); } var_by_scope_.erase(it); @@ -419,7 +419,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { scope_ = last; } - TORCH_API void visit(Var* v) override { + TORCH_API void visit(VarPtr v) override { auto iter = eval_context_.find(v); if (iter == eval_context_.end()) { throw malformed_input("could not find Var in context", v); @@ -456,8 +456,8 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } - TORCH_API void visit(Cast* v) override { - Expr* src_value = v->src_value(); + TORCH_API void visit(CastPtr v) override { + ExprPtr src_value = v->src_value(); src_value->accept(this); Dtype dst_dtype = v->dtype(); Dtype src_dtype = src_value->dtype(); @@ -507,8 +507,8 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } - TORCH_API void visit(BitCast* v) override { - Expr* src_value = v->src_value(); + TORCH_API void visit(BitCastPtr v) override { + ExprPtr src_value = v->src_value(); src_value->accept(this); Dtype dst_dtype = v->dtype(); Dtype src_dtype = src_value->dtype(); @@ -530,8 +530,8 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } - TORCH_API void visit(For* v) override { - Expr* var_node = v->var(); + TORCH_API void visit(ForPtr v) override { + ExprPtr var_node = v->var(); v->start()->accept(this); int start = value_.as(); v->stop()->accept(this); @@ -549,7 +549,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { eval_context_.erase(var_node); } - TORCH_API void visit(Ramp* v) override { + TORCH_API void visit(RampPtr v) override { v->base()->accept(this); int base = value().as(); v->stride()->accept(this); @@ -564,7 +564,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { value_ = Value(values); } - TORCH_API void visit(Broadcast* v) override { + TORCH_API void visit(BroadcastPtr v) override { v->value()->accept(this); Value value = this->value(); int lanes = v->lanes(); @@ -581,7 +581,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } - TORCH_API void visit(IfThenElse* v) override { + TORCH_API void visit(IfThenElsePtr v) override { v->condition()->accept(this); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool cond_v; @@ -605,14 +605,14 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } - TORCH_API void visit(Load* v) override { + TORCH_API void visit(LoadPtr v) override { auto iter = buffer_mapping_.find(v->buf()); if (iter == buffer_mapping_.end()) { throw malformed_input("could not find base node in Load", v); } void* ptr = iter->second; - Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices()); + ExprPtr flat_idx = flatten_index(v->buf()->dims(), v->indices()); flat_idx->accept(this); std::vector index = value().as_vec(); ScalarType v_sdtype = v->dtype().scalar_type(); @@ -633,7 +633,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } - TORCH_API void visit(Store* v) override { + TORCH_API void visit(StorePtr v) override { auto iter = buffer_mapping_.find(v->buf()); if (iter == buffer_mapping_.end()) { throw malformed_input("could not find base node in Store", v); @@ -641,7 +641,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { void* ptr = iter->second; - Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices()); + ExprPtr flat_idx = flatten_index(v->buf()->dims(), v->indices()); flat_idx->accept(this); std::vector index = value().as_vec(); ScalarType v_sdtype = v->value()->dtype().scalar_type(); @@ -666,13 +666,13 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } - void visit(ExternalCall* v) override { + void visit(ExternalCallPtr v) override { auto& func_registry = getNNCFunctionRegistry(); if (!func_registry.count(v->func_name())) { throw unimplemented_lowering(v); } - std::vector bufs(v->buf_args()); + std::vector bufs(v->buf_args()); bufs.insert(bufs.begin(), v->buf()); std::vector buf_ptrs; @@ -681,7 +681,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { std::vector buf_dtypes; std::vector extra_args; - for (Buf* b : bufs) { + for (BufPtr b : bufs) { auto iter = buffer_mapping_.find(b); if (iter == buffer_mapping_.end()) { throw malformed_input("could not find buf", v); @@ -690,12 +690,12 @@ class SimpleIREvaluatorImpl : public IRVisitor { buf_ptrs.push_back(iter->second); buf_ranks.push_back(b->dims().size()); buf_dtypes.push_back((int8_t)b->dtype().scalar_type()); - for (Expr* dim_expr : b->dims()) { + for (ExprPtr dim_expr : b->dims()) { dim_expr->accept(this); buf_dims.push_back(value().as()); } } - for (Expr* a : v->args()) { + for (ExprPtr a : v->args()) { a->accept(this); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t val; @@ -722,7 +722,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { } template - void visit_intrinsics_helper(Intrinsics* v) { + void visit_intrinsics_helper(IntrinsicsPtr v) { std::vector values(v->nparams()); for (const auto i : c10::irange(v->nparams())) { v->param(i)->accept(this); @@ -757,7 +757,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { value_ = Value(result); } - TORCH_API void visit(Intrinsics* v) override { + TORCH_API void visit(IntrinsicsPtr v) override { auto ty = v->dtype().scalar_type(); if (v->op_type() == kIsNan) { auto inp_dtype = v->params().at(0)->dtype().scalar_type(); @@ -782,9 +782,9 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } - void visit(Allocate* v) override { - Buf* b = v->buf(); - std::vector dims = b->dims(); + void visit(AllocatePtr v) override { + BufPtr b = v->buf(); + std::vector dims = b->dims(); int total_byte_size = b->dtype().byte_size(); for (auto& dim : dims) { dim->accept(this); @@ -802,8 +802,8 @@ class SimpleIREvaluatorImpl : public IRVisitor { internal_buffers_.insert(std::make_pair(b, std::move(buffer))); } - void visit(Free* v) override { - Buf* b = v->buf(); + void visit(FreePtr v) override { + BufPtr b = v->buf(); int count = internal_buffers_.erase(b); if (count == 0) { throw std::runtime_error( @@ -813,12 +813,12 @@ class SimpleIREvaluatorImpl : public IRVisitor { buffer_mapping_.erase(b); } - void visit(Let* v) override { + void visit(LetPtr v) override { var_by_scope_[scope_].push_back(v->var()); bindVar(v->var(), evaluateExpr(v->value())); } - void visit(Cond* v) override { + void visit(CondPtr v) override { v->condition()->accept(this); if (value().as()) { if (v->true_stmt()) { @@ -950,15 +950,16 @@ class SimpleIREvaluatorImpl : public IRVisitor { } Value value_; - Block* scope_; - std::unordered_map eval_context_; - std::unordered_map> var_by_scope_; - std::unordered_map buffer_mapping_; - std::unordered_map>> internal_buffers_; + BlockPtr scope_; + std::unordered_map eval_context_; + std::unordered_map> var_by_scope_; + std::unordered_map buffer_mapping_; + std::unordered_map>> + internal_buffers_; }; SimpleIREvaluator::SimpleIREvaluator( - Stmt* stmt, + StmtPtr stmt, const std::vector& buffer_args, at::Device device, const std::string& kernel_func_name) @@ -1011,7 +1012,7 @@ void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) { } } -void SimpleIREvaluator::bindVar(Var* v, Expr* e) { +void SimpleIREvaluator::bindVar(VarPtr v, ExprPtr e) { impl_->bindVar(v, impl_->evaluateExpr(e)); } diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 0d6e522..38ec99b 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -97,7 +97,7 @@ class SimpleIREvaluatorImpl; class TORCH_API SimpleIREvaluator : public CodeGen { public: SimpleIREvaluator( - Stmt* stmt, + StmtPtr stmt, const std::vector& buffer_args, at::Device device = at::kCPU, const std::string& kernel_func_name = "func"); @@ -114,7 +114,7 @@ class TORCH_API SimpleIREvaluator : public CodeGen { call(args); } - void bindVar(Var* v, Expr* e); + void bindVar(VarPtr v, ExprPtr e); Value value() const; private: @@ -145,15 +145,15 @@ class ExprEval { std::vector buffer_args_extended = buffer_args; Placeholder ret_buf("ret_val", dtype_, {1}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector indices; - Expr* zero = new IntImm(0); + std::vector indices; + ExprPtr zero = alloc(0); for (size_t i = 0; i < ret_buf.data()->ndim(); i++) { indices.push_back(zero); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Stmt* store_stmt = + StmtPtr store_stmt = // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - new Store(ret_buf.data(), indices, expr.node()); + alloc(ret_buf.data(), indices, expr.node()); buffer_args_extended.emplace_back(ret_buf); codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended)); } @@ -167,7 +167,7 @@ class ExprEval { call(call_args); } - void bindVar(Var* v, Expr* e) { + void bindVar(VarPtr v, ExprPtr e) { codegen_->bindVar(v, e); } @@ -255,14 +255,14 @@ class ExprEval { // Substitutes the given vars with their corresponding expressions in the input // expression. -inline Expr* Substitute(Expr* expr, const VarMapping& var_mapping) { +inline ExprPtr Substitute(ExprPtr expr, const VarMapping& var_mapping) { VarSubMutator var_sub(var_mapping); return expr->accept_mutator(&var_sub); } // Substitutes the given vars with their corresponding expressions in the input // statement. -inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) { +inline StmtPtr Substitute(StmtPtr stmt, const VarMapping& var_mapping) { VarSubMutator var_sub(var_mapping); return stmt->accept_mutator(&var_sub); } @@ -271,7 +271,7 @@ inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) { // their corresponding expressions in the clone. // NOTE: This works because cloning reuses variables and does not create new // ones, and `VarMapping` input has variables as the key. -inline Expr* SubstituteInClone(Expr* expr, const VarMapping& var_mapping) { +inline ExprPtr SubstituteInClone(ExprPtr expr, const VarMapping& var_mapping) { VarSubMutator var_sub(var_mapping); return Expr::clone(expr)->accept_mutator(&var_sub); } @@ -280,7 +280,7 @@ inline Expr* SubstituteInClone(Expr* expr, const VarMapping& var_mapping) { // their corresponding expressions in the clone. // NOTE: This works because cloning reuses variables and does not create new // ones, and `VarMapping` input has variables as the key. -inline Stmt* SubstituteInClone(Stmt* stmt, const VarMapping& var_mapping) { +inline StmtPtr SubstituteInClone(StmtPtr stmt, const VarMapping& var_mapping) { VarSubMutator var_sub(var_mapping); return Stmt::clone(stmt)->accept_mutator(&var_sub); } diff --git a/torch/csrc/jit/tensorexpr/exceptions.h b/torch/csrc/jit/tensorexpr/exceptions.h index 3079748..cf23bbc 100644 --- a/torch/csrc/jit/tensorexpr/exceptions.h +++ b/torch/csrc/jit/tensorexpr/exceptions.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -17,8 +18,8 @@ class Stmt; // Forward declarations of functions namespace std { -TORCH_API std::string to_string(const torch::jit::tensorexpr::Expr*); -TORCH_API std::string to_string(const torch::jit::tensorexpr::Stmt*); +TORCH_API std::string to_string(const torch::jit::tensorexpr::ExprPtr); +TORCH_API std::string to_string(const torch::jit::tensorexpr::StmtPtr); } // namespace std namespace torch { @@ -43,9 +44,9 @@ class unimplemented_lowering : public std::runtime_error { public: explicit unimplemented_lowering() : std::runtime_error("UNIMPLEMENTED LOWERING") {} - explicit unimplemented_lowering(Expr* expr) + explicit unimplemented_lowering(ExprPtr expr) : std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(expr)) {} - explicit unimplemented_lowering(Stmt* stmt) + explicit unimplemented_lowering(StmtPtr stmt) : std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(stmt)) {} }; @@ -54,14 +55,14 @@ class malformed_input : public std::runtime_error { explicit malformed_input() : std::runtime_error("MALFORMED INPUT") {} explicit malformed_input(const std::string& err) : std::runtime_error("MALFORMED INPUT: " + err) {} - explicit malformed_input(Expr* expr) + explicit malformed_input(ExprPtr expr) : std::runtime_error("MALFORMED INPUT: " + std::to_string(expr)) {} - explicit malformed_input(const std::string& err, Expr* expr) + explicit malformed_input(const std::string& err, ExprPtr expr) : std::runtime_error( "MALFORMED INPUT: " + err + " - " + std::to_string(expr)) {} - explicit malformed_input(Stmt* stmt) + explicit malformed_input(StmtPtr stmt) : std::runtime_error("MALFORMED INPUT: " + std::to_string(stmt)) {} - explicit malformed_input(const std::string& err, Stmt* stmt) + explicit malformed_input(const std::string& err, StmtPtr stmt) : std::runtime_error( "MALFORMED INPUT: " + err + " - " + std::to_string(stmt)) {} }; @@ -71,14 +72,14 @@ class malformed_ir : public std::runtime_error { explicit malformed_ir() : std::runtime_error("MALFORMED IR") {} explicit malformed_ir(const std::string& err) : std::runtime_error("MALFORMED IR: " + err) {} - explicit malformed_ir(Expr* expr) + explicit malformed_ir(ExprPtr expr) : std::runtime_error("MALFORMED IR: " + std::to_string(expr)) {} - explicit malformed_ir(const std::string& err, Expr* expr) + explicit malformed_ir(const std::string& err, ExprPtr expr) : std::runtime_error( "MALFORMED IR: " + err + " - " + std::to_string(expr)) {} - explicit malformed_ir(Stmt* stmt) + explicit malformed_ir(StmtPtr stmt) : std::runtime_error("MALFORMED IR: " + std::to_string(stmt)) {} - explicit malformed_ir(const std::string& err, Stmt* stmt) + explicit malformed_ir(const std::string& err, StmtPtr stmt) : std::runtime_error( "MALFORMED IR: " + err + " - " + std::to_string(stmt)) {} }; diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp index a812b49..cbf5ddd 100644 --- a/torch/csrc/jit/tensorexpr/expr.cpp +++ b/torch/csrc/jit/tensorexpr/expr.cpp @@ -360,7 +360,7 @@ ExprHandle Buf::make( const std::vector& dims, Dtype dtype) { return ExprHandle( - new Buf(name_hint, ExprHandleVectorToExprVector(dims), dtype)); + alloc(name_hint, ExprHandleVectorToExprVector(dims), dtype)); } ExprHandle Buf::make(const std::vector& dims, Dtype dtype) { diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 99474c6..fae24ec 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -5,6 +5,7 @@ */ #pragma once +#include #include #include #include @@ -43,7 +44,7 @@ class TORCH_API Expr : public KernelScopedObject { return dtype_; } virtual void accept(IRVisitor* visitor) = 0; - virtual Expr* accept_mutator(IRMutator* mutator) = 0; + virtual ExprPtr accept_mutator(IRMutator* mutator) = 0; IRNodeType expr_type() const { return expr_type_; @@ -63,7 +64,7 @@ class TORCH_API Expr : public KernelScopedObject { * All sub-expressions inside the given expressions are also cloned. Note * that the variables are not deep-copied since they are immutable. */ - static Expr* clone(Expr* s); + static ExprPtr clone(ExprPtr s); private: Dtype dtype_; @@ -77,9 +78,9 @@ class ExprNode : public Base { public: using ExprNodeBase = ExprNode; void accept(IRVisitor* visitor) override { - visitor->visit(static_cast(this)); + visitor->visit(static_to(this)); } - Expr* accept_mutator(IRMutator* mutator) override; + ExprPtr accept_mutator(IRMutator* mutator) override; // pass the constructor to the base class using Base::Base; }; @@ -89,15 +90,13 @@ class ExprNode : public Base { class TORCH_API ExprHandle { public: ExprHandle() = default; - explicit ExprHandle(Expr* node) - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - : base_expr_node_(const_cast(node)) {} + explicit ExprHandle(ExprPtr node) : base_expr_node_(node) {} - Expr* node() { + ExprPtr node() { return base_expr_node_; } - Expr* node() const { + ExprPtr node() const { return base_expr_node_; } @@ -110,12 +109,12 @@ class TORCH_API ExprHandle { #undef IMM_EXPR_DECLARE template - Op* AsNode() { - return dynamic_cast(this->node()); + NodePtr AsNode() { + return to(this->node()); } template - Op* AsNode() const { + NodePtr AsNode() const { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) return const_cast(this)->AsNode(); } @@ -145,7 +144,7 @@ class TORCH_API ExprHandle { ExprHandle operator>>(const ExprHandle& other) const; private: - Expr* base_expr_node_ = nullptr; + ExprPtr base_expr_node_ = nullptr; }; // The underlying representation node to a Var. @@ -154,10 +153,10 @@ class TORCH_API ExprHandle { class TORCH_API Var : public ExprNode { public: static ExprHandle make(const std::string& name_hint, Dtype dtype) { - return ExprHandle(new Var(name_hint, dtype)); + return ExprHandle(alloc(name_hint, dtype)); } static ExprHandle make(Dtype dtype) { - return ExprHandle(new Var("", dtype)); + return ExprHandle(alloc("", dtype)); } // TODO: unique_name @@ -185,10 +184,10 @@ class TORCH_API Buf : public ExprNode { static ExprHandle make(const std::vector& dims, Dtype dtype); // TODO: unique_name - Var* base_handle() const { + VarPtr base_handle() const { return base_handle_; } - void set_base_handle(Var* base_handle) { + void set_base_handle(VarPtr base_handle) { base_handle_ = base_handle; } @@ -201,16 +200,16 @@ class TORCH_API Buf : public ExprNode { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) Buf(const std::string& name_hint, - const std::vector& dims, + const std::vector& dims, Dtype dtype, - Expr* initializer = nullptr) - : Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {} + ExprPtr initializer = nullptr) + : Buf(alloc(name_hint, kHandle), dims, dtype, initializer) {} // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Buf(Var* var, - std::vector dims, + Buf(VarPtr var, + std::vector dims, Dtype dtype, - Expr* initializer = nullptr) + ExprPtr initializer = nullptr) : ExprNodeBase(dtype, kPrimitive), base_handle_(var), dims_(std::move(dims)), @@ -221,20 +220,20 @@ class TORCH_API Buf : public ExprNode { size_t ndim() const { return dims_.size(); } - Expr* dim(size_t index) const { + ExprPtr dim(size_t index) const { if (index >= ndim()) { throw out_of_range_index(); } return dims_[index]; } - std::vector dims() const { + std::vector dims() const { return dims_; } - void set_dims(std::vector dims) { + void set_dims(std::vector dims) { dims_ = dims; }; - Expr* initializer() const { + ExprPtr initializer() const { return initializer_; }; @@ -248,9 +247,9 @@ class TORCH_API Buf : public ExprNode { } private: - Var* base_handle_; - std::vector dims_; - Expr* initializer_; + VarPtr base_handle_; + std::vector dims_; + ExprPtr initializer_; }; class TORCH_API BufHandle : public ExprHandle { @@ -266,12 +265,12 @@ class TORCH_API BufHandle : public ExprHandle { explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {} - explicit BufHandle(Buf* node) : ExprHandle(node) {} - Buf* node() const { - return static_cast(ExprHandle::node()); + explicit BufHandle(BufPtr node) : ExprHandle(node) {} + BufPtr node() const { + return static_to(ExprHandle::node()); } - Buf* node() { - return static_cast(ExprHandle::node()); + BufPtr node() { + return static_to(ExprHandle::node()); } template @@ -315,9 +314,9 @@ class TORCH_API VarHandle : public ExprHandle { explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {} VarHandle(const std::string& name_hint, Dtype dtype) : ExprHandle(Var::make(name_hint, dtype)) {} - explicit VarHandle(Var* node) : ExprHandle(node) {} - Var* node() const { - return static_cast(ExprHandle::node()); + explicit VarHandle(VarPtr node) : ExprHandle(node) {} + VarPtr node() const { + return static_to(ExprHandle::node()); } bool operator==(const VarHandle& other) const { return this->node() == other.node(); @@ -335,10 +334,8 @@ class TORCH_API VarHandle : public ExprHandle { }; template -Expr* ExprNode::accept_mutator(IRMutator* mutator) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - ExprNode* this_mutable = const_cast(this); - return mutator->mutate(static_cast(this_mutable)); +ExprPtr ExprNode::accept_mutator(IRMutator* mutator) { + return mutator->mutate(static_to(this)); } inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) { diff --git a/torch/csrc/jit/tensorexpr/fwd_decls.h b/torch/csrc/jit/tensorexpr/fwd_decls.h new file mode 100644 index 0000000..01a7670 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/fwd_decls.h @@ -0,0 +1,120 @@ +#pragma once +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +using NodePtr = Node*; + +template +NodePtr to(NodePtr x) { + return dynamic_cast>(x); +} + +template +NodePtr static_to(NodePtr x) { + return static_cast>(x); +} + +template +NodePtr alloc(Args&&... args) { + return new Node(std::forward(args)...); +} + +class Buf; +class Expr; +class Stmt; +class Var; + +using BufPtr = NodePtr; +using ExprPtr = NodePtr; +using StmtPtr = NodePtr; +using VarPtr = NodePtr; + +class ExprHandle; + +class Add; +class And; +class BitCast; +class Broadcast; +class Cast; +class CompareSelect; +class Div; +class IfThenElse; +class Intrinsics; +class Let; +class Load; +class Lshift; +class Max; +class MaxTerm; +class Min; +class MinTerm; +class Mod; +class Mul; +class Or; +class Polynomial; +class Ramp; +class ReduceOp; +class RoundOff; +class Rshift; +class Store; +class Sub; +class Term; +class Xor; +using AddPtr = NodePtr; +using AndPtr = NodePtr; +using BitCastPtr = NodePtr; +using BroadcastPtr = NodePtr; +using CastPtr = NodePtr; +using CompareSelectPtr = NodePtr; +using DivPtr = NodePtr
; +using IfThenElsePtr = NodePtr; +using IntrinsicsPtr = NodePtr; +using LetPtr = NodePtr; +using LoadPtr = NodePtr; +using LshiftPtr = NodePtr; +using MaxPtr = NodePtr; +using MaxTermPtr = NodePtr; +using MinPtr = NodePtr; +using MinTermPtr = NodePtr; +using ModPtr = NodePtr; +using MulPtr = NodePtr; +using OrPtr = NodePtr; +using PolynomialPtr = NodePtr; +using RampPtr = NodePtr; +using ReduceOpPtr = NodePtr; +using RoundOffPtr = NodePtr; +using RshiftPtr = NodePtr; +using StorePtr = NodePtr; +using SubPtr = NodePtr; +using TermPtr = NodePtr; +using XorPtr = NodePtr; + +class Allocate; +class AtomicAdd; +class Block; +class Cond; +class ExternalCall; +class For; +class Free; +class SyncThreads; +using AllocatePtr = NodePtr; +using AtomicAddPtr = NodePtr; +using BlockPtr = NodePtr; +using CondPtr = NodePtr; +using ExternalCallPtr = NodePtr; +using ForPtr = NodePtr; +using FreePtr = NodePtr; +using SyncThreadsPtr = NodePtr; + +#define IMM_DECLARE(Type, Name) \ + class Name##Imm; \ + using Name##ImmPtr = NodePtr; +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); +#undef IMM_DECLARE + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/half_support.h b/torch/csrc/jit/tensorexpr/half_support.h index f7e5426..15d48cd 100644 --- a/torch/csrc/jit/tensorexpr/half_support.h +++ b/torch/csrc/jit/tensorexpr/half_support.h @@ -22,21 +22,21 @@ class HalfChecker : public IRVisitor { return hasHalf_; } - void visit(Load* v) override { + void visit(LoadPtr v) override { hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; IRVisitor::visit(v); } - void visit(Store* v) override { + void visit(StorePtr v) override { hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half; IRVisitor::visit(v); } - void visit(HalfImm* v) override { + void visit(HalfImmPtr v) override { hasHalf_ = true; } - void visit(Cast* v) override { + void visit(CastPtr v) override { hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; IRVisitor::visit(v); } @@ -47,40 +47,40 @@ class HalfChecker : public IRVisitor { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class HalfRewriter : public IRMutator { - Expr* mutate(Load* v) override { - Expr* child = IRMutator::mutate(v); + ExprPtr mutate(LoadPtr v) override { + ExprPtr child = IRMutator::mutate(v); if (child->dtype().scalar_type() != ScalarType::Half) { return child; } - Expr* ret = - new Cast(child->dtype().cloneWithScalarType(ScalarType::Float), child); + ExprPtr ret = alloc( + child->dtype().cloneWithScalarType(ScalarType::Float), child); inserted_half_casts_.insert(ret); return ret; } - Stmt* mutate(Store* v) override { + StmtPtr mutate(StorePtr v) override { // Since mutation changes the `value()` expression in-place, we need to // get the dtype of the `value()` before that is mutated. Dtype newType = v->value()->dtype(); - Expr* new_val = v->value()->accept_mutator(this); + ExprPtr new_val = v->value()->accept_mutator(this); if (newType.scalar_type() == ScalarType::Half) { new_val = - new Cast(newType.cloneWithScalarType(ScalarType::Half), new_val); + alloc(newType.cloneWithScalarType(ScalarType::Half), new_val); inserted_half_casts_.insert(new_val); } - return new Store(v->buf(), v->indices(), new_val); + return alloc(v->buf(), v->indices(), new_val); } - Expr* mutate(HalfImm* v) override { - return new Cast(kFloat, v); + ExprPtr mutate(HalfImmPtr v) override { + return alloc(kFloat, v); } - Expr* mutate(Cast* v) override { - Expr* child = v->src_value()->accept_mutator(this); + ExprPtr mutate(CastPtr v) override { + ExprPtr child = v->src_value()->accept_mutator(this); // just don't allow half casts we didn't insert. if (v->dtype().scalar_type() == ScalarType::Half) { @@ -90,11 +90,11 @@ class HalfRewriter : public IRMutator { } // Remove Half(Float()) and friends. - Cast* cast_child = dynamic_cast(child); + CastPtr cast_child = to(child); if (cast_child) { if (v->dtype().is_floating_point() && cast_child->dtype().is_floating_point()) { - return new Cast(v->dtype(), cast_child->src_value()); + return alloc(v->dtype(), cast_child->src_value()); } } @@ -102,23 +102,23 @@ class HalfRewriter : public IRMutator { return v; } - return new Cast(v->dtype(), child); + return alloc(v->dtype(), child); } - Stmt* mutate(Let* v) override { + StmtPtr mutate(LetPtr v) override { if (v->dtype().scalar_type() == ScalarType::Half) { - Var* load_new_var = new Var(v->var()->name_hint(), kFloat); - Expr* new_value = new Cast( + VarPtr load_new_var = alloc(v->var()->name_hint(), kFloat); + ExprPtr new_value = alloc( v->dtype().cloneWithScalarType(ScalarType::Float), v->value()->accept_mutator(this)); var_map[v->var()] = load_new_var; - return new Let(load_new_var, new_value); + return alloc(load_new_var, new_value); } return IRMutator::mutate(v); } - Expr* mutate(Var* v) override { + ExprPtr mutate(VarPtr v) override { auto it = var_map.find(v); if (it != var_map.end()) { return it->second; @@ -128,8 +128,8 @@ class HalfRewriter : public IRMutator { } private: - std::unordered_set inserted_half_casts_; - std::unordered_map var_map; + std::unordered_set inserted_half_casts_; + std::unordered_map var_map; }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/hash_provider.cpp b/torch/csrc/jit/tensorexpr/hash_provider.cpp index facdc58..fbc257d 100644 --- a/torch/csrc/jit/tensorexpr/hash_provider.cpp +++ b/torch/csrc/jit/tensorexpr/hash_provider.cpp @@ -28,91 +28,91 @@ bool SimplifierHashType::operator!=(const size_t other) const { return _h != other; } -void HashProvider::visit(Add* v) { +void HashProvider::visit(AddPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "+", hashOf(v->rhs()))); } -void HashProvider::visit(Sub* v) { +void HashProvider::visit(SubPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "-", hashOf(v->rhs()))); } -void HashProvider::visit(Mul* v) { +void HashProvider::visit(MulPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "*", hashOf(v->rhs()))); } -void HashProvider::visit(Div* v) { +void HashProvider::visit(DivPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "/", hashOf(v->rhs()))); } -void HashProvider::visit(Mod* v) { +void HashProvider::visit(ModPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs()))); } -void HashProvider::visit(Max* v) { +void HashProvider::visit(MaxPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "Mx", hashOf(v->rhs()))); } -void HashProvider::visit(Min* v) { +void HashProvider::visit(MinPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "Mn", hashOf(v->rhs()))); } -void HashProvider::visit(And* v) { +void HashProvider::visit(AndPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "&", hashOf(v->rhs()))); } -void HashProvider::visit(Or* v) { +void HashProvider::visit(OrPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "|", hashOf(v->rhs()))); } -void HashProvider::visit(Xor* v) { +void HashProvider::visit(XorPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "^", hashOf(v->rhs()))); } -void HashProvider::visit(Lshift* v) { +void HashProvider::visit(LshiftPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), "<<", hashOf(v->rhs()))); } -void HashProvider::visit(Rshift* v) { +void HashProvider::visit(RshiftPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); putHash(v, hash_combine(hashOf(v->lhs()), ">>", hashOf(v->rhs()))); } -void HashProvider::visit(CompareSelect* v) { +void HashProvider::visit(CompareSelectPtr v) { CACHE_GUARD(); v->lhs()->accept(this); v->rhs()->accept(this); @@ -128,18 +128,18 @@ void HashProvider::visit(CompareSelect* v) { hashOf(v->ret_val2()))); } -void HashProvider::visit(Cast* v) { +void HashProvider::visit(CastPtr v) { CACHE_GUARD(); v->src_value()->accept(this); putHash(v, hash_combine("cast", v->dtype(), hashOf(v->src_value()))); } -void HashProvider::visit(Var* v) { +void HashProvider::visit(VarPtr v) { CACHE_GUARD(); putHash(v, hash_combine("var", name_manager_.get_unique_name(v))); } -void HashProvider::visit(Ramp* v) { +void HashProvider::visit(RampPtr v) { CACHE_GUARD(); v->base()->accept(this); v->stride()->accept(this); @@ -148,22 +148,22 @@ void HashProvider::visit(Ramp* v) { hash_combine("ramp", hashOf(v->base()), hashOf(v->stride()), v->lanes())); } -void HashProvider::visit(Load* v) { +void HashProvider::visit(LoadPtr v) { CACHE_GUARD(); v->base_handle()->accept(this); SimplifierHashType indices_hash; - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { ind->accept(this); indices_hash = hash_combine(indices_hash, hashOf(ind)); } putHash(v, hash_combine("load", hashOf(v->base_handle()), indices_hash)); } -void HashProvider::visit(Store* v) { +void HashProvider::visit(StorePtr v) { CACHE_GUARD(); v->base_handle()->accept(this); SimplifierHashType indices_hash; - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { ind->accept(this); indices_hash = hash_combine(indices_hash, hashOf(ind)); } @@ -174,18 +174,18 @@ void HashProvider::visit(Store* v) { "store", hashOf(v->base_handle()), indices_hash, hashOf(v->value()))); } -void HashProvider::visit(Block* v) { +void HashProvider::visit(BlockPtr v) { CACHE_GUARD(); SimplifierHashType hash; - for (Stmt* s : *v) { + for (StmtPtr s : *v) { s->accept(this); hash = hash_combine(hash, hashOf(s)); } putHash(v, hash); } -void HashProvider::visit(For* v) { +void HashProvider::visit(ForPtr v) { CACHE_GUARD(); v->var()->accept(this); v->start()->accept(this); @@ -202,13 +202,13 @@ void HashProvider::visit(For* v) { putHash(v, hash); } -void HashProvider::visit(Broadcast* v) { +void HashProvider::visit(BroadcastPtr v) { CACHE_GUARD(); v->value()->accept(this); putHash(v, hash_combine("broadcast", hashOf(v->value()), v->lanes())); } -void HashProvider::visit(IfThenElse* v) { +void HashProvider::visit(IfThenElsePtr v) { CACHE_GUARD(); v->condition()->accept(this); v->true_value()->accept(this); @@ -223,7 +223,7 @@ void HashProvider::visit(IfThenElse* v) { hashOf(v->false_value()))); } -void HashProvider::visit(Intrinsics* v) { +void HashProvider::visit(IntrinsicsPtr v) { CACHE_GUARD(); // calls to rand are not symbolic and have a different value each time, they // should not hash to anything and this is the best we can do. @@ -242,35 +242,35 @@ void HashProvider::visit(Intrinsics* v) { putHash(v, hash); } -void HashProvider::visit(Allocate* v) { +void HashProvider::visit(AllocatePtr v) { CACHE_GUARD(); - Var* buffer_var = v->buffer_var(); + VarPtr buffer_var = v->buffer_var(); buffer_var->accept(this); SimplifierHashType hash = hash_combine("allocate", hashOf(buffer_var), v->dtype()); - std::vector dims = v->dims(); - for (Expr* dim : dims) { + std::vector dims = v->dims(); + for (ExprPtr dim : dims) { dim->accept(this); hash = hash_combine(hash, hashOf(dim)); } putHash(v, hash); } -void HashProvider::visit(Free* v) { +void HashProvider::visit(FreePtr v) { CACHE_GUARD(); - Var* buffer_var = v->buffer_var(); + VarPtr buffer_var = v->buffer_var(); buffer_var->accept(this); putHash(v, hash_combine("free", hashOf(buffer_var))); } -void HashProvider::visit(Cond* v) { +void HashProvider::visit(CondPtr v) { CACHE_GUARD(); - Expr* condition = v->condition(); - Stmt* true_stmt = v->true_stmt(); - Stmt* false_stmt = v->false_stmt(); + ExprPtr condition = v->condition(); + StmtPtr true_stmt = v->true_stmt(); + StmtPtr false_stmt = v->false_stmt(); condition->accept(this); SimplifierHashType hash = hash_combine("cond", hashOf(condition)); @@ -286,12 +286,12 @@ void HashProvider::visit(Cond* v) { putHash(v, hash); } -void HashProvider::visit(Term* v) { +void HashProvider::visit(TermPtr v) { CACHE_GUARD(); v->scalar()->accept(this); SimplifierHashType hash = hash_combine("term", hashOf(v->scalar())); - for (auto* c : v->variables()) { + for (auto c : v->variables()) { c->accept(this); hash = hash_combine(hash, hashOf(c)); } @@ -299,12 +299,12 @@ void HashProvider::visit(Term* v) { putHash(v, hash); } -void HashProvider::visit(Polynomial* v) { +void HashProvider::visit(PolynomialPtr v) { CACHE_GUARD(); v->scalar()->accept(this); SimplifierHashType hash = hash_combine("term", hashOf(v->scalar())); - for (auto* c : v->variables()) { + for (auto c : v->variables()) { c->accept(this); hash = hash_combine(hash, hashOf(c)); } @@ -312,7 +312,7 @@ void HashProvider::visit(Polynomial* v) { putHash(v, hash); } -void HashProvider::visit(MaxTerm* v) { +void HashProvider::visit(MaxTermPtr v) { CACHE_GUARD(); SimplifierHashType hash = hash_combine("maxterm"); if (v->scalar()) { @@ -320,7 +320,7 @@ void HashProvider::visit(MaxTerm* v) { hash = hash_combine(hash, hashOf(v->scalar())); } - for (auto* c : v->variables()) { + for (auto c : v->variables()) { c->accept(this); hash = hash_combine(hash, hashOf(c)); } @@ -328,7 +328,7 @@ void HashProvider::visit(MaxTerm* v) { putHash(v, hash); } -void HashProvider::visit(MinTerm* v) { +void HashProvider::visit(MinTermPtr v) { CACHE_GUARD(); SimplifierHashType hash = hash_combine("minterm"); if (v->scalar()) { @@ -336,7 +336,7 @@ void HashProvider::visit(MinTerm* v) { hash = hash_combine(hash, hashOf(v->scalar())); } - for (auto* c : v->variables()) { + for (auto c : v->variables()) { c->accept(this); hash = hash_combine(hash, hashOf(c)); } diff --git a/torch/csrc/jit/tensorexpr/hash_provider.h b/torch/csrc/jit/tensorexpr/hash_provider.h index 4783e0c..5a33f04 100644 --- a/torch/csrc/jit/tensorexpr/hash_provider.h +++ b/torch/csrc/jit/tensorexpr/hash_provider.h @@ -53,7 +53,7 @@ class Polynomial; class TORCH_API HashProvider : public IRVisitor { public: template - SimplifierHashType hash(T* e) { + SimplifierHashType hash(T e) { // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) e->accept(this); return hashOf(e); @@ -67,46 +67,46 @@ class TORCH_API HashProvider : public IRVisitor { exprToHash_.clear(); } - void visit(Add* v) override; - void visit(Sub* v) override; - void visit(Mul* v) override; - void visit(Div* v) override; - void visit(Mod* v) override; - void visit(Max* v) override; - void visit(Min* v) override; - void visit(And* v) override; - void visit(Or* v) override; - void visit(Xor* v) override; - void visit(Lshift* v) override; - void visit(Rshift* v) override; - void visit(CompareSelect* v) override; + void visit(AddPtr v) override; + void visit(SubPtr v) override; + void visit(MulPtr v) override; + void visit(DivPtr v) override; + void visit(ModPtr v) override; + void visit(MaxPtr v) override; + void visit(MinPtr v) override; + void visit(AndPtr v) override; + void visit(OrPtr v) override; + void visit(XorPtr v) override; + void visit(LshiftPtr v) override; + void visit(RshiftPtr v) override; + void visit(CompareSelectPtr v) override; // NOLINTNEXTLINE #define IMM_VISIT(Type, Name) \ - void visit(Name##Imm* v) override { \ + void visit(Name##ImmPtr v) override { \ CACHE_GUARD(); \ putHash(v, hash_combine(#Name, v->value())); \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); #undef IMM_VISIT - void visit(Cast* v) override; - void visit(Var* v) override; - void visit(Ramp* v) override; - void visit(Load* v) override; - void visit(Store* v) override; - void visit(Block* v) override; - void visit(For* v) override; - void visit(Broadcast* v) override; - void visit(IfThenElse* v) override; - void visit(Intrinsics* v) override; - void visit(Allocate* v) override; - void visit(Free* v) override; - void visit(Cond* v) override; - void visit(Term* v) override; - void visit(Polynomial* v) override; - void visit(MaxTerm* v) override; - void visit(MinTerm* v) override; + void visit(CastPtr v) override; + void visit(VarPtr v) override; + void visit(RampPtr v) override; + void visit(LoadPtr v) override; + void visit(StorePtr v) override; + void visit(BlockPtr v) override; + void visit(ForPtr v) override; + void visit(BroadcastPtr v) override; + void visit(IfThenElsePtr v) override; + void visit(IntrinsicsPtr v) override; + void visit(AllocatePtr v) override; + void visit(FreePtr v) override; + void visit(CondPtr v) override; + void visit(TermPtr v) override; + void visit(PolynomialPtr v) override; + void visit(MaxTermPtr v) override; + void visit(MinTermPtr v) override; template SimplifierHashType hash_combine(const Types&... args) { @@ -116,7 +116,7 @@ class TORCH_API HashProvider : public IRVisitor { } private: - SimplifierHashType hashOf(Expr* e) { + SimplifierHashType hashOf(ExprPtr e) { auto it = exprToHash_.find(e); if (it != exprToHash_.end()) { return it->second; @@ -132,7 +132,7 @@ class TORCH_API HashProvider : public IRVisitor { return hash; } - SimplifierHashType hashOf(Stmt* s) { + SimplifierHashType hashOf(StmtPtr s) { auto it = exprToHash_.find(s); if (it != exprToHash_.end()) { return it->second; @@ -169,7 +169,7 @@ class TORCH_API HashProvider : public IRVisitor { (seed._h >> 4); } - void _hash_combine(SimplifierHashType& seed, Expr* e) { + void _hash_combine(SimplifierHashType& seed, ExprPtr e) { _hash_combine(seed, hash(e)); } diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp index 64ad8fb..f66c0c5 100644 --- a/torch/csrc/jit/tensorexpr/ir.cpp +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -12,7 +12,7 @@ static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) { return Dtype(buffer_dtype, index_dtype.lanes()); } -static Dtype dtypeOfIndices(const std::vector& indices) { +static Dtype dtypeOfIndices(const std::vector& indices) { if (!indices.size()) { // Return something so we can handle scalar buffers. return kInt; @@ -20,7 +20,7 @@ static Dtype dtypeOfIndices(const std::vector& indices) { return indices.at(0)->dtype(); } -void castIndicesToInts(std::vector& indices) { +void castIndicesToInts(std::vector& indices) { // Cast all indices to either Int or Long auto index_dtype = ScalarType::Int; for (auto& index : indices) { @@ -35,17 +35,17 @@ void castIndicesToInts(std::vector& indices) { const Dtype& dt = index->dtype(); if (c10::isIntegralType(dt.scalar_type(), true) && dt.scalar_type() != index_dtype) { - index = new Cast(Dtype(index_dtype, dt.lanes()), index); + index = alloc(Dtype(index_dtype, dt.lanes()), index); } } } -Load::Load(Dtype dtype, Buf* buf, std::vector indices) +Load::Load(Dtype dtype, BufPtr buf, std::vector indices) : ExprNodeBase(dtype), buf_(buf), indices_(std::move(indices)) { castIndicesToInts(indices_); } -Load::Load(Buf* buf, const std::vector& indices) +Load::Load(BufPtr buf, const std::vector& indices) : Load(ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), buf, indices) {} ExprHandle Load::make( @@ -53,7 +53,7 @@ ExprHandle Load::make( const BufHandle& buf, const std::vector& indices) { return ExprHandle( - new Load(dtype, buf.node(), ExprHandleVectorToExprVector(indices))); + alloc(dtype, buf.node(), ExprHandleVectorToExprVector(indices))); } ExprHandle Load::make( @@ -62,22 +62,22 @@ ExprHandle Load::make( return Load::make(buf.dtype(), buf, indices); } -Store::Store(Buf* buf, std::vector indices, Expr* value) +Store::Store(BufPtr buf, std::vector indices, ExprPtr value) : buf_(buf), indices_(std::move(indices)), value_(value) { castIndicesToInts(indices_); } -Store* Store::make( +StorePtr Store::make( const BufHandle& buf, const std::vector& indices, const ExprHandle& value) { - return new Store( + return alloc( buf.node(), ExprHandleVectorToExprVector(indices), value.node()); } -Expr* flatten_index( - const std::vector& dims, - const std::vector& indices) { +ExprPtr flatten_index( + const std::vector& dims, + const std::vector& indices) { // Handle already flattened indices first if (indices.size() == 1) { return indices[0]; @@ -88,19 +88,19 @@ Expr* flatten_index( throw malformed_input("dimensions mismatch in flatten_index"); } if (ndim == 0) { - return new IntImm(0); + return alloc(0); } - std::vector strides(ndim); + std::vector strides(ndim); // stride[i] = stride[i+1]*dims[i+1], i < ndim-1 // stride[i] = 1, i = ndim-1 - strides[ndim - 1] = new IntImm(1); + strides[ndim - 1] = alloc(1); for (size_t i = 1; i < ndim; i++) { - strides[ndim - 1 - i] = new Mul(strides[ndim - i], dims[ndim - i]); + strides[ndim - 1 - i] = alloc(strides[ndim - i], dims[ndim - i]); } - Expr* total_index = new IntImm(0); + ExprPtr total_index = alloc(0); for (const auto i : c10::irange(ndim)) { - total_index = new Add(total_index, new Mul(indices[i], strides[i])); + total_index = alloc(total_index, alloc(indices[i], strides[i])); } return total_index; } @@ -120,7 +120,7 @@ Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) { Dtype Intrinsics::IntrinsicsDtype( IntrinsicsOp op_type, - const std::vector& params) { + const std::vector& params) { // TODO: check the op_type and make a real decision // Doesnt this fail with kRand? if (params.size() == 0) { @@ -176,23 +176,23 @@ int Intrinsics::OpArgCount(IntrinsicsOp op_type) { } } -ExternalCall* ExternalCall::make( +ExternalCallPtr ExternalCall::make( BufHandle buf, const std::string& func_name, const std::vector& buf_args, const std::vector& args) { - std::vector buf_arg_nodes; + std::vector buf_arg_nodes; buf_arg_nodes.reserve(buf_args.size()); for (const BufHandle& buf_arg : buf_args) { buf_arg_nodes.push_back(buf_arg.node()); } - return new ExternalCall( + return alloc( buf.node(), func_name, buf_arg_nodes, ExprHandleVectorToExprVector(args)); } -std::vector ExprHandleVectorToExprVector( +std::vector ExprHandleVectorToExprVector( const std::vector& v) { - std::vector result(v.size()); + std::vector result(v.size()); for (const auto i : c10::irange(v.size())) { result[i] = v[i].node(); } @@ -200,7 +200,7 @@ std::vector ExprHandleVectorToExprVector( } std::vector ExprVectorToExprHandleVector( - const std::vector& v) { + const std::vector& v) { std::vector result(v.size()); for (const auto i : c10::irange(v.size())) { result[i] = ExprHandle(v[i]); @@ -208,15 +208,17 @@ std::vector ExprVectorToExprHandleVector( return result; } -std::vector VarHandleVectorToVarVector(const std::vector& v) { - std::vector result(v.size()); +std::vector VarHandleVectorToVarVector( + const std::vector& v) { + std::vector result(v.size()); for (const auto i : c10::irange(v.size())) { result[i] = v[i].node(); } return result; } -std::vector VarVectorToVarHandleVector(const std::vector& v) { +std::vector VarVectorToVarHandleVector( + const std::vector& v) { std::vector result(v.size()); for (const auto i : c10::irange(v.size())) { result[i] = VarHandle(v[i]); @@ -224,10 +226,10 @@ std::vector VarVectorToVarHandleVector(const std::vector& v) { return result; } -bool immediateIsNegative(Expr* e) { -#define TYPE_CASE(Type, Name) \ - if (Name##Imm* imm = dynamic_cast(e)) { \ - return imm->value() < 0; \ +bool immediateIsNegative(ExprPtr e) { +#define TYPE_CASE(Type, Name) \ + if (Name##ImmPtr imm = to(e)) { \ + return imm->value() < 0; \ } AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); #undef TYPE_CASE diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 680a78a..761b233 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -68,18 +69,18 @@ class Placeholder; class TORCH_API Cast : public ExprNode { public: - Expr* src_value() const { + ExprPtr src_value() const { return src_value_; } - void set_src_value(Expr* src_value) { + void set_src_value(ExprPtr src_value) { src_value_ = src_value; } static ExprHandle make(Dtype dtype, const ExprHandle& src_value) { - return ExprHandle(new Cast(dtype, src_value.node())); + return ExprHandle(alloc(dtype, src_value.node())); } - Cast(Dtype dtype, Expr* src_value) + Cast(Dtype dtype, ExprPtr src_value) : ExprNodeBase(dtype, kCast), src_value_(src_value) {} bool isConstant() const override { @@ -87,7 +88,7 @@ class TORCH_API Cast : public ExprNode { } private: - Expr* src_value_; + ExprPtr src_value_; }; template @@ -98,18 +99,18 @@ ExprHandle cast(const ExprHandle& src_value) { // This is a bitwise cast, akin to bitcast in LLVM class TORCH_API BitCast : public ExprNode { public: - Expr* src_value() const { + ExprPtr src_value() const { return src_value_; } - void set_src_value(Expr* src_value) { + void set_src_value(ExprPtr src_value) { src_value_ = src_value; } static ExprHandle make(Dtype dtype, const ExprHandle& src_value) { - return ExprHandle(new BitCast(dtype, src_value.node())); + return ExprHandle(alloc(dtype, src_value.node())); } - BitCast(Dtype dtype, Expr* src_value) + BitCast(Dtype dtype, ExprPtr src_value) : ExprNodeBase(dtype, kBitCast), src_value_(src_value) { TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size()); } @@ -119,7 +120,7 @@ class TORCH_API BitCast : public ExprNode { } private: - Expr* src_value_; + ExprPtr src_value_; }; template @@ -133,29 +134,29 @@ ExprHandle bitcast(const ExprHandle& src_value) { template class BinaryOpNode : public ExprNode { public: - Expr* lhs() const { + ExprPtr lhs() const { return this->lhs_; } - Expr* rhs() const { + ExprPtr rhs() const { return this->rhs_; } - void set_lhs(Expr* lhs) { + void set_lhs(ExprPtr lhs) { lhs_ = lhs; } - void set_rhs(Expr* rhs) { + void set_rhs(ExprPtr rhs) { rhs_ = rhs; } static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) { - return ExprHandle(new Op(lhs.node(), rhs.node())); + return ExprHandle(alloc(lhs.node(), rhs.node())); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) BinaryOpNode( - Expr* lhs_v, - Expr* rhs_v, + ExprPtr lhs_v, + ExprPtr rhs_v, IRNodeType expr_type, ScalarType ret_type = ScalarType::Undefined) : ExprNode( @@ -166,46 +167,46 @@ class BinaryOpNode : public ExprNode { rhs_(CastIfNeeded(rhs_v, ExprNode::dtype())) {} private: - static Expr* CastIfNeeded(Expr* expr, Dtype dst_dtype) { + static ExprPtr CastIfNeeded(ExprPtr expr, Dtype dst_dtype) { if (expr->dtype() == dst_dtype) { return expr; } return Cast::make(dst_dtype, ExprHandle(expr)).node(); } - Expr* lhs_; - Expr* rhs_; + ExprPtr lhs_; + ExprPtr rhs_; }; class TORCH_API Add : public BinaryOpNode { public: - Add(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {} + Add(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {} }; class TORCH_API Sub : public BinaryOpNode { public: - Sub(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kSub) {} + Sub(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kSub) {} }; class TORCH_API Mul : public BinaryOpNode { public: - Mul(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMul) {} + Mul(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMul) {} }; class TORCH_API Div : public BinaryOpNode
{ public: - Div(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {} + Div(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {} }; class TORCH_API Mod : public BinaryOpNode { public: - Mod(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {} + Mod(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {} }; template class BitwiseOpNode : public BinaryOpNode { public: - BitwiseOpNode(Expr* lhs, Expr* rhs, IRNodeType type) + BitwiseOpNode(ExprPtr lhs, ExprPtr rhs, IRNodeType type) : BinaryOpNode(lhs, rhs, type) {} static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) { @@ -221,27 +222,29 @@ class BitwiseOpNode : public BinaryOpNode { class TORCH_API And : public BitwiseOpNode { public: - And(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {} + And(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {} }; class TORCH_API Or : public BitwiseOpNode { public: - Or(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {} + Or(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {} }; class TORCH_API Xor : public BitwiseOpNode { public: - Xor(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {} + Xor(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {} }; class TORCH_API Lshift : public BitwiseOpNode { public: - Lshift(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {} + Lshift(ExprPtr lhs, ExprPtr rhs) + : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {} }; class TORCH_API Rshift : public BitwiseOpNode { public: - Rshift(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {} + Rshift(ExprPtr lhs, ExprPtr rhs) + : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {} }; // TODO: add TORCH_API @@ -251,7 +254,7 @@ class Max : public BinaryOpNode { bool propagate_nans_; public: - Max(Expr* lhs, Expr* rhs, bool propagate_nans) + Max(ExprPtr lhs, ExprPtr rhs, bool propagate_nans) : BinaryOpNode(lhs, rhs, IRNodeType::kMax), propagate_nans_(propagate_nans) {} @@ -264,7 +267,7 @@ class Max : public BinaryOpNode { const ExprHandle& lhs, const ExprHandle& rhs, bool propagate_nans) { - return ExprHandle(new Max(lhs.node(), rhs.node(), propagate_nans)); + return ExprHandle(alloc(lhs.node(), rhs.node(), propagate_nans)); } }; @@ -275,7 +278,7 @@ class Min : public BinaryOpNode { bool propagate_nans_; public: - Min(Expr* lhs, Expr* rhs, bool propagate_nans) + Min(ExprPtr lhs, ExprPtr rhs, bool propagate_nans) : BinaryOpNode(lhs, rhs, IRNodeType::kMin), propagate_nans_(propagate_nans) {} @@ -288,7 +291,7 @@ class Min : public BinaryOpNode { const ExprHandle& lhs, const ExprHandle& rhs, bool propagate_nans) { - return ExprHandle(new Min(lhs.node(), rhs.node(), propagate_nans)); + return ExprHandle(alloc(lhs.node(), rhs.node(), propagate_nans)); } }; @@ -305,7 +308,7 @@ class Min : public BinaryOpNode { return value_; \ } \ static ExprHandle make(Type value) { \ - return ExprHandle(new Name##Imm(value)); \ + return ExprHandle(alloc(value)); \ } \ \ private: \ @@ -316,11 +319,11 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); // Get immediate by ScalarType. template -Expr* getImmediateByType(ScalarType immType, T initialVal) { +ExprPtr getImmediateByType(ScalarType immType, T initialVal) { switch (immType) { #define TYPE_CASE(Type, Name) \ case ScalarType::Name: \ - return new Name##Imm(initialVal); + return alloc(initialVal); // NOLINTNEXTLINE(bugprone-branch-clone) AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); #undef TYPE_CASE @@ -331,15 +334,15 @@ Expr* getImmediateByType(ScalarType immType, T initialVal) { } template -Expr* getImmediateByType(Dtype dtype, T initialVal) { +ExprPtr getImmediateByType(Dtype dtype, T initialVal) { return getImmediateByType(dtype.scalar_type(), initialVal); } template -T immediateAs(Expr* e) { -#define TYPE_CASE(Type, Name) \ - if (Name##Imm* imm = dynamic_cast(e)) { \ - return imm->value(); \ +T immediateAs(ExprPtr e) { +#define TYPE_CASE(Type, Name) \ + if (Name##ImmPtr imm = to(e)) { \ + return imm->value(); \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); #undef TYPE_CASE @@ -353,10 +356,10 @@ T immediateAs(ExprHandle e) { } template -bool immediateEquals(Expr* e, T val) { -#define TYPE_CASE(Type, Name) \ - if (Name##Imm* imm = dynamic_cast(e)) { \ - return imm->value() == val; \ +bool immediateEquals(ExprPtr e, T val) { +#define TYPE_CASE(Type, Name) \ + if (Name##ImmPtr imm = to(e)) { \ + return imm->value() == val; \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); #undef TYPE_CASE @@ -364,24 +367,24 @@ bool immediateEquals(Expr* e, T val) { return false; } -TORCH_API bool immediateIsNegative(Expr* e); +TORCH_API bool immediateIsNegative(ExprPtr e); // Represents a ramp vector node: // [base, base + 1 * stride, ... , base + (lanes - 1) * stride] class TORCH_API Ramp : public ExprNode { public: - Expr* base() const { + ExprPtr base() const { return base_; } - Expr* stride() const { + ExprPtr stride() const { return stride_; } - void set_base(Expr* base) { + void set_base(ExprPtr base) { base_ = base; } - void set_stride(Expr* stride) { + void set_stride(ExprPtr stride) { stride_ = stride; } @@ -392,45 +395,45 @@ class TORCH_API Ramp : public ExprNode { if (stride.dtype() != base.dtype()) { throw malformed_input("Bad stride in Ramp"); } - return ExprHandle(new Ramp(base.node(), stride.node(), lanes)); + return ExprHandle(alloc(base.node(), stride.node(), lanes)); } int lanes() const { return lanes_; } - Ramp(Expr* base, Expr* stride, int lanes) + Ramp(ExprPtr base, ExprPtr stride, int lanes) : ExprNodeBase(Dtype(base->dtype(), lanes)), base_(base), stride_(stride), lanes_(lanes) {} private: - Expr* base_; - Expr* stride_; + ExprPtr base_; + ExprPtr stride_; int lanes_; }; class TORCH_API Load : public ExprNode { public: - Var* base_handle() const { + VarPtr base_handle() const { return buf_->base_handle(); } - std::vector indices() const { + std::vector indices() const { return indices_; } - Expr* flat_index() const { + ExprPtr flat_index() const { TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened."); return indices_[0]; } - Buf* buf() const { + BufPtr buf() const { return buf_; } - void set_buf(Buf* buf) { + void set_buf(BufPtr buf) { buf_ = buf; } - void set_indices(std::vector indices) { + void set_indices(std::vector indices) { indices_ = indices; } @@ -442,21 +445,21 @@ class TORCH_API Load : public ExprNode { const BufHandle& buf, const std::vector& indices); - Load(Dtype dtype, Buf* base_handle, std::vector indices); - Load(Buf* base_handle, const std::vector& indices); + Load(Dtype dtype, BufPtr base_handle, std::vector indices); + Load(BufPtr base_handle, const std::vector& indices); private: - Buf* buf_; - std::vector indices_; + BufPtr buf_; + std::vector indices_; }; class TORCH_API Broadcast : public ExprNode { public: - Expr* value() const { + ExprPtr value() const { return value_; } - void set_value(Expr* value) { + void set_value(ExprPtr value) { value_ = value; } @@ -464,43 +467,43 @@ class TORCH_API Broadcast : public ExprNode { return lanes_; } static ExprHandle make(const ExprHandle& value, int lanes) { - return ExprHandle(new Broadcast(value.node(), lanes)); + return ExprHandle(alloc(value.node(), lanes)); } - Broadcast(Expr* value, int lanes) + Broadcast(ExprPtr value, int lanes) : ExprNodeBase(Dtype(value->dtype(), lanes)), value_(value), lanes_(lanes) {} private: - Expr* value_; + ExprPtr value_; int lanes_; }; class TORCH_API IfThenElse : public ExprNode { public: - Expr* condition() const { + ExprPtr condition() const { return condition_; } // Lazily evaluated only if condition is true - Expr* true_value() const { + ExprPtr true_value() const { return true_; } // Lazily evaluated only if condition is false - Expr* false_value() const { + ExprPtr false_value() const { return false_; } - void set_condition(Expr* condition) { + void set_condition(ExprPtr condition) { condition_ = condition; } - void set_true_value(Expr* true_value) { + void set_true_value(ExprPtr true_value) { true_ = true_value; } - void set_false_value(Expr* false_value) { + void set_false_value(ExprPtr false_value) { false_ = false_value; } @@ -517,16 +520,16 @@ class TORCH_API IfThenElse : public ExprNode { if (t.dtype() != f.dtype()) { throw malformed_input("Bad dtype in IfThenElse"); } - return ExprHandle(new IfThenElse(c.node(), t.node(), f.node())); + return ExprHandle(alloc(c.node(), t.node(), f.node())); } - IfThenElse(Expr* c, Expr* t, Expr* f) + IfThenElse(ExprPtr c, ExprPtr t, ExprPtr f) : ExprNodeBase(t->dtype()), condition_(c), true_(t), false_(f) {} private: - Expr* condition_; - Expr* true_; - Expr* false_; + ExprPtr condition_; + ExprPtr true_; + ExprPtr false_; }; class TORCH_API CompareSelect : public ExprNode { @@ -534,32 +537,32 @@ class TORCH_API CompareSelect : public ExprNode { CompareSelectOperation compare_select_op() const { return compare_op_; } - Expr* lhs() const { + ExprPtr lhs() const { return this->lhs_; } - Expr* rhs() const { + ExprPtr rhs() const { return this->rhs_; } - Expr* ret_val1() const { + ExprPtr ret_val1() const { return this->ret_val1_; } - Expr* ret_val2() const { + ExprPtr ret_val2() const { return this->ret_val2_; } - void set_lhs(Expr* lhs) { + void set_lhs(ExprPtr lhs) { lhs_ = lhs; } - void set_rhs(Expr* rhs) { + void set_rhs(ExprPtr rhs) { rhs_ = rhs; } - void set_ret_val1(Expr* ret_val1) { + void set_ret_val1(ExprPtr ret_val1) { ret_val1_ = ret_val1; } - void set_ret_val2(Expr* ret_val2) { + void set_ret_val2(ExprPtr ret_val2) { ret_val2_ = ret_val2; } @@ -575,7 +578,7 @@ class TORCH_API CompareSelect : public ExprNode { if (lhs.dtype() != rhs.dtype()) { throw malformed_input("bad dtype in CompareSelect"); } - return ExprHandle(new CompareSelect( + return ExprHandle(alloc( lhs.node(), rhs.node(), IntImm::make(1).node(), @@ -594,7 +597,7 @@ class TORCH_API CompareSelect : public ExprNode { if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) { throw malformed_input("bad dtype in CompareSelect"); } - return ExprHandle(new CompareSelect( + return ExprHandle(alloc( lhs.node(), rhs.node(), ret_val1.node(), @@ -604,10 +607,10 @@ class TORCH_API CompareSelect : public ExprNode { } CompareSelect( - Expr* lhs, - Expr* rhs, - Expr* ret_val1, - Expr* ret_val2, + ExprPtr lhs, + ExprPtr rhs, + ExprPtr ret_val1, + ExprPtr ret_val2, CompareSelectOperation cmp_op, CompareSelectBias bias = kUnbiased) : ExprNodeBase(ret_val1->dtype()), @@ -620,23 +623,23 @@ class TORCH_API CompareSelect : public ExprNode { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CompareSelect( - Expr* lhs, - Expr* rhs, + ExprPtr lhs, + ExprPtr rhs, CompareSelectOperation cmp_op, CompareSelectBias bias = kUnbiased) : ExprNodeBase(kInt), lhs_(lhs), rhs_(rhs), - ret_val1_(new IntImm(1)), - ret_val2_(new IntImm(0)), + ret_val1_(alloc(1)), + ret_val2_(alloc(0)), compare_op_(cmp_op), bias_(bias) {} private: - Expr* lhs_; - Expr* rhs_; - Expr* ret_val1_; - Expr* ret_val2_; + ExprPtr lhs_; + ExprPtr rhs_; + ExprPtr ret_val1_; + ExprPtr ret_val2_; CompareSelectOperation compare_op_; CompareSelectBias bias_; }; @@ -680,29 +683,29 @@ enum IntrinsicsOp { class TORCH_API Intrinsics : public ExprNode { public: static ExprHandle make(IntrinsicsOp op_type, const ExprHandle& v1) { - return ExprHandle(new Intrinsics(op_type, v1.node())); + return ExprHandle(alloc(op_type, v1.node())); } static ExprHandle make( IntrinsicsOp op_type, const ExprHandle& v1, const ExprHandle& v2) { - return ExprHandle(new Intrinsics(op_type, v1.node(), v2.node())); + return ExprHandle(alloc(op_type, v1.node(), v2.node())); } static ExprHandle make( IntrinsicsOp op_type, const std::vector& params) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector params_nodes(params.size()); + std::vector params_nodes(params.size()); for (size_t i = 0; i < params.size(); i++) { params_nodes[i] = params[i].node(); } - return ExprHandle(new Intrinsics(op_type, params_nodes)); + return ExprHandle(alloc(op_type, params_nodes)); } static ExprHandle make(IntrinsicsOp op_type, Dtype dtype) { - return ExprHandle(new Intrinsics(op_type, dtype)); + return ExprHandle(alloc(op_type, dtype)); } IntrinsicsOp op_type() const { @@ -794,7 +797,7 @@ class TORCH_API Intrinsics : public ExprNode { } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Intrinsics(IntrinsicsOp op_type, Expr* v1) + Intrinsics(IntrinsicsOp op_type, ExprPtr v1) : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype())), params_({v1}), op_type_(op_type) { @@ -804,7 +807,7 @@ class TORCH_API Intrinsics : public ExprNode { } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Intrinsics(IntrinsicsOp op_type, Expr* v1, Expr* v2) + Intrinsics(IntrinsicsOp op_type, ExprPtr v1, ExprPtr v2) : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype(), v2->dtype())), params_({v1, v2}), op_type_(op_type) { @@ -814,7 +817,7 @@ class TORCH_API Intrinsics : public ExprNode { } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Intrinsics(IntrinsicsOp op_type, const std::vector& params) + Intrinsics(IntrinsicsOp op_type, const std::vector& params) : ExprNodeBase(IntrinsicsDtype(op_type, params)), params_(params), op_type_(op_type) { @@ -827,7 +830,7 @@ class TORCH_API Intrinsics : public ExprNode { Intrinsics( IntrinsicsOp op_type, Dtype dtype, - const std::vector& params) + const std::vector& params) : ExprNodeBase(IntrinsicsDtype(op_type, dtype)), params_(params), op_type_(op_type) { @@ -844,14 +847,14 @@ class TORCH_API Intrinsics : public ExprNode { return params_.size(); } - Expr* param(int index) const { + ExprPtr param(int index) const { return params_[index]; } - const std::vector& params() const { + const std::vector& params() const { return params_; } - void set_params(std::vector params) { + void set_params(std::vector params) { params_ = std::move(params); } @@ -861,9 +864,9 @@ class TORCH_API Intrinsics : public ExprNode { static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2); static Dtype IntrinsicsDtype( IntrinsicsOp op_type, - const std::vector& params); + const std::vector& params); - std::vector params_; + std::vector params_; IntrinsicsOp op_type_; }; @@ -872,17 +875,17 @@ class Term; class MaxTerm; class MinTerm; -TORCH_API std::vector ExprHandleVectorToExprVector( +TORCH_API std::vector ExprHandleVectorToExprVector( const std::vector&); TORCH_API std::vector ExprVectorToExprHandleVector( - const std::vector&); -TORCH_API std::vector VarHandleVectorToVarVector( + const std::vector&); +TORCH_API std::vector VarHandleVectorToVarVector( const std::vector&); TORCH_API std::vector VarVectorToVarHandleVector( - const std::vector&); -TORCH_API Expr* flatten_index( - const std::vector& dims, - const std::vector& indices); + const std::vector&); +TORCH_API ExprPtr flatten_index( + const std::vector& dims, + const std::vector& indices); } // namespace tensorexpr } // namespace jit diff --git a/torch/csrc/jit/tensorexpr/ir_cloner.cpp b/torch/csrc/jit/tensorexpr/ir_cloner.cpp index e6a8db7..f724f2c 100644 --- a/torch/csrc/jit/tensorexpr/ir_cloner.cpp +++ b/torch/csrc/jit/tensorexpr/ir_cloner.cpp @@ -11,97 +11,97 @@ namespace jit { namespace tensorexpr { template -static Expr* mutate_binary_op( - BinaryOpNode* v, +static ExprPtr mutate_binary_op( + NodePtr> v, IRCloner* cloner, bool option = false) { - Expr* lhs_new = v->lhs()->accept_mutator(cloner); - Expr* rhs_new = v->rhs()->accept_mutator(cloner); + ExprPtr lhs_new = v->lhs()->accept_mutator(cloner); + ExprPtr rhs_new = v->rhs()->accept_mutator(cloner); IRNodeType expr_type = v->expr_type(); switch (expr_type) { case IRNodeType::kAdd: - return new Add(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); case IRNodeType::kSub: - return new Sub(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); case IRNodeType::kMul: - return new Mul(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); case IRNodeType::kDiv: - return new Div(lhs_new, rhs_new); + return alloc
(lhs_new, rhs_new); case IRNodeType::kMod: - return new Mod(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); case IRNodeType::kMax: - return new Max(lhs_new, rhs_new, option); + return alloc(lhs_new, rhs_new, option); case IRNodeType::kMin: - return new Min(lhs_new, rhs_new, option); + return alloc(lhs_new, rhs_new, option); case IRNodeType::kAnd: - return new And(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); case IRNodeType::kOr: - return new Or(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); case IRNodeType::kXor: - return new Xor(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); case IRNodeType::kLshift: - return new Lshift(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); case IRNodeType::kRshift: - return new Rshift(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); default: throw unimplemented_lowering(v); } } -Expr* IRCloner::mutate(Add* v) { +ExprPtr IRCloner::mutate(AddPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Sub* v) { +ExprPtr IRCloner::mutate(SubPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Mul* v) { +ExprPtr IRCloner::mutate(MulPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Div* v) { +ExprPtr IRCloner::mutate(DivPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Mod* v) { +ExprPtr IRCloner::mutate(ModPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(And* v) { +ExprPtr IRCloner::mutate(AndPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Or* v) { +ExprPtr IRCloner::mutate(OrPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Xor* v) { +ExprPtr IRCloner::mutate(XorPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Lshift* v) { +ExprPtr IRCloner::mutate(LshiftPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Rshift* v) { +ExprPtr IRCloner::mutate(RshiftPtr v) { return mutate_binary_op(v, this); } -Expr* IRCloner::mutate(Max* v) { +ExprPtr IRCloner::mutate(MaxPtr v) { return mutate_binary_op(v, this, v->propagate_nans()); } -Expr* IRCloner::mutate(Min* v) { +ExprPtr IRCloner::mutate(MinPtr v) { return mutate_binary_op(v, this, v->propagate_nans()); } -Expr* IRCloner::mutate(CompareSelect* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); - Expr* retval1_new = v->ret_val1()->accept_mutator(this); - Expr* retval2_new = v->ret_val2()->accept_mutator(this); - return new CompareSelect( +ExprPtr IRCloner::mutate(CompareSelectPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); + ExprPtr retval1_new = v->ret_val1()->accept_mutator(this); + ExprPtr retval2_new = v->ret_val2()->accept_mutator(this); + return alloc( lhs_new, rhs_new, retval1_new, @@ -111,42 +111,42 @@ Expr* IRCloner::mutate(CompareSelect* v) { } // NOLINTNEXTLINE -#define IMM_MUTATE_DEFINE(_1, Name) \ - Expr* IRCloner::mutate(Name##Imm* v) { \ - return v; \ +#define IMM_MUTATE_DEFINE(_1, Name) \ + ExprPtr IRCloner::mutate(Name##ImmPtr v) { \ + return v; \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE); #undef IMM_MUTATE_DEFINE -Expr* IRCloner::mutate(Cast* v) { - Expr* src_value_new = v->src_value()->accept_mutator(this); - return new Cast(v->dtype(), src_value_new); +ExprPtr IRCloner::mutate(CastPtr v) { + ExprPtr src_value_new = v->src_value()->accept_mutator(this); + return alloc(v->dtype(), src_value_new); } -Expr* IRCloner::mutate(BitCast* v) { - Expr* src_value_new = v->src_value()->accept_mutator(this); - return new BitCast(v->dtype(), src_value_new); +ExprPtr IRCloner::mutate(BitCastPtr v) { + ExprPtr src_value_new = v->src_value()->accept_mutator(this); + return alloc(v->dtype(), src_value_new); } -Expr* IRCloner::mutate(Ramp* v) { - Expr* base_new = v->base()->accept_mutator(this); - Expr* stride_new = v->stride()->accept_mutator(this); - return new Ramp(base_new, stride_new, v->lanes()); +ExprPtr IRCloner::mutate(RampPtr v) { + ExprPtr base_new = v->base()->accept_mutator(this); + ExprPtr stride_new = v->stride()->accept_mutator(this); + return alloc(base_new, stride_new, v->lanes()); } -Expr* IRCloner::mutate(Load* v) { - std::vector indices_new; +ExprPtr IRCloner::mutate(LoadPtr v) { + std::vector indices_new; indices_new.reserve(v->indices().size()); - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { indices_new.push_back(ind->accept_mutator(this)); } - Buf* buf_new = dynamic_cast(v->buf()->accept_mutator(this)); - return new Load(v->dtype(), buf_new, indices_new); + BufPtr buf_new = to(v->buf()->accept_mutator(this)); + return alloc(v->dtype(), buf_new, indices_new); } // We do not clone Vars since the original IR and cloned IR are expected to // share the underlying variables. -Expr* IRCloner::mutate(Var* v) { +ExprPtr IRCloner::mutate(VarPtr v) { return v; } @@ -155,188 +155,190 @@ Expr* IRCloner::mutate(Var* v) { // initializers, this is the expected usage of clone at this point. // // TODO: Revisit this if Bufs need to be cloned as well. -Expr* IRCloner::mutate(Buf* v) { +ExprPtr IRCloner::mutate(BufPtr v) { return v; } -Expr* IRCloner::mutate(Broadcast* v) { +ExprPtr IRCloner::mutate(BroadcastPtr v) { int lanes = v->lanes(); - Expr* value_new = v->value()->accept_mutator(this); - return new Broadcast(value_new, lanes); + ExprPtr value_new = v->value()->accept_mutator(this); + return alloc(value_new, lanes); } -Expr* IRCloner::mutate(IfThenElse* v) { - Expr* condition_new = v->condition()->accept_mutator(this); - Expr* true_value_new = v->true_value()->accept_mutator(this); - Expr* false_value_new = v->false_value()->accept_mutator(this); +ExprPtr IRCloner::mutate(IfThenElsePtr v) { + ExprPtr condition_new = v->condition()->accept_mutator(this); + ExprPtr true_value_new = v->true_value()->accept_mutator(this); + ExprPtr false_value_new = v->false_value()->accept_mutator(this); - return new IfThenElse(condition_new, true_value_new, false_value_new); + return alloc(condition_new, true_value_new, false_value_new); } -Expr* IRCloner::mutate(Intrinsics* v) { - std::vector params_new; +ExprPtr IRCloner::mutate(IntrinsicsPtr v) { + std::vector params_new; params_new.reserve(v->nparams()); for (auto param : v->params()) { params_new.push_back(param->accept_mutator(this)); } - return new Intrinsics(v->op_type(), v->dtype(), params_new); + return alloc(v->op_type(), v->dtype(), params_new); } -Expr* IRCloner::mutate(Term* v) { - Expr* scalar_new = v->scalar()->accept_mutator(this); +ExprPtr IRCloner::mutate(TermPtr v) { + ExprPtr scalar_new = v->scalar()->accept_mutator(this); - std::vector variables_new; + std::vector variables_new; variables_new.reserve(v->variables().size()); - for (auto* t : v->variables()) { + for (auto t : v->variables()) { variables_new.push_back(t->accept_mutator(this)); } - return new Term(v->hasher(), scalar_new, variables_new); + return alloc(v->hasher(), scalar_new, variables_new); } -Expr* IRCloner::mutate(Polynomial* v) { - Expr* scalar_new = v->scalar()->accept_mutator(this); +ExprPtr IRCloner::mutate(PolynomialPtr v) { + ExprPtr scalar_new = v->scalar()->accept_mutator(this); - std::vector variables_new; + std::vector variables_new; variables_new.reserve(v->variables().size()); - for (auto* t : v->variables()) { - variables_new.push_back(static_cast(t->accept_mutator(this))); + for (auto t : v->variables()) { + variables_new.push_back(static_to(t->accept_mutator(this))); } - return new Polynomial(v->hasher(), scalar_new, variables_new); + return alloc(v->hasher(), scalar_new, variables_new); } -Expr* IRCloner::mutate(RoundOff* v) { - return new RoundOff( +ExprPtr IRCloner::mutate(RoundOffPtr v) { + return alloc( v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this)); } -Expr* IRCloner::mutate(MaxTerm* v) { - Expr* scalar_new = v->scalar() ? v->scalar()->accept_mutator(this) : nullptr; +ExprPtr IRCloner::mutate(MaxTermPtr v) { + ExprPtr scalar_new = + v->scalar() ? v->scalar()->accept_mutator(this) : nullptr; - std::vector variables_new; + std::vector variables_new; variables_new.reserve(v->variables().size()); - for (auto* t : v->variables()) { + for (auto t : v->variables()) { variables_new.push_back(t->accept_mutator(this)); } - return new MaxTerm( + return alloc( v->hasher(), scalar_new, v->propagate_nans(), variables_new); } -Expr* IRCloner::mutate(MinTerm* v) { - Expr* scalar_new = v->scalar() ? v->scalar()->accept_mutator(this) : nullptr; +ExprPtr IRCloner::mutate(MinTermPtr v) { + ExprPtr scalar_new = + v->scalar() ? v->scalar()->accept_mutator(this) : nullptr; - std::vector variables_new; + std::vector variables_new; variables_new.reserve(v->variables().size()); - for (auto* t : v->variables()) { + for (auto t : v->variables()) { variables_new.push_back(t->accept_mutator(this)); } - return new MinTerm( + return alloc( v->hasher(), scalar_new, v->propagate_nans(), variables_new); } -Expr* IRCloner::mutate(ReduceOp* v) { - Expr* body_new = v->body()->accept_mutator(this); +ExprPtr IRCloner::mutate(ReduceOpPtr v) { + ExprPtr body_new = v->body()->accept_mutator(this); - std::vector reduce_args_new; + std::vector reduce_args_new; reduce_args_new.reserve(v->reduce_args().size()); - for (auto* r : v->reduce_args()) { - reduce_args_new.push_back(static_cast(r->accept_mutator(this))); + for (auto r : v->reduce_args()) { + reduce_args_new.push_back(static_to(r->accept_mutator(this))); } - return new ReduceOp(body_new, reduce_args_new, v->reducer()); + return alloc(body_new, reduce_args_new, v->reducer()); } -Stmt* IRCloner::mutate(For* v) { +StmtPtr IRCloner::mutate(ForPtr v) { auto start_new = v->start()->accept_mutator(this); auto stop_new = v->stop()->accept_mutator(this); auto body_new = v->body()->accept_mutator(this); - return new For(v->var(), start_new, stop_new, body_new, v->loop_options()); + return alloc(v->var(), start_new, stop_new, body_new, v->loop_options()); } -Stmt* IRCloner::mutate(Block* v) { - std::vector stmts_new; +StmtPtr IRCloner::mutate(BlockPtr v) { + std::vector stmts_new; stmts_new.reserve(v->nstmts()); - for (Stmt* stmt : *v) { + for (StmtPtr stmt : *v) { stmts_new.push_back(stmt->accept_mutator(this)); } - return new Block(stmts_new); + return alloc(stmts_new); } -Stmt* IRCloner::mutate(Store* v) { - std::vector indices_new; +StmtPtr IRCloner::mutate(StorePtr v) { + std::vector indices_new; indices_new.reserve(v->indices().size()); for (auto ind : v->indices()) { indices_new.push_back(ind->accept_mutator(this)); } auto value_new = v->value()->accept_mutator(this); - Buf* buf_new = dynamic_cast(v->buf()->accept_mutator(this)); - return new Store(buf_new, indices_new, value_new); + BufPtr buf_new = to(v->buf()->accept_mutator(this)); + return alloc(buf_new, indices_new, value_new); } -Stmt* IRCloner::mutate(AtomicAdd* v) { - std::vector indices_new; +StmtPtr IRCloner::mutate(AtomicAddPtr v) { + std::vector indices_new; indices_new.reserve(v->indices().size()); for (auto ind : v->indices()) { indices_new.push_back(ind->accept_mutator(this)); } auto value_new = v->value()->accept_mutator(this); - Buf* buf_new = dynamic_cast(v->buf()->accept_mutator(this)); - return new AtomicAdd(buf_new, indices_new, value_new); + BufPtr buf_new = to(v->buf()->accept_mutator(this)); + return alloc(buf_new, indices_new, value_new); } -Stmt* IRCloner::mutate(Allocate* v) { - Buf* buf_new = dynamic_cast(v->buf()->accept_mutator(this)); - return new Allocate(buf_new); +StmtPtr IRCloner::mutate(AllocatePtr v) { + BufPtr buf_new = to(v->buf()->accept_mutator(this)); + return alloc(buf_new); } -Stmt* IRCloner::mutate(Free* v) { - Buf* buf_new = dynamic_cast(v->buf()->accept_mutator(this)); - return new Free(buf_new); +StmtPtr IRCloner::mutate(FreePtr v) { + BufPtr buf_new = to(v->buf()->accept_mutator(this)); + return alloc(buf_new); } -Stmt* IRCloner::mutate(SyncThreads* v) { - return new SyncThreads(); +StmtPtr IRCloner::mutate(SyncThreadsPtr v) { + return alloc(); } -Stmt* IRCloner::mutate(ExternalCall* v) { - Buf* buf_new = dynamic_cast(v->buf()->accept_mutator(this)); +StmtPtr IRCloner::mutate(ExternalCallPtr v) { + BufPtr buf_new = to(v->buf()->accept_mutator(this)); - std::vector buf_args_new; + std::vector buf_args_new; buf_args_new.reserve(v->buf_args().size()); - for (Buf* buf_arg : v->buf_args()) { - buf_args_new.push_back(dynamic_cast(buf_arg->accept_mutator(this))); + for (BufPtr buf_arg : v->buf_args()) { + buf_args_new.push_back(to(buf_arg->accept_mutator(this))); } - std::vector args_new; + std::vector args_new; args_new.reserve(v->args().size()); - for (Expr* arg : v->args()) { + for (ExprPtr arg : v->args()) { args_new.push_back(arg->accept_mutator(this)); } - return new ExternalCall(buf_new, v->func_name(), buf_args_new, args_new); + return alloc(buf_new, v->func_name(), buf_args_new, args_new); } -Stmt* IRCloner::mutate(Let* v) { +StmtPtr IRCloner::mutate(LetPtr v) { auto value_new = v->value()->accept_mutator(this); - return new Let(v->var(), value_new); + return alloc(v->var(), value_new); } -Stmt* IRCloner::mutate(Cond* v) { +StmtPtr IRCloner::mutate(CondPtr v) { auto condition_new = v->condition()->accept_mutator(this); - Stmt* true_old = v->true_stmt(); - Stmt* false_old = v->false_stmt(); - Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old; - Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old; - return new Cond(condition_new, true_new, false_new); + StmtPtr true_old = v->true_stmt(); + StmtPtr false_old = v->false_stmt(); + StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old; + StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old; + return alloc(condition_new, true_new, false_new); } -Stmt* Stmt::clone(Stmt* s) { +StmtPtr Stmt::clone(StmtPtr s) { IRCloner cloner; - Stmt* cloned = s->accept_mutator(&cloner); + StmtPtr cloned = s->accept_mutator(&cloner); set_parent(cloned, nullptr); return cloned; } -Expr* Expr::clone(Expr* e) { +ExprPtr Expr::clone(ExprPtr e) { IRCloner cloner; return e->accept_mutator(&cloner); } diff --git a/torch/csrc/jit/tensorexpr/ir_cloner.h b/torch/csrc/jit/tensorexpr/ir_cloner.h index 2f25198..f03e128 100644 --- a/torch/csrc/jit/tensorexpr/ir_cloner.h +++ b/torch/csrc/jit/tensorexpr/ir_cloner.h @@ -9,101 +9,54 @@ namespace torch { namespace jit { namespace tensorexpr { -class Add; -class Sub; -class Mul; -class Div; -class Mod; -class Max; -class Min; -class And; -class Or; -class Xor; -class Lshift; -class Rshift; -class CompareSelect; - -#define IMM_DECLARE(Type, Name) class Name##Imm; -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); -#undef IMM_DECLARE - -class Cast; -class BitCast; -class Var; -class Buf; -class Ramp; -class Load; -class For; -class Block; -class Store; -class Broadcast; -class IfThenElse; -class ExprHandle; -class Expr; -class Intrinsics; -class Allocate; -class Free; -class Let; -class Cond; -class Stmt; -class Term; -class Polynomial; -class RoundOff; -class MaxTerm; -class MinTerm; -class ReduceOp; -class AtomicAdd; -class SyncThreads; -class ExternalCall; - class TORCH_API IRCloner : public IRMutator { public: ~IRCloner() override = default; - Expr* mutate(Add* v) override; - Expr* mutate(Sub* v) override; - Expr* mutate(Mul* v) override; - Expr* mutate(Div* v) override; - Expr* mutate(Mod* v) override; - Expr* mutate(Max* v) override; - Expr* mutate(Min* v) override; - Expr* mutate(And* v) override; - Expr* mutate(Or* v) override; - Expr* mutate(Xor* v) override; - Expr* mutate(Lshift* v) override; - Expr* mutate(Rshift* v) override; - Expr* mutate(CompareSelect* v) override; -#define IMM_MUTATE_DECLARE(Type, Name) Expr* mutate(Name##Imm* v) override; + ExprPtr mutate(AddPtr v) override; + ExprPtr mutate(SubPtr v) override; + ExprPtr mutate(MulPtr v) override; + ExprPtr mutate(DivPtr v) override; + ExprPtr mutate(ModPtr v) override; + ExprPtr mutate(MaxPtr v) override; + ExprPtr mutate(MinPtr v) override; + ExprPtr mutate(AndPtr v) override; + ExprPtr mutate(OrPtr v) override; + ExprPtr mutate(XorPtr v) override; + ExprPtr mutate(LshiftPtr v) override; + ExprPtr mutate(RshiftPtr v) override; + ExprPtr mutate(CompareSelectPtr v) override; +#define IMM_MUTATE_DECLARE(Type, Name) ExprPtr mutate(Name##ImmPtr v) override; AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE); #undef IMM_MUTATE_DECLARE - Expr* mutate(Cast* v) override; - Expr* mutate(BitCast* v) override; - Expr* mutate(Var* v) override; - Expr* mutate(Buf* v) override; - Expr* mutate(Ramp* v) override; - Expr* mutate(Load* v) override; - Expr* mutate(Broadcast* v) override; - Expr* mutate(IfThenElse* v) override; - Expr* mutate(Intrinsics* v) override; + ExprPtr mutate(CastPtr v) override; + ExprPtr mutate(BitCastPtr v) override; + ExprPtr mutate(VarPtr v) override; + ExprPtr mutate(BufPtr v) override; + ExprPtr mutate(RampPtr v) override; + ExprPtr mutate(LoadPtr v) override; + ExprPtr mutate(BroadcastPtr v) override; + ExprPtr mutate(IfThenElsePtr v) override; + ExprPtr mutate(IntrinsicsPtr v) override; - Expr* mutate(Term* v) override; - Expr* mutate(Polynomial* v) override; - Expr* mutate(RoundOff* v) override; - Expr* mutate(MaxTerm* v) override; - Expr* mutate(MinTerm* v) override; + ExprPtr mutate(TermPtr v) override; + ExprPtr mutate(PolynomialPtr v) override; + ExprPtr mutate(RoundOffPtr v) override; + ExprPtr mutate(MaxTermPtr v) override; + ExprPtr mutate(MinTermPtr v) override; - Expr* mutate(ReduceOp* v) override; + ExprPtr mutate(ReduceOpPtr v) override; - Stmt* mutate(For* v) override; - Stmt* mutate(Block* v) override; - Stmt* mutate(Store* v) override; - Stmt* mutate(AtomicAdd* v) override; - Stmt* mutate(SyncThreads* v) override; - Stmt* mutate(ExternalCall* v) override; + StmtPtr mutate(ForPtr v) override; + StmtPtr mutate(BlockPtr v) override; + StmtPtr mutate(StorePtr v) override; + StmtPtr mutate(AtomicAddPtr v) override; + StmtPtr mutate(SyncThreadsPtr v) override; + StmtPtr mutate(ExternalCallPtr v) override; - Stmt* mutate(Allocate* v) override; - Stmt* mutate(Free* v) override; - Stmt* mutate(Let* v) override; - Stmt* mutate(Cond* v) override; + StmtPtr mutate(AllocatePtr v) override; + StmtPtr mutate(FreePtr v) override; + StmtPtr mutate(LetPtr v) override; + StmtPtr mutate(CondPtr v) override; }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index 8ebff4f..96635ac 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -12,14 +12,14 @@ namespace jit { namespace tensorexpr { template -static Expr* mutate_binary_op( +static ExprPtr mutate_binary_op( BinaryOpNode* v, IRMutator* mutator, bool option = false) { - Expr* lhs = v->lhs(); - Expr* rhs = v->rhs(); - Expr* lhs_new = lhs->accept_mutator(mutator); - Expr* rhs_new = rhs->accept_mutator(mutator); + ExprPtr lhs = v->lhs(); + ExprPtr rhs = v->rhs(); + ExprPtr lhs_new = lhs->accept_mutator(mutator); + ExprPtr rhs_new = rhs->accept_mutator(mutator); if (lhs != lhs_new) { v->set_lhs(lhs_new); } @@ -34,63 +34,63 @@ static Expr* mutate_binary_op( return v; } -Expr* IRMutator::mutate(Add* v) { +ExprPtr IRMutator::mutate(AddPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Sub* v) { +ExprPtr IRMutator::mutate(SubPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Mul* v) { +ExprPtr IRMutator::mutate(MulPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Div* v) { +ExprPtr IRMutator::mutate(DivPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Mod* v) { +ExprPtr IRMutator::mutate(ModPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(And* v) { +ExprPtr IRMutator::mutate(AndPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Or* v) { +ExprPtr IRMutator::mutate(OrPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Xor* v) { +ExprPtr IRMutator::mutate(XorPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Lshift* v) { +ExprPtr IRMutator::mutate(LshiftPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Rshift* v) { +ExprPtr IRMutator::mutate(RshiftPtr v) { return mutate_binary_op(v, this); } -Expr* IRMutator::mutate(Max* v) { +ExprPtr IRMutator::mutate(MaxPtr v) { return mutate_binary_op(v, this, v->propagate_nans()); } -Expr* IRMutator::mutate(Min* v) { +ExprPtr IRMutator::mutate(MinPtr v) { return mutate_binary_op(v, this, v->propagate_nans()); } -Expr* IRMutator::mutate(CompareSelect* v) { - Expr* lhs = v->lhs(); - Expr* rhs = v->rhs(); - Expr* ret_val1 = v->ret_val1(); - Expr* ret_val2 = v->ret_val2(); - Expr* lhs_new = lhs->accept_mutator(this); - Expr* rhs_new = rhs->accept_mutator(this); - Expr* ret_val1_new = ret_val1->accept_mutator(this); - Expr* ret_val2_new = ret_val2->accept_mutator(this); +ExprPtr IRMutator::mutate(CompareSelectPtr v) { + ExprPtr lhs = v->lhs(); + ExprPtr rhs = v->rhs(); + ExprPtr ret_val1 = v->ret_val1(); + ExprPtr ret_val2 = v->ret_val2(); + ExprPtr lhs_new = lhs->accept_mutator(this); + ExprPtr rhs_new = rhs->accept_mutator(this); + ExprPtr ret_val1_new = ret_val1->accept_mutator(this); + ExprPtr ret_val2_new = ret_val2->accept_mutator(this); if (lhs != lhs_new) { v->set_lhs(lhs_new); } @@ -107,40 +107,40 @@ Expr* IRMutator::mutate(CompareSelect* v) { } // NOLINTNEXTLINE -#define IMM_MUTATE_DEFINE(_1, Name) \ - Expr* IRMutator::mutate(Name##Imm* v) { \ - return v; \ +#define IMM_MUTATE_DEFINE(_1, Name) \ + ExprPtr IRMutator::mutate(Name##ImmPtr v) { \ + return v; \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE); #undef IMM_MUTATE_DEFINE -Expr* IRMutator::mutate(Cast* v) { - Expr* src_value = v->src_value(); - Expr* src_value_new = src_value->accept_mutator(this); +ExprPtr IRMutator::mutate(CastPtr v) { + ExprPtr src_value = v->src_value(); + ExprPtr src_value_new = src_value->accept_mutator(this); if (src_value != src_value_new) { v->set_src_value(src_value_new); } return v; } -Expr* IRMutator::mutate(BitCast* v) { - Expr* src_value = v->src_value(); - Expr* src_value_new = src_value->accept_mutator(this); +ExprPtr IRMutator::mutate(BitCastPtr v) { + ExprPtr src_value = v->src_value(); + ExprPtr src_value_new = src_value->accept_mutator(this); if (src_value != src_value_new) { v->set_src_value(src_value_new); } return v; } -Expr* IRMutator::mutate(Var* v) { +ExprPtr IRMutator::mutate(VarPtr v) { return v; } -Expr* IRMutator::mutate(Ramp* v) { - Expr* base = v->base(); - Expr* stride = v->stride(); - Expr* base_new = base->accept_mutator(this); - Expr* stride_new = stride->accept_mutator(this); +ExprPtr IRMutator::mutate(RampPtr v) { + ExprPtr base = v->base(); + ExprPtr stride = v->stride(); + ExprPtr base_new = base->accept_mutator(this); + ExprPtr stride_new = stride->accept_mutator(this); if (base != base_new) { v->set_base(base_new); } @@ -150,20 +150,20 @@ Expr* IRMutator::mutate(Ramp* v) { return v; } -Expr* IRMutator::mutate(Load* v) { - Buf* buf = v->buf(); +ExprPtr IRMutator::mutate(LoadPtr v) { + BufPtr buf = v->buf(); bool any_index_changed = false; - std::vector indices_new; + std::vector indices_new; indices_new.reserve(v->indices().size()); - for (Expr* ind : v->indices()) { - Expr* new_ind = ind->accept_mutator(this); + for (ExprPtr ind : v->indices()) { + ExprPtr new_ind = ind->accept_mutator(this); if (new_ind != ind) { any_index_changed = true; } indices_new.push_back(new_ind); } - Buf* buf_new = dynamic_cast(buf->accept_mutator(this)); + BufPtr buf_new = to(buf->accept_mutator(this)); if (buf != buf_new) { v->set_buf(buf_new); @@ -174,18 +174,16 @@ Expr* IRMutator::mutate(Load* v) { return v; } -Expr* IRMutator::mutate(Buf* v) { - Var* var = v->base_handle(); - Var* var_new = - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - dynamic_cast(const_cast(var->accept_mutator(this))); +ExprPtr IRMutator::mutate(BufPtr v) { + VarPtr var = v->base_handle(); + VarPtr var_new = to(var->accept_mutator(this)); if (!var_new) { return nullptr; } bool dims_changed = false; - std::vector dims_old = v->dims(); - std::vector dims_new(dims_old.size()); + std::vector dims_old = v->dims(); + std::vector dims_new(dims_old.size()); for (const auto i : c10::irange(dims_old.size())) { dims_new[i] = dims_old[i]->accept_mutator(this); dims_changed |= (dims_new[i] != dims_old[i]); @@ -201,22 +199,22 @@ Expr* IRMutator::mutate(Buf* v) { return v; } -Expr* IRMutator::mutate(Broadcast* v) { - Expr* value = v->value(); - Expr* value_new = value->accept_mutator(this); +ExprPtr IRMutator::mutate(BroadcastPtr v) { + ExprPtr value = v->value(); + ExprPtr value_new = value->accept_mutator(this); if (value != value_new) { v->set_value(value_new); } return v; } -Expr* IRMutator::mutate(IfThenElse* v) { - Expr* condition = v->condition(); - Expr* true_value = v->true_value(); - Expr* false_value = v->false_value(); - Expr* condition_new = condition->accept_mutator(this); - Expr* true_value_new = true_value->accept_mutator(this); - Expr* false_value_new = false_value->accept_mutator(this); +ExprPtr IRMutator::mutate(IfThenElsePtr v) { + ExprPtr condition = v->condition(); + ExprPtr true_value = v->true_value(); + ExprPtr false_value = v->false_value(); + ExprPtr condition_new = condition->accept_mutator(this); + ExprPtr true_value_new = true_value->accept_mutator(this); + ExprPtr false_value_new = false_value->accept_mutator(this); if (condition != condition_new) { v->set_condition(condition_new); @@ -230,12 +228,12 @@ Expr* IRMutator::mutate(IfThenElse* v) { return v; } -Expr* IRMutator::mutate(Intrinsics* v) { - std::vector params(v->nparams()); +ExprPtr IRMutator::mutate(IntrinsicsPtr v) { + std::vector params(v->nparams()); bool any_change = false; for (int i = 0; i < v->nparams(); i++) { - Expr* value = v->param(i); - Expr* value_new = value->accept_mutator(this); + ExprPtr value = v->param(i); + ExprPtr value_new = value->accept_mutator(this); if (value != value_new) { any_change = true; } @@ -247,79 +245,79 @@ Expr* IRMutator::mutate(Intrinsics* v) { return v; } -Expr* IRMutator::mutate(Term* v) { - Expr* newScalar = v->scalar()->accept_mutator(this); +ExprPtr IRMutator::mutate(TermPtr v) { + ExprPtr newScalar = v->scalar()->accept_mutator(this); - std::vector variables; - for (auto* t : v->variables()) { + std::vector variables; + for (auto t : v->variables()) { variables.push_back(t->accept_mutator(this)); } - return new Term(v->hasher(), newScalar, variables); + return alloc(v->hasher(), newScalar, variables); } -Expr* IRMutator::mutate(Polynomial* v) { - Expr* newScalar = v->scalar()->accept_mutator(this); +ExprPtr IRMutator::mutate(PolynomialPtr v) { + ExprPtr newScalar = v->scalar()->accept_mutator(this); - std::vector variables; - for (auto* t : v->variables()) { - variables.push_back(static_cast(t->accept_mutator(this))); + std::vector variables; + for (auto t : v->variables()) { + variables.push_back(static_to(t->accept_mutator(this))); } - return new Polynomial(v->hasher(), newScalar, variables); + return alloc(v->hasher(), newScalar, variables); } -Expr* IRMutator::mutate(RoundOff* v) { - return new RoundOff( +ExprPtr IRMutator::mutate(RoundOffPtr v) { + return alloc( v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this)); } -Expr* IRMutator::mutate(MaxTerm* v) { - Expr* newScalar = nullptr; +ExprPtr IRMutator::mutate(MaxTermPtr v) { + ExprPtr newScalar = nullptr; if (v->scalar()) { newScalar = v->scalar()->accept_mutator(this); } - std::vector variables; - for (auto* t : v->variables()) { + std::vector variables; + for (auto t : v->variables()) { variables.push_back(t->accept_mutator(this)); } - return new MaxTerm(v->hasher(), newScalar, v->propagate_nans(), variables); + return alloc(v->hasher(), newScalar, v->propagate_nans(), variables); } -Expr* IRMutator::mutate(MinTerm* v) { - Expr* newScalar = nullptr; +ExprPtr IRMutator::mutate(MinTermPtr v) { + ExprPtr newScalar = nullptr; if (v->scalar()) { newScalar = v->scalar()->accept_mutator(this); } - std::vector variables; - for (auto* t : v->variables()) { + std::vector variables; + for (auto t : v->variables()) { variables.push_back(t->accept_mutator(this)); } - return new MinTerm(v->hasher(), newScalar, v->propagate_nans(), variables); + return alloc(v->hasher(), newScalar, v->propagate_nans(), variables); } -Expr* IRMutator::mutate(ReduceOp* v) { - Expr* body_new = v->body()->accept_mutator(this); +ExprPtr IRMutator::mutate(ReduceOpPtr v) { + ExprPtr body_new = v->body()->accept_mutator(this); - std::vector new_reduce_args; - for (auto* r : v->reduce_args()) { - new_reduce_args.push_back(static_cast(r->accept_mutator(this))); + std::vector new_reduce_args; + for (auto r : v->reduce_args()) { + new_reduce_args.push_back(static_to(r->accept_mutator(this))); } - return new ReduceOp(body_new, new_reduce_args, v->reducer()); + return alloc(body_new, new_reduce_args, v->reducer()); } -Stmt* IRMutator::mutate(For* v) { - Expr* var = v->var(); - Expr* start = v->start(); - Expr* stop = v->stop(); - Stmt* body = v->body(); +StmtPtr IRMutator::mutate(ForPtr v) { + ExprPtr var = v->var(); + ExprPtr start = v->start(); + ExprPtr stop = v->stop(); + StmtPtr body = v->body(); LoopOptions loop_options = v->loop_options(); - Expr* var_new_expr = var->accept_mutator(this); - Var* var_new = dynamic_cast(var_new_expr); - Expr* start_new = start->accept_mutator(this); - Expr* stop_new = stop->accept_mutator(this); - Stmt* body_new = body->accept_mutator(this); + ExprPtr var_new_expr = var->accept_mutator(this); + VarPtr var_new = to(var_new_expr); + ExprPtr start_new = start->accept_mutator(this); + ExprPtr stop_new = stop->accept_mutator(this); + StmtPtr body_new = body->accept_mutator(this); if (!body_new) { return nullptr; } @@ -338,12 +336,12 @@ Stmt* IRMutator::mutate(For* v) { return v; } -Stmt* IRMutator::mutate(Block* v) { +StmtPtr IRMutator::mutate(BlockPtr v) { bool any_change = false; - std::vector stmts; - for (Stmt* stmt : *v) { - Stmt* stmt_new = stmt->accept_mutator(this); + std::vector stmts; + for (StmtPtr stmt : *v) { + StmtPtr stmt_new = stmt->accept_mutator(this); if (stmt != stmt_new) { any_change = true; } else { @@ -359,21 +357,21 @@ Stmt* IRMutator::mutate(Block* v) { return v; } -Stmt* IRMutator::mutate(Store* v) { - Buf* buf = v->buf(); +StmtPtr IRMutator::mutate(StorePtr v) { + BufPtr buf = v->buf(); bool any_index_changed = false; - std::vector indices_new; - for (Expr* ind : v->indices()) { - Expr* new_ind = ind->accept_mutator(this); + std::vector indices_new; + for (ExprPtr ind : v->indices()) { + ExprPtr new_ind = ind->accept_mutator(this); if (new_ind != ind) { any_index_changed = true; } indices_new.push_back(new_ind); } - Expr* value = v->value(); - Buf* buf_new = dynamic_cast(buf->accept_mutator(this)); - Expr* value_new = value->accept_mutator(this); + ExprPtr value = v->value(); + BufPtr buf_new = to(buf->accept_mutator(this)); + ExprPtr value_new = value->accept_mutator(this); if (buf != buf_new) { v->set_buf(buf_new); @@ -387,21 +385,21 @@ Stmt* IRMutator::mutate(Store* v) { return v; } -Stmt* IRMutator::mutate(AtomicAdd* v) { - Buf* buf = v->buf(); +StmtPtr IRMutator::mutate(AtomicAddPtr v) { + BufPtr buf = v->buf(); bool any_index_changed = false; - std::vector indices_new; - for (Expr* ind : v->indices()) { - Expr* new_ind = ind->accept_mutator(this); + std::vector indices_new; + for (ExprPtr ind : v->indices()) { + ExprPtr new_ind = ind->accept_mutator(this); if (new_ind != ind) { any_index_changed = true; } indices_new.push_back(new_ind); } - Expr* value = v->value(); - Buf* buf_new = dynamic_cast(buf->accept_mutator(this)); - Expr* value_new = value->accept_mutator(this); + ExprPtr value = v->value(); + BufPtr buf_new = to(buf->accept_mutator(this)); + ExprPtr value_new = value->accept_mutator(this); if (buf != buf_new) { v->set_buf(buf_new); @@ -415,30 +413,30 @@ Stmt* IRMutator::mutate(AtomicAdd* v) { return v; } -Stmt* IRMutator::mutate(SyncThreads* v) { - return new SyncThreads(); +StmtPtr IRMutator::mutate(SyncThreadsPtr v) { + return alloc(); } -Stmt* IRMutator::mutate(ExternalCall* v) { - Buf* buf = v->buf(); - Buf* buf_new = dynamic_cast(buf->accept_mutator(this)); +StmtPtr IRMutator::mutate(ExternalCallPtr v) { + BufPtr buf = v->buf(); + BufPtr buf_new = to(buf->accept_mutator(this)); TORCH_INTERNAL_ASSERT(buf_new); bool buf_args_changed = false; - std::vector buf_args_new; + std::vector buf_args_new; buf_args_new.reserve(v->buf_args().size()); - for (Buf* buf_arg : v->buf_args()) { - Buf* buf_arg_new = dynamic_cast(buf_arg->accept_mutator(this)); + for (BufPtr buf_arg : v->buf_args()) { + BufPtr buf_arg_new = to(buf_arg->accept_mutator(this)); TORCH_INTERNAL_ASSERT(buf_arg_new); buf_args_new.push_back(buf_arg_new); buf_args_changed |= buf_arg_new != buf_arg; } bool args_changed = false; - std::vector args_new; + std::vector args_new; args_new.reserve(v->args().size()); - for (Expr* arg : v->args()) { - Expr* arg_new = arg->accept_mutator(this); + for (ExprPtr arg : v->args()) { + ExprPtr arg_new = arg->accept_mutator(this); args_new.push_back(arg_new); args_changed |= arg_new != arg; } @@ -455,9 +453,9 @@ Stmt* IRMutator::mutate(ExternalCall* v) { return v; } -Stmt* IRMutator::mutate(Allocate* v) { - Buf* buf = v->buf(); - Buf* buf_new = dynamic_cast(buf->accept_mutator(this)); +StmtPtr IRMutator::mutate(AllocatePtr v) { + BufPtr buf = v->buf(); + BufPtr buf_new = to(buf->accept_mutator(this)); TORCH_INTERNAL_ASSERT(buf_new); if (buf != buf_new) { v->set_buf(buf_new); @@ -465,9 +463,9 @@ Stmt* IRMutator::mutate(Allocate* v) { return v; } -Stmt* IRMutator::mutate(Free* v) { - Buf* buf = v->buf(); - Buf* buf_new = dynamic_cast(buf->accept_mutator(this)); +StmtPtr IRMutator::mutate(FreePtr v) { + BufPtr buf = v->buf(); + BufPtr buf_new = to(buf->accept_mutator(this)); TORCH_INTERNAL_ASSERT(buf_new); if (buf != buf_new) { v->set_buf(buf_new); @@ -475,12 +473,12 @@ Stmt* IRMutator::mutate(Free* v) { return v; } -Stmt* IRMutator::mutate(Let* v) { - Var* var_old = v->var(); - Var* var_new = dynamic_cast(var_old->accept_mutator(this)); +StmtPtr IRMutator::mutate(LetPtr v) { + VarPtr var_old = v->var(); + VarPtr var_new = to(var_old->accept_mutator(this)); - Expr* val_old = v->value(); - Expr* val_new = val_old->accept_mutator(this); + ExprPtr val_old = v->value(); + ExprPtr val_new = val_old->accept_mutator(this); if (var_old != var_new) { v->set_var(var_new); @@ -491,14 +489,14 @@ Stmt* IRMutator::mutate(Let* v) { return v; } -Stmt* IRMutator::mutate(Cond* v) { - Expr* cond_old = v->condition(); - Stmt* true_old = v->true_stmt(); - Stmt* false_old = v->false_stmt(); +StmtPtr IRMutator::mutate(CondPtr v) { + ExprPtr cond_old = v->condition(); + StmtPtr true_old = v->true_stmt(); + StmtPtr false_old = v->false_stmt(); - Expr* cond_new = cond_old->accept_mutator(this); - Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old; - Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old; + ExprPtr cond_new = cond_old->accept_mutator(this); + StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old; + StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old; if (cond_old != cond_new) { v->set_condition(cond_new); diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h index 42168fe..fb6c420 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.h +++ b/torch/csrc/jit/tensorexpr/ir_mutator.h @@ -1,107 +1,61 @@ #pragma once #include #include +#include #include namespace torch { namespace jit { namespace tensorexpr { -class Add; -class Sub; -class Mul; -class Div; -class Mod; -class Max; -class Min; -class And; -class Or; -class Xor; -class Lshift; -class Rshift; -class CompareSelect; - -#define IMM_DECLARE(Type, Name) class Name##Imm; -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); -#undef IMM_DECLARE - -class Cast; -class BitCast; -class Var; -class Buf; -class Ramp; -class Load; -class For; -class Block; -class Store; -class Broadcast; -class IfThenElse; -class ExprHandle; -class Expr; -class Intrinsics; -class Allocate; -class Free; -class Let; -class Cond; -class Stmt; -class Term; -class Polynomial; -class RoundOff; -class MaxTerm; -class MinTerm; -class ReduceOp; -class AtomicAdd; -class SyncThreads; -class ExternalCall; - class TORCH_API IRMutator { public: virtual ~IRMutator() = default; - virtual Expr* mutate(Add* v); - virtual Expr* mutate(Sub* v); - virtual Expr* mutate(Mul* v); - virtual Expr* mutate(Div* v); - virtual Expr* mutate(Mod* v); - virtual Expr* mutate(Max* v); - virtual Expr* mutate(Min* v); - virtual Expr* mutate(And* v); - virtual Expr* mutate(Or* v); - virtual Expr* mutate(Xor* v); - virtual Expr* mutate(Lshift* v); - virtual Expr* mutate(Rshift* v); - virtual Expr* mutate(CompareSelect* v); -#define IMM_MUTATE_DECLARE(Type, Name) virtual Expr* mutate(Name##Imm* v); + virtual ExprPtr mutate(AddPtr v); + virtual ExprPtr mutate(SubPtr v); + virtual ExprPtr mutate(MulPtr v); + virtual ExprPtr mutate(DivPtr v); + virtual ExprPtr mutate(ModPtr v); + virtual ExprPtr mutate(MaxPtr v); + virtual ExprPtr mutate(MinPtr v); + virtual ExprPtr mutate(AndPtr v); + virtual ExprPtr mutate(OrPtr v); + virtual ExprPtr mutate(XorPtr v); + virtual ExprPtr mutate(LshiftPtr v); + virtual ExprPtr mutate(RshiftPtr v); + virtual ExprPtr mutate(CompareSelectPtr v); +#define IMM_MUTATE_DECLARE(Type, Name) virtual ExprPtr mutate(Name##ImmPtr v); AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE); #undef IMM_MUTATE_DECLARE - virtual Expr* mutate(Cast* v); - virtual Expr* mutate(BitCast* v); - virtual Expr* mutate(Var* v); - virtual Expr* mutate(Buf* v); - virtual Expr* mutate(Ramp* v); - virtual Expr* mutate(Load* v); - virtual Expr* mutate(Broadcast* v); - virtual Expr* mutate(IfThenElse* v); - virtual Expr* mutate(Intrinsics* v); - - virtual Expr* mutate(Term* v); - virtual Expr* mutate(Polynomial* v); - virtual Expr* mutate(RoundOff* v); - virtual Expr* mutate(MaxTerm* v); - virtual Expr* mutate(MinTerm* v); - - virtual Expr* mutate(ReduceOp* v); - - virtual Stmt* mutate(For* v); - virtual Stmt* mutate(Block* v); - virtual Stmt* mutate(Store* v); - virtual Stmt* mutate(AtomicAdd* v); - virtual Stmt* mutate(SyncThreads* v); - virtual Stmt* mutate(ExternalCall* v); - - virtual Stmt* mutate(Allocate* v); - virtual Stmt* mutate(Free* v); - virtual Stmt* mutate(Let* v); - virtual Stmt* mutate(Cond* v); + virtual ExprPtr mutate(CastPtr v); + virtual ExprPtr mutate(BitCastPtr v); + virtual ExprPtr mutate(VarPtr v); + virtual ExprPtr mutate(BufPtr v); + virtual ExprPtr mutate(RampPtr v); + virtual ExprPtr mutate(LoadPtr v); + virtual ExprPtr mutate(BroadcastPtr v); + virtual ExprPtr mutate(IfThenElsePtr v); + virtual ExprPtr mutate(IntrinsicsPtr v); + + virtual ExprPtr mutate(TermPtr v); + virtual ExprPtr mutate(PolynomialPtr v); + virtual ExprPtr mutate(RoundOffPtr v); + virtual ExprPtr mutate(MaxTermPtr v); + virtual ExprPtr mutate(MinTermPtr v); + + virtual ExprPtr mutate(ReduceOpPtr v); + + virtual StmtPtr mutate(ForPtr v); + virtual StmtPtr mutate(BlockPtr v); + virtual StmtPtr mutate(StorePtr v); + virtual StmtPtr mutate(AtomicAddPtr v); + virtual StmtPtr mutate(SyncThreadsPtr v); + virtual StmtPtr mutate(ExternalCallPtr v); + + virtual StmtPtr mutate(AllocatePtr v); + virtual StmtPtr mutate(FreePtr v); + virtual StmtPtr mutate(LetPtr v); + virtual StmtPtr mutate(CondPtr v); }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 1db9368..23466f3 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -58,43 +58,43 @@ void visitBinaryOp( } } -void IRPrinter::visit(Add* v) { +void IRPrinter::visit(AddPtr v) { visitBinaryOp(v, "+", this); } -void IRPrinter::visit(Sub* v) { +void IRPrinter::visit(SubPtr v) { visitBinaryOp(v, "-", this); } -void IRPrinter::visit(Mul* v) { +void IRPrinter::visit(MulPtr v) { visitBinaryOp(v, "*", this); } -void IRPrinter::visit(Div* v) { +void IRPrinter::visit(DivPtr v) { visitBinaryOp(v, "/", this); } -void IRPrinter::visit(And* v) { +void IRPrinter::visit(AndPtr v) { visitBinaryOp(v, "&", this); } -void IRPrinter::visit(Or* v) { +void IRPrinter::visit(OrPtr v) { visitBinaryOp(v, "|", this); } -void IRPrinter::visit(Xor* v) { +void IRPrinter::visit(XorPtr v) { visitBinaryOp(v, "^", this); } -void IRPrinter::visit(Lshift* v) { +void IRPrinter::visit(LshiftPtr v) { visitBinaryOp(v, "<<", this); } -void IRPrinter::visit(Rshift* v) { +void IRPrinter::visit(RshiftPtr v) { visitBinaryOp(v, ">>", this); } -void IRPrinter::visit(Mod* v) { +void IRPrinter::visit(ModPtr v) { if (v->dtype().is_integral()) { visitBinaryOp(v, "%", this); } else if (v->dtype().is_floating_point()) { @@ -104,7 +104,7 @@ void IRPrinter::visit(Mod* v) { } } -void IRPrinter::visit(Max* v) { +void IRPrinter::visit(MaxPtr v) { os() << "Max("; v->lhs()->accept(this); os() << ", "; @@ -112,7 +112,7 @@ void IRPrinter::visit(Max* v) { os() << ", " << (unsigned int)v->propagate_nans() << ")"; } -void IRPrinter::visit(Min* v) { +void IRPrinter::visit(MinPtr v) { os() << "Min("; v->lhs()->accept(this); os() << ", "; @@ -120,7 +120,7 @@ void IRPrinter::visit(Min* v) { os() << ", " << (unsigned int)v->propagate_nans() << ")"; } -void IRPrinter::visit(CompareSelect* v) { +void IRPrinter::visit(CompareSelectPtr v) { CompareSelectOperation cmp_op = v->compare_select_op(); int self_prec = getPrecedence(v->expr_type()); int lhs_prec = getPrecedence(v->lhs()->expr_type()); @@ -165,7 +165,7 @@ void IRPrinter::visit(CompareSelect* v) { } os() << " ? "; - auto withParens = [&](Expr* e) { + auto withParens = [&](ExprPtr e) { auto prec = getPrecedence(e->expr_type()); if (prec >= self_prec) { os() << "("; @@ -212,37 +212,37 @@ static void formatImm(std::ostream& os, T v) { } // NOLINTNEXTLINE -#define IMM_PRINT_VISIT(Type, Name) \ - void IRPrinter::visit(Name##Imm* v) { \ - formatImm(os(), v->value()); \ +#define IMM_PRINT_VISIT(Type, Name) \ + void IRPrinter::visit(Name##ImmPtr v) { \ + formatImm(os(), v->value()); \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT); #undef IMM_PRINT_VISIT -void IRPrinter::visit(Cast* v) { +void IRPrinter::visit(CastPtr v) { auto dtype = v->dtype(); os() << dtypeToCppString(dtype) << "("; v->src_value()->accept(this); os() << ")"; } -void IRPrinter::visit(Var* v) { +void IRPrinter::visit(VarPtr v) { os() << name_manager_.get_unique_name(v); } -void IRPrinter::visit(Ramp* v) { +void IRPrinter::visit(RampPtr v) { os() << "Ramp(" << *v->base() << ", " << *v->stride() << ", " << v->lanes() << ")"; } -void IRPrinter::visit(Load* v) { +void IRPrinter::visit(LoadPtr v) { // TODO: support the mask case if (v->indices().size() == 0) { os() << *v->base_handle(); } else { os() << *v->base_handle() << "["; size_t i = 0; - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { if (i++) { os() << ", "; } @@ -255,16 +255,16 @@ void IRPrinter::visit(Load* v) { } } -void IRPrinter::visit(Broadcast* v) { +void IRPrinter::visit(BroadcastPtr v) { os() << "Broadcast(" << *v->value() << ", " << v->lanes() << ")"; } -void IRPrinter::visit(IfThenElse* v) { +void IRPrinter::visit(IfThenElsePtr v) { os() << "IfThenElse(" << *v->condition() << ", " << *v->true_value() << ", " << *v->false_value() << ")"; } -void IRPrinter::visit(Intrinsics* v) { +void IRPrinter::visit(IntrinsicsPtr v) { os() << v->func_name() << "("; for (const auto i : c10::irange(v->nparams())) { if (i > 0) { @@ -275,20 +275,20 @@ void IRPrinter::visit(Intrinsics* v) { os() << ")"; } -void IRPrinter::visit(Term* v) { +void IRPrinter::visit(TermPtr v) { os() << "Term("; v->scalar()->accept(this); - for (auto* t : v->variables()) { + for (auto t : v->variables()) { os() << ","; t->accept(this); } os() << ")"; } -void IRPrinter::visit(Polynomial* v) { +void IRPrinter::visit(PolynomialPtr v) { bool first = true; os() << "Polynomial("; - for (auto* t : v->variables()) { + for (auto t : v->variables()) { if (!first) { os() << " + "; } @@ -303,7 +303,7 @@ void IRPrinter::visit(Polynomial* v) { os() << ")"; } -void IRPrinter::visit(RoundOff* v) { +void IRPrinter::visit(RoundOffPtr v) { os() << "RoundOff("; v->lhs()->accept(this); os() << ", "; @@ -311,7 +311,7 @@ void IRPrinter::visit(RoundOff* v) { os() << ")"; } -void IRPrinter::visit(MaxTerm* v) { +void IRPrinter::visit(MaxTermPtr v) { os() << "MaxTerm("; if (v->scalar()) { v->scalar()->accept(this); @@ -326,7 +326,7 @@ void IRPrinter::visit(MaxTerm* v) { os() << ")"; } -void IRPrinter::visit(MinTerm* v) { +void IRPrinter::visit(MinTermPtr v) { os() << "MinTerm("; if (v->scalar()) { v->scalar()->accept(this); @@ -341,13 +341,13 @@ void IRPrinter::visit(MinTerm* v) { os() << ")"; } -void IRPrinter::visit(ReduceOp* v) { +void IRPrinter::visit(ReduceOpPtr v) { os() << "ReduceOp("; os() << *v->body() << ", "; bool first = true; os() << "reduce_args={"; - for (auto* d : v->reduce_args()) { + for (auto d : v->reduce_args()) { if (!first) { os() << ", "; } @@ -363,7 +363,7 @@ void IRPrinter::visit(ReduceOp* v) { // each statement in a `Block` the printer will insert indentation before // the statement and a newline after the statement. -void IRPrinter::visit(Store* v) { +void IRPrinter::visit(StorePtr v) { // TODO: handle the mask if (v->indices().size() == 0) { os() << *v->base_handle() << " = " << *v->value() << ";"; @@ -372,7 +372,7 @@ void IRPrinter::visit(Store* v) { os() << *v->base_handle() << "["; size_t i = 0; - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { if (i++) { os() << ", "; } @@ -384,8 +384,8 @@ void IRPrinter::visit(Store* v) { os() << "] = " << *v->value() << ";"; } -void IRPrinter::visit(For* v) { - Var* var = v->var(); +void IRPrinter::visit(ForPtr v) { + VarPtr var = v->var(); VarHandle vv(var); os() << "for (" << dtypeToCppString(var->dtype()) << " " << vv << " = " << ExprHandle(v->start()) << "; " << vv << " < " << ExprHandle(v->stop()) @@ -401,11 +401,11 @@ void IRPrinter::visit(For* v) { } } -void IRPrinter::visit(Block* v) { +void IRPrinter::visit(BlockPtr v) { os() << "{\n"; indent_++; - for (Stmt* s : *v) { + for (StmtPtr s : *v) { emitIndent(); os() << *s << "\n"; } @@ -414,11 +414,11 @@ void IRPrinter::visit(Block* v) { os() << "}"; } -void IRPrinter::visit(Allocate* v) { +void IRPrinter::visit(AllocatePtr v) { os() << "Allocate(" << *v->buffer_var() << "); // dtype=" << dtypeToCppString(v->dtype()); os() << ", dims=["; - const std::vector& dims = v->dims(); + const std::vector& dims = v->dims(); for (const auto i : c10::irange(dims.size())) { if (i != 0) { os() << ", "; @@ -428,20 +428,20 @@ void IRPrinter::visit(Allocate* v) { os() << "]"; } -void IRPrinter::visit(Free* v) { +void IRPrinter::visit(FreePtr v) { os() << "Free(" << *v->buffer_var() << ");"; } -void IRPrinter::visit(Let* v) { +void IRPrinter::visit(LetPtr v) { os() << dtypeToCppString(v->dtype()) << " " << *v->var(); os() << " = " << *v->value(); os() << ";"; } -void IRPrinter::visit(Cond* v) { - Expr* cond = v->condition(); - Stmt* true_stmt = v->true_stmt(); - Stmt* false_stmt = v->false_stmt(); +void IRPrinter::visit(CondPtr v) { + ExprPtr cond = v->condition(); + StmtPtr true_stmt = v->true_stmt(); + StmtPtr false_stmt = v->false_stmt(); if (!true_stmt) { os() << "if (!" << *cond << ") "; os() << *false_stmt; @@ -455,10 +455,10 @@ void IRPrinter::visit(Cond* v) { } } -void IRPrinter::visit(AtomicAdd* v) { +void IRPrinter::visit(AtomicAddPtr v) { os() << "atomicAdd(&" << *v->base_handle() << "["; size_t i = 0; - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { if (i++) { os() << ", "; } @@ -470,16 +470,16 @@ void IRPrinter::visit(AtomicAdd* v) { os() << "], " << *v->value() << ");"; } -void IRPrinter::visit(SyncThreads* v) { +void IRPrinter::visit(SyncThreadsPtr v) { os() << "__syncthreads();"; } -void IRPrinter::visit(ExternalCall* v) { +void IRPrinter::visit(ExternalCallPtr v) { os() << *v->buf() << " = " << v->func_name() << "("; os() << "buf_args={"; int i = 0; - for (Buf* buf_arg : v->buf_args()) { + for (BufPtr buf_arg : v->buf_args()) { if (i++ > 0) { os() << ", "; } @@ -488,7 +488,7 @@ void IRPrinter::visit(ExternalCall* v) { os() << "}, args={"; i = 0; - for (Expr* arg : v->args()) { + for (ExprPtr arg : v->args()) { if (i++ > 0) { os() << ", "; } @@ -545,22 +545,20 @@ std::ostream& operator<<(std::ostream& stream, const Tensor& t) { return stream; } -void print(const Expr* expr) { +void print(ExprPtr expr) { if (expr) { - Expr* mutable_expr = const_cast(expr); IRPrinter p(std::cout); - p.print(*mutable_expr); + p.print(*expr); } else { std::cout << "(null expr)"; } std::cout << "\n"; } -void print(const Stmt* stmt) { +void print(StmtPtr stmt) { if (stmt) { - Stmt* mutable_stmt = const_cast(stmt); IRPrinter p(std::cout); - p.print(*mutable_stmt); + p.print(*stmt); } else { std::cout << "(null stmt)\n"; } @@ -575,13 +573,13 @@ void print(const Tensor* t) { } // namespace torch namespace std { -std::string to_string(const Expr* expr) { +std::string to_string(ExprPtr expr) { std::ostringstream oss; oss << *expr; return oss.str(); } -std::string to_string(const Stmt* stmt) { +std::string to_string(StmtPtr stmt) { std::ostringstream oss; oss << *stmt; return oss.str(); diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h index b417d08..e76dcca 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.h +++ b/torch/csrc/jit/tensorexpr/ir_printer.h @@ -2,6 +2,7 @@ #include +#include #include #include #include @@ -19,46 +20,46 @@ class TORCH_API IRPrinter : public IRVisitor { void print(ExprHandle); void print(Expr&); void print(Stmt&); - void visit(Add* v) override; - void visit(Sub* v) override; - void visit(Mul* v) override; - void visit(Div* v) override; - void visit(Mod* v) override; - void visit(Max* v) override; - void visit(Min* v) override; - void visit(And* v) override; - void visit(Or* v) override; - void visit(Xor* v) override; - void visit(Lshift* v) override; - void visit(Rshift* v) override; - void visit(CompareSelect* v) override; -#define IMM_PRINT_VISIT(Type, Name) void visit(Name##Imm* v) override; + void visit(AddPtr v) override; + void visit(SubPtr v) override; + void visit(MulPtr v) override; + void visit(DivPtr v) override; + void visit(ModPtr v) override; + void visit(MaxPtr v) override; + void visit(MinPtr v) override; + void visit(AndPtr v) override; + void visit(OrPtr v) override; + void visit(XorPtr v) override; + void visit(LshiftPtr v) override; + void visit(RshiftPtr v) override; + void visit(CompareSelectPtr v) override; +#define IMM_PRINT_VISIT(Type, Name) void visit(Name##ImmPtr v) override; AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT); #undef IMM_PRINT_VISIT - void visit(Cast* v) override; - void visit(Var* v) override; - void visit(Ramp* v) override; - void visit(Load* v) override; - void visit(Broadcast* v) override; - void visit(IfThenElse* v) override; - void visit(Intrinsics* v) override; - void visit(Term* v) override; - void visit(Polynomial* v) override; - void visit(RoundOff* v) override; - void visit(MaxTerm* v) override; - void visit(MinTerm* v) override; - void visit(ReduceOp* v) override; - - void visit(AtomicAdd* v) override; - void visit(SyncThreads* v) override; - void visit(ExternalCall* v) override; - void visit(Store* v) override; - void visit(For* v) override; - void visit(Cond* v) override; - void visit(Block* v) override; - void visit(Allocate* v) override; - void visit(Free* v) override; - void visit(Let* v) override; + void visit(CastPtr v) override; + void visit(VarPtr v) override; + void visit(RampPtr v) override; + void visit(LoadPtr v) override; + void visit(BroadcastPtr v) override; + void visit(IfThenElsePtr v) override; + void visit(IntrinsicsPtr v) override; + void visit(TermPtr v) override; + void visit(PolynomialPtr v) override; + void visit(RoundOffPtr v) override; + void visit(MaxTermPtr v) override; + void visit(MinTermPtr v) override; + void visit(ReduceOpPtr v) override; + + void visit(AtomicAddPtr v) override; + void visit(SyncThreadsPtr v) override; + void visit(ExternalCallPtr v) override; + void visit(StorePtr v) override; + void visit(ForPtr v) override; + void visit(CondPtr v) override; + void visit(BlockPtr v) override; + void visit(AllocatePtr v) override; + void visit(FreePtr v) override; + void visit(LetPtr v) override; // A child class may have a difference rule for generating dtype // string, e.g. CUDA needs int64_t to be generated as long long. @@ -100,8 +101,8 @@ TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&); TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&); TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&); -TORCH_API void print(const Expr* expr); -TORCH_API void print(const Stmt* stmt); +TORCH_API void print(ExprPtr expr); +TORCH_API void print(StmtPtr stmt); TORCH_API void print(const Tensor* t); } // namespace tensorexpr @@ -111,10 +112,12 @@ TORCH_API void print(const Tensor* t); namespace std { using torch::jit::tensorexpr::Expr; +using torch::jit::tensorexpr::ExprPtr; using torch::jit::tensorexpr::Stmt; +using torch::jit::tensorexpr::StmtPtr; using torch::jit::tensorexpr::Tensor; -TORCH_API std::string to_string(const Expr* expr); -TORCH_API std::string to_string(const Stmt* stmt); +TORCH_API std::string to_string(ExprPtr expr); +TORCH_API std::string to_string(StmtPtr stmt); TORCH_API std::string to_string(const Tensor* t); } // namespace std diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index e8e1868..3d849fe 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -17,13 +17,13 @@ T gcd(T a, T b) { // Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast // or Ramp). -bool isMultilanePrimitive(Expr* e) { - return dynamic_cast(e) || dynamic_cast(e); +bool isMultilanePrimitive(ExprPtr e) { + return to(e) || to(e); } SimplifierHashType Term::hashVars() const { SimplifierHashType hash; - for (auto* v : variables_) { + for (auto v : variables_) { hash = hasher_.hash_combine(hash, hasher_.hash(v)); } @@ -35,14 +35,14 @@ void Term::sort() { if (dtype().is_floating_point()) { throw std::logic_error("reordering FP ops"); } - std::sort(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) { + std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { return hasher_.hash(a) < hasher_.hash(b); }); } SimplifierHashType Polynomial::hashVars() const { SimplifierHashType hash; - for (auto* v : variables_) { + for (auto v : variables_) { hash = hasher_.hash_combine(hash, hasher_.hash(v)); } return hash; @@ -52,28 +52,28 @@ void Polynomial::sort() { if (dtype().is_floating_point()) { throw std::logic_error("reordering FP ops"); } - std::sort(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) { + std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { return hasher_.hash(a) < hasher_.hash(b); }); } void MaxTerm::uniquefy() { - std::sort(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) { + std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { return hasher_.hash(a) < hasher_.hash(b); }); - auto it = - std::unique(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) { + auto it = std::unique( + variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { return hasher_.hash(a) == hasher_.hash(b); }); variables_.resize(std::distance(variables_.begin(), it)); } void MinTerm::uniquefy() { - std::sort(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) { + std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { return hasher_.hash(a) < hasher_.hash(b); }); - auto it = - std::unique(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) { + auto it = std::unique( + variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { return hasher_.hash(a) == hasher_.hash(b); }); variables_.resize(std::distance(variables_.begin(), it)); @@ -81,46 +81,46 @@ void MinTerm::uniquefy() { // Handles optimization cases for Broadcast/Ramp +/- Broadcast/Ramp template -Expr* combineMultilane(Expr* lhs, Expr* rhs) { - if (Broadcast* bc = dynamic_cast(lhs)) { - if (Broadcast* bcother = dynamic_cast(rhs)) { +ExprPtr combineMultilane(ExprPtr lhs, ExprPtr rhs) { + if (BroadcastPtr bc = to(lhs)) { + if (BroadcastPtr bcother = to(rhs)) { if (bc->lanes() != bcother->lanes()) { throw malformed_input("multilane lane mismatch"); } - Expr* ret = - new Broadcast(new Op(bc->value(), bcother->value()), bc->lanes()); + ExprPtr ret = alloc( + alloc(bc->value(), bcother->value()), bc->lanes()); return ret; } - if (Ramp* r = dynamic_cast(rhs)) { + if (RampPtr r = to(rhs)) { if (bc->lanes() != r->lanes()) { throw malformed_input("multilane lane mismatch"); } - Expr* ret = - new Ramp(new Op(bc->value(), r->base()), r->stride(), r->lanes()); + ExprPtr ret = alloc( + alloc(bc->value(), r->base()), r->stride(), r->lanes()); return ret; } - } else if (Ramp* ramp = dynamic_cast(lhs)) { - if (Ramp* rother = dynamic_cast(rhs)) { + } else if (RampPtr ramp = to(lhs)) { + if (RampPtr rother = to(rhs)) { if (ramp->lanes() != rother->lanes()) { throw malformed_input("multilane lane mismatch"); } - Expr* ret = new Ramp( - new Op(ramp->base(), rother->base()), - new Op(ramp->stride(), rother->stride()), + ExprPtr ret = alloc( + alloc(ramp->base(), rother->base()), + alloc(ramp->stride(), rother->stride()), ramp->lanes()); return ret; } - if (Broadcast* bc = dynamic_cast(rhs)) { + if (BroadcastPtr bc = to(rhs)) { if (ramp->lanes() != bc->lanes()) { throw malformed_input("multilane lane mismatch"); } - Expr* ret = new Ramp( - new Op(ramp->base(), bc->value()), ramp->stride(), ramp->lanes()); + ExprPtr ret = alloc( + alloc(ramp->base(), bc->value()), ramp->stride(), ramp->lanes()); return ret; } } @@ -129,50 +129,50 @@ Expr* combineMultilane(Expr* lhs, Expr* rhs) { } // Handles optimization cases for Broadcast/Ramp * Broadcast/Ramp -Expr* mulMultilane(Expr* lhs, Expr* rhs) { - if (Broadcast* bc = dynamic_cast(lhs)) { - if (Broadcast* bcother = dynamic_cast(rhs)) { +ExprPtr mulMultilane(ExprPtr lhs, ExprPtr rhs) { + if (BroadcastPtr bc = to(lhs)) { + if (BroadcastPtr bcother = to(rhs)) { if (bc->lanes() != bcother->lanes()) { throw malformed_input("multilane lane mismatch"); } - Expr* ret = - new Broadcast(new Mul(bc->value(), bcother->value()), bc->lanes()); + ExprPtr ret = alloc( + alloc(bc->value(), bcother->value()), bc->lanes()); return ret; } - if (Ramp* r = dynamic_cast(rhs)) { + if (RampPtr r = to(rhs)) { if (bc->lanes() != r->lanes()) { throw malformed_input("multilane lane mismatch"); } - Expr* ret = new Ramp( - new Mul(bc->value(), r->base()), - new Mul(bc->value(), r->stride()), + ExprPtr ret = alloc( + alloc(bc->value(), r->base()), + alloc(bc->value(), r->stride()), r->lanes()); return ret; } - } else if (Ramp* ramp = dynamic_cast(lhs)) { - if (Ramp* r = dynamic_cast(rhs)) { + } else if (RampPtr ramp = to(lhs)) { + if (RampPtr r = to(rhs)) { if (ramp->lanes() != r->lanes()) { throw malformed_input("multilane lane mismatch"); } - Expr* ret = new Ramp( - new Mul(ramp->base(), r->base()), - new Mul(ramp->stride(), r->stride()), + ExprPtr ret = alloc( + alloc(ramp->base(), r->base()), + alloc(ramp->stride(), r->stride()), r->lanes()); return ret; } - if (Broadcast* bc = dynamic_cast(rhs)) { + if (BroadcastPtr bc = to(rhs)) { if (ramp->lanes() != bc->lanes()) { throw malformed_input("multilane lane mismatch"); } - Expr* ret = new Ramp( - new Mul(bc->value(), ramp->base()), - new Mul(bc->value(), ramp->stride()), + ExprPtr ret = alloc( + alloc(bc->value(), ramp->base()), + alloc(bc->value(), ramp->stride()), ramp->lanes()); return ret; } @@ -182,13 +182,13 @@ Expr* mulMultilane(Expr* lhs, Expr* rhs) { } void PolynomialTransformer::addOrUpdateTerm( - std::unordered_map& varmap, - Term* term) { + std::unordered_map& varmap, + TermPtr term) { SimplifierHashType hash = term->hashVars(); auto insertRes = varmap.emplace(hash, term); if (insertRes.second == false) { - Term* lt = insertRes.first->second; - Expr* termScalar = evaluateOp(new Add(lt->scalar(), term->scalar())); + TermPtr lt = insertRes.first->second; + ExprPtr termScalar = evaluateOp(alloc(lt->scalar(), term->scalar())); // If the term is canceled out, remove from the map. if (immediateEquals(termScalar, 0)) { @@ -196,43 +196,45 @@ void PolynomialTransformer::addOrUpdateTerm( return; } - varmap[hash] = new Term(hasher_, termScalar, lt->variables()); + varmap[hash] = alloc(hasher_, termScalar, lt->variables()); } } -Expr* PolynomialTransformer::addPolynomials(Polynomial* lhs, Polynomial* rhs) { +ExprPtr PolynomialTransformer::addPolynomials( + PolynomialPtr lhs, + PolynomialPtr rhs) { // simplify common components // The key here is the variable hash, not the term's hash since we do want // to combine terms that have the same vars but different scalar components. - std::unordered_map varmap; + std::unordered_map varmap; - for (auto* lt : lhs->variables()) { + for (auto lt : lhs->variables()) { addOrUpdateTerm(varmap, lt); } - for (auto* rt : rhs->variables()) { + for (auto rt : rhs->variables()) { addOrUpdateTerm(varmap, rt); } - Expr* newScalar = evaluateOp(new Add(lhs->scalar(), rhs->scalar())); - return new Polynomial(hasher_, newScalar, varmap); + ExprPtr newScalar = evaluateOp(alloc(lhs->scalar(), rhs->scalar())); + return alloc(hasher_, newScalar, varmap); } // Insert a new Term into the provided polynomial. If the new term has common // variables to an existing term it is combined. -Expr* PolynomialTransformer::insertTerm(Polynomial* poly, Term* term) { +ExprPtr PolynomialTransformer::insertTerm(PolynomialPtr poly, TermPtr term) { SimplifierHashType tHash = term->hashVars(); - std::vector newVars; + std::vector newVars; bool found = false; - for (auto* v : poly->variables()) { + for (auto v : poly->variables()) { if (v->hashVars() == tHash) { - Expr* newScalar = evaluateOp(new Add(term->scalar(), v->scalar())); + ExprPtr newScalar = evaluateOp(alloc(term->scalar(), v->scalar())); found = true; // Skip this term if we cancelled it out. if (immediateEquals(newScalar, 0)) { continue; } - auto* term = new Term(hasher_, newScalar, v->variables()); + auto term = alloc(hasher_, newScalar, v->variables()); newVars.push_back(term); } else { newVars.push_back(v); @@ -247,29 +249,29 @@ Expr* PolynomialTransformer::insertTerm(Polynomial* poly, Term* term) { return poly->scalar(); } - auto* Poly = new Polynomial(hasher_, poly->scalar(), newVars); + auto Poly = alloc(hasher_, poly->scalar(), newVars); return Poly; } -Expr* PolynomialTransformer::mutate(Add* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(AddPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { - Expr* result = evaluateOp(new Add(lhs_new, rhs_new)); + ExprPtr result = evaluateOp(alloc(lhs_new, rhs_new)); return result; } // Multilane folding. if (isMultilanePrimitive(lhs_new)) { - if (auto* ret = combineMultilane(lhs_new, rhs_new)) { + if (auto ret = combineMultilane(lhs_new, rhs_new)) { return ret->accept_mutator(this); } } - Expr* scalar = nullptr; - Expr* variable = nullptr; + ExprPtr scalar = nullptr; + ExprPtr variable = nullptr; if (lhs_new->isConstant()) { scalar = evaluateOp(lhs_new); variable = rhs_new; @@ -281,7 +283,7 @@ Expr* PolynomialTransformer::mutate(Add* v) { // If there is a scalar, and it's zero: short circuit and return the other // side. if (scalar && immediateEquals(scalar, 0)) { - auto* c = new Cast(v->dtype(), variable); + auto c = alloc(v->dtype(), variable); return c->accept_mutator(this); } @@ -289,18 +291,18 @@ Expr* PolynomialTransformer::mutate(Add* v) { // dont want to combine ops. if (lhs_new->dtype().is_floating_point() || rhs_new->dtype().is_floating_point()) { - return new Add(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); } - Polynomial* lhsPoly = dynamic_cast(lhs_new); - Polynomial* rhsPoly = dynamic_cast(rhs_new); + PolynomialPtr lhsPoly = to(lhs_new); + PolynomialPtr rhsPoly = to(rhs_new); if (lhsPoly && rhsPoly) { return addPolynomials(lhsPoly, rhsPoly); } - Term* lhsTerm = dynamic_cast(lhs_new); - Term* rhsTerm = dynamic_cast(rhs_new); + TermPtr lhsTerm = to(lhs_new); + TermPtr rhsTerm = to(rhs_new); if (lhsPoly && rhsTerm) { return insertTerm(lhsPoly, rhsTerm); @@ -313,54 +315,54 @@ Expr* PolynomialTransformer::mutate(Add* v) { if (lhsTerm && rhsTerm) { // If the terms refer to the same variables: combine them. if (lhsTerm->hashVars() == rhsTerm->hashVars()) { - Expr* newScalar = - evaluateOp(new Add(lhsTerm->scalar(), rhsTerm->scalar())); + ExprPtr newScalar = + evaluateOp(alloc(lhsTerm->scalar(), rhsTerm->scalar())); // If the terms cancelled out, return zero. if (immediateEquals(newScalar, 0)) { return newScalar->accept_mutator(this); } - return new Term(hasher_, newScalar, lhsTerm->variables()); + return alloc(hasher_, newScalar, lhsTerm->variables()); } // Otherwise this is a new polynomial with no scalar and two variable // terms. - return new Polynomial( + return alloc( hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); } // Adds are commutative. - Polynomial* poly = lhsPoly ? lhsPoly : rhsPoly; + PolynomialPtr poly = lhsPoly ? lhsPoly : rhsPoly; // Add to Polynomial->scalar(). if (scalar && poly) { - Expr* newScalar = evaluateOp(new Add(scalar, poly->scalar())); - return new Polynomial(hasher_, newScalar, poly->variables()); + ExprPtr newScalar = evaluateOp(alloc(scalar, poly->scalar())); + return alloc(hasher_, newScalar, poly->variables()); } // Simple Polynomial with a scalar and Term. - Term* term = lhsTerm ? lhsTerm : rhsTerm; + TermPtr term = lhsTerm ? lhsTerm : rhsTerm; if (scalar && term) { - return new Polynomial(hasher_, scalar, term); + return alloc(hasher_, scalar, term); } // Simple Term with a scalar and variable type. if (scalar) { - return new Polynomial( + return alloc( hasher_, scalar, - new Term(hasher_, getImmediateByType(v->dtype(), 1), variable)); + alloc(hasher_, getImmediateByType(v->dtype(), 1), variable)); } // If LHS is neither Term not Polynomial, wrap it in a Term. if (!lhsTerm && !lhsPoly) { - lhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); + lhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); } // Same for RHS. if (!rhsTerm && !rhsPoly) { - rhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), rhs_new); + rhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), rhs_new); } // If we now have a poly and a term, we can insert. @@ -369,37 +371,40 @@ Expr* PolynomialTransformer::mutate(Add* v) { } if (lhsTerm->hashVars() == rhsTerm->hashVars()) { - return new Term( + return alloc( hasher_, - evaluateOp(new Add(lhsTerm->scalar(), rhsTerm->scalar())), + evaluateOp(alloc(lhsTerm->scalar(), rhsTerm->scalar())), lhsTerm->variables()); } // If all else fails we have a new Polynomial with two new variable Terms. - return new Polynomial( + return alloc( hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); } -Expr* PolynomialTransformer::subTerms(Term* lhs, Term* rhs, bool negated) { +ExprPtr PolynomialTransformer::subTerms( + TermPtr lhs, + TermPtr rhs, + bool negated) { // If RHS not already negated, negate it. if (!negated) { - Expr* minusOne = getImmediateByType(rhs->dtype(), -1); - Expr* negateScalar = evaluateOp(new Mul(minusOne, rhs->scalar())); - rhs = new Term(hasher_, negateScalar, rhs->variables()); + ExprPtr minusOne = getImmediateByType(rhs->dtype(), -1); + ExprPtr negateScalar = evaluateOp(alloc(minusOne, rhs->scalar())); + rhs = alloc(hasher_, negateScalar, rhs->variables()); } if (lhs->hashVars() == rhs->hashVars()) { - Expr* newScalar = evaluateOp(new Add(lhs->scalar(), rhs->scalar())); + ExprPtr newScalar = evaluateOp(alloc(lhs->scalar(), rhs->scalar())); // If the terms cancel out, return zero. if (immediateEquals(newScalar, 0)) { return newScalar; } - return new Term(hasher_, newScalar, lhs->variables()); + return alloc(hasher_, newScalar, lhs->variables()); } - return new Polynomial( + return alloc( hasher_, getImmediateByType(promoteTypes(lhs->dtype(), rhs->dtype()), 0), lhs, @@ -408,25 +413,27 @@ Expr* PolynomialTransformer::subTerms(Term* lhs, Term* rhs, bool negated) { // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where // possible. -Expr* PolynomialTransformer::subPolynomials(Polynomial* lhs, Polynomial* rhs) { +ExprPtr PolynomialTransformer::subPolynomials( + PolynomialPtr lhs, + PolynomialPtr rhs) { // simplify common components // The key here is the variable hash, not the term's hash since we do want // to combine terms that have the same vars but different scalar components. - std::unordered_map varmap; + std::unordered_map varmap; - for (auto* lt : lhs->variables()) { + for (auto lt : lhs->variables()) { addOrUpdateTerm(varmap, lt); } - for (auto* rt : rhs->variables()) { + for (auto rt : rhs->variables()) { // Polynomials add their terms, so negate the RHS's Terms. - Expr* negated = - evaluateOp(new Mul(getImmediateByType(rt->dtype(), -1), rt->scalar())); - Term* newRHS = new Term(hasher_, negated, rt->variables()); + ExprPtr negated = evaluateOp( + alloc(getImmediateByType(rt->dtype(), -1), rt->scalar())); + TermPtr newRHS = alloc(hasher_, negated, rt->variables()); addOrUpdateTerm(varmap, newRHS); } - Expr* newScalar = evaluateOp(new Sub(lhs->scalar(), rhs->scalar())); + ExprPtr newScalar = evaluateOp(alloc(lhs->scalar(), rhs->scalar())); // No vars means this cancelled out to a scalar, return it unwrapped. if (varmap.empty()) { @@ -444,29 +451,29 @@ Expr* PolynomialTransformer::subPolynomials(Polynomial* lhs, Polynomial* rhs) { } // Wrap new variables in a Polynomial. - return new Polynomial(hasher_, newScalar, varmap); + return alloc(hasher_, newScalar, varmap); } -Expr* PolynomialTransformer::mutate(Sub* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(SubPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { - Expr* result = evaluateOp(new Sub(lhs_new, rhs_new)); + ExprPtr result = evaluateOp(alloc(lhs_new, rhs_new)); return result; } // Multilane folding. if (isMultilanePrimitive(lhs_new)) { - if (auto* ret = combineMultilane(lhs_new, rhs_new)) { + if (auto ret = combineMultilane(lhs_new, rhs_new)) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return ret->accept_mutator(this); } } if (rhs_new->isConstant() && immediateEquals(rhs_new, 0)) { - auto* c = new Cast(v->dtype(), lhs_new); + auto c = alloc(v->dtype(), lhs_new); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return c->accept_mutator(this); } @@ -475,14 +482,14 @@ Expr* PolynomialTransformer::mutate(Sub* v) { // dont want to combine ops. if (lhs_new->dtype().is_floating_point() || rhs_new->dtype().is_floating_point()) { - return new Sub(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); } - Polynomial* lhsPoly = dynamic_cast(lhs_new); - Polynomial* rhsPoly = dynamic_cast(rhs_new); + PolynomialPtr lhsPoly = to(lhs_new); + PolynomialPtr rhsPoly = to(rhs_new); if (lhsPoly && rhsPoly) { - auto* ret = subPolynomials(lhsPoly, rhsPoly); + auto ret = subPolynomials(lhsPoly, rhsPoly); if (!ret) { // Cancelled out completely. return getImmediateByType(v->dtype(), 0); @@ -490,31 +497,31 @@ Expr* PolynomialTransformer::mutate(Sub* v) { return ret; } - Term* lhsTerm = dynamic_cast(lhs_new); - Term* rhsTerm = dynamic_cast(rhs_new); + TermPtr lhsTerm = to(lhs_new); + TermPtr rhsTerm = to(rhs_new); // Polynomial - Term. if (lhsPoly && rhsTerm) { // Negate the term. - Expr* negate = evaluateOp( - new Mul(getImmediateByType(rhsTerm->dtype(), -1), rhsTerm->scalar())); - Term* newTerm = new Term(hasher_, negate, rhsTerm->variables()); + ExprPtr negate = evaluateOp(alloc( + getImmediateByType(rhsTerm->dtype(), -1), rhsTerm->scalar())); + TermPtr newTerm = alloc(hasher_, negate, rhsTerm->variables()); return insertTerm(lhsPoly, newTerm); } // Term - Polynomial. if (rhsPoly && lhsTerm) { // Negate every part of the Polynomial. - Expr* minusOne = getImmediateByType(lhsTerm->dtype(), -1); - Expr* negateScalar = evaluateOp(new Mul(minusOne, rhsPoly->scalar())); + ExprPtr minusOne = getImmediateByType(lhsTerm->dtype(), -1); + ExprPtr negateScalar = evaluateOp(alloc(minusOne, rhsPoly->scalar())); - std::vector variables; - for (auto* t : rhsPoly->variables()) { - Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); - variables.push_back(new Term(hasher_, negate, t->variables())); + std::vector variables; + for (auto t : rhsPoly->variables()) { + ExprPtr negate = evaluateOp(alloc(minusOne, t->scalar())); + variables.push_back(alloc(hasher_, negate, t->variables())); } - Polynomial* newPoly = new Polynomial(hasher_, negateScalar, variables); + PolynomialPtr newPoly = alloc(hasher_, negateScalar, variables); return insertTerm(newPoly, lhsTerm); } @@ -527,68 +534,68 @@ Expr* PolynomialTransformer::mutate(Sub* v) { if (lhsPoly && rhsScalar) { // Easy path, just sub the scalar component. - Expr* newScalar = evaluateOp(new Sub(lhsPoly->scalar(), rhs_new)); - return new Polynomial(hasher_, newScalar, lhsPoly->variables()); + ExprPtr newScalar = evaluateOp(alloc(lhsPoly->scalar(), rhs_new)); + return alloc(hasher_, newScalar, lhsPoly->variables()); } if (lhsScalar && rhsPoly) { // Sub the scalar component. - Expr* newScalar = evaluateOp(new Sub(lhs_new, rhsPoly->scalar())); + ExprPtr newScalar = evaluateOp(alloc(lhs_new, rhsPoly->scalar())); // Negate each term in the Polynomial RHS. - Expr* minusOne = getImmediateByType(rhsPoly->dtype(), -1); - std::vector variables; - for (auto* t : rhsPoly->variables()) { - Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); - variables.push_back(new Term(hasher_, negate, t->variables())); + ExprPtr minusOne = getImmediateByType(rhsPoly->dtype(), -1); + std::vector variables; + for (auto t : rhsPoly->variables()) { + ExprPtr negate = evaluateOp(alloc(minusOne, t->scalar())); + variables.push_back(alloc(hasher_, negate, t->variables())); } - return new Polynomial(hasher_, newScalar, variables); + return alloc(hasher_, newScalar, variables); } if (lhsTerm && rhsScalar) { // Negate the constant. - Expr* negate = - evaluateOp(new Mul(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); - return new Polynomial(hasher_, negate, lhsTerm); + ExprPtr negate = evaluateOp( + alloc(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); + return alloc(hasher_, negate, lhsTerm); } if (lhsScalar && rhsTerm) { // Negate the RHS Term. - Expr* negate = evaluateOp(new Mul( + ExprPtr negate = evaluateOp(alloc( getImmediateByType(rhsTerm->scalar()->dtype(), -1), rhsTerm->scalar())); - return new Polynomial( - hasher_, lhs_new, new Term(hasher_, negate, rhsTerm->variables())); + return alloc( + hasher_, lhs_new, alloc(hasher_, negate, rhsTerm->variables())); } // simple term with a scalar and variable type. if (lhsScalar) { // Create a negated term. - return new Polynomial( + return alloc( hasher_, lhs_new, - new Term(hasher_, getImmediateByType(v->dtype(), -1), rhs_new)); + alloc(hasher_, getImmediateByType(v->dtype(), -1), rhs_new)); } if (rhsScalar) { // Negate the scalar. - Expr* negate = - evaluateOp(new Mul(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); - return new Polynomial( + ExprPtr negate = evaluateOp( + alloc(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); + return alloc( hasher_, negate, - new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new)); + alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new)); } // no scalar... if (!lhsTerm && !lhsPoly) { - lhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); + lhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); } bool createdRHSnegated = false; if (!rhsTerm && !rhsPoly) { - rhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), -1), rhs_new); + rhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), -1), rhs_new); createdRHSnegated = true; } @@ -605,44 +612,44 @@ Expr* PolynomialTransformer::mutate(Sub* v) { // Insert wrapper Term into negated RHS Poly. if (rhsPoly) { CHECK(lhsTerm); - Expr* minusOne = getImmediateByType(rhsPoly->dtype(), -1); - Expr* newScalar = evaluateOp(new Mul(minusOne, rhsPoly->scalar())); + ExprPtr minusOne = getImmediateByType(rhsPoly->dtype(), -1); + ExprPtr newScalar = evaluateOp(alloc(minusOne, rhsPoly->scalar())); // Negate each term in the Polynomial RHS. - std::vector variables; - for (auto* t : rhsPoly->variables()) { - Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); - variables.push_back(new Term(hasher_, negate, t->variables())); + std::vector variables; + for (auto t : rhsPoly->variables()) { + ExprPtr negate = evaluateOp(alloc(minusOne, t->scalar())); + variables.push_back(alloc(hasher_, negate, t->variables())); } - auto* poly = new Polynomial(hasher_, newScalar, variables); + auto poly = alloc(hasher_, newScalar, variables); return insertTerm(poly, lhsTerm); } - return new Polynomial( + return alloc( hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); } // Multiply two terms together, usually creating a new term with the variable // lists concatenated. -Term* PolynomialTransformer::mulTerms(Term* lhs, Term* rhs) { - Expr* scalar = evaluateOp(new Mul(lhs->scalar(), rhs->scalar())); +TermPtr PolynomialTransformer::mulTerms(TermPtr lhs, TermPtr rhs) { + ExprPtr scalar = evaluateOp(alloc(lhs->scalar(), rhs->scalar())); if (immediateEquals(scalar, 0)) { return nullptr; } // Can reorder here since floating point ops don't get put into Terms. - std::vector variables; - std::vector multilaneVariables; + std::vector variables; + std::vector multilaneVariables; // For now don't handle exponents. - for (auto* c : lhs->variables()) { + for (auto c : lhs->variables()) { if (isMultilanePrimitive(c)) { multilaneVariables.push_back(c); } else { variables.push_back(c); } } - for (auto* c : rhs->variables()) { + for (auto c : rhs->variables()) { if (isMultilanePrimitive(c)) { multilaneVariables.push_back(c); } else { @@ -651,12 +658,12 @@ Term* PolynomialTransformer::mulTerms(Term* lhs, Term* rhs) { } // Merge all the multilane vars: - Expr* lastNode{nullptr}; - for (auto* node : multilaneVariables) { + ExprPtr lastNode{nullptr}; + for (auto node : multilaneVariables) { if (lastNode == nullptr) { lastNode = node; } else { - if (auto* next = mulMultilane(lastNode, node)) { + if (auto next = mulMultilane(lastNode, node)) { lastNode = next->accept_mutator(this); } else { variables.push_back(lastNode); @@ -668,20 +675,20 @@ Term* PolynomialTransformer::mulTerms(Term* lhs, Term* rhs) { variables.push_back(lastNode); } - return new Term(hasher_, scalar, variables); + return alloc(hasher_, scalar, variables); } // Multiply a Polynomial by a Term. -Expr* PolynomialTransformer::polyByTerm(Polynomial* poly, Term* term) { +ExprPtr PolynomialTransformer::polyByTerm(PolynomialPtr poly, TermPtr term) { // poly * term // = (poly_terms + poly_scalar) * term // = poly_terms * term + poly_scalar * term // First, multiply all variables (terms) in the polynomial by the input // term. - std::vector newTerms; - for (auto* var : poly->variables()) { - Term* newTerm = mulTerms(var, term); + std::vector newTerms; + for (auto var : poly->variables()) { + TermPtr newTerm = mulTerms(var, term); if (newTerm) { newTerms.push_back(newTerm); } @@ -692,37 +699,37 @@ Expr* PolynomialTransformer::polyByTerm(Polynomial* poly, Term* term) { // polynomial. If there are variables in term, this becomes a new term in // the result polynomial. if (!immediateEquals(poly->scalar(), 0)) { - Expr* scalar = evaluateOp(new Mul(poly->scalar(), term->scalar())); + ExprPtr scalar = evaluateOp(alloc(poly->scalar(), term->scalar())); if (term->variables().empty()) { - return new Polynomial(hasher_, scalar, newTerms); + return alloc(hasher_, scalar, newTerms); } - newTerms.push_back(new Term(hasher_, scalar, term->variables())); + newTerms.push_back(alloc(hasher_, scalar, term->variables())); } // The only case when the result polynomial has a scalar is when the input // term does not have any variables and the input polynomial has a non-zero // scalar. That case is handled above. So, at this point, we do not have any // scalars in the result polynomial. - return new Polynomial(hasher_, std::move(newTerms)); + return alloc(hasher_, std::move(newTerms)); } // Does multiplying these two expressions make a Rounding Off operation. // e.g. LHS = (x/y), RHS = y => (x / y) * y => RoundOff(x, y). -Expr* PolynomialTransformer::isRoundOff(Expr* lhs, Expr* rhs) { - Div* div{nullptr}; - Expr* other{nullptr}; +ExprPtr PolynomialTransformer::isRoundOff(ExprPtr lhs, ExprPtr rhs) { + DivPtr div{nullptr}; + ExprPtr other{nullptr}; - if ((div = dynamic_cast(lhs))) { + if ((div = to
(lhs))) { other = rhs; - } else if ((div = dynamic_cast(rhs))) { + } else if ((div = to
(rhs))) { other = lhs; } else { return nullptr; } - Expr* denom = div->rhs(); + ExprPtr denom = div->rhs(); - if (Term* denomTerm = dynamic_cast(denom)) { + if (TermPtr denomTerm = to(denom)) { if (immediateEquals(denomTerm->scalar(), 1) && denomTerm->variables().size() == 1) { denom = denomTerm->variables()[0]; @@ -731,7 +738,7 @@ Expr* PolynomialTransformer::isRoundOff(Expr* lhs, Expr* rhs) { if (hasher_.hash(denom) == hasher_.hash(other)) { // If the denominator is equal to the other, then yes it's a RoundOff. - return new RoundOff(div->lhs(), div->rhs()); + return alloc(div->lhs(), div->rhs()); } if (denom->isConstant() && other->isConstant()) { @@ -739,10 +746,11 @@ Expr* PolynomialTransformer::isRoundOff(Expr* lhs, Expr* rhs) { return nullptr; } // If they are both scalar we may be able to find a common factor. - if (immediateEquals(evaluateOp(new Mod(other, denom)), 0)) { - Expr* scalar = evaluateOp(new Div(other, denom)); - Expr* newDenom = evaluateOp(new Div(other, scalar)); - return new Term(hasher_, scalar, new RoundOff(div->lhs(), newDenom)); + if (immediateEquals(evaluateOp(alloc(other, denom)), 0)) { + ExprPtr scalar = evaluateOp(alloc
(other, denom)); + ExprPtr newDenom = evaluateOp(alloc
(other, scalar)); + return alloc( + hasher_, scalar, alloc(div->lhs(), newDenom)); } } @@ -750,13 +758,13 @@ Expr* PolynomialTransformer::isRoundOff(Expr* lhs, Expr* rhs) { } // Inserts a new component into a term, looking for opportunities to simplify. -Expr* PolynomialTransformer::insertIntoTerm(Term* term, Expr* expr) { - std::vector vars; +ExprPtr PolynomialTransformer::insertIntoTerm(TermPtr term, ExprPtr expr) { + std::vector vars; // Search for RoundOffs. bool merged{false}; - for (auto* component : term->variables()) { - if (auto* roundoff = isRoundOff(component, expr)) { + for (auto component : term->variables()) { + if (auto roundoff = isRoundOff(component, expr)) { vars.push_back(roundoff); merged = true; } else { @@ -772,29 +780,29 @@ Expr* PolynomialTransformer::insertIntoTerm(Term* term, Expr* expr) { return vars[0]; } - return new Term(hasher_, term->scalar(), vars); + return alloc(hasher_, term->scalar(), vars); } -Expr* PolynomialTransformer::mutate(Mul* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(MulPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { - return evaluateOp(new Mul(lhs_new, rhs_new)); + return evaluateOp(alloc(lhs_new, rhs_new)); } // Multilane folding. if (isMultilanePrimitive(lhs_new)) { - if (auto* ret = mulMultilane(lhs_new, rhs_new)) { + if (auto ret = mulMultilane(lhs_new, rhs_new)) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return ret->accept_mutator(this); } } // Order doesn't matter. - Expr* scalar = nullptr; - Expr* variable = nullptr; + ExprPtr scalar = nullptr; + ExprPtr variable = nullptr; if (lhs_new->isConstant()) { scalar = lhs_new; variable = rhs_new; @@ -806,7 +814,7 @@ Expr* PolynomialTransformer::mutate(Mul* v) { // Handle special case mul by 1 since thats safe for floating point, even if // it's Nan/Inf. if (scalar && immediateEquals(scalar, 1)) { - auto* c = new Cast(v->dtype(), variable); + auto c = alloc(v->dtype(), variable); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return c->accept_mutator(this); } @@ -815,7 +823,7 @@ Expr* PolynomialTransformer::mutate(Mul* v) { // dont want to combine ops. if (lhs_new->dtype().is_floating_point() || rhs_new->dtype().is_floating_point()) { - return new Mul(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); } // Handle special case mul by 0. @@ -824,26 +832,26 @@ Expr* PolynomialTransformer::mutate(Mul* v) { } // Catch cases of rounding (Div(A/B) * B). - if (auto* ret = isRoundOff(lhs_new, rhs_new)) { + if (auto ret = isRoundOff(lhs_new, rhs_new)) { return ret; - } else if (auto* ret = isRoundOff(v->lhs(), v->rhs())) { + } else if (auto ret = isRoundOff(v->lhs(), v->rhs())) { // We can break the Round + Mod pattern via factorization of the Div, so // check whether it would have worked on the unsimplified tree. If so, we // need to simplify again. return ret->accept_mutator(this); } - Polynomial* lhsPoly = dynamic_cast(lhs_new); - Polynomial* rhsPoly = dynamic_cast(rhs_new); + PolynomialPtr lhsPoly = to(lhs_new); + PolynomialPtr rhsPoly = to(rhs_new); if (lhsPoly && rhsPoly) { // This expands to more terms that we can't generally fix without variable // factorization, it's more efficient to just leave these as Muls. - return new Mul(lhsPoly, rhsPoly); + return alloc(lhsPoly, rhsPoly); } - Term* lhsTerm = dynamic_cast(lhs_new); - Term* rhsTerm = dynamic_cast(rhs_new); + TermPtr lhsTerm = to(lhs_new); + TermPtr rhsTerm = to(rhs_new); if (lhsPoly && rhsTerm) { return polyByTerm(lhsPoly, rhsTerm); @@ -858,39 +866,39 @@ Expr* PolynomialTransformer::mutate(Mul* v) { } if (scalar && lhsTerm) { - Expr* newScalar = evaluateOp(new Mul(scalar, lhsTerm->scalar())); - return new Term(hasher_, newScalar, lhsTerm->variables()); + ExprPtr newScalar = evaluateOp(alloc(scalar, lhsTerm->scalar())); + return alloc(hasher_, newScalar, lhsTerm->variables()); } if (scalar && rhsTerm) { - Expr* newScalar = evaluateOp(new Mul(scalar, rhsTerm->scalar())); - return new Term(hasher_, newScalar, rhsTerm->variables()); + ExprPtr newScalar = evaluateOp(alloc(scalar, rhsTerm->scalar())); + return alloc(hasher_, newScalar, rhsTerm->variables()); } // If this is a scalar * a Polynomial, push the scalar term down. // We can wrap the scalar with a Term and use polyByTerm. if (scalar && lhsPoly) { - return polyByTerm(lhsPoly, new Term(hasher_, scalar)); + return polyByTerm(lhsPoly, alloc(hasher_, scalar)); } if (scalar && rhsPoly) { - return polyByTerm(rhsPoly, new Term(hasher_, scalar)); + return polyByTerm(rhsPoly, alloc(hasher_, scalar)); } // simple term with a scalar and variable type. if (scalar) { - return new Term(hasher_, scalar, variable); + return alloc(hasher_, scalar, variable); } // Multiplying Polynomial by variable can be wrapped in a term and handled // by polyByTerm also. if (lhsPoly) { - auto* term = - new Term(hasher_, getImmediateByType(rhs_new->dtype(), 1), rhs_new); + auto term = + alloc(hasher_, getImmediateByType(rhs_new->dtype(), 1), rhs_new); return polyByTerm(lhsPoly, term); } if (rhsPoly) { - auto* term = - new Term(hasher_, getImmediateByType(lhs_new->dtype(), 1), lhs_new); + auto term = + alloc(hasher_, getImmediateByType(lhs_new->dtype(), 1), lhs_new); return polyByTerm(rhsPoly, term); } @@ -904,19 +912,20 @@ Expr* PolynomialTransformer::mutate(Mul* v) { } // Two variables, create a new Term. - return new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new); + return alloc( + hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new); } -Expr* factorizeDivision(Expr* lhs_new, Expr* rhs_new) { +ExprPtr factorizeDivision(ExprPtr lhs_new, ExprPtr rhs_new) { if (!lhs_new || !rhs_new) { return nullptr; } - Expr* leftScalar = lhs_new->isConstant() ? lhs_new : nullptr; - Expr* rightScalar = rhs_new->isConstant() ? rhs_new : nullptr; + ExprPtr leftScalar = lhs_new->isConstant() ? lhs_new : nullptr; + ExprPtr rightScalar = rhs_new->isConstant() ? rhs_new : nullptr; - auto* lhsTerm = dynamic_cast(lhs_new); - auto* rhsTerm = dynamic_cast(rhs_new); + auto lhsTerm = to(lhs_new); + auto rhsTerm = to(rhs_new); if (lhsTerm) { leftScalar = lhsTerm->scalar(); } @@ -938,39 +947,39 @@ Expr* factorizeDivision(Expr* lhs_new, Expr* rhs_new) { } leftScalar = evaluateOp( - new Div(leftScalar, getImmediateByType(leftScalar->dtype(), GCD))); + alloc
(leftScalar, getImmediateByType(leftScalar->dtype(), GCD))); rightScalar = evaluateOp( - new Div(rightScalar, getImmediateByType(rightScalar->dtype(), GCD))); + alloc
(rightScalar, getImmediateByType(rightScalar->dtype(), GCD))); if (lhsTerm) { - lhs_new = new Term(lhsTerm->hasher(), leftScalar, lhsTerm->variables()); + lhs_new = alloc(lhsTerm->hasher(), leftScalar, lhsTerm->variables()); } else { lhs_new = leftScalar; } if (rhsTerm) { - rhs_new = new Term(rhsTerm->hasher(), rightScalar, rhsTerm->variables()); + rhs_new = alloc(rhsTerm->hasher(), rightScalar, rhsTerm->variables()); } else { rhs_new = rightScalar; } - return new Div(lhs_new, rhs_new); + return alloc
(lhs_new, rhs_new); } -Expr* PolynomialTransformer::mutate(Div* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(DivPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { - return evaluateOp(new Div(lhs_new, rhs_new)); + return evaluateOp(alloc
(lhs_new, rhs_new)); } // If this is a floating point Div then order of operations is important, we // dont want to combine ops. if (lhs_new->dtype().is_floating_point() || rhs_new->dtype().is_floating_point()) { - return new Div(lhs_new, rhs_new); + return alloc
(lhs_new, rhs_new); } // If the numerator is zero, so is the result. @@ -995,16 +1004,16 @@ Expr* PolynomialTransformer::mutate(Div* v) { return ret->accept_mutator(this); } - return new Div(lhs_new, rhs_new); + return alloc
(lhs_new, rhs_new); } -Expr* PolynomialTransformer::mutate(Mod* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(ModPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { - return evaluateOp(new Mod(lhs_new, rhs_new)); + return evaluateOp(alloc(lhs_new, rhs_new)); } // 0 % x => 0. @@ -1024,9 +1033,9 @@ Expr* PolynomialTransformer::mutate(Mod* v) { return getImmediateByType(v->dtype(), 0); } - Term* lhsTerm = dynamic_cast(lhs_new); + TermPtr lhsTerm = to(lhs_new); if (!lhsTerm) { - Polynomial* lhsPoly = dynamic_cast(lhs_new); + PolynomialPtr lhsPoly = to(lhs_new); if (lhsPoly) { // Can still optimize this out if we can factorize the polynomial. lhsTerm = factorizePolynomial(lhsPoly); @@ -1036,12 +1045,13 @@ Expr* PolynomialTransformer::mutate(Mod* v) { if (lhsTerm) { // ((C1 * C2) * x) % C1 => 0. if (rhs_new->isConstant() && - immediateEquals(evaluateOp(new Mod(lhsTerm->scalar(), rhs_new)), 0)) { + immediateEquals( + evaluateOp(alloc(lhsTerm->scalar(), rhs_new)), 0)) { return getImmediateByType(v->dtype(), 0); } // (x * y * z) % x => 0. - for (auto* component : lhsTerm->variables()) { + for (auto component : lhsTerm->variables()) { if (hasher_.hash(component) == hasher_.hash(rhs_new)) { return getImmediateByType(v->dtype(), 0); } @@ -1051,7 +1061,7 @@ Expr* PolynomialTransformer::mutate(Mod* v) { // also, (x * y * z) % (z * y) => 0. // This requires all variable terms found in the RHS to be present in the // LHS. - Term* rhsTerm = dynamic_cast(rhs_new); + TermPtr rhsTerm = to(rhs_new); if (rhsTerm) { auto& lVars = lhsTerm->variables(); auto& rVars = rhsTerm->variables(); @@ -1075,13 +1085,14 @@ Expr* PolynomialTransformer::mutate(Mod* v) { if (rLeft == 0 && immediateEquals( - evaluateOp(new Mod(lhsTerm->scalar(), rhsTerm->scalar())), 0)) { + evaluateOp(alloc(lhsTerm->scalar(), rhsTerm->scalar())), + 0)) { return getImmediateByType(v->dtype(), 0); } } } - return new Mod(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); } namespace { @@ -1090,14 +1101,14 @@ namespace { // The first type on the template refers to the op, as in Min or Max and the // second type refers to the corresponding term, as in MinTerm or MaxTerm. template -Expr* combineMinMaxTerms( - Expr* lhs, - Expr* rhs, +ExprPtr combineMinMaxTerms( + ExprPtr lhs, + ExprPtr rhs, bool propagate_nans, HashProvider& hasher) { - auto combine_scalars = [&](Expr* c1, Expr* c2) -> Expr* { + auto combine_scalars = [&](ExprPtr c1, ExprPtr c2) -> ExprPtr { if (c1 && c2) { - return evaluateOp(new Op(c1, c2, propagate_nans)); + return evaluateOp(alloc(c1, c2, propagate_nans)); } if (c1) { return c1; @@ -1105,21 +1116,21 @@ Expr* combineMinMaxTerms( return c2; }; - auto combine_opterms = [&](OpTerm* m1, OpTerm* m2) { - Expr* scalar = combine_scalars(m1->scalar(), m2->scalar()); - std::vector variables; + auto combine_opterms = [&](NodePtr m1, NodePtr m2) { + ExprPtr scalar = combine_scalars(m1->scalar(), m2->scalar()); + std::vector variables; for (auto v : m1->variables()) { variables.push_back(v); } for (auto v : m2->variables()) { variables.push_back(v); } - return new OpTerm(hasher, scalar, propagate_nans, std::move(variables)); + return alloc(hasher, scalar, propagate_nans, std::move(variables)); }; - auto add_expr_to_opterm = [&](Expr* expr, OpTerm* opterm) { - Expr* scalar = nullptr; - std::vector variables; + auto add_expr_to_opterm = [&](ExprPtr expr, NodePtr opterm) { + ExprPtr scalar = nullptr; + std::vector variables; if (opterm) { scalar = opterm->scalar(); variables = opterm->variables(); @@ -1130,16 +1141,16 @@ Expr* combineMinMaxTerms( } else { variables.push_back(expr); } - return new OpTerm(hasher, scalar, propagate_nans, std::move(variables)); + return alloc(hasher, scalar, propagate_nans, std::move(variables)); }; - OpTerm* lhs_opterm = dynamic_cast(lhs); - OpTerm* rhs_opterm = dynamic_cast(rhs); + auto lhs_opterm = to(lhs); + auto rhs_opterm = to(rhs); if (lhs_opterm && lhs_opterm->propagate_nans() != propagate_nans) { - return new Op(lhs, rhs, propagate_nans); + return alloc(lhs, rhs, propagate_nans); } if (rhs_opterm && rhs_opterm->propagate_nans() != propagate_nans) { - return new Op(lhs, rhs, propagate_nans); + return alloc(lhs, rhs, propagate_nans); } if (lhs_opterm && rhs_opterm) { @@ -1156,10 +1167,10 @@ Expr* combineMinMaxTerms( // the other op of opterm in other_op. template bool isOperandInMinMaxTerm( - OpTerm* opterm, - Expr* op, + NodePtr opterm, + ExprPtr op, HashProvider& hasher, - Expr** other_op) { + ExprPtr* other_op) { if (opterm->variables().size() != 2) { return false; } @@ -1189,13 +1200,13 @@ bool isOperandInMinMaxTerm( // type corresponding to the expected inner op (e.g. MinTerm). template bool simplifyNestedMinMax( - Expr* lhs, - Expr* rhs, + ExprPtr lhs, + ExprPtr rhs, bool propagate_nans, HashProvider& hasher, - Expr** new_op) { - auto lhs_opterm = dynamic_cast(lhs); - auto rhs_opterm = dynamic_cast(rhs); + ExprPtr* new_op) { + auto lhs_opterm = to(lhs); + auto rhs_opterm = to(rhs); if (lhs_opterm && rhs_opterm && lhs_opterm->propagate_nans() == propagate_nans && rhs_opterm->propagate_nans() == propagate_nans) { @@ -1205,20 +1216,20 @@ bool simplifyNestedMinMax( auto rhs_v1 = rhs_opterm->variables()[0]; auto rhs_v2 = rhs_opterm->variables()[1]; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* new_op_lhs; + ExprPtr new_op_lhs; if (isOperandInMinMaxTerm( lhs_opterm, rhs_v1, hasher, &new_op_lhs)) { - auto inner_op = - new OpTerm(hasher, nullptr, propagate_nans, new_op_lhs, rhs_v2); - *new_op = new OtherOpTerm( + auto inner_op = alloc( + hasher, nullptr, propagate_nans, new_op_lhs, rhs_v2); + *new_op = alloc( hasher, nullptr, propagate_nans, rhs_v1, inner_op); return true; } if (isOperandInMinMaxTerm( lhs_opterm, rhs_v2, hasher, &new_op_lhs)) { - auto inner_op = - new OpTerm(hasher, nullptr, propagate_nans, new_op_lhs, rhs_v1); - *new_op = new OtherOpTerm( + auto inner_op = alloc( + hasher, nullptr, propagate_nans, new_op_lhs, rhs_v1); + *new_op = alloc( hasher, nullptr, propagate_nans, rhs_v2, inner_op); return true; } @@ -1230,17 +1241,17 @@ bool simplifyNestedMinMax( } // namespace -Expr* PolynomialTransformer::mutate(Max* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(MaxPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { - return evaluateOp(new Max(lhs_new, rhs_new, v->propagate_nans())); + return evaluateOp(alloc(lhs_new, rhs_new, v->propagate_nans())); } // If diff is constant, return the appropriate operand. - Expr* diff = new Sub(lhs_new, rhs_new); + ExprPtr diff = alloc(lhs_new, rhs_new); diff = diff->accept_mutator(this); if (diff->isConstant()) { if (immediateAs(diff) > 0) { @@ -1251,7 +1262,7 @@ Expr* PolynomialTransformer::mutate(Max* v) { // Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z)) // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* new_op; + ExprPtr new_op; if (simplifyNestedMinMax( lhs_new, rhs_new, v->propagate_nans(), hasher_, &new_op)) { return new_op; @@ -1261,17 +1272,17 @@ Expr* PolynomialTransformer::mutate(Max* v) { lhs_new, rhs_new, v->propagate_nans(), hasher_); } -Expr* PolynomialTransformer::mutate(Min* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(MinPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { - return evaluateOp(new Min(lhs_new, rhs_new, v->propagate_nans())); + return evaluateOp(alloc(lhs_new, rhs_new, v->propagate_nans())); } // If diff is constant, return the appropriate operand. - Expr* diff = new Sub(lhs_new, rhs_new); + ExprPtr diff = alloc(lhs_new, rhs_new); diff = diff->accept_mutator(this); if (diff->isConstant()) { if (immediateAs(diff) < 0) { @@ -1282,7 +1293,7 @@ Expr* PolynomialTransformer::mutate(Min* v) { // Min(Max(x, y), Max(x, z)) => Max(x, Min(y, z)) // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* new_op; + ExprPtr new_op; if (simplifyNestedMinMax( lhs_new, rhs_new, v->propagate_nans(), hasher_, &new_op)) { return new_op; @@ -1292,15 +1303,15 @@ Expr* PolynomialTransformer::mutate(Min* v) { lhs_new, rhs_new, v->propagate_nans(), hasher_); } -Expr* PolynomialTransformer::mutate(CompareSelect* v) { - Expr* lhs_new = v->lhs()->accept_mutator(this); - Expr* rhs_new = v->rhs()->accept_mutator(this); - Expr* true_branch = v->ret_val1()->accept_mutator(this); - Expr* false_branch = v->ret_val2()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(CompareSelectPtr v) { + ExprPtr lhs_new = v->lhs()->accept_mutator(this); + ExprPtr rhs_new = v->rhs()->accept_mutator(this); + ExprPtr true_branch = v->ret_val1()->accept_mutator(this); + ExprPtr false_branch = v->ret_val2()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { - Expr* v_new = new CompareSelect( + ExprPtr v_new = alloc( lhs_new, rhs_new, true_branch, @@ -1314,7 +1325,7 @@ Expr* PolynomialTransformer::mutate(CompareSelect* v) { // since we can't correctly handle NaN. if (lhs_new->dtype().is_floating_point() || rhs_new->dtype().is_floating_point()) { - return new CompareSelect( + return alloc( lhs_new, rhs_new, true_branch, @@ -1324,12 +1335,12 @@ Expr* PolynomialTransformer::mutate(CompareSelect* v) { } // If diff is constant, we can determine it. - Expr* diff = new Sub(rhs_new, lhs_new); + ExprPtr diff = alloc(rhs_new, lhs_new); diff = diff->accept_mutator(this); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) if (!diff->isConstant()) { - return new CompareSelect( + return alloc( lhs_new, rhs_new, true_branch, @@ -1358,7 +1369,7 @@ Expr* PolynomialTransformer::mutate(CompareSelect* v) { } // should not be possible but just in case. - return new CompareSelect( + return alloc( lhs_new, rhs_new, true_branch, @@ -1367,21 +1378,21 @@ Expr* PolynomialTransformer::mutate(CompareSelect* v) { v->bias()); } -Expr* PolynomialTransformer::mutate(Intrinsics* v) { - std::vector new_params; +ExprPtr PolynomialTransformer::mutate(IntrinsicsPtr v) { + std::vector new_params; bool changed = false; bool allConstant = true; - for (auto* p : v->params()) { - Expr* new_child = p->accept_mutator(this); + for (auto p : v->params()) { + ExprPtr new_child = p->accept_mutator(this); new_params.push_back(new_child); changed |= p != new_child; allConstant &= new_child->isConstant(); } - Expr* node = v; + ExprPtr node = v; if (changed) { - node = new Intrinsics(v->op_type(), new_params); + node = alloc(v->op_type(), new_params); } if (!allConstant || !v->isPure()) { @@ -1389,44 +1400,44 @@ Expr* PolynomialTransformer::mutate(Intrinsics* v) { } // we're evaluating, but the evaluator only supports float intrinsics. - std::vector const_params; + std::vector const_params; changed = false; - for (auto* p : new_params) { + for (auto p : new_params) { if (p->dtype().scalar_type() == ScalarType::Float) { const_params.push_back(p); } else { const_params.push_back( - new Cast(Dtype(ScalarType::Float, p->dtype().lanes()), p)); + alloc(Dtype(ScalarType::Float, p->dtype().lanes()), p)); changed = true; } } if (changed) { - node = new Intrinsics(v->op_type(), const_params); + node = alloc(v->op_type(), const_params); } return evaluateOp(node); } -Expr* PolynomialTransformer::mutate(Cast* v) { - Expr* node = v->src_value()->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(CastPtr v) { + ExprPtr node = v->src_value()->accept_mutator(this); if (node->isConstant()) { - return evaluateOp(new Cast(v->dtype(), node)); + return evaluateOp(alloc(v->dtype(), node)); } if (v->dtype() == node->dtype()) { return node; } - return new Cast(v->dtype(), node); + return alloc(v->dtype(), node); } -Expr* PolynomialTransformer::mutate(IfThenElse* v) { - Expr* condition = v->condition(); - Expr* true_value = v->true_value(); - Expr* false_value = v->false_value(); - Expr* condition_new = condition->accept_mutator(this); - Expr* true_value_new = true_value->accept_mutator(this); - Expr* false_value_new = false_value->accept_mutator(this); +ExprPtr PolynomialTransformer::mutate(IfThenElsePtr v) { + ExprPtr condition = v->condition(); + ExprPtr true_value = v->true_value(); + ExprPtr false_value = v->false_value(); + ExprPtr condition_new = condition->accept_mutator(this); + ExprPtr true_value_new = true_value->accept_mutator(this); + ExprPtr false_value_new = false_value->accept_mutator(this); // If the condition is constant then we can choose the right branch now. if (condition_new->isConstant()) { @@ -1447,17 +1458,17 @@ Expr* PolynomialTransformer::mutate(IfThenElse* v) { return v; } - return new IfThenElse(condition_new, true_value_new, false_value_new); + return alloc(condition_new, true_value_new, false_value_new); } -Stmt* PolynomialBase::mutate(Cond* v) { - Expr* cond_old = v->condition(); - Stmt* true_old = v->true_stmt(); - Stmt* false_old = v->false_stmt(); +StmtPtr PolynomialBase::mutate(CondPtr v) { + ExprPtr cond_old = v->condition(); + StmtPtr true_old = v->true_stmt(); + StmtPtr false_old = v->false_stmt(); - Expr* cond_new = cond_old->accept_mutator(this); - Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old; - Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old; + ExprPtr cond_new = cond_old->accept_mutator(this); + StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old; + StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old; // If the condition is constant then we can choose the right branch now. if (cond_new->isConstant()) { @@ -1476,13 +1487,13 @@ Stmt* PolynomialBase::mutate(Cond* v) { return true_new; } - Block* true_block = dynamic_cast(true_new); - Block* false_block = dynamic_cast(false_new); + BlockPtr true_block = to(true_new); + BlockPtr false_block = to(false_new); bool true_empty = !true_new || (true_block && true_block->nstmts() == 0); bool false_empty = !false_new || (false_block && false_block->nstmts() == 0); if (true_empty && false_empty) { - return new Block({}); + return alloc(std::vector({})); } if (cond_old != cond_new) { v->set_condition(cond_new); @@ -1496,13 +1507,13 @@ Stmt* PolynomialBase::mutate(Cond* v) { return v; } -Stmt* handleForCondReordering(For* loop, Cond* cond) { +StmtPtr handleForCondReordering(ForPtr loop, CondPtr cond) { if (cond->false_stmt()) { return nullptr; } auto condition_vars = VarFinder::find(cond->condition()); - for (auto* v : condition_vars) { + for (auto v : condition_vars) { // If the condition depends on a Var that is modified in the loop body, it // may not be safe to reorder. if (ModifiesVarChecker::check(loop, v)) { @@ -1510,27 +1521,27 @@ Stmt* handleForCondReordering(For* loop, Cond* cond) { } } - For* new_f = loop->cloneWithNewBody(Stmt::clone(cond->true_stmt())); + ForPtr new_f = loop->cloneWithNewBody(Stmt::clone(cond->true_stmt())); return cond->cloneWithNewBody(new_f); } -Stmt* PolynomialBase::mutate(For* v) { - Expr* var = v->var(); - Expr* start = v->start(); - Expr* stop = v->stop(); - Stmt* body = v->body(); +StmtPtr PolynomialBase::mutate(ForPtr v) { + ExprPtr var = v->var(); + ExprPtr start = v->start(); + ExprPtr stop = v->stop(); + StmtPtr body = v->body(); LoopOptions loop_options = v->loop_options(); - Expr* var_new_expr = var->accept_mutator(this); - Var* var_new = dynamic_cast(var_new_expr); - Expr* start_new = start->accept_mutator(this); - Expr* stop_new = stop->accept_mutator(this); - Stmt* body_new = body; + ExprPtr var_new_expr = var->accept_mutator(this); + VarPtr var_new = to(var_new_expr); + ExprPtr start_new = start->accept_mutator(this); + ExprPtr stop_new = stop->accept_mutator(this); + StmtPtr body_new = body; - Expr* loops = new Sub(stop_new, start_new); + ExprPtr loops = alloc(stop_new, start_new); loops = loops->accept_mutator(this); if (loop_options.isDefault() && loops->isConstant()) { if (immediateEquals(loops, 0)) { - return new Block({}); + return alloc(std::vector({})); } else if (immediateEquals(loops, 1)) { body_new = Substitute(body, {{var_new, start_new}}); body_new = body_new->accept_mutator(this); @@ -1540,17 +1551,17 @@ Stmt* PolynomialBase::mutate(For* v) { body_new = body_new->accept_mutator(this); if (!body_new) { - return new Block({}); + return alloc(std::vector({})); } - if (auto* block = dynamic_cast(body_new)) { + if (auto block = to(body_new)) { if (block->nstmts() == 0) { - return new Block({}); + return alloc(std::vector({})); } if (block->nstmts() == 1) { - if (auto* cond = dynamic_cast(block->front())) { - Stmt* reordered = handleForCondReordering(v, cond); + if (auto cond = to(block->front())) { + StmtPtr reordered = handleForCondReordering(v, cond); if (reordered) { return reordered->accept_mutator(this); } @@ -1573,22 +1584,22 @@ Stmt* PolynomialBase::mutate(For* v) { return v; } -Stmt* PolynomialBase::mutate(Block* v) { - std::vector stmts; +StmtPtr PolynomialBase::mutate(BlockPtr v) { + std::vector stmts; // Flatten sub-blocks: bool stmts_changed = false; - for (Stmt* stmt : *v) { - Stmt* stmt_new = stmt->accept_mutator(this); + for (StmtPtr stmt : *v) { + StmtPtr stmt_new = stmt->accept_mutator(this); stmts_changed |= stmt != stmt_new; if (stmt_new == nullptr) { continue; } - if (auto* subBlock = dynamic_cast(stmt_new)) { + if (auto subBlock = to(stmt_new)) { for (Block::iterator I = subBlock->begin(), E = subBlock->end(); I != E;) { // Be careful to avoid invalidating the iterator. - Stmt* s = *(I++); + StmtPtr s = *(I++); subBlock->remove_stmt(s); stmts.push_back(s); } @@ -1605,20 +1616,20 @@ Stmt* PolynomialBase::mutate(Block* v) { // TermExpander -Expr* TermExpander::mutate(Term* v) { - Expr* newScalar = v->scalar()->accept_mutator(this); +ExprPtr TermExpander::mutate(TermPtr v) { + ExprPtr newScalar = v->scalar()->accept_mutator(this); if (immediateEquals(newScalar, 0)) { return newScalar; } - std::vector vars; - std::vector multilaneVars; + std::vector vars; + std::vector multilaneVars; // Assume we can reorder here because we wont merge floating terms. - Expr* lastNode{nullptr}; - for (auto* var : v->variables()) { - Expr* node = var->accept_mutator(this); - if (Mul* mul = dynamic_cast(node)) { + ExprPtr lastNode{nullptr}; + for (auto var : v->variables()) { + ExprPtr node = var->accept_mutator(this); + if (MulPtr mul = to(node)) { // If the sub-Expr resolved to a multiplication, lift it into this // term. if (isMultilanePrimitive(mul->lhs())) { @@ -1641,7 +1652,7 @@ Expr* TermExpander::mutate(Term* v) { } } - for (auto* node : multilaneVars) { + for (auto node : multilaneVars) { if (lastNode == nullptr) { lastNode = node; } else { @@ -1652,11 +1663,11 @@ Expr* TermExpander::mutate(Term* v) { } } - for (auto* node : vars) { + for (auto node : vars) { if (lastNode == nullptr) { lastNode = node; } else { - lastNode = new Mul(lastNode, node); + lastNode = alloc(lastNode, node); } } @@ -1667,22 +1678,22 @@ Expr* TermExpander::mutate(Term* v) { auto termDtype = v->scalar()->dtype(); auto lastNodeDtype = lastNode->dtype(); if (termDtype != lastNodeDtype) { - Expr* castV = v->scalar(); + ExprPtr castV = v->scalar(); // Take care of lane mismatch first. if (termDtype.lanes() != lastNodeDtype.lanes()) { - castV = new Broadcast(v->scalar(), lastNodeDtype.lanes()); + castV = alloc(v->scalar(), lastNodeDtype.lanes()); } // Now take care of scalar type as well. if (termDtype.scalar_type() != lastNodeDtype.scalar_type()) { - castV = new Cast(lastNode->dtype(), castV); + castV = alloc(lastNode->dtype(), castV); // For scalars, we can simplify the cast further. if (lastNodeDtype.lanes() == 1) { castV = evaluateOp(castV); } } - lastNode = new Mul(castV, lastNode); + lastNode = alloc(castV, lastNode); } else { - lastNode = new Mul(v->scalar(), lastNode); + lastNode = alloc(v->scalar(), lastNode); } } else { lastNode = v->scalar(); @@ -1695,15 +1706,15 @@ Expr* TermExpander::mutate(Term* v) { // Returns an immediate containing the greatest common divisor of all terms // (inc. the scalar term) in the polynomial. If the GCD is uninteresting // (e.g. 1) then returns nullptr. -Expr* polyGCD(Polynomial* poly) { - Expr* scalar = poly->scalar(); - const std::vector& variables = poly->variables(); +ExprPtr polyGCD(PolynomialPtr poly) { + ExprPtr scalar = poly->scalar(); + const std::vector& variables = poly->variables(); // We ony want to factorize if we're saving complete operations, i.e. no // value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work. int opsSaved = 1; // default to saving the scalar. long GCD = std::abs(immediateAs(scalar)); - for (auto* t : variables) { + for (auto t : variables) { long termScalar = std::abs(immediateAs(t->scalar())); long newGCD = gcd(std::max(GCD, termScalar), std::min(GCD, termScalar)); if (newGCD == 1) { @@ -1742,34 +1753,34 @@ Expr* polyGCD(Polynomial* poly) { // denotes x, 'divisor' denotes y and 'mod_divisor' denotes z. class ModRound { public: - ModRound(Expr* scalar, Expr* denom, Expr* divisor, Expr* mod_divisor) + ModRound(ExprPtr scalar, ExprPtr denom, ExprPtr divisor, ExprPtr mod_divisor) : scalar(scalar), denom(denom), divisor(divisor), mod_divisor(mod_divisor) {} - Expr* scalar; - Expr* denom; - Expr* divisor; - Expr* mod_divisor; + ExprPtr scalar; + ExprPtr denom; + ExprPtr divisor; + ExprPtr mod_divisor; }; -c10::optional isModRound(Term* e) { - Div* div{nullptr}; - Mod* mod{nullptr}; - Expr* denom{nullptr}; - Expr* divisor{nullptr}; - Expr* mod_divisor{nullptr}; - Expr* multiplier = e->scalar(); - Expr* scalar{nullptr}; - Expr* other{nullptr}; - - for (auto* m : e->variables()) { +c10::optional isModRound(TermPtr e) { + DivPtr div{nullptr}; + ModPtr mod{nullptr}; + ExprPtr denom{nullptr}; + ExprPtr divisor{nullptr}; + ExprPtr mod_divisor{nullptr}; + ExprPtr multiplier = e->scalar(); + ExprPtr scalar{nullptr}; + ExprPtr other{nullptr}; + + for (auto m : e->variables()) { if (m->expr_type() == IRNodeType::kMod) { // TODO: currently only identify terms with one variable being mod; it is // possible to extend this if we have to handle terms like (t/(x%2 * y) % // z) * (x%2 *y). if (!mod) { - mod = dynamic_cast(m); + mod = to(m); } else { return c10::nullopt; } @@ -1778,11 +1789,11 @@ c10::optional isModRound(Term* e) { if (multiplier->isConstant()) { // Take care of lane mismatch first. if (multiplier->dtype().lanes() != m->dtype().lanes()) { - multiplier = new Broadcast(multiplier, m->dtype().lanes()); + multiplier = alloc(multiplier, m->dtype().lanes()); } // Take care of scalar type mismatch. if (multiplier->dtype().scalar_type() != m->dtype().scalar_type()) { - multiplier = new Cast(m->dtype(), multiplier); + multiplier = alloc(m->dtype(), multiplier); if (m->dtype().lanes() == 1) { multiplier = evaluateOp(multiplier); } @@ -1790,7 +1801,7 @@ c10::optional isModRound(Term* e) { } // All non-mod vairables are considered as part of the multiplier. - multiplier = new Mul(multiplier, m); + multiplier = alloc(multiplier, m); } } multiplier = IRSimplifier::simplify(multiplier); @@ -1803,7 +1814,7 @@ c10::optional isModRound(Term* e) { mod_divisor = IRSimplifier::simplify(mod->rhs()); other = mod->lhs(); - if (!(div = dynamic_cast(other))) { + if (!(div = to
(other))) { return c10::nullopt; } @@ -1825,18 +1836,19 @@ c10::optional isModRound(Term* e) { // transformations. if (divisor->isConstant() && multiplier->isConstant()) { // If both are scalar we may be able to find a common factor. - if (immediateEquals(evaluateOp(new Mod(multiplier, divisor)), 0)) { + if (immediateEquals(evaluateOp(alloc(multiplier, divisor)), 0)) { // The common factor becomes 'scalar' of the term, e.g.,in t/3%7*6, // divisor=multiplier=3, scalar=2. - Expr* c = evaluateOp(new Div(multiplier, divisor)); + ExprPtr c = evaluateOp(alloc
(multiplier, divisor)); scalar = c; - } else if (immediateEquals(evaluateOp(new Mod(divisor, multiplier)), 0)) { + } else if (immediateEquals( + evaluateOp(alloc(divisor, multiplier)), 0)) { // The common factor becomes part of 'denom', e.g., in t/14%7*2, // divisor=multiplier=2, denom=t/7. - Expr* c = evaluateOp(new Div(divisor, multiplier)); + ExprPtr c = evaluateOp(alloc
(divisor, multiplier)); divisor = multiplier; // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - denom = IRSimplifier::simplify(new Div(other, c)); + denom = IRSimplifier::simplify(alloc
(other, c)); } else { return c10::nullopt; } @@ -1861,14 +1873,14 @@ c10::optional isModRound(Term* e) { // (1) Round + Mod pattern: (x/y) * y + x % y => RoundOff(x,y) + Mod(x, y) => x // (2) Mod round + Mod pattern: (x/y % z)*y + x%y => ModRound(x, y, z) + Mod(x, // y) => x % (y*z) -Expr* simplifyRoundModPattern(Polynomial* poly) { - std::vector rounds; - std::vector mods; - std::vector mod_rounds; - std::vector others; +ExprPtr simplifyRoundModPattern(PolynomialPtr poly) { + std::vector rounds; + std::vector mods; + std::vector mod_rounds; + std::vector others; // Split out the Mod, ModRounds and RoundOffs operations so we can inspect. - for (auto* c : poly->variables()) { + for (auto c : poly->variables()) { if (c->variables().size() > 1) { if (auto a = isModRound(c)) { mod_rounds.push_back(c); @@ -1878,9 +1890,9 @@ Expr* simplifyRoundModPattern(Polynomial* poly) { continue; } - Expr* e = c->variables()[0]; + ExprPtr e = c->variables()[0]; - if (dynamic_cast(e)) { + if (to(e)) { rounds.push_back(c); // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) } else if (e->expr_type() == IRNodeType::kMod) { @@ -1901,7 +1913,7 @@ Expr* simplifyRoundModPattern(Polynomial* poly) { HashProvider& hasher = poly->hasher(); bool didAnything = false; - std::vector mods_merged; + std::vector mods_merged; bool repeat = true; // Repeat merging terms till there are no Mods or the terms cannot be merged // any further. @@ -1909,15 +1921,15 @@ Expr* simplifyRoundModPattern(Polynomial* poly) { repeat = false; // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) for (int64_t i = mods.size() - 1; i >= 0; i--) { - Term* m = mods[i]; - Mod* mod = dynamic_cast(m->variables()[0]); + TermPtr m = mods[i]; + ModPtr mod = to(m->variables()[0]); CHECK(mod); - Expr* mod_lhs = IRSimplifier::simplify(mod->lhs()); - Expr* mod_rhs = IRSimplifier::simplify(mod->rhs()); + ExprPtr mod_lhs = IRSimplifier::simplify(mod->lhs()); + ExprPtr mod_rhs = IRSimplifier::simplify(mod->rhs()); bool merged = false; // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) for (int64_t j = mod_rounds.size() - 1; j >= 0; j--) { - Term* mr = mod_rounds[j]; + TermPtr mr = mod_rounds[j]; auto a = isModRound(mr); CHECK(a); ModRound* mod_round = dynamic_cast(*a); @@ -1926,7 +1938,7 @@ Expr* simplifyRoundModPattern(Polynomial* poly) { // optimization. E.g. it's possible to do: 2 * (x/y%z) * y + (x%y) => // x%(y*z) + (x/y%z) * y if (!immediateEquals( - evaluateOp(new Sub(mod_round->scalar, m->scalar())), 0)) { + evaluateOp(alloc(mod_round->scalar, m->scalar())), 0)) { continue; } // Valid optimization if mod LHS matches denom and mod RHS matches @@ -1934,12 +1946,12 @@ Expr* simplifyRoundModPattern(Polynomial* poly) { if (hasher.hash(mod_round->denom) == hasher.hash(mod_lhs) && hasher.hash(mod_round->divisor) == hasher.hash(mod_rhs)) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Term* merged_m = new Term( + TermPtr merged_m = alloc( hasher, mod_round->scalar, - IRSimplifier::simplify(new Mod( + IRSimplifier::simplify(alloc( mod_round->denom, - new Mul(mod_round->divisor, mod_round->mod_divisor)))); + alloc(mod_round->divisor, mod_round->mod_divisor)))); mods_merged.push_back(merged_m); merged = true; repeat = true; @@ -1956,8 +1968,8 @@ Expr* simplifyRoundModPattern(Polynomial* poly) { // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) for (int64_t k = rounds.size() - 1; k >= 0; k--) { - Term* r = rounds[k]; - RoundOff* roundoff = dynamic_cast(r->variables()[0]); + TermPtr r = rounds[k]; + RoundOffPtr roundoff = to(r->variables()[0]); CHECK(roundoff); // TODO: for now don't attempt partial factorization of this @@ -1965,15 +1977,15 @@ Expr* simplifyRoundModPattern(Polynomial* poly) { // (x/y) * y but unsure thats actually much better, particulary with // CSE. if (!immediateEquals( - evaluateOp(new Sub(r->scalar(), m->scalar())), 0)) { + evaluateOp(alloc(r->scalar(), m->scalar())), 0)) { continue; } - Expr* round_lhs = IRSimplifier::simplify(roundoff->lhs()); - Expr* round_rhs = IRSimplifier::simplify(roundoff->rhs()); + ExprPtr round_lhs = IRSimplifier::simplify(roundoff->lhs()); + ExprPtr round_rhs = IRSimplifier::simplify(roundoff->rhs()); // Valid optimization if LHS and RHS are equal for both. if (hasher.hash(round_lhs) == hasher.hash(mod_lhs) && hasher.hash(round_rhs) == hasher.hash(mod_rhs)) { - Term* merged_r = new Term(hasher, r->scalar(), round_lhs); + TermPtr merged_r = alloc(hasher, r->scalar(), round_lhs); others.push_back(merged_r); merged = true; didAnything = true; @@ -2013,16 +2025,16 @@ Expr* simplifyRoundModPattern(Polynomial* poly) { others.insert(others.end(), rounds.begin(), rounds.end()); } - return new Polynomial(hasher, poly->scalar(), others); + return alloc(hasher, poly->scalar(), others); } // Trivially factorize terms by GCD of scalar components. -Term* PolynomialBase::factorizePolynomial(Polynomial* poly) { - Expr* scalar = poly->scalar(); - const std::vector& variables = poly->variables(); +TermPtr PolynomialBase::factorizePolynomial(PolynomialPtr poly) { + ExprPtr scalar = poly->scalar(); + const std::vector& variables = poly->variables(); // Compute the GCD of terms. - Expr* GCD = polyGCD(poly); + ExprPtr GCD = polyGCD(poly); // No GCD means 0 or 1 and can't be factored. if (!GCD) { @@ -2030,40 +2042,42 @@ Term* PolynomialBase::factorizePolynomial(Polynomial* poly) { } // Create new struture. - std::vector newPolyTerms; + std::vector newPolyTerms; newPolyTerms.reserve(variables.size()); - for (auto* t : variables) { + for (auto t : variables) { // New term with the scalar divided by the GCD. - newPolyTerms.push_back(new Term( - poly->hasher(), evaluateOp(new Div(t->scalar(), GCD)), t->variables())); + newPolyTerms.push_back(alloc( + poly->hasher(), + evaluateOp(alloc
(t->scalar(), GCD)), + t->variables())); } - Polynomial* newPoly = new Polynomial( - poly->hasher(), evaluateOp(new Div(scalar, GCD)), newPolyTerms); + PolynomialPtr newPoly = alloc( + poly->hasher(), evaluateOp(alloc
(scalar, GCD)), newPolyTerms); - return new Term(poly->hasher(), GCD, newPoly); + return alloc(poly->hasher(), GCD, newPoly); } -Expr* TermExpander::mutate(Polynomial* v) { +ExprPtr TermExpander::mutate(PolynomialPtr v) { if (v->variables().empty()) { return v->scalar(); } // If this Polynomial can be factorized: do it, then expand the result. - if (Expr* simplified = simplifyRoundModPattern(v)) { + if (ExprPtr simplified = simplifyRoundModPattern(v)) { return simplified->accept_mutator(this); } // If this Polynomial can be factorized: do it, then expand the result. - if (Expr* factorized = factorizePolynomial(v)) { + if (ExprPtr factorized = factorizePolynomial(v)) { return factorized->accept_mutator(this); } - std::vector addTerms; - std::vector subTerms; + std::vector addTerms; + std::vector subTerms; // partition the terms into a list to add and list to subtract. - for (auto* node : v->variables()) { + for (auto node : v->variables()) { if (immediateIsNegative(node->scalar())) { subTerms.push_back(node); } else if (!immediateEquals(node->scalar(), 0)) { @@ -2073,10 +2087,10 @@ Expr* TermExpander::mutate(Polynomial* v) { } // The last node constructed. - Expr* lastNode{nullptr}; + ExprPtr lastNode{nullptr}; - for (auto* node : addTerms) { - Expr* simpleNode = node->accept_mutator(this); + for (auto node : addTerms) { + ExprPtr simpleNode = node->accept_mutator(this); if (lastNode == nullptr) { lastNode = simpleNode; @@ -2084,7 +2098,7 @@ Expr* TermExpander::mutate(Polynomial* v) { } if (isMultilanePrimitive(simpleNode)) { - auto* ret = combineMultilane(lastNode, simpleNode); + auto ret = combineMultilane(lastNode, simpleNode); if (ret) { // simplify result first, then expand. lastNode = ret->accept_mutator(simplifier_); @@ -2093,14 +2107,14 @@ Expr* TermExpander::mutate(Polynomial* v) { } } - lastNode = new Add(lastNode, simpleNode); + lastNode = alloc(lastNode, simpleNode); } // If we have no add terms the scalar should go first. // E.g. 1 - x. bool scalarWritten = false; if (lastNode == nullptr) { - auto* scalarNode = v->scalar()->accept_mutator(simplifier_); + auto scalarNode = v->scalar()->accept_mutator(simplifier_); if (!immediateEquals(scalarNode, 0)) { lastNode = scalarNode; @@ -2108,7 +2122,7 @@ Expr* TermExpander::mutate(Polynomial* v) { } } - for (auto* node : subTerms) { + for (auto node : subTerms) { // Can still be first node if scalarVal is 0. if (lastNode == nullptr) { lastNode = node->accept_mutator(this); @@ -2116,10 +2130,10 @@ Expr* TermExpander::mutate(Polynomial* v) { } // Negate the term back to positive since we'll be subtracting it. - Expr* negated = evaluateOp(new Mul( + ExprPtr negated = evaluateOp(alloc( getImmediateByType(node->scalar()->dtype(), -1), node->scalar())); - Term* newRHS = new Term(node->hasher(), negated, node->variables()); - lastNode = new Sub(lastNode, newRHS->accept_mutator(this)); + TermPtr newRHS = alloc(node->hasher(), negated, node->variables()); + lastNode = alloc(lastNode, newRHS->accept_mutator(this)); } if (scalarWritten || immediateEquals(v->scalar(), 0)) { @@ -2131,24 +2145,24 @@ Expr* TermExpander::mutate(Polynomial* v) { if (immediateIsNegative(v->scalar())) { // Negate the scalar and subtract. - Expr* negated = evaluateOp( - new Mul(getImmediateByType(lastNode->dtype(), -1), v->scalar())); - lastNode = new Sub(lastNode, evaluateOp(negated)); + ExprPtr negated = evaluateOp( + alloc(getImmediateByType(lastNode->dtype(), -1), v->scalar())); + lastNode = alloc(lastNode, evaluateOp(negated)); } else { // we want to avoid a cast to the scalar if it would happen. // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) if (v->scalar()->dtype() != lastNode->dtype()) { - lastNode = new Add( - lastNode, evaluateOp(new Cast(lastNode->dtype(), v->scalar()))); + lastNode = alloc( + lastNode, evaluateOp(alloc(lastNode->dtype(), v->scalar()))); } else { - lastNode = new Add(lastNode, v->scalar()); + lastNode = alloc(lastNode, v->scalar()); } } return lastNode; } -Expr* TermExpander::mutate(MaxTerm* v) { +ExprPtr TermExpander::mutate(MaxTermPtr v) { auto& variables = v->variables(); if (variables.empty()) { if (!v->scalar()) { @@ -2159,19 +2173,19 @@ Expr* TermExpander::mutate(MaxTerm* v) { return v->scalar(); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* max; + ExprPtr max; if (v->scalar()) { - max = new Max(variables[0], v->scalar(), v->propagate_nans()); + max = alloc(variables[0], v->scalar(), v->propagate_nans()); } else { max = variables[0]; } for (size_t i = 1; i < variables.size(); i++) { - max = new Max(max, variables[i], v->propagate_nans()); + max = alloc(max, variables[i], v->propagate_nans()); } return max->accept_mutator(this); } -Expr* TermExpander::mutate(MinTerm* v) { +ExprPtr TermExpander::mutate(MinTermPtr v) { auto& variables = v->variables(); if (variables.empty()) { if (!v->scalar()) { @@ -2182,35 +2196,35 @@ Expr* TermExpander::mutate(MinTerm* v) { return v->scalar(); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* min; + ExprPtr min; if (v->scalar()) { - min = new Min(variables[0], v->scalar(), v->propagate_nans()); + min = alloc(variables[0], v->scalar(), v->propagate_nans()); } else { min = variables[0]; } for (size_t i = 1; i < variables.size(); i++) { - min = new Min(min, variables[i], v->propagate_nans()); + min = alloc(min, variables[i], v->propagate_nans()); } return min->accept_mutator(this); } // Expands RoundOff(x, y) => Term(1, Div(x, y), y), which will later be expanded // to Mul(Div(x, y), y). -Expr* TermExpander::mutate(RoundOff* v) { - Term* term = new Term( +ExprPtr TermExpander::mutate(RoundOffPtr v) { + TermPtr term = alloc( simplifier_->hasher(), getImmediateByType(v->dtype(), 1), - new Div(v->lhs(), v->rhs()), + alloc
(v->lhs(), v->rhs()), v->rhs()); return term->accept_mutator(this); } -Expr* buf_flat_size(Buf* v) { - std::vector dims = v->dims(); +ExprPtr buf_flat_size(BufPtr v) { + std::vector dims = v->dims(); - Expr* flattened = getImmediateByType(kInt, 1); + ExprPtr flattened = getImmediateByType(kInt, 1); for (auto& dim : dims) { - flattened = new Mul(flattened, dim); + flattened = alloc(flattened, dim); } flattened = IRSimplifier::simplify(flattened); @@ -2218,11 +2232,11 @@ Expr* buf_flat_size(Buf* v) { return flattened; } -Stmt* TermExpander::mutate(Allocate* v) { - Buf* buf = v->buf(); - Buf* buf_new = dynamic_cast(v->buf()->accept_mutator(this)); +StmtPtr TermExpander::mutate(AllocatePtr v) { + BufPtr buf = v->buf(); + BufPtr buf_new = to(v->buf()->accept_mutator(this)); TORCH_INTERNAL_ASSERT(buf_new); - Expr* flattened = buf_flat_size(buf_new); + ExprPtr flattened = buf_flat_size(buf_new); if (flattened->isConstant() && immediateEquals(flattened, 0)) { eliminated_allocations_.insert(buf_new->base_handle()); @@ -2235,9 +2249,9 @@ Stmt* TermExpander::mutate(Allocate* v) { return v; } -Stmt* TermExpander::mutate(Free* v) { - Buf* buf = v->buf(); - Buf* buf_new = dynamic_cast(v->buf()->accept_mutator(this)); +StmtPtr TermExpander::mutate(FreePtr v) { + BufPtr buf = v->buf(); + BufPtr buf_new = to(v->buf()->accept_mutator(this)); TORCH_INTERNAL_ASSERT(buf_new); if (eliminated_allocations_.count(buf_new->base_handle())) { @@ -2252,13 +2266,13 @@ Stmt* TermExpander::mutate(Free* v) { } // Combines adjactent Cond nodes with identical conditions. -Block* TermExpander::fuseConditions(Block* v) { - std::vector stmts; +BlockPtr TermExpander::fuseConditions(BlockPtr v) { + std::vector stmts; bool did_anything = false; - Cond* prev_cond = nullptr; + CondPtr prev_cond = nullptr; - for (auto* s : *v) { - Cond* cond = dynamic_cast(s); + for (auto s : *v) { + CondPtr cond = to(s); if (!cond) { prev_cond = nullptr; stmts.push_back(s); @@ -2277,8 +2291,8 @@ Block* TermExpander::fuseConditions(Block* v) { // Fuse the two Conds by appending the bodies of the second Cond to the // first. - Block* true_block = new Block({}); - Block* false_block = new Block({}); + BlockPtr true_block = alloc(std::vector({})); + BlockPtr false_block = alloc(std::vector({})); if (prev_cond->true_stmt()) { true_block->splice(true_block->end(), prev_cond->true_stmt()); @@ -2306,9 +2320,9 @@ Block* TermExpander::fuseConditions(Block* v) { } // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Stmt* new_cond = prev_cond->cloneWithNewBodies(true_block, false_block) - ->accept_mutator(this); - prev_cond = dynamic_cast(new_cond); + StmtPtr new_cond = prev_cond->cloneWithNewBodies(true_block, false_block) + ->accept_mutator(this); + prev_cond = to(new_cond); // erase, which shortens the list. stmts.pop_back(); @@ -2321,24 +2335,24 @@ Block* TermExpander::fuseConditions(Block* v) { } // clean up parents. - for (auto* s : stmts) { + for (auto s : stmts) { if (s->get_parent() == v) { v->remove_stmt(s); } } - return new Block(stmts); + return alloc(stmts); } -Stmt* TermExpander::fuseSyncThreads(Block* block) { +StmtPtr TermExpander::fuseSyncThreads(BlockPtr block) { // only really first if highest level Block. bool first = block->get_parent() == nullptr; - SyncThreads* last = nullptr; - std::vector stmts; + SyncThreadsPtr last = nullptr; + std::vector stmts; bool did_anything = false; - for (auto* s : *block) { - SyncThreads* sync = dynamic_cast(s); + for (auto s : *block) { + SyncThreadsPtr sync = to(s); if (!sync) { first = false; last = nullptr; @@ -2366,18 +2380,18 @@ Stmt* TermExpander::fuseSyncThreads(Block* block) { } // clean up parents. - for (auto* s : stmts) { + for (auto s : stmts) { if (s->get_parent() == block) { block->remove_stmt(s); } } - return new Block({stmts}); + return alloc(std::vector({stmts})); } -Stmt* TermExpander::mutate(Block* v) { - Stmt* new_stmt = PolynomialBase::mutate(v); - Block* new_block = dynamic_cast(new_stmt); +StmtPtr TermExpander::mutate(BlockPtr v) { + StmtPtr new_stmt = PolynomialBase::mutate(v); + BlockPtr new_block = to(new_stmt); if (!new_block) { return new_stmt; } @@ -2393,17 +2407,17 @@ Stmt* TermExpander::mutate(Block* v) { // This function records the bounds(range) info of the index var in a for-stmt. // The bounds info will be used later when simplifying expressions with the // index var. -Stmt* SimplifierUnderContext::mutate(For* v) { - Expr* var = v->var(); - Expr* start = v->start(); - Expr* stop = v->stop(); - Stmt* body = v->body(); +StmtPtr SimplifierUnderContext::mutate(ForPtr v) { + ExprPtr var = v->var(); + ExprPtr start = v->start(); + ExprPtr stop = v->stop(); + StmtPtr body = v->body(); LoopOptions loop_options = v->loop_options(); - Expr* var_new_expr = var->accept_mutator(this); - Var* var_new = dynamic_cast(var_new_expr); - Expr* start_new = start->accept_mutator(this); - Expr* stop_new = stop->accept_mutator(this); - Stmt* body_new = body; + ExprPtr var_new_expr = var->accept_mutator(this); + VarPtr var_new = to(var_new_expr); + ExprPtr start_new = start->accept_mutator(this); + ExprPtr stop_new = stop->accept_mutator(this); + StmtPtr body_new = body; // save bounds info before this for-stmt // @@ -2421,22 +2435,23 @@ Stmt* SimplifierUnderContext::mutate(For* v) { // bound info after the for stmt, we can use it to simplify the assignment // stmt x = (i+20)/5 to x = 4. bool has_bounds = false; - std::pair bound_old; - Var* var_key = dynamic_cast(var); + std::pair bound_old; + VarPtr var_key = to(var); auto got = var_bound_info_.find(var_key); if (got != var_bound_info_.end()) { has_bounds = true; bound_old = got->second; } // set bounds info for index var - const std::pair bound_new = std::make_pair(start_new, stop_new); + const std::pair bound_new = + std::make_pair(start_new, stop_new); var_bound_info_[var_key] = bound_new; - Expr* iters = new Sub(stop_new, start_new); + ExprPtr iters = alloc(stop_new, start_new); iters = iters->accept_mutator(this); if (loop_options.isDefault() && iters->isConstant()) { if (immediateEquals(iters, 0)) { - return new Block({}); + return alloc(std::vector({})); } else if (immediateEquals(iters, 1)) { body_new = Substitute(body, {{var_new, start_new}}); body_new = body_new->accept_mutator(this); @@ -2462,19 +2477,19 @@ Stmt* SimplifierUnderContext::mutate(For* v) { } if (!body_new) { - return new Block({}); + return alloc(std::vector({})); } - if (auto* block = dynamic_cast(body_new)) { + if (auto block = to(body_new)) { if (block->nstmts() == 0) { - return new Block({}); + return alloc(std::vector({})); } if (block->nstmts() == 1) { // if the stmt in the loop body is a if-stmt, try to move the branching // out of the loop - if (auto* cond = dynamic_cast(block->front())) { - Stmt* reordered = handleForCondReordering(v, cond); + if (auto cond = to(block->front())) { + StmtPtr reordered = handleForCondReordering(v, cond); if (reordered) { return reordered->accept_mutator(this); } @@ -2523,7 +2538,7 @@ Stmt* SimplifierUnderContext::mutate(For* v) { // returns -1. But currently, both Pytorch and NNC are performing an incorrect // integer division: (-1)/6 = 0. With the current implementation of integer // division, x has to be not negative. d) j is not negative -Expr* distributeDiv(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { +ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { if (!lhs || !rhs) { return nullptr; } @@ -2533,28 +2548,28 @@ Expr* distributeDiv(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { } // identify n: a positive integer constant - Expr* rhsScalar = rhs->isConstant() ? rhs : nullptr; + ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr; if (!rhsScalar) { return nullptr; } - Expr* check_n_value = - IRSimplifier::simplify(new CompareSelect(rhsScalar, new IntImm(0), kGT)); + ExprPtr check_n_value = IRSimplifier::simplify( + alloc(rhsScalar, alloc(0), kGT)); if (!immediateEquals(check_n_value, 1)) { return nullptr; } - auto* lhsAdd = dynamic_cast(lhs); + auto lhsAdd = to(lhs); if (!lhsAdd) { return nullptr; } - Expr* lhsAdd1 = lhsAdd->lhs(); - Expr* lhsAdd2 = lhsAdd->rhs(); + ExprPtr lhsAdd1 = lhsAdd->lhs(); + ExprPtr lhsAdd2 = lhsAdd->rhs(); // identify index var 'i' - Var* var_key = dynamic_cast(lhsAdd1); - Expr* main = lhsAdd2; + VarPtr var_key = to(lhsAdd1); + ExprPtr main = lhsAdd2; if (var_key == nullptr) { - var_key = dynamic_cast(lhsAdd2); + var_key = to(lhsAdd2); main = lhsAdd1; } @@ -2572,30 +2587,30 @@ Expr* distributeDiv(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { // open upper bound, i.e., end is one more than the maximum value in the // range auto end = got->second.second; - Expr* check_start = - IRSimplifier::simplify(new CompareSelect(start, new IntImm(0), kGE)); - Expr* check_end = - IRSimplifier::simplify(new CompareSelect(end, rhsScalar, kLE)); + ExprPtr check_start = IRSimplifier::simplify( + alloc(start, alloc(0), kGE)); + ExprPtr check_end = + IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (!check_start->isConstant() || !check_end->isConstant() || !immediateEquals(check_start, 1) || !immediateEquals(check_end, 1)) { return nullptr; } - Expr* ret = IRSimplifier::simplify(new Div(main, rhsScalar)); + ExprPtr ret = IRSimplifier::simplify(alloc
(main, rhsScalar)); // simplify type 1) exprs: '(i+x)/n' => 'x/n' - Expr* sign_check = - IRSimplifier::simplify(new CompareSelect(main, new IntImm(0), kGE)); - Expr* main_mod = IRSimplifier::simplify(new Mod(main, rhsScalar)); - Expr* mod_check = IRSimplifier::simplify( - new CompareSelect(new Add(main_mod, end), rhsScalar, kLE)); + ExprPtr sign_check = + IRSimplifier::simplify(alloc(main, alloc(0), kGE)); + ExprPtr main_mod = IRSimplifier::simplify(alloc(main, rhsScalar)); + ExprPtr mod_check = IRSimplifier::simplify( + alloc(alloc(main_mod, end), rhsScalar, kLE)); if (sign_check->isConstant() && immediateEquals(sign_check, 1) && mod_check->isConstant() && immediateEquals(mod_check, 1)) { return ret; } // simplify type 2 exprs: '(i+j*n)/n' => 'j' - auto ret_var = dynamic_cast(ret); + auto ret_var = to(ret); if (ret_var && ret_var->dtype() == kInt) { // retrieve j's range info auto got = var_bound_info.find(ret_var); @@ -2605,7 +2620,7 @@ Expr* distributeDiv(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { // check if j is not negative sign_check = IRSimplifier::simplify( - new CompareSelect(got->second.first, new IntImm(0), kGE)); + alloc(got->second.first, alloc(0), kGE)); if (sign_check->isConstant() && immediateEquals(sign_check, 1)) { return ret_var; } @@ -2640,7 +2655,7 @@ Expr* distributeDiv(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { // returns -1. But currently, both Pytorch and NNC are performing an incorrect // integer division: (-1)/6 = 0. With the current implementation of integer // division, j has to be not negative. d) j is not negative -Expr* distributeMod(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { +ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { if (!lhs || !rhs) { return nullptr; } @@ -2650,31 +2665,31 @@ Expr* distributeMod(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { } // identify n: a positive integer constant - Expr* rhsScalar = rhs->isConstant() ? rhs : nullptr; + ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr; if (!rhsScalar) { return nullptr; } - Expr* check_n_value = - IRSimplifier::simplify(new CompareSelect(rhsScalar, new IntImm(0), kGT)); + ExprPtr check_n_value = IRSimplifier::simplify( + alloc(rhsScalar, alloc(0), kGT)); if (!immediateEquals(check_n_value, 1)) { return nullptr; } - auto* lhsAdd = dynamic_cast(lhs); + auto lhsAdd = to(lhs); if (!lhsAdd) { return nullptr; } if (!lhsAdd || !rhsScalar) { return nullptr; } - Expr* lhsAdd1 = lhsAdd->lhs(); - Expr* lhsAdd2 = lhsAdd->rhs(); + ExprPtr lhsAdd1 = lhsAdd->lhs(); + ExprPtr lhsAdd2 = lhsAdd->rhs(); // identify index var 'i' - Var* var_key = dynamic_cast(lhsAdd1); - Expr* main = lhsAdd2; + VarPtr var_key = to(lhsAdd1); + ExprPtr main = lhsAdd2; if (var_key == nullptr) { - var_key = dynamic_cast(lhsAdd2); + var_key = to(lhsAdd2); main = lhsAdd1; } if (var_key == nullptr) { @@ -2691,29 +2706,29 @@ Expr* distributeMod(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { // open upper bound, i.e., end is one more than the maximum value in the // range auto end = got->second.second; - Expr* check_start = - IRSimplifier::simplify(new CompareSelect(start, new IntImm(0), kGE)); - Expr* check_end = - IRSimplifier::simplify(new CompareSelect(end, rhsScalar, kLE)); + ExprPtr check_start = IRSimplifier::simplify( + alloc(start, alloc(0), kGE)); + ExprPtr check_end = + IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (!check_start->isConstant() || !check_end->isConstant() || !immediateEquals(check_start, 1) || !immediateEquals(check_end, 1)) { return nullptr; } // simplify type 1) exprs: '(i+x)%n' => 'i+x%n' - Expr* sign_check = - IRSimplifier::simplify(new CompareSelect(main, new IntImm(0), kGE)); - Expr* main_mod = IRSimplifier::simplify(new Mod(main, rhsScalar)); - Expr* mod_check = IRSimplifier::simplify( - new CompareSelect(new Add(main_mod, end), rhsScalar, kLE)); + ExprPtr sign_check = + IRSimplifier::simplify(alloc(main, alloc(0), kGE)); + ExprPtr main_mod = IRSimplifier::simplify(alloc(main, rhsScalar)); + ExprPtr mod_check = IRSimplifier::simplify( + alloc(alloc(main_mod, end), rhsScalar, kLE)); if (sign_check->isConstant() && immediateEquals(sign_check, 1) && mod_check->isConstant() && immediateEquals(mod_check, 1)) { - return new Add(var_key, main_mod); + return alloc(var_key, main_mod); } // simplify type 2) exprs: '(i+j*n)%n' => 'i' - Expr* main_div = IRSimplifier::simplify(new Div(main, rhsScalar)); - auto j_var = dynamic_cast(main_div); + ExprPtr main_div = IRSimplifier::simplify(alloc
(main, rhsScalar)); + auto j_var = to(main_div); if (j_var && j_var->dtype() == kInt) { // retrieve j's range info auto got = var_bound_info.find(j_var); @@ -2723,7 +2738,7 @@ Expr* distributeMod(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { // check if j is not negative sign_check = IRSimplifier::simplify( - new CompareSelect(got->second.first, new IntImm(0), kGE)); + alloc(got->second.first, alloc(0), kGE)); if (sign_check->isConstant() && immediateEquals(sign_check, 1)) { return var_key; } @@ -2732,9 +2747,9 @@ Expr* distributeMod(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) { return nullptr; } -Expr* SimplifierUnderContext::mutate(Div* v) { - Expr* lhs = v->lhs(); - Expr* rhs = v->rhs(); +ExprPtr SimplifierUnderContext::mutate(DivPtr v) { + ExprPtr lhs = v->lhs(); + ExprPtr rhs = v->rhs(); std::ostringstream oss; if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) { @@ -2744,17 +2759,17 @@ Expr* SimplifierUnderContext::mutate(Div* v) { return ret->accept_mutator(this); } - Expr* lhs_new = lhs->accept_mutator(this); - Expr* rhs_new = rhs->accept_mutator(this); + ExprPtr lhs_new = lhs->accept_mutator(this); + ExprPtr rhs_new = rhs->accept_mutator(this); if (lhs == lhs_new && rhs == rhs_new) { return v; } - return new Div(lhs_new, rhs_new); + return alloc
(lhs_new, rhs_new); } -Expr* SimplifierUnderContext::mutate(Mod* v) { - Expr* lhs = v->lhs(); - Expr* rhs = v->rhs(); +ExprPtr SimplifierUnderContext::mutate(ModPtr v) { + ExprPtr lhs = v->lhs(); + ExprPtr rhs = v->rhs(); std::ostringstream oss; if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) { @@ -2766,17 +2781,17 @@ Expr* SimplifierUnderContext::mutate(Mod* v) { // i % N -> i if the range of i's values is a subset of [0, N) // where N is an integer constant - auto* lhsVar = dynamic_cast(lhs); - Expr* rhsScalar = rhs->isConstant() ? rhs : nullptr; + auto lhsVar = to(lhs); + ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr; if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) { auto got = var_bound_info_.find(lhsVar); if (got != var_bound_info_.end()) { auto start = got->second.first; auto end = got->second.second; - Expr* check_start = - IRSimplifier::simplify(new CompareSelect(start, new IntImm(0), kGE)); - Expr* check_end = - IRSimplifier::simplify(new CompareSelect(end, rhsScalar, kLE)); + ExprPtr check_start = IRSimplifier::simplify( + alloc(start, alloc(0), kGE)); + ExprPtr check_end = + IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (check_start->isConstant() && check_end->isConstant() && immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) { oss << "SimplifierUnderContext: " << *v << " => " << *lhsVar << "\n"; @@ -2786,18 +2801,18 @@ Expr* SimplifierUnderContext::mutate(Mod* v) { } } - Expr* lhs_new = lhs->accept_mutator(this); - Expr* rhs_new = rhs->accept_mutator(this); + ExprPtr lhs_new = lhs->accept_mutator(this); + ExprPtr rhs_new = rhs->accept_mutator(this); if (lhs == lhs_new && rhs == rhs_new) { return v; } - return new Mod(lhs_new, rhs_new); + return alloc(lhs_new, rhs_new); } -bool exprEquals(Expr* A, Expr* B) { +bool exprEquals(ExprPtr A, ExprPtr B) { try { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Expr* diff = IRSimplifier::simplify(new Sub(A, B)); + ExprPtr diff = IRSimplifier::simplify(alloc(A, B)); if (!diff->isConstant()) { return false; } diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.h b/torch/csrc/jit/tensorexpr/ir_simplifier.h index 2c5ea66..6281b77 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.h +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.h @@ -25,11 +25,11 @@ namespace tensorexpr { // A bunch of helpers for determine the Dtype of the output of a multi argument // Term or Polynomial. template -Dtype promoteTypesVec(Expr* s, std::vector& v) { +Dtype promoteTypesVec(ExprPtr s, std::vector& v) { Dtype t = s->dtype(); bool first = true; - for (auto* e : v) { + for (auto e : v) { if (first) { t = Dtype(t.scalar_type(), e->dtype().lanes()); first = false; @@ -40,13 +40,13 @@ Dtype promoteTypesVec(Expr* s, std::vector& v) { } template -Dtype promoteTypesVec(std::vector& v) { +Dtype promoteTypesVec(std::vector& v) { if (v.empty()) { throw malformed_input("empty list of types"); } Dtype t = v[0]->dtype(); - for (auto* e : v) { + for (auto e : v) { t = promoteTypes(t, e->dtype()); } return t; @@ -54,7 +54,7 @@ Dtype promoteTypesVec(std::vector& v) { template Dtype promoteTypesMap( - Expr* s, + ExprPtr s, std::unordered_map& m) { Dtype t = s->dtype(); bool first = true; @@ -85,35 +85,35 @@ Dtype promoteTypesVar(ExprType* e, Args... es) { } // Creates a new Expr of the given type with the provided lhs and rhs. -inline Expr* newBinaryOpOfType( +inline ExprPtr newBinaryOpOfType( IRNodeType expr_type, - Expr* lhs, - Expr* rhs, + ExprPtr lhs, + ExprPtr rhs, bool option) { switch (expr_type) { // NOLINTNEXTLINE(bugprone-branch-clone) case IRNodeType::kAdd: - return new Add(lhs, rhs); + return alloc(lhs, rhs); case IRNodeType::kSub: - return new Sub(lhs, rhs); + return alloc(lhs, rhs); case IRNodeType::kMul: - return new Mul(lhs, rhs); + return alloc(lhs, rhs); case IRNodeType::kDiv: - return new Div(lhs, rhs); + return alloc
(lhs, rhs); case IRNodeType::kMod: - return new Mod(lhs, rhs); + return alloc(lhs, rhs); case IRNodeType::kMax: - return new Max(lhs, rhs, option); + return alloc(lhs, rhs, option); case IRNodeType::kMin: - return new Min(lhs, rhs, option); + return alloc(lhs, rhs, option); case IRNodeType::kAnd: - return new And(lhs, rhs); + return alloc(lhs, rhs); case IRNodeType::kXor: - return new Xor(lhs, rhs); + return alloc(lhs, rhs); case IRNodeType::kLshift: - return new Lshift(lhs, rhs); + return alloc(lhs, rhs); case IRNodeType::kRshift: - return new Rshift(lhs, rhs); + return alloc(lhs, rhs); default: LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); return nullptr; @@ -123,7 +123,7 @@ inline Expr* newBinaryOpOfType( // Uses the evaluator to fold an Expression with constant terms. // E.g. evaluateOp(Add(3, 4)) => 7. // Expr v must not have any unbound Vars. -inline Expr* evaluateOp(Expr* v) { +inline ExprPtr evaluateOp(ExprPtr v) { ExprHandle handle(v); ExprEval eval(handle); @@ -148,7 +148,7 @@ class Term : public ExprNode { public: template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Term(HashProvider& hasher, Expr* s, Args... ts) + Term(HashProvider& hasher, ExprPtr s, Args... ts) : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) { CHECK(s->isConstant()); addComponent(ts...); @@ -156,7 +156,7 @@ class Term : public ExprNode { } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Term(HashProvider& hasher, Expr* s, std::vector v) + Term(HashProvider& hasher, ExprPtr s, std::vector v) : ExprNodeBase(promoteTypesVec(s, v)), variables_(std::move(v)), scalar_(s), @@ -168,8 +168,8 @@ class Term : public ExprNode { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) Term( HashProvider& hasher, - Expr* s, - std::unordered_map varmap) + ExprPtr s, + std::unordered_map varmap) : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) { for (auto& p : varmap) { addComponent(p.second); @@ -177,10 +177,10 @@ class Term : public ExprNode { sort(); } - Expr* scalar() const { + ExprPtr scalar() const { return scalar_; } - const std::vector& variables() const { + const std::vector& variables() const { return variables_; } HashProvider& hasher() const { @@ -192,16 +192,16 @@ class Term : public ExprNode { SimplifierHashType hashVars() const; private: - std::vector variables_; - Expr* scalar_; + std::vector variables_; + ExprPtr scalar_; HashProvider& hasher_; void addComponent() {} - void addComponent(Expr* e) { + void addComponent(ExprPtr e) { variables_.push_back(e); } template - void addComponent(Expr* e, Es... es) { + void addComponent(ExprPtr e, Es... es) { addComponent(e); addComponent(es...); } @@ -217,7 +217,7 @@ class Polynomial : public ExprNode { public: template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Polynomial(HashProvider& hasher, Expr* s, Args... ts) + Polynomial(HashProvider& hasher, ExprPtr s, Args... ts) : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) { CHECK(s->isConstant()); addTerm(ts...); @@ -225,7 +225,7 @@ class Polynomial : public ExprNode { } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Polynomial(HashProvider& hasher, Expr* s, std::vector v) + Polynomial(HashProvider& hasher, ExprPtr s, std::vector v) : ExprNodeBase(promoteTypesVec(s, v)), variables_(std::move(v)), scalar_(s), @@ -235,7 +235,7 @@ class Polynomial : public ExprNode { // Helper constructor for list of terms with no scalar component. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Polynomial(HashProvider& hasher, std::vector terms) + Polynomial(HashProvider& hasher, std::vector terms) : ExprNodeBase(promoteTypesVec(terms)), variables_(std::move(terms)), scalar_(getImmediateByType(dtype(), 0)), @@ -248,8 +248,8 @@ class Polynomial : public ExprNode { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) Polynomial( HashProvider& hasher, - Expr* s, - std::unordered_map varmap) + ExprPtr s, + std::unordered_map varmap) : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) { for (auto& p : varmap) { addTerm(p.second); @@ -257,10 +257,10 @@ class Polynomial : public ExprNode { sort(); } - Expr* scalar() const { + ExprPtr scalar() const { return scalar_; } - const std::vector& variables() const { + const std::vector& variables() const { return variables_; } HashProvider& hasher() const { @@ -270,15 +270,15 @@ class Polynomial : public ExprNode { SimplifierHashType hashVars() const; private: - std::vector variables_; - Expr* scalar_; + std::vector variables_; + ExprPtr scalar_; HashProvider& hasher_; - void addTerm(Term* t) { + void addTerm(TermPtr t) { variables_.push_back(t); } template - void addTerm(Term* t, Ts... ts) { + void addTerm(TermPtr t, Ts... ts) { addTerm(t); addTerm(ts...); } @@ -289,14 +289,15 @@ class Polynomial : public ExprNode { class RoundOff : public BinaryOpNode { public: - RoundOff(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kOther) {} + RoundOff(ExprPtr lhs, ExprPtr rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kOther) {} }; class MaxTerm : public ExprNode { public: template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - MaxTerm(HashProvider& hasher, Expr* s, bool p, Args... ts) + MaxTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts) : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)), scalar_(s), hasher_(hasher), @@ -306,7 +307,7 @@ class MaxTerm : public ExprNode { } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - MaxTerm(HashProvider& hasher, Expr* s, bool p, std::vector v) + MaxTerm(HashProvider& hasher, ExprPtr s, bool p, std::vector v) : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)), variables_(std::move(v)), scalar_(s), @@ -319,10 +320,10 @@ class MaxTerm : public ExprNode { return propagate_nans_; } - Expr* scalar() const { + ExprPtr scalar() const { return scalar_; } - const std::vector& variables() const { + const std::vector& variables() const { return variables_; } HashProvider& hasher() const { @@ -330,17 +331,17 @@ class MaxTerm : public ExprNode { } private: - std::vector variables_; - Expr* scalar_; + std::vector variables_; + ExprPtr scalar_; HashProvider& hasher_; bool propagate_nans_; void addComponent() {} - void addComponent(Expr* e) { + void addComponent(ExprPtr e) { variables_.push_back(e); } template - void addComponent(Expr* e, Es... es) { + void addComponent(ExprPtr e, Es... es) { addComponent(e); addComponent(es...); } @@ -353,7 +354,7 @@ class MinTerm : public ExprNode { public: template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - MinTerm(HashProvider& hasher, Expr* s, bool p, Args... ts) + MinTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts) : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)), scalar_(s), hasher_(hasher), @@ -363,7 +364,7 @@ class MinTerm : public ExprNode { } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - MinTerm(HashProvider& hasher, Expr* s, bool p, std::vector v) + MinTerm(HashProvider& hasher, ExprPtr s, bool p, std::vector v) : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)), variables_(std::move(v)), scalar_(s), @@ -376,10 +377,10 @@ class MinTerm : public ExprNode { return propagate_nans_; } - Expr* scalar() const { + ExprPtr scalar() const { return scalar_; } - const std::vector& variables() const { + const std::vector& variables() const { return variables_; } HashProvider& hasher() const { @@ -387,17 +388,17 @@ class MinTerm : public ExprNode { } private: - std::vector variables_; - Expr* scalar_; + std::vector variables_; + ExprPtr scalar_; HashProvider& hasher_; bool propagate_nans_; void addComponent() {} - void addComponent(Expr* e) { + void addComponent(ExprPtr e) { variables_.push_back(e); } template - void addComponent(Expr* e, Es... es) { + void addComponent(ExprPtr e, Es... es) { addComponent(e); addComponent(es...); } @@ -407,15 +408,15 @@ class MinTerm : public ExprNode { }; // Context-sensitive IR simplification -using VarBoundInfo = std::unordered_map>; +using VarBoundInfo = std::unordered_map>; class TORCH_API SimplifierUnderContext : public IRMutator { public: ~SimplifierUnderContext() override = default; // Add boundary info for index variables in for-loops - Stmt* mutate(For* v) override; + StmtPtr mutate(ForPtr v) override; - Expr* mutate(Div* v) override; - Expr* mutate(Mod* v) override; + ExprPtr mutate(DivPtr v) override; + ExprPtr mutate(ModPtr v) override; protected: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) @@ -428,14 +429,14 @@ class TORCH_API PolynomialBase : public IRMutator { public: ~PolynomialBase() override = default; - Stmt* mutate(Block* v) override; + StmtPtr mutate(BlockPtr v) override; - Stmt* mutate(Cond* v) override; + StmtPtr mutate(CondPtr v) override; - Stmt* mutate(For* v) override; + StmtPtr mutate(ForPtr v) override; // Trivially factorize terms by GCD of scalar components. - Term* factorizePolynomial(Polynomial* poly); + TermPtr factorizePolynomial(PolynomialPtr poly); HashProvider& hasher() { return hasher_; @@ -453,89 +454,89 @@ class TORCH_API PolynomialTransformer : public PolynomialBase { // Inserts term into the provided map, in the case of a hash collision // combines the term with the existing and updates the map. void addOrUpdateTerm( - std::unordered_map& varmap, - Term* term); + std::unordered_map& varmap, + TermPtr term); // Add Polynomial expressions, combining Terms representing the same // variables. - Expr* addPolynomials(Polynomial* lhs, Polynomial* rhs); + ExprPtr addPolynomials(PolynomialPtr lhs, PolynomialPtr rhs); - // Insert a new Term into the provided polynomial. If the new term has common - // variables to an existing term it is combined. - Expr* insertTerm(Polynomial* poly, Term* term); + // Insert a new Term into the provided polynomial. If the new term has + // common variables to an existing term it is combined. + ExprPtr insertTerm(PolynomialPtr poly, TermPtr term); // Merge and simplify addition. - Expr* mutate(Add* v) override; + ExprPtr mutate(AddPtr v) override; // Subtract one term from another, cancelling if necessary. - Expr* subTerms(Term* lhs, Term* rhs, bool negated); + ExprPtr subTerms(TermPtr lhs, TermPtr rhs, bool negated); // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where // possible. - Expr* subPolynomials(Polynomial* lhs, Polynomial* rhs); + ExprPtr subPolynomials(PolynomialPtr lhs, PolynomialPtr rhs); // Merge and simplify subtraction. - Expr* mutate(Sub* v) override; + ExprPtr mutate(SubPtr v) override; // Multiply two terms together, usually creating a new term with the variable // lists concatenated. - Term* mulTerms(Term* lhs, Term* rhs); + TermPtr mulTerms(TermPtr lhs, TermPtr rhs); // Multiply a Polynomial by a Term. - Expr* polyByTerm(Polynomial* poly, Term* term); + ExprPtr polyByTerm(PolynomialPtr poly, TermPtr term); // Match a rounding pattern and create a RoundOff if found. - Expr* isRoundOff(Expr* lhs, Expr* rhs); + ExprPtr isRoundOff(ExprPtr lhs, ExprPtr rhs); // Inserts a new component into a term, simplifying if possible. - Expr* insertIntoTerm(Term* term, Expr* expr); + ExprPtr insertIntoTerm(TermPtr term, ExprPtr expr); // Merge and simplify multiplication. - Expr* mutate(Mul* v) override; + ExprPtr mutate(MulPtr v) override; - Expr* mutate(Div* v) override; + ExprPtr mutate(DivPtr v) override; - Expr* mutate(Mod* v) override; + ExprPtr mutate(ModPtr v) override; - Expr* mutate(And* v) override { + ExprPtr mutate(AndPtr v) override { return mutateBinaryOp(v, this); } - Expr* mutate(Xor* v) override { + ExprPtr mutate(XorPtr v) override { return mutateBinaryOp(v, this); } - Expr* mutate(Lshift* v) override { + ExprPtr mutate(LshiftPtr v) override { return mutateBinaryOp(v, this); } - Expr* mutate(Rshift* v) override { + ExprPtr mutate(RshiftPtr v) override { return mutateBinaryOp(v, this); } - Expr* mutate(Max* v) override; + ExprPtr mutate(MaxPtr v) override; - Expr* mutate(Min* v) override; + ExprPtr mutate(MinPtr v) override; - Expr* mutate(CompareSelect* v) override; + ExprPtr mutate(CompareSelectPtr v) override; - Expr* mutate(Intrinsics* v) override; + ExprPtr mutate(IntrinsicsPtr v) override; - Expr* mutate(Cast* v) override; + ExprPtr mutate(CastPtr v) override; - Expr* mutate(IfThenElse* v) override; + ExprPtr mutate(IfThenElsePtr v) override; template - static Expr* mutateBinaryOp( + static ExprPtr mutateBinaryOp( BinaryOpNode* v, IRMutator* mutator, bool option = false) { - Expr* lhs = v->lhs(); - Expr* rhs = v->rhs(); - Expr* lhs_new = lhs->accept_mutator(mutator); - Expr* rhs_new = rhs->accept_mutator(mutator); + ExprPtr lhs = v->lhs(); + ExprPtr rhs = v->rhs(); + ExprPtr lhs_new = lhs->accept_mutator(mutator); + ExprPtr rhs_new = rhs->accept_mutator(mutator); - Expr* node = v; + ExprPtr node = v; if (lhs != lhs_new || rhs != rhs_new) { node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option); @@ -549,16 +550,16 @@ class TORCH_API PolynomialTransformer : public PolynomialBase { return evaluateOp(node); } - static Expr* simplify(Expr* e); + static ExprPtr simplify(ExprPtr e); static ExprHandle simplify(const ExprHandle& e); - static Stmt* simplify(Stmt* e); + static StmtPtr simplify(StmtPtr e); }; // Expands Terms and Polynomial expressions into primitive operations. // Does some simple factorization and reordering. class TORCH_API TermExpander : public PolynomialBase { PolynomialTransformer* simplifier_; - std::set eliminated_allocations_; + std::set eliminated_allocations_; public: using PolynomialBase::mutate; @@ -569,33 +570,33 @@ class TORCH_API TermExpander : public PolynomialBase { } // Expand Terms out to a series of Muls. - Expr* mutate(Term* v) override; + ExprPtr mutate(TermPtr v) override; // Expand Polynomials out to a series of Adds. - Expr* mutate(Polynomial* v) override; + ExprPtr mutate(PolynomialPtr v) override; // Expand MaxTerms to a series of Max ops. - Expr* mutate(MaxTerm* v) override; + ExprPtr mutate(MaxTermPtr v) override; // Expand MinTerms to a series of Min ops. - Expr* mutate(MinTerm* v) override; + ExprPtr mutate(MinTermPtr v) override; // Expand RoundOff to it's component: Mul(Div(lhs, rhs), rhs). - Expr* mutate(RoundOff* v) override; + ExprPtr mutate(RoundOffPtr v) override; // Eliminate zero length allocations. - Stmt* mutate(Allocate* v) override; - Stmt* mutate(Free* v) override; + StmtPtr mutate(AllocatePtr v) override; + StmtPtr mutate(FreePtr v) override; // Override to enable condition fusing. - Block* fuseConditions(Block* v); - Stmt* fuseSyncThreads(Block* block); - Stmt* mutate(Block* v) override; + BlockPtr fuseConditions(BlockPtr v); + StmtPtr fuseSyncThreads(BlockPtr block); + StmtPtr mutate(BlockPtr v) override; }; class TORCH_API IRSimplifier { public: - static Expr* simplify(Expr* e) { + static ExprPtr simplify(ExprPtr e) { SimplifierUnderContext ctxsimplifier; e = e->accept_mutator(&ctxsimplifier); @@ -617,7 +618,7 @@ class TORCH_API IRSimplifier { return ExprHandle(simplify(e.node())); } - static Stmt* simplify(Stmt* s) { + static StmtPtr simplify(StmtPtr s) { SimplifierUnderContext ctxsimplifier; s = s->accept_mutator(&ctxsimplifier); @@ -639,9 +640,9 @@ class TORCH_API IRSimplifier { }; // Flattens the buf and performs the simplifier on the flattened dims. -Expr* buf_flat_size(Buf* v); +ExprPtr buf_flat_size(BufPtr v); // Returns true if expressions A and B can be simplified to an equal expression. -TORCH_API bool exprEquals(Expr* A, Expr* B); +TORCH_API bool exprEquals(ExprPtr A, ExprPtr B); } // namespace tensorexpr } // namespace jit diff --git a/torch/csrc/jit/tensorexpr/ir_verifier.cpp b/torch/csrc/jit/tensorexpr/ir_verifier.cpp index c57080d..c88e92c 100644 --- a/torch/csrc/jit/tensorexpr/ir_verifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_verifier.cpp @@ -19,39 +19,39 @@ void verifyBitwiseOp(const BitwiseOpNode* v, IRVerifier* verifier) { } } -void IRVerifier::visit(And* v) { +void IRVerifier::visit(AndPtr v) { verifyBitwiseOp(v, this); IRVisitor::visit(v); } -void IRVerifier::visit(Or* v) { +void IRVerifier::visit(OrPtr v) { verifyBitwiseOp(v, this); IRVisitor::visit(v); } -void IRVerifier::visit(Xor* v) { +void IRVerifier::visit(XorPtr v) { verifyBitwiseOp(v, this); IRVisitor::visit(v); } -void IRVerifier::visit(Lshift* v) { +void IRVerifier::visit(LshiftPtr v) { verifyBitwiseOp(v, this); IRVisitor::visit(v); } -void IRVerifier::visit(Rshift* v) { +void IRVerifier::visit(RshiftPtr v) { verifyBitwiseOp(v, this); IRVisitor::visit(v); } -void IRVerifier::visit(Mod* v) { +void IRVerifier::visit(ModPtr v) { if (!v->dtype().is_integral() && !v->dtype().is_floating_point()) { throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype())); } IRVisitor::visit(v); } -void IRVerifier::visit(CompareSelect* v) { +void IRVerifier::visit(CompareSelectPtr v) { if (v->ret_val1()->dtype() != v->ret_val2()->dtype()) { throw malformed_ir("bad dtype in CompareSelect"); } @@ -61,14 +61,14 @@ void IRVerifier::visit(CompareSelect* v) { IRVisitor::visit(v); } -void IRVerifier::visit(Ramp* v) { +void IRVerifier::visit(RampPtr v) { if (v->stride()->dtype() != v->base()->dtype()) { throw malformed_ir("Bad stride in Ramp"); } IRVisitor::visit(v); } -void IRVerifier::visit(Load* v) { +void IRVerifier::visit(LoadPtr v) { auto indices = v->indices(); if (indices.size() > 0 && v->buf()->base_handle()->dtype() != kHandle) { throw malformed_ir( @@ -94,7 +94,7 @@ void IRVerifier::visit(Load* v) { IRVisitor::visit(v); } -void IRVerifier::visit(IfThenElse* v) { +void IRVerifier::visit(IfThenElsePtr v) { if (!v->condition()->dtype().is_integral()) { throw unsupported_dtype(); } @@ -107,12 +107,12 @@ void IRVerifier::visit(IfThenElse* v) { IRVisitor::visit(v); } -void IRVerifier::visit(Intrinsics* v) { +void IRVerifier::visit(IntrinsicsPtr v) { // TODO: add a check for OpArgCount and op_type IRVisitor::visit(v); } -void IRVerifier::visit(Store* v) { +void IRVerifier::visit(StorePtr v) { auto indices = v->indices(); if (indices.size() > 0 && v->buf()->base_handle()->dtype() != kHandle) { throw malformed_ir( @@ -141,7 +141,7 @@ void IRVerifier::visit(Store* v) { IRVisitor::visit(v); } -void IRVerifier::visit(For* v) { +void IRVerifier::visit(ForPtr v) { if (!v->var()) { throw malformed_ir("nullptr Var in For loop"); } else if (!v->start()) { @@ -154,8 +154,8 @@ void IRVerifier::visit(For* v) { IRVisitor::visit(v); } -void IRVerifier::visit(Block* v) { - for (Stmt* s : v->stmts()) { +void IRVerifier::visit(BlockPtr v) { + for (StmtPtr s : v->stmts()) { if (s->get_parent() != v) { throw malformed_ir("Broken child-parent link inside a Block"); } @@ -163,16 +163,16 @@ void IRVerifier::visit(Block* v) { IRVisitor::visit(v); } -void IRVerifier::visit(ExternalCall* v) { +void IRVerifier::visit(ExternalCallPtr v) { IRVisitor::visit(v); } -void verify(Stmt* s) { +void verify(StmtPtr s) { IRVerifier verifier; s->accept(&verifier); } -void verify(Expr* e) { +void verify(ExprPtr e) { IRVerifier verifier; e->accept(&verifier); } diff --git a/torch/csrc/jit/tensorexpr/ir_verifier.h b/torch/csrc/jit/tensorexpr/ir_verifier.h index 660c5cd..90bca0e 100644 --- a/torch/csrc/jit/tensorexpr/ir_verifier.h +++ b/torch/csrc/jit/tensorexpr/ir_verifier.h @@ -2,6 +2,7 @@ #include +#include #include namespace torch { @@ -32,26 +33,26 @@ class TORCH_API IRVerifier : public IRVisitor { public: IRVerifier() = default; - void visit(Mod* v) override; - void visit(And* v) override; - void visit(Or* v) override; - void visit(Xor* v) override; - void visit(Lshift* v) override; - void visit(Rshift* v) override; - void visit(CompareSelect* v) override; - void visit(Ramp* v) override; - void visit(Load* v) override; - void visit(IfThenElse* v) override; - void visit(Intrinsics* v) override; - - void visit(ExternalCall* v) override; - void visit(Store* v) override; - void visit(For* v) override; - void visit(Block* v) override; + void visit(ModPtr v) override; + void visit(AndPtr v) override; + void visit(OrPtr v) override; + void visit(XorPtr v) override; + void visit(LshiftPtr v) override; + void visit(RshiftPtr v) override; + void visit(CompareSelectPtr v) override; + void visit(RampPtr v) override; + void visit(LoadPtr v) override; + void visit(IfThenElsePtr v) override; + void visit(IntrinsicsPtr v) override; + + void visit(ExternalCallPtr v) override; + void visit(StorePtr v) override; + void visit(ForPtr v) override; + void visit(BlockPtr v) override; }; -TORCH_API void verify(Stmt*); -TORCH_API void verify(Expr*); +TORCH_API void verify(StmtPtr); +TORCH_API void verify(ExprPtr); TORCH_API void verify(ExprHandle); } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp index b29baf6..9066544 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -17,55 +17,55 @@ static void visit_binary_op(BinaryOpNode* v, IRVisitor* visitor) { v->rhs()->accept(visitor); } -void IRVisitor::visit(Add* v) { +void IRVisitor::visit(AddPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Sub* v) { +void IRVisitor::visit(SubPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Mul* v) { +void IRVisitor::visit(MulPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Div* v) { +void IRVisitor::visit(DivPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Mod* v) { +void IRVisitor::visit(ModPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Max* v) { +void IRVisitor::visit(MaxPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Min* v) { +void IRVisitor::visit(MinPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(And* v) { +void IRVisitor::visit(AndPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Or* v) { +void IRVisitor::visit(OrPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Xor* v) { +void IRVisitor::visit(XorPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Lshift* v) { +void IRVisitor::visit(LshiftPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(Rshift* v) { +void IRVisitor::visit(RshiftPtr v) { visit_binary_op(v, this); } -void IRVisitor::visit(CompareSelect* v) { +void IRVisitor::visit(CompareSelectPtr v) { v->lhs()->accept(this); v->rhs()->accept(this); v->ret_val1()->accept(this); @@ -74,69 +74,69 @@ void IRVisitor::visit(CompareSelect* v) { // NOLINTNEXTLINE #define IMM_VISIT(Type, Name) \ - void IRVisitor::visit(Name##Imm* v) {} + void IRVisitor::visit(Name##ImmPtr v) {} AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); #undef IMM_VISIT -void IRVisitor::visit(Cast* v) { +void IRVisitor::visit(CastPtr v) { v->src_value()->accept(this); } -void IRVisitor::visit(BitCast* v) { +void IRVisitor::visit(BitCastPtr v) { v->src_value()->accept(this); } -void IRVisitor::visit(Var* v) {} +void IRVisitor::visit(VarPtr v) {} -void IRVisitor::visit(Ramp* v) { +void IRVisitor::visit(RampPtr v) { v->base()->accept(this); v->stride()->accept(this); } -void IRVisitor::visit(Load* v) { +void IRVisitor::visit(LoadPtr v) { v->buf()->accept(this); - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { ind->accept(this); } } -void IRVisitor::visit(Buf* v) { +void IRVisitor::visit(BufPtr v) { v->base_handle()->accept(this); } -void IRVisitor::visit(Store* v) { +void IRVisitor::visit(StorePtr v) { v->buf()->accept(this); - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { ind->accept(this); } v->value()->accept(this); } -void IRVisitor::visit(AtomicAdd* v) { +void IRVisitor::visit(AtomicAddPtr v) { v->buf()->accept(this); - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { ind->accept(this); } v->value()->accept(this); } -void IRVisitor::visit(SyncThreads* v) {} +void IRVisitor::visit(SyncThreadsPtr v) {} -void IRVisitor::visit(ExternalCall* v) { +void IRVisitor::visit(ExternalCallPtr v) { v->buf()->accept(this); - for (Buf* buf_arg : v->buf_args()) { + for (BufPtr buf_arg : v->buf_args()) { buf_arg->accept(this); } - for (Expr* arg : v->args()) { + for (ExprPtr arg : v->args()) { arg->accept(this); } } -void IRVisitor::visit(Block* v) { - for (Stmt* s : *v) { +void IRVisitor::visit(BlockPtr v) { + for (StmtPtr s : *v) { s->accept(this); } } -void IRVisitor::visit(For* v) { +void IRVisitor::visit(ForPtr v) { v->var()->accept(this); v->start()->accept(this); v->stop()->accept(this); @@ -145,43 +145,43 @@ void IRVisitor::visit(For* v) { } } -void IRVisitor::visit(Broadcast* v) { +void IRVisitor::visit(BroadcastPtr v) { v->value()->accept(this); } -void IRVisitor::visit(IfThenElse* v) { +void IRVisitor::visit(IfThenElsePtr v) { v->condition()->accept(this); v->true_value()->accept(this); v->false_value()->accept(this); } -void IRVisitor::visit(Intrinsics* v) { +void IRVisitor::visit(IntrinsicsPtr v) { for (const auto i : c10::irange(v->nparams())) { v->param(i)->accept(this); } } -void IRVisitor::visit(Allocate* v) { +void IRVisitor::visit(AllocatePtr v) { v->buffer_var()->accept(this); - std::vector dims = v->dims(); - for (Expr* dim : dims) { + std::vector dims = v->dims(); + for (ExprPtr dim : dims) { dim->accept(this); } } -void IRVisitor::visit(Free* v) { +void IRVisitor::visit(FreePtr v) { v->buffer_var()->accept(this); } -void IRVisitor::visit(Let* v) { +void IRVisitor::visit(LetPtr v) { v->var()->accept(this); v->value()->accept(this); } -void IRVisitor::visit(Cond* v) { - Expr* condition = v->condition(); - Stmt* true_stmt = v->true_stmt(); - Stmt* false_stmt = v->false_stmt(); +void IRVisitor::visit(CondPtr v) { + ExprPtr condition = v->condition(); + StmtPtr true_stmt = v->true_stmt(); + StmtPtr false_stmt = v->false_stmt(); condition->accept(this); if (true_stmt) { true_stmt->accept(this); @@ -191,47 +191,47 @@ void IRVisitor::visit(Cond* v) { } } -void IRVisitor::visit(Term* v) { +void IRVisitor::visit(TermPtr v) { v->scalar()->accept(this); - for (auto* t : v->variables()) { + for (auto t : v->variables()) { t->accept(this); } } -void IRVisitor::visit(Polynomial* v) { +void IRVisitor::visit(PolynomialPtr v) { v->scalar()->accept(this); - for (auto* t : v->variables()) { + for (auto t : v->variables()) { t->accept(this); } } -void IRVisitor::visit(RoundOff* v) { +void IRVisitor::visit(RoundOffPtr v) { v->lhs()->accept(this); v->rhs()->accept(this); } -void IRVisitor::visit(MaxTerm* v) { +void IRVisitor::visit(MaxTermPtr v) { if (v->scalar()) { v->scalar()->accept(this); } - for (auto* t : v->variables()) { + for (auto t : v->variables()) { t->accept(this); } } -void IRVisitor::visit(MinTerm* v) { +void IRVisitor::visit(MinTermPtr v) { if (v->scalar()) { v->scalar()->accept(this); } - for (auto* t : v->variables()) { + for (auto t : v->variables()) { t->accept(this); } } -void IRVisitor::visit(ReduceOp* v) { +void IRVisitor::visit(ReduceOpPtr v) { v->body()->accept(this); - for (auto* r : v->reduce_args()) { + for (auto r : v->reduce_args()) { r->accept(this); } } diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h index 20616a6..001725f 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.h +++ b/torch/csrc/jit/tensorexpr/ir_visitor.h @@ -1,103 +1,59 @@ #pragma once #include #include +#include namespace torch { namespace jit { namespace tensorexpr { -class Add; -class Sub; -class Mul; -class Div; -class Mod; -class Max; -class Min; -class And; -class Or; -class Xor; -class Lshift; -class Rshift; -class CompareSelect; - -#define IMM_DECLARE(Type, Name) class Name##Imm; - -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE) -#undef IMM_DECLARE - -class Cast; -class BitCast; -class Var; -class Buf; -class Ramp; -class Load; -class For; -class Block; -class Store; -class Broadcast; -class IfThenElse; -class Intrinsics; -class Allocate; -class Free; -class Let; -class Cond; -class Term; -class Polynomial; -class RoundOff; -class MaxTerm; -class MinTerm; -class ReduceOp; -class AtomicAdd; -class SyncThreads; -class ExternalCall; - class TORCH_API IRVisitor { public: virtual ~IRVisitor() = default; - virtual void visit(Add* v); - virtual void visit(Sub* v); - virtual void visit(Mul* v); - virtual void visit(Div* v); - virtual void visit(Mod* v); - virtual void visit(Max* v); - virtual void visit(Min* v); - virtual void visit(And* v); - virtual void visit(Or* v); - virtual void visit(Xor* v); - virtual void visit(Lshift* v); - virtual void visit(Rshift* v); - virtual void visit(CompareSelect* v); - -#define IMM_PRINT_VISIT(Type, Name) virtual void visit(Name##Imm* v); + virtual void visit(AddPtr v); + virtual void visit(SubPtr v); + virtual void visit(MulPtr v); + virtual void visit(DivPtr v); + virtual void visit(ModPtr v); + virtual void visit(MaxPtr v); + virtual void visit(MinPtr v); + virtual void visit(AndPtr v); + virtual void visit(OrPtr v); + virtual void visit(XorPtr v); + virtual void visit(LshiftPtr v); + virtual void visit(RshiftPtr v); + virtual void visit(CompareSelectPtr v); + +#define IMM_PRINT_VISIT(Type, Name) virtual void visit(Name##ImmPtr v); AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT) #undef IMM_PRINT_VISIT - virtual void visit(Cast* v); - virtual void visit(BitCast* v); - virtual void visit(Var* v); - virtual void visit(Buf* v); - virtual void visit(Ramp* v); - virtual void visit(Load* v); - virtual void visit(For* v); - virtual void visit(Block* v); - virtual void visit(Store* v); - virtual void visit(Broadcast* v); - virtual void visit(IfThenElse* v); - virtual void visit(Intrinsics* v); - virtual void visit(Allocate* v); - virtual void visit(Free* v); - virtual void visit(Let* v); - virtual void visit(Cond* v); - virtual void visit(Term* v); - virtual void visit(Polynomial* v); - virtual void visit(RoundOff* v); - virtual void visit(MaxTerm* v); - virtual void visit(MinTerm* v); - virtual void visit(ReduceOp* v); - virtual void visit(AtomicAdd* v); - virtual void visit(SyncThreads* v); - virtual void visit(ExternalCall* v); + virtual void visit(CastPtr v); + virtual void visit(BitCastPtr v); + virtual void visit(VarPtr v); + virtual void visit(BufPtr v); + virtual void visit(RampPtr v); + virtual void visit(LoadPtr v); + virtual void visit(ForPtr v); + virtual void visit(BlockPtr v); + virtual void visit(StorePtr v); + virtual void visit(BroadcastPtr v); + virtual void visit(IfThenElsePtr v); + virtual void visit(IntrinsicsPtr v); + virtual void visit(AllocatePtr v); + virtual void visit(FreePtr v); + virtual void visit(LetPtr v); + virtual void visit(CondPtr v); + virtual void visit(TermPtr v); + virtual void visit(PolynomialPtr v); + virtual void visit(RoundOffPtr v); + virtual void visit(MaxTermPtr v); + virtual void visit(MinTermPtr v); + virtual void visit(ReduceOpPtr v); + virtual void visit(AtomicAddPtr v); + virtual void visit(SyncThreadsPtr v); + virtual void visit(ExternalCallPtr v); }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index c8fc745..faacd02 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -201,7 +201,7 @@ c10::optional getTensorInfoJit(torch::jit::Value* v) { c10::optional getTensorInfo(BufHandle b) { std::vector dims; for (auto dim : b.dims()) { - auto val = dynamic_cast(dim.node()); + auto val = to(dim.node()); if (!val) { return c10::nullopt; } @@ -503,20 +503,20 @@ ExprHandle demoteOutput( } // namespace jit } // namespace torch -static at::ScalarType tensorType(Buf* b) { +static at::ScalarType tensorType(BufPtr b) { return static_cast(b->dtype().scalar_type()); } -std::vector bufferSizes(Buf* b) { +std::vector bufferSizes(BufPtr b) { std::vector sizes; for (size_t i = 0; i < b->ndim(); i++) { - sizes.push_back(dynamic_cast(b->dim(i))->value()); + sizes.push_back(to(b->dim(i))->value()); } return sizes; } ExprHandle TensorExprKernel::chunk( - Buf* b, + BufPtr b, size_t chunkIdx, int64_t dim, int64_t chunks, @@ -1161,9 +1161,11 @@ Tensor* computeCatWoConditionals( // output[i,j+l2,k] = inp3[i,j,k] auto output_sizes_expr = ExprHandleVectorToExprVector(outputShape); - auto output_buf = new Buf("aten_cat", output_sizes_expr, ToDtype(high_type)); + auto output_buf = + alloc("aten_cat", output_sizes_expr, ToDtype(high_type)); if (non_empty_inputs.size() == 0) { - return new Tensor(output_buf, new tensorexpr::Block({})); + return new Tensor( + output_buf, alloc(std::vector({}))); } int64_t concat_dim = c10::get(arg_dim); @@ -1172,43 +1174,44 @@ Tensor* computeCatWoConditionals( auto gen_code_for_input = [&](const BufHandle& inp, size_t inp_pos, - Expr* concat_dim_size, + ExprPtr concat_dim_size, const std::vector& dims) { - std::vector for_vars(dims.size()); - std::vector load_indices(dims.size()); - std::vector store_indices(dims.size()); + std::vector for_vars(dims.size()); + std::vector load_indices(dims.size()); + std::vector store_indices(dims.size()); for (size_t i = 0; i < dims.size(); ++i) { - for_vars[i] = new Var( + for_vars[i] = alloc( "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt); load_indices[i] = for_vars[i]; if (i == norm_concat_dim) { - store_indices[i] = new Add(for_vars[i], concat_dim_size); + store_indices[i] = alloc(for_vars[i], concat_dim_size); } else { store_indices[i] = for_vars[i]; } } auto inp_buf = inp.node(); - auto load_expr = new Load(inp_buf, load_indices); + auto load_expr = alloc(inp_buf, load_indices); auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type); - Stmt* st = new Store(output_buf, store_indices, load_promoted.node()); + StmtPtr st = alloc(output_buf, store_indices, load_promoted.node()); for (size_t i = dims.size(); i > 0; --i) { - st = new For(for_vars[i - 1], new IntImm(0), dims[i - 1].node(), st); + st = + alloc(for_vars[i - 1], alloc(0), dims[i - 1].node(), st); } return st; }; - Expr* concat_dim_size = nullptr; - auto block = new tensorexpr::Block({}); + ExprPtr concat_dim_size = nullptr; + auto block = alloc(std::vector({})); for (size_t i = 0; i < non_empty_inputs.size(); ++i) { auto input_dims = ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims()); if (concat_dim_size == nullptr) { - concat_dim_size = new IntImm(0); + concat_dim_size = alloc(0); } block->append_stmt(gen_code_for_input( non_empty_inputs[i], i, concat_dim_size, input_dims)); concat_dim_size = - new Add(concat_dim_size, input_dims[norm_concat_dim].node()); + alloc(concat_dim_size, input_dims[norm_concat_dim].node()); } return new Tensor(output_buf, IRSimplifier::simplify(block)); } @@ -1255,8 +1258,7 @@ Tensor* computeCat( std::vector newAxes(axes.begin(), axes.end()); ExprHandle load = promoteToDtype( tensorOrConstant(nonEmptyInputs[0], newAxes), highType); - size_t offset = - dynamic_cast(nonEmptyInputs[0].node()->dim(dim))->value(); + size_t offset = to(nonEmptyInputs[0].node()->dim(dim))->value(); newAxes[dim] = newAxes[dim] - IntImm::make(offset); for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) { @@ -1266,7 +1268,7 @@ Tensor* computeCat( load, promoteToDtype(tensorOrConstant(input, newAxes), highType)); - offset += dynamic_cast(input.node()->dim(dim))->value(); + offset += to(input.node()->dim(dim))->value(); newAxes[dim] = axes[dim] - IntImm::make(offset); } @@ -1306,7 +1308,7 @@ Tensor* computeConv2d( // Once we have a performant TE representation for conv2d, we could use it // here instead of the external call! - Stmt* s = ExternalCall::make( + StmtPtr s = ExternalCall::make( ResultBuf, "nnc_aten_conv2d", {inp, w, b}, @@ -2315,9 +2317,9 @@ Tensor* tensorexpr::computeOperandValue( */ // NOLINTNEXTLINE(clang-diagnostic-unused-variable) ExprHandle cur_stride = 1; - std::vector dims, indices; + std::vector dims, indices; for (size_t idx = 0; idx < view_dims.size(); idx++) { - dims.push_back(new IntImm(view_dims[idx])); + dims.push_back(alloc(view_dims[idx])); indices.push_back(axes[idx].node()); } ExprHandle flat_idx = ExprHandle(flatten_index(dims, indices)); @@ -2429,7 +2431,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } // Return the (lower, upper) loop bounds if they are constants, else nullopt. -c10::optional> loopBounds(For* loop) { +c10::optional> loopBounds(ForPtr loop) { auto start = IRSimplifier::simplify(loop->start()); auto stop = IRSimplifier::simplify(loop->stop()); if (!start->isConstant() || !stop->isConstant()) { @@ -2440,7 +2442,7 @@ c10::optional> loopBounds(For* loop) { } // True if all the loops in this vector have equal bounds. -bool loopBoundsAllEqual(const std::vector& loops) { +bool loopBoundsAllEqual(const std::vector& loops) { auto bounds = loopBounds(loops[0]); if (!bounds) { return false; @@ -2462,11 +2464,11 @@ bool loopBoundsAllEqual(const std::vector& loops) { // on matching bounds exists to avoid inserting conditionals on the loop // indices where none would be needed, which would significantly complicate // vectorization. -void fuseAllLoops(Stmt* st) { - if (auto block = dynamic_cast(st)) { - std::vector loopsToFuse; +void fuseAllLoops(StmtPtr st) { + if (auto block = to(st)) { + std::vector loopsToFuse; for (auto stmt : *block) { - auto loop = dynamic_cast(stmt); + auto loop = to(stmt); if (!loop) { // Block contains something that's not a loop. Quit. return; @@ -2477,7 +2479,7 @@ void fuseAllLoops(Stmt* st) { return; } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fusedLoop; + ForPtr fusedLoop; if (!LoopNest::fuseLoops(loopsToFuse, &fusedLoop)) { return; } @@ -2485,7 +2487,7 @@ void fuseAllLoops(Stmt* st) { } } -Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { +StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { torch::jit::tensorexpr::LoopNest l(st, bufOutputs_); GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); @@ -2530,12 +2532,12 @@ Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { if (backendType == kCudaCodeGen) { for (auto buf : bufOutputs_) { - std::vector loops = l.getLoopStmtsFor(buf); + std::vector loops = l.getLoopStmtsFor(buf); if (loops.empty()) { // This happens when Buf is 0-dim continue; } - For* flattened = nullptr; + ForPtr flattened = nullptr; LoopNest::flatten(loops, &flattened); assert(flattened); @@ -2547,7 +2549,7 @@ Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { if (loopLevels == 2) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner; + ForPtr inner; const int kDefaultBlockSize = 512; if (blockSize < 0) { blockSize = kDefaultBlockSize; @@ -2557,9 +2559,9 @@ Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { inner->set_gpu_thread_index(0); } else if (loopLevels == 3) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner; + ForPtr inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner1; + ForPtr inner1; // TODO: change the number of microprocessors const int kDefaultBlockCount = 1280; const int kDefaultBlockSize = 256; @@ -2585,13 +2587,13 @@ Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { if (buf->dtype().scalar_type() == ScalarType::Byte) { blockSize = default_uint8_blocksize; } - std::vector loops = l.getLoopStmtsFor(buf); + std::vector loops = l.getLoopStmtsFor(buf); TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty"); - For* flattened = nullptr; + ForPtr flattened = nullptr; LoopNest::flatten(loops, &flattened); assert(flattened); - For* inner = nullptr; + ForPtr inner = nullptr; LoopNest::splitWithMask(flattened, blockSize, &inner); flattened->set_gpu_block_index(0); inner->set_gpu_thread_index(0); @@ -2605,7 +2607,7 @@ Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { l.vectorizeInnerLoops(); } - Stmt* stmt = l.root_stmt(); + StmtPtr stmt = l.root_stmt(); // Arithmetic Simplification. stmt = IRSimplifier::simplify(stmt); GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n"); @@ -2801,7 +2803,7 @@ bool denseAndNonOverlapping( Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { const TensorTypePtr& tt = v->type()->expect(); TORCH_INTERNAL_ASSERT(bufs_.count(v)); - Buf* buf = bufs_.at(v); + BufPtr buf = bufs_.at(v); // No shape info is present in the graph if (!tt->sizes().concrete_sizes()) { @@ -2892,7 +2894,7 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) { te_sizes.push_back(IntImm::make(s)); } - Buf* buf = new Buf( + BufPtr buf = alloc( "const_" + v->debugName(), ExprHandleVectorToExprVector(te_sizes), ToDtype(static_cast(*tt->scalarType()))); @@ -2914,7 +2916,7 @@ void TensorExprKernel::compile() { OptimizeCat(graph_); // Block to collect the Stmts corresponding to all tensors. - auto block = new Block({}); + auto block = alloc(std::vector({})); // Bind inputs to buffers. nInputs_ = graph_->inputs().size(); @@ -2987,7 +2989,7 @@ void TensorExprKernel::compile() { } BackendType backendType = inferBackendTypeFromDevice(device_); - Stmt* stmt = transformLoops(backendType, block); + StmtPtr stmt = transformLoops(backendType, block); // Generate code. codegen_ = CreateCodeGen( @@ -3073,7 +3075,7 @@ std::vector TensorExprKernel::prepareRunArgs( return runArgs; } -Stmt* TensorExprKernel::getCodeGenStmt() { +StmtPtr TensorExprKernel::getCodeGenStmt() { return codegen_->stmt(); } diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index d6cc1e7..7b35e1e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -19,7 +19,7 @@ template inline std::vector bufferSizes(const T& t) { std::vector sizes; for (size_t i = 0; i < t->ndim(); i++) { - sizes.push_back(dynamic_cast(t->dim(i))->value()); + sizes.push_back(to(t->dim(i))->value()); } return sizes; } @@ -132,7 +132,7 @@ TORCH_API Tensor* computeOperandValue( class TORCH_API TensorExprKernel { struct ConstantDescr { - Buf* buf; + BufPtr buf; void* ptr; }; @@ -151,7 +151,7 @@ class TORCH_API TensorExprKernel { InterpreterState(code_).run(stack); } - Stmt* getCodeGenStmt(); + StmtPtr getCodeGenStmt(); std::string getCodeText(const std::string& attr = "") { return codegen_->getCodeText(attr); @@ -196,7 +196,7 @@ class TORCH_API TensorExprKernel { std::vector> shapes); ExprHandle chunk( - Buf* b, + BufPtr b, size_t chunkIdx, int64_t dim, int64_t chunks, @@ -213,7 +213,7 @@ class TORCH_API TensorExprKernel { void bindConstant(const torch::jit::Value* v); - Stmt* transformLoops(BackendType backendType, Stmt* st); + StmtPtr transformLoops(BackendType backendType, StmtPtr st); std::string getCodeGenName(BackendType backendType); @@ -260,8 +260,8 @@ class TORCH_API TensorExprKernel { std::vector> tensorOutputSizes_; std::vector> tensorOutputStrides_; std::vector tensorOutputTensorOptions_; - std::unordered_set bufOutputs_; - std::unordered_map bufs_; + std::unordered_set bufOutputs_; + std::unordered_map bufs_; std::unordered_map scalars_; std::unordered_map input_name_map_; std::unique_ptr codegen_; diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 7b9929c..eac1f82 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -167,10 +167,10 @@ class LLVMCodeGenImpl : public IRVisitor { llvm::Type* Int8PtrTy_; llvm::Type* VoidTy_; - std::unordered_map varToArg_; - std::unordered_map varToVal_; - std::unordered_map> scopeToVar_; - Block* scope_; + std::unordered_map varToArg_; + std::unordered_map varToVal_; + std::unordered_map> scopeToVar_; + BlockPtr scope_; std::string llvmCode_; std::string asmCode_; @@ -180,7 +180,7 @@ class LLVMCodeGenImpl : public IRVisitor { llvm::Type* dtypeToLLVM(Dtype dtype); llvm::Type* dtypeToLLVMPtr(Dtype dtype); void emitWrapper(const std::vector& params); - void emitKernel(Stmt* stmt, const std::vector& params); + void emitKernel(StmtPtr stmt, const std::vector& params); llvm::Value* toVec(llvm::Value* v, int lanes); enum Arity { @@ -195,17 +195,17 @@ class LLVMCodeGenImpl : public IRVisitor { Arity arity, int lanes); - llvm::Value* varToValue(Var* var); + llvm::Value* varToValue(VarPtr var); void replaceVarMapping( - const std::vector& vars, + const std::vector& vars, const std::vector& vals); llvm::Value* packFuncArgs(const std::vector& func_args); std::vector unpackFuncArgs(llvm::Value* packed, int arg_count); - void processParallelFor(For* v); + void processParallelFor(ForPtr v); public: LLVMCodeGenImpl( - Stmt* stmt, + StmtPtr stmt, const std::vector& args, at::Device device, Dtype dtype, @@ -216,42 +216,42 @@ class LLVMCodeGenImpl : public IRVisitor { llvm::JITTargetAddress getKernelAddress() const; - void visit(Add* v) override; - void visit(Sub* v) override; - void visit(Mul* v) override; - void visit(Div* v) override; - void visit(Mod* v) override; - void visit(Max* v) override; - void visit(Min* v) override; - void visit(And* v) override; - void visit(Or* v) override; - void visit(Xor* v) override; - void visit(Lshift* v) override; - void visit(Rshift* v) override; - void visit(CompareSelect* v) override; - -#define IMM_VISIT_DECLARE(_1, Name) void visit(Name##Imm* v) override; + void visit(AddPtr v) override; + void visit(SubPtr v) override; + void visit(MulPtr v) override; + void visit(DivPtr v) override; + void visit(ModPtr v) override; + void visit(MaxPtr v) override; + void visit(MinPtr v) override; + void visit(AndPtr v) override; + void visit(OrPtr v) override; + void visit(XorPtr v) override; + void visit(LshiftPtr v) override; + void visit(RshiftPtr v) override; + void visit(CompareSelectPtr v) override; + +#define IMM_VISIT_DECLARE(_1, Name) void visit(Name##ImmPtr v) override; AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT_DECLARE); #undef IMM_VISIT_DECLARE - void visit(Cast* v) override; - void visit(BitCast* v) override; - void visit(Var* v) override; - void visit(Ramp* v) override; - void visit(Load* v) override; - void visit(For* v) override; - void visit(Block* v) override; - void visit(Store* v) override; - void visit(Broadcast* v) override; - void visit(IfThenElse* v) override; - void visit(Intrinsics* v) override; - void visit(Allocate* v) override; - void visit(Free* v) override; - void visit(Let* v) override; - void visit(Cond* v) override; - void visit(ExternalCall* v) override; - - void emitIsNan(Intrinsics* v); + void visit(CastPtr v) override; + void visit(BitCastPtr v) override; + void visit(VarPtr v) override; + void visit(RampPtr v) override; + void visit(LoadPtr v) override; + void visit(ForPtr v) override; + void visit(BlockPtr v) override; + void visit(StorePtr v) override; + void visit(BroadcastPtr v) override; + void visit(IfThenElsePtr v) override; + void visit(IntrinsicsPtr v) override; + void visit(AllocatePtr v) override; + void visit(FreePtr v) override; + void visit(LetPtr v) override; + void visit(CondPtr v) override; + void visit(ExternalCallPtr v) override; + + void emitIsNan(IntrinsicsPtr v); llvm::Value* emitUnmaskedLoad(llvm::Value* addr, llvm::Value* idx); llvm::Value* emitMaskedLoad( @@ -291,11 +291,11 @@ void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) { LLVMCodeGen::~LLVMCodeGen() = default; -LLVMCodeGen::LLVMCodeGen(Stmt* stmt) +LLVMCodeGen::LLVMCodeGen(StmtPtr stmt) : LLVMCodeGen(stmt, std::vector()) {} LLVMCodeGen::LLVMCodeGen( - Stmt* stmt, + StmtPtr stmt, const std::vector& args, at::Device device, const std::string& kernel_func_name, @@ -362,7 +362,7 @@ static std::mutex llvmInitMutex; } // namespace LLVMCodeGenImpl::LLVMCodeGenImpl( - Stmt* stmt, + StmtPtr stmt, const std::vector& args, at::Device device, Dtype dtype, @@ -484,7 +484,7 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector& params) { class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { private: - Expr* mutate(Intrinsics* v) { + ExprPtr mutate(IntrinsicsPtr v) { if (v->op_type() == kTanh) { ScalarType stype = v->dtype().scalar_type(); if (stype == ScalarType::Float) { @@ -504,7 +504,7 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { }; void LLVMCodeGenImpl::emitKernel( - Stmt* stmt, + StmtPtr stmt, const std::vector& params) { // Set insert point to the real function. bb_ = llvm::BasicBlock::Create(getContext(), "entry", fn_); @@ -570,7 +570,7 @@ void LLVMCodeGenImpl::emitKernel( // TODO: The binary ops are copypasta. -void LLVMCodeGenImpl::visit(Add* v) { +void LLVMCodeGenImpl::visit(AddPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -588,7 +588,7 @@ void LLVMCodeGenImpl::visit(Add* v) { } } -void LLVMCodeGenImpl::visit(Sub* v) { +void LLVMCodeGenImpl::visit(SubPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -606,7 +606,7 @@ void LLVMCodeGenImpl::visit(Sub* v) { } } -void LLVMCodeGenImpl::visit(Mul* v) { +void LLVMCodeGenImpl::visit(MulPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -624,7 +624,7 @@ void LLVMCodeGenImpl::visit(Mul* v) { } } -void LLVMCodeGenImpl::visit(Div* v) { +void LLVMCodeGenImpl::visit(DivPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -642,7 +642,7 @@ void LLVMCodeGenImpl::visit(Div* v) { } } -void LLVMCodeGenImpl::visit(And* v) { +void LLVMCodeGenImpl::visit(AndPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -657,7 +657,7 @@ void LLVMCodeGenImpl::visit(And* v) { } } -void LLVMCodeGenImpl::visit(Or* v) { +void LLVMCodeGenImpl::visit(OrPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -672,7 +672,7 @@ void LLVMCodeGenImpl::visit(Or* v) { } } -void LLVMCodeGenImpl::visit(Xor* v) { +void LLVMCodeGenImpl::visit(XorPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -687,7 +687,7 @@ void LLVMCodeGenImpl::visit(Xor* v) { } } -void LLVMCodeGenImpl::visit(Lshift* v) { +void LLVMCodeGenImpl::visit(LshiftPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -702,7 +702,7 @@ void LLVMCodeGenImpl::visit(Lshift* v) { } } -void LLVMCodeGenImpl::visit(Rshift* v) { +void LLVMCodeGenImpl::visit(RshiftPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -721,7 +721,7 @@ void LLVMCodeGenImpl::visit(Rshift* v) { } } -void LLVMCodeGenImpl::visit(Mod* v) { +void LLVMCodeGenImpl::visit(ModPtr v) { v->lhs()->accept(this); auto lhs = this->value_; bool lfp = lhs->getType()->isFPOrFPVectorTy(); @@ -736,7 +736,7 @@ void LLVMCodeGenImpl::visit(Mod* v) { } } -void LLVMCodeGenImpl::visit(Max* v) { +void LLVMCodeGenImpl::visit(MaxPtr v) { v->lhs()->accept(this); auto lhs = this->value_; v->rhs()->accept(this); @@ -759,7 +759,7 @@ void LLVMCodeGenImpl::visit(Max* v) { irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs)); } -void LLVMCodeGenImpl::visit(Min* v) { +void LLVMCodeGenImpl::visit(MinPtr v) { v->lhs()->accept(this); auto lhs = this->value_; v->rhs()->accept(this); @@ -781,7 +781,7 @@ void LLVMCodeGenImpl::visit(Min* v) { irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs)); } -void LLVMCodeGenImpl::visit(CompareSelect* v) { +void LLVMCodeGenImpl::visit(CompareSelectPtr v) { auto genUnbiased = [this, v]() -> llvm::Value* { v->lhs()->accept(this); auto lhs = this->value_; @@ -884,17 +884,17 @@ getFromType(llvm::Type* type, T value) { } #define IMM_VISIT_DECLARE(Type, Name) \ - void LLVMCodeGenImpl::visit(Name##Imm* v) { \ + void LLVMCodeGenImpl::visit(Name##ImmPtr v) { \ value_ = getFromType(Name##Ty_, v->value()); \ } AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE); #undef IMM_VISIT_DECLARE -void LLVMCodeGenImpl::visit(HalfImm* v) { +void LLVMCodeGenImpl::visit(HalfImmPtr v) { value_ = llvm::ConstantFP::get(HalfTy_, v->value()); } -void LLVMCodeGenImpl::visit(BoolImm* v) { +void LLVMCodeGenImpl::visit(BoolImmPtr v) { value_ = llvm::ConstantInt::get(BoolTy_, v->value()); } @@ -906,7 +906,7 @@ llvm::Type* llvmTypeToVec(llvm::Type* type, int lanes) { } } -void LLVMCodeGenImpl::visit(Cast* v) { +void LLVMCodeGenImpl::visit(CastPtr v) { v->src_value()->accept(this); llvm::Type* dstType = @@ -978,7 +978,7 @@ void LLVMCodeGenImpl::visit(Cast* v) { } } -void LLVMCodeGenImpl::visit(BitCast* v) { +void LLVMCodeGenImpl::visit(BitCastPtr v) { v->src_value()->accept(this); llvm::Type* dstType = dtypeToLLVM(v->dtype()); @@ -997,11 +997,11 @@ void LLVMCodeGenImpl::visit(BitCast* v) { value_ = irb_.CreateBitOrPointerCast(value_, dstType); } -void LLVMCodeGenImpl::visit(Var* v) { +void LLVMCodeGenImpl::visit(VarPtr v) { value_ = varToValue(v); } -llvm::Value* LLVMCodeGenImpl::varToValue(Var* v) { +llvm::Value* LLVMCodeGenImpl::varToValue(VarPtr v) { // It is possible for v to be in both varToVal_ and varToArgs. // In that case, varToVal_ takes precedence. if (varToVal_.count(v)) { @@ -1015,11 +1015,11 @@ llvm::Value* LLVMCodeGenImpl::varToValue(Var* v) { } void LLVMCodeGenImpl::replaceVarMapping( - const std::vector& vars, + const std::vector& vars, const std::vector& vals) { TORCH_CHECK(vars.size() == vals.size()); for (const auto i : c10::irange(vars.size())) { - Var* var = vars[i]; + VarPtr var = vars[i]; llvm::Value* val = vals[i]; if (val) { varToVal_[var] = val; @@ -1029,7 +1029,7 @@ void LLVMCodeGenImpl::replaceVarMapping( } } -void LLVMCodeGenImpl::visit(Ramp* v) { +void LLVMCodeGenImpl::visit(RampPtr v) { v->base()->accept(this); auto base = this->value_; v->stride()->accept(this); @@ -1105,7 +1105,7 @@ llvm::Value* LLVMCodeGenImpl::emitMaskedLoad( return phi; } -void LLVMCodeGenImpl::visit(Load* v) { +void LLVMCodeGenImpl::visit(LoadPtr v) { if (v->dtype().lanes() == 1) { v->base_handle()->accept(this); auto base = this->value_; @@ -1134,9 +1134,9 @@ void LLVMCodeGenImpl::visit(Load* v) { bool unmasked_load = true; // Handle the case where the load is contiguous and unmasked efficiently - auto* idx_ramp = dynamic_cast(v->flat_index()); + auto idx_ramp = to(v->flat_index()); if (idx_ramp) { - auto* stride_imm = dynamic_cast(idx_ramp->stride()); + auto stride_imm = to(idx_ramp->stride()); if (stride_imm && stride_imm->value() == 1) { v->base_handle()->accept(this); auto base = this->value_; @@ -1215,7 +1215,7 @@ std::vector LLVMCodeGenImpl::unpackFuncArgs( // * Move the body into its own closure. // * Identify var across the boundary into arguments and forward them. // * Send the closure and range to the dispatcher for execution. -void LLVMCodeGenImpl::processParallelFor(For* v) { +void LLVMCodeGenImpl::processParallelFor(ForPtr v) { // Create "start" and "stop" values. v->start()->accept(this); auto start = this->value_; @@ -1223,13 +1223,13 @@ void LLVMCodeGenImpl::processParallelFor(For* v) { auto stop = this->value_; // The Vars that need to be forward in the body closure. - std::vector body_arg_vars; + std::vector body_arg_vars; // Corresponding Value* that was used in the old body for the caller. std::vector body_caller_vals; // Corresponding Value* that will be used in the new body closure. std::vector body_closure_args; - // Identify the Var* used in the body, and generated outside. + // Identify the VarPtr used in the body, and generated outside. VarFinder var_finder; v->body()->accept(&var_finder); auto& vars = var_finder.vars(); @@ -1292,7 +1292,7 @@ void LLVMCodeGenImpl::processParallelFor(For* v) { value_ = llvm::ConstantInt::get(IntTy_, 0); } -void LLVMCodeGenImpl::visit(For* v) { +void LLVMCodeGenImpl::visit(ForPtr v) { if (v->is_parallel()) { processParallelFor(v); return; @@ -1347,11 +1347,11 @@ void LLVMCodeGenImpl::visit(For* v) { value_ = llvm::ConstantInt::get(IntTy_, 0); } -void LLVMCodeGenImpl::visit(Block* v) { - Block* last = scope_; +void LLVMCodeGenImpl::visit(BlockPtr v) { + BlockPtr last = scope_; scope_ = v; - for (Stmt* s : *v) { + for (StmtPtr s : *v) { s->accept(this); } @@ -1359,7 +1359,7 @@ void LLVMCodeGenImpl::visit(Block* v) { auto it = scopeToVar_.find(v); if (it != scopeToVar_.end()) { - for (Var* e : it->second) { + for (VarPtr e : it->second) { if (varToVal_.erase(e) != 1) { throw std::runtime_error("erasing var that doesn't exist"); } @@ -1398,7 +1398,7 @@ void LLVMCodeGenImpl::emitMaskedStore( irb_.SetInsertPoint(tailblock); } -void LLVMCodeGenImpl::visit(Store* v) { +void LLVMCodeGenImpl::visit(StorePtr v) { if (v->value()->dtype().lanes() == 1) { v->base_handle()->accept(this); auto base = this->value_; @@ -1419,9 +1419,9 @@ void LLVMCodeGenImpl::visit(Store* v) { auto val = this->value_; // Handle the case where the store is contiguous and unmasked efficiently - auto* idx_ramp = dynamic_cast(v->flat_index()); + auto idx_ramp = to(v->flat_index()); if (idx_ramp) { - auto* stride_imm = dynamic_cast(idx_ramp->stride()); + auto stride_imm = to(idx_ramp->stride()); if (stride_imm && stride_imm->value() == 1) { idx_ramp->base()->accept(this); auto first_idx = value_; @@ -1453,13 +1453,13 @@ void LLVMCodeGenImpl::visit(Store* v) { value_ = llvm::ConstantInt::get(IntTy_, 0); } -void LLVMCodeGenImpl::visit(Broadcast* v) { +void LLVMCodeGenImpl::visit(BroadcastPtr v) { v->value()->accept(this); int lanes = v->lanes(); value_ = irb_.CreateVectorSplat(lanes, value_); } -void LLVMCodeGenImpl::visit(IfThenElse* v) { +void LLVMCodeGenImpl::visit(IfThenElsePtr v) { v->condition()->accept(this); llvm::Value* condition = value_; llvm::Value* c = irb_.CreateICmpNE( @@ -1509,7 +1509,7 @@ llvm::Value* LLVMCodeGenImpl::toVec(llvm::Value* v, int lanes) { } } -void LLVMCodeGenImpl::emitIsNan(Intrinsics* v) { +void LLVMCodeGenImpl::emitIsNan(IntrinsicsPtr v) { v->param(0)->accept(this); llvm::Type* dstType = dtypeToLLVM(v->dtype()); if (!v->param(0)->dtype().is_floating_point()) { @@ -1583,7 +1583,7 @@ LLVMCodeGenImpl::SimdCallee LLVMCodeGenImpl::getSimdFunction( return SimdCallee{callee.getFunctionType(), callee.getCallee(), useSimd}; } -void LLVMCodeGenImpl::visit(Intrinsics* v) { +void LLVMCodeGenImpl::visit(IntrinsicsPtr v) { llvm::FunctionType* call_ty = nullptr; llvm::Value* call_fn = nullptr; bool call_simd_sleef = false; @@ -1772,7 +1772,7 @@ void LLVMCodeGenImpl::visit(Intrinsics* v) { } } -void LLVMCodeGenImpl::visit(ExternalCall* v) { +void LLVMCodeGenImpl::visit(ExternalCallPtr v) { constexpr int max_buffers = 10; constexpr int max_dimensions = 40; @@ -1783,7 +1783,7 @@ void LLVMCodeGenImpl::visit(ExternalCall* v) { // Prepare a vector of bufs that we need to pass to the external function. // This vector is the output buf followed by the buf_args. - std::vector bufs(v->buf_args()); + std::vector bufs(v->buf_args()); bufs.insert(bufs.begin(), v->buf()); int64_t bufs_num = bufs.size(); @@ -1792,7 +1792,7 @@ void LLVMCodeGenImpl::visit(ExternalCall* v) { // Count the size of dims array - it consists of dimension of all bufs // concatenated together. int64_t dims_num = 0; - for (Buf* b : bufs) { + for (BufPtr b : bufs) { dims_num += b->dims().size(); } @@ -1809,7 +1809,7 @@ void LLVMCodeGenImpl::visit(ExternalCall* v) { int i = 0; int dim_idx = 0; - for (Buf* b : bufs) { + for (BufPtr b : bufs) { // Store value for buf pointer auto gep = irb_.CreateInBoundsGEP( buf_ptrs, {llvm::ConstantInt::getSigned(IntTy_, i)}); @@ -1845,7 +1845,7 @@ void LLVMCodeGenImpl::visit(ExternalCall* v) { } i = 0; - for (Expr* arg : v->args()) { + for (ExprPtr arg : v->args()) { auto gep = irb_.CreateInBoundsGEP( extra_args, {llvm::ConstantInt::getSigned(IntTy_, i)}); arg->accept(this); @@ -1886,10 +1886,10 @@ void LLVMCodeGenImpl::visit(ExternalCall* v) { value_ = llvm::ConstantInt::get(IntTy_, 0); } -void LLVMCodeGenImpl::visit(Allocate* v) { +void LLVMCodeGenImpl::visit(AllocatePtr v) { llvm::Value* size = llvm::ConstantInt::getSigned(LongTy_, v->dtype().byte_size()); - for (Expr* e : v->dims()) { + for (ExprPtr e : v->dims()) { e->accept(this); size = irb_.CreateMul(size, irb_.CreateZExt(value_, LongTy_)); } @@ -1918,7 +1918,7 @@ void LLVMCodeGenImpl::visit(Allocate* v) { varToVal_[v->buffer_var()] = malloc; } -void LLVMCodeGenImpl::visit(Free* v) { +void LLVMCodeGenImpl::visit(FreePtr v) { value_ = llvm::ConstantInt::get(IntTy_, 0); llvm::Value* ptr = varToVal_.at(v->buffer_var()); if (!llvm::isa(ptr)) { @@ -1926,7 +1926,7 @@ void LLVMCodeGenImpl::visit(Free* v) { } } -void LLVMCodeGenImpl::visit(Let* v) { +void LLVMCodeGenImpl::visit(LetPtr v) { v->value()->accept(this); if (!varToVal_.count(v->var())) { varToVal_.emplace(v->var(), value_); @@ -1936,7 +1936,7 @@ void LLVMCodeGenImpl::visit(Let* v) { } } -void LLVMCodeGenImpl::visit(Cond* v) { +void LLVMCodeGenImpl::visit(CondPtr v) { // Even if true_stmt and false_stmt are nullptr, // in case condition is a function call with side effect, // we still evaluate it. diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.h b/torch/csrc/jit/tensorexpr/llvm_codegen.h index eeec032..b33aeff 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.h +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.h @@ -21,7 +21,7 @@ class LLVMCodeGenImpl; class TORCH_API LLVMCodeGen : public CodeGen { public: explicit LLVMCodeGen( - Stmt* stmt, + StmtPtr stmt, const std::vector& args, at::Device device = at::kCPU, const std::string& kernel_func_name = "func", @@ -29,7 +29,7 @@ class TORCH_API LLVMCodeGen : public CodeGen { c10::optional triple = c10::nullopt, c10::optional cpu = c10::nullopt, c10::optional attrs = c10::nullopt); - explicit LLVMCodeGen(Stmt* stmt); + explicit LLVMCodeGen(StmtPtr stmt); LLVMCodeGen() = delete; ~LLVMCodeGen() override; @@ -73,7 +73,7 @@ class TORCH_API LLVMCodeGen : public CodeGen { struct TORCH_API LLVMCodeGenBuilder { using BufferArg = CodeGen::BufferArg; - LLVMCodeGenBuilder(Stmt* stmt, std::vector args) + LLVMCodeGenBuilder(StmtPtr stmt, std::vector args) : stmt_(stmt), args_(std::move(args)) {} LLVMCodeGenBuilder& device(at::Device device) { @@ -112,7 +112,7 @@ struct TORCH_API LLVMCodeGenBuilder { } private: - Stmt* stmt_; + StmtPtr stmt_; std::vector args_; at::Device device_ = at::kCPU; std::string kernelFuncName_ = "func"; diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 407cf90..ea6f093 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -39,7 +39,7 @@ LoopNest::LoopNest(const LoopNest& other) verify(root_stmt_); } -LoopNest::LoopNest(Stmt* stmt, std::unordered_set output_bufs) +LoopNest::LoopNest(StmtPtr stmt, std::unordered_set output_bufs) : root_stmt_(stmt), output_bufs_(std::move(output_bufs)) { verify(root_stmt_); } @@ -58,11 +58,11 @@ LoopNest::LoopNest(const std::vector& output_tensors) { verify(root_stmt_); } -const std::unordered_set LoopNest::getIntermediateBufs() const { - std::unordered_set result; +const std::unordered_set LoopNest::getIntermediateBufs() const { + std::unordered_set result; auto input_bufs = getInputBufs(); auto bufs = NodeFinder::find(root_stmt_); - for (auto* buf : bufs) { + for (auto buf : bufs) { if (!output_bufs_.count(buf) && !input_bufs.count(buf)) { result.insert(buf); } @@ -70,8 +70,8 @@ const std::unordered_set LoopNest::getIntermediateBufs() const { return result; } -const std::unordered_set LoopNest::getInputBufs() const { - std::unordered_set result; +const std::unordered_set LoopNest::getInputBufs() const { + std::unordered_set result; auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); for (auto& kv : buf_load_store_uses) { bool has_store = false; @@ -90,39 +90,43 @@ const std::unordered_set LoopNest::getInputBufs() const { class IndexFlattener : public IRMutator { public: - Stmt* flatten(Stmt* s) { + StmtPtr flatten(StmtPtr s) { return s->accept_mutator(this); } - Expr* mutate(Load* v) override { + ExprPtr mutate(LoadPtr v) override { if (v->indices().size() == 1) { return v; } - return new Load( - v->dtype(), v->buf(), {flatten_index(v->buf()->dims(), v->indices())}); + return alloc( + v->dtype(), + v->buf(), + std::vector({flatten_index(v->buf()->dims(), v->indices())})); } - Stmt* mutate(Store* v) override { - Expr* value = v->value(); - Expr* new_value = value->accept_mutator(this); + StmtPtr mutate(StorePtr v) override { + ExprPtr value = v->value(); + ExprPtr new_value = value->accept_mutator(this); if (v->indices().size() == 1 && value == new_value) { - return (Stmt*)v; + return (StmtPtr)v; } - return new Store( - v->buf(), {flatten_index(v->buf()->dims(), v->indices())}, new_value); + return alloc( + v->buf(), + std::vector({flatten_index(v->buf()->dims(), v->indices())}), + new_value); } }; class Vectorizer : public IRMutator { public: - Stmt* vectorize(For* v) { - Stmt* body = v->body(); - Var* var = v->var(); - Expr* start = v->start(); - Expr* stop = v->stop(); - - IntImm* start_imm = dynamic_cast(start); - IntImm* stop_imm = dynamic_cast(stop); + StmtPtr vectorize(ForPtr v) { + StmtPtr body = v->body(); + VarPtr var = v->var(); + ExprPtr start = v->start(); + ExprPtr stop = v->stop(); + + IntImmPtr start_imm = to(start); + IntImmPtr stop_imm = to(stop); if (!start_imm) { throw std::runtime_error( "Can't vectorize due to non-constant loop start!"); @@ -137,7 +141,7 @@ class Vectorizer : public IRMutator { start_ = start_imm; lanes_ = stop_imm->value(); - Stmt* new_body = body->accept_mutator(this); + StmtPtr new_body = body->accept_mutator(this); if (new_body == body) { throw std::runtime_error("Vectorization failed!"); } @@ -145,87 +149,87 @@ class Vectorizer : public IRMutator { return new_body; } - Expr* mutate(Add* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(AddPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) + ExprHandle(inputs[1]); }); } - Expr* mutate(Sub* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(SubPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) - ExprHandle(inputs[1]); }); } - Expr* mutate(Mul* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(MulPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) * ExprHandle(inputs[1]); }); } - Expr* mutate(Div* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(DivPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) / ExprHandle(inputs[1]); }); } - Expr* mutate(And* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(AndPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) & ExprHandle(inputs[1]); }); } - Expr* mutate(Or* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(OrPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) | ExprHandle(inputs[1]); }); } - Expr* mutate(Xor* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(XorPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) ^ ExprHandle(inputs[1]); }); } - Expr* mutate(Lshift* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(LshiftPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) << ExprHandle(inputs[1]); }); } - Expr* mutate(Rshift* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(RshiftPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) >> ExprHandle(inputs[1]); }); } - Expr* mutate(Max* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(MaxPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return Max::make( ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans()); }); } - Expr* mutate(Min* v) override { - std::vector inputs = {v->lhs(), v->rhs()}; + ExprPtr mutate(MinPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return Min::make( ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans()); }); } - Expr* mutate(CompareSelect* v) override { - std::vector inputs = { + ExprPtr mutate(CompareSelectPtr v) override { + std::vector inputs = { v->lhs(), v->rhs(), v->ret_val1(), v->ret_val2()}; return try_vectorize(v, inputs, [&]() { return CompareSelect::make( @@ -238,23 +242,23 @@ class Vectorizer : public IRMutator { }); } - Expr* mutate(BitCast* v) override { - std::vector inputs = {v->src_value()}; + ExprPtr mutate(BitCastPtr v) override { + std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { return BitCast::make( Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); }); } - Expr* mutate(Cast* v) override { - std::vector inputs = {v->src_value()}; + ExprPtr mutate(CastPtr v) override { + std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { return Cast::make( Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); }); } - Expr* mutate(Var* v) override { + ExprPtr mutate(VarPtr v) override { if (v == var_) { return Ramp::make(ExprHandle(start_), 1, lanes_).node(); } @@ -262,12 +266,12 @@ class Vectorizer : public IRMutator { return v; } - Expr* mutate(Ramp* v) override { - Expr* base = v->base(); - Expr* stride = v->stride(); + ExprPtr mutate(RampPtr v) override { + ExprPtr base = v->base(); + ExprPtr stride = v->stride(); - Expr* base_new = base->accept_mutator(this); - Expr* stride_new = stride->accept_mutator(this); + ExprPtr base_new = base->accept_mutator(this); + ExprPtr stride_new = stride->accept_mutator(this); if (base_new == base && stride_new == stride) { return v; @@ -276,30 +280,30 @@ class Vectorizer : public IRMutator { throw std::runtime_error("Can't vectorize a Ramp!"); } - Expr* mutate(Load* v) override { + ExprPtr mutate(LoadPtr v) override { Dtype dtype(v->dtype().scalar_type(), lanes_); - Buf* buf = v->buf(); - std::vector inputs = {v->flat_index()}; + BufPtr buf = v->buf(); + std::vector inputs = {v->flat_index()}; return try_vectorize(v, inputs, [&]() { return Load::make(dtype, BufHandle(buf), {ExprHandle(inputs[0])}); }); } - Expr* mutate(ReduceOp* v) override { + ExprPtr mutate(ReduceOpPtr v) override { Dtype dtype(v->dtype().scalar_type(), lanes_); - std::vector inputs = {v->body()}; + std::vector inputs = {v->body()}; - auto* out = try_vectorize(v, inputs, [&]() { + auto out = try_vectorize(v, inputs, [&]() { return ExprHandle( - new ReduceOp(inputs[0], v->reduce_args(), v->reducer())); + alloc(inputs[0], v->reduce_args(), v->reducer())); }); return out; } - Expr* mutate(Broadcast* v) override { - Expr* val = v->value(); - Expr* new_val = val->accept_mutator(this); + ExprPtr mutate(BroadcastPtr v) override { + ExprPtr val = v->value(); + ExprPtr new_val = val->accept_mutator(this); if (new_val == val) { return v; } @@ -307,69 +311,69 @@ class Vectorizer : public IRMutator { throw std::runtime_error("Can't vectorize a Broadcast!"); } - Expr* mutate(IfThenElse* v) override { - Expr* condition = v->condition(); - Expr* new_condition = condition->accept_mutator(this); + ExprPtr mutate(IfThenElsePtr v) override { + ExprPtr condition = v->condition(); + ExprPtr new_condition = condition->accept_mutator(this); if (new_condition != condition) { throw std::runtime_error("Can't vectorize an IfThenElse condition!"); } - std::vector inputs = {v->true_value(), v->false_value()}; + std::vector inputs = {v->true_value(), v->false_value()}; return try_vectorize(v, inputs, [&]() { return IfThenElse::make( ExprHandle(condition), ExprHandle(inputs[0]), ExprHandle(inputs[1])); }); } - Expr* mutate(Intrinsics* v) override { - std::vector inputs = v->params(); + ExprPtr mutate(IntrinsicsPtr v) override { + std::vector inputs = v->params(); return try_vectorize(v, inputs, [&]() { - return ExprHandle(new Intrinsics(v->op_type(), inputs)); + return ExprHandle(alloc(v->op_type(), inputs)); }); } - Stmt* mutate(Store* v) override { - Buf* buf = v->buf(); - std::vector inputs = {v->flat_index(), v->value()}; + StmtPtr mutate(StorePtr v) override { + BufPtr buf = v->buf(); + std::vector inputs = {v->flat_index(), v->value()}; return try_vectorize(v, inputs, [&]() { return Store::make( BufHandle(buf), {ExprHandle(inputs[0])}, ExprHandle(inputs[1])); }); } - Stmt* mutate(For* v) override { - Var* var = v->var(); - Expr* start = v->start(); - Expr* stop = v->stop(); + StmtPtr mutate(ForPtr v) override { + VarPtr var = v->var(); + ExprPtr start = v->start(); + ExprPtr stop = v->stop(); LoopOptions loop_options = v->loop_options(); - Expr* new_start = start->accept_mutator(this); - Expr* new_stop = stop->accept_mutator(this); + ExprPtr new_start = start->accept_mutator(this); + ExprPtr new_stop = stop->accept_mutator(this); if (new_start != start || new_stop != stop) { throw std::runtime_error( "Can't vectorize nested For with dependent loop bounds!"); } - Stmt* body = v->body(); - Stmt* new_body = body->accept_mutator(this); + StmtPtr body = v->body(); + StmtPtr new_body = body->accept_mutator(this); if (new_body == body) { - return (For*)v; + return (ForPtr)v; } - return new For(var, new_start, new_stop, new_body, loop_options); + return alloc(var, new_start, new_stop, new_body, loop_options); } - Stmt* mutate(Block* v) override { + StmtPtr mutate(BlockPtr v) override { // IRMutator does in-place mutations. But the logic in vectorization checks // for success by looking for a new stmt. So, we override the in-place // mutations and create a clone here if any of its statements change. // TODO: Can we change the logic of vectorizer so that we don't need this? bool any_change = false; - std::vector stmts; - for (Stmt* stmt : *v) { - Stmt* stmt_new = stmt->accept_mutator(this); + std::vector stmts; + for (StmtPtr stmt : *v) { + StmtPtr stmt_new = stmt->accept_mutator(this); if (stmt != stmt_new) { any_change = true; } else { @@ -380,13 +384,13 @@ class Vectorizer : public IRMutator { } } if (any_change) { - return new Block(stmts); + return alloc(stmts); } return v; } template - Expr* try_vectorize(Expr* e, std::vector& inputs, T&& vec_ctor) { + ExprPtr try_vectorize(ExprPtr e, std::vector& inputs, T&& vec_ctor) { bool vectorize = vectorize_inputs(inputs); if (vectorize) { return vec_ctor().node(); @@ -396,22 +400,22 @@ class Vectorizer : public IRMutator { } template - Stmt* try_vectorize(Stmt* s, std::vector& inputs, T&& vec_ctor) { + StmtPtr try_vectorize(StmtPtr s, std::vector& inputs, T&& vec_ctor) { bool vectorize = vectorize_inputs(inputs); if (vectorize) { return vec_ctor(); } - return (Stmt*)s; + return (StmtPtr)s; } - bool vectorize_inputs(std::vector& inputs) { + bool vectorize_inputs(std::vector& inputs) { bool any_vectorized = false; - std::vector new_inputs; + std::vector new_inputs; // Attempt to vectorize each input. - for (Expr*& in : inputs) { - Expr* new_in = in->accept_mutator(this); + for (ExprPtr& in : inputs) { + ExprPtr new_in = in->accept_mutator(this); new_inputs.push_back(new_in); if (new_in != in) { any_vectorized = true; @@ -436,20 +440,20 @@ class Vectorizer : public IRMutator { return true; } - Var* var_ = nullptr; + VarPtr var_ = nullptr; int lanes_ = 0; - Expr* start_ = nullptr; + ExprPtr start_ = nullptr; }; -bool LoopNest::vectorize(For* f) { - Block* b = dynamic_cast(f->get_parent()); +bool LoopNest::vectorize(ForPtr f) { + BlockPtr b = to(f->get_parent()); if (!b) { return false; } // Can't vectorize reduction axes. auto reductions = NodeFinder::find(f); - for (auto* r : reductions) { + for (auto r : reductions) { if (std::find(r->reduce_args().begin(), r->reduce_args().end(), f->var()) != r->reduce_args().end()) { return false; @@ -457,12 +461,12 @@ bool LoopNest::vectorize(For* f) { } Vectorizer v; - Stmt* new_f = nullptr; + StmtPtr new_f = nullptr; try { new_f = Stmt::clone(f); - normalize(dynamic_cast(new_f)); + normalize(to(new_f)); new_f = FlattenIndexes(new_f); - new_f = v.vectorize(dynamic_cast(new_f)); + new_f = v.vectorize(to(new_f)); } catch (std::runtime_error& e) { // We clone f before vectorizing. So, any partial vectorization will // have modified the clone. In case of an exception, we can continue @@ -486,17 +490,17 @@ void LoopNest::initialize( output_bufs_.insert(t->buf()); } - std::vector loops; + std::vector loops; for (Tensor* t : tensors_to_compute) { - Stmt* loop = t->stmt(); + StmtPtr loop = t->stmt(); if (loop->get_parent()) { std::cerr << "Error: creating a loopnest from already used Tensors\n"; loops = {}; break; } // Flatten initializers. - if (Block* block = dynamic_cast(loop)) { - for (auto* s : block->stmts()) { + if (BlockPtr block = to(loop)) { + for (auto s : block->stmts()) { block->remove_stmt(s); loops.push_back(s); } @@ -505,24 +509,24 @@ void LoopNest::initialize( } } - root_stmt_ = new Block(loops); + root_stmt_ = alloc(loops); } class FunctionInliner : public IRMutator { public: - FunctionInliner(Store* producer, std::unordered_set outputs) + FunctionInliner(StorePtr producer, std::unordered_set outputs) : buf_(producer->buf()), producer_(producer), outputs_(std::move(outputs)) { - for (auto* i : producer->indices()) { - if (auto index_var = dynamic_cast(i)) { + for (auto i : producer->indices()) { + if (auto index_var = to(i)) { index_vars_.insert(index_var); producer_index_vars_.push_back(index_var); - } else if (dynamic_cast(i) != nullptr) { + } else if (to(i) != nullptr) { // If the index can be a constant, then that dimension must have size 1 // (since we don't support in-place writes). Resolves issue 52581. TORCH_INTERNAL_ASSERT( - dynamic_cast(i)->value() == 0, + to(i)->value() == 0, "Constant index impression should always be zero"); producer_index_vars_.push_back(nullptr); } else { @@ -532,16 +536,16 @@ class FunctionInliner : public IRMutator { } private: - Expr* mutate_loads(Buf* buf, std::vector dims) { - std::vector index_vars; + ExprPtr mutate_loads(BufPtr buf, std::vector dims) { + std::vector index_vars; TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size()); for (const auto i : c10::irange(buf->ndim())) { - Var* func_callee_arg = producer_index_vars_.at(i); - Expr* func_caller_param = dims.at(i); + VarPtr func_callee_arg = producer_index_vars_.at(i); + ExprPtr func_caller_param = dims.at(i); if (func_callee_arg == nullptr) { TORCH_INTERNAL_ASSERT( - dynamic_cast(func_caller_param) != nullptr && - dynamic_cast(func_caller_param)->value() == 0, + to(func_caller_param) != nullptr && + to(func_caller_param)->value() == 0, "We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0"); continue; } @@ -558,15 +562,15 @@ class FunctionInliner : public IRMutator { } // Call the actual replacement. - Expr* body = producer_->value(); - Expr* result = Expr::clone(body)->accept_mutator(this); + ExprPtr body = producer_->value(); + ExprPtr result = Expr::clone(body)->accept_mutator(this); // Remove the mappings we created for this function parameters. - for (auto* v : index_vars) { + for (auto v : index_vars) { for (auto& pair : random_bindings_) { if (pair.second.erase(v)) { - Expr* inlined = inline_mapping_[v]; - for (auto* nv : VarFinder::find(inlined)) { + ExprPtr inlined = inline_mapping_[v]; + for (auto nv : VarFinder::find(inlined)) { pair.second.insert(nv); } } @@ -576,8 +580,8 @@ class FunctionInliner : public IRMutator { return result; } - Expr* mutate(Load* v) override { - Buf* buf = v->buf(); + ExprPtr mutate(LoadPtr v) override { + BufPtr buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } @@ -590,39 +594,39 @@ class FunctionInliner : public IRMutator { } // Replace the target variable with the caller expressions. - Expr* mutate(Var* v) override { + ExprPtr mutate(VarPtr v) override { auto iter = inline_mapping_.find(v); if (iter == inline_mapping_.end()) { return v; } else { - Expr* expr = iter->second; + ExprPtr expr = iter->second; // Continue to transform the value from the lookup table. return expr->accept_mutator(this); } } // Handle random intrinsics which should be cached. - Expr* mutate(Intrinsics* v) override { + ExprPtr mutate(IntrinsicsPtr v) override { if (!in_producer_ || v->op_type() != kRand) { return IRMutator::mutate(v); } - // Create a new Let Statment for the random variable, which we can refer to - // multiple times and resolve the same value (ie. store it in a scalar + // Create a new Let Statement for the random variable, which we can refer + // to multiple times and resolve the same value (ie. store it in a scalar // rather than the Tensor). const std::string& name = buf_->name_hint(); - Var* new_var = new Var(name, v->dtype()); - random_bindings_[new Let(new_var, v)] = index_vars_; + VarPtr new_var = alloc(name, v->dtype()); + random_bindings_[alloc(new_var, v)] = index_vars_; return new_var; } // Remove the buffer write from the inlined function. - Stmt* mutate(Store* v) override { + StmtPtr mutate(StorePtr v) override { // If the buf_ is in the outputs set, keep its statement intact. Otherwise, // remove it. if (v == producer_ && !outputs_.count(buf_)) { in_producer_ = true; - producer_ = dynamic_cast(IRMutator::mutate(v)); + producer_ = to(IRMutator::mutate(v)); TORCH_INTERNAL_ASSERT(producer_ != nullptr); in_producer_ = false; return nullptr; @@ -632,10 +636,10 @@ class FunctionInliner : public IRMutator { } // Any Random Instrinsics that were turned into vars must be inserted here. - Stmt* mutate(Block* v) override { - std::vector stmts; - for (Stmt* stmt : *v) { - Stmt* stmt_new = stmt->accept_mutator(this); + StmtPtr mutate(BlockPtr v) override { + std::vector stmts; + for (StmtPtr stmt : *v) { + StmtPtr stmt_new = stmt->accept_mutator(this); if (!stmt_new) { continue; } @@ -650,15 +654,15 @@ class FunctionInliner : public IRMutator { return Block::make(stmts); } - Stmt* mutate(For* v) override { - For* res = dynamic_cast(IRMutator::mutate(v)); + StmtPtr mutate(ForPtr v) override { + ForPtr res = to(IRMutator::mutate(v)); if (!res) { return nullptr; } // Find any random bindings that should be defined in this loops body. - std::vector bindings_this_loop; - Var* fv = v->var(); + std::vector bindings_this_loop; + VarPtr fv = v->var(); for (auto& pair : random_bindings_) { auto& index_var = pair.second; if (index_var.erase(fv)) { @@ -666,7 +670,7 @@ class FunctionInliner : public IRMutator { } } - for (auto* l : bindings_this_loop) { + for (auto l : bindings_this_loop) { res->body()->prepend_stmt(l); random_bindings_.erase(l); } @@ -674,43 +678,43 @@ class FunctionInliner : public IRMutator { } private: - Buf* buf_; - Store* producer_; + BufPtr buf_; + StorePtr producer_; // Index Vars present in the producer. - std::unordered_set index_vars_; - std::vector producer_index_vars_; + std::unordered_set index_vars_; + std::vector producer_index_vars_; - std::unordered_map inline_mapping_; + std::unordered_map inline_mapping_; // In the producer's scope - we need to bind any calls to rand(). bool in_producer_ = false; - std::unordered_map> random_bindings_; - std::unordered_set outputs_; + std::unordered_map> random_bindings_; + std::unordered_set outputs_; }; -bool LoopNest::computeInline(Stmt* s) { - auto* s_store = dynamic_cast(s); +bool LoopNest::computeInline(StmtPtr s) { + auto s_store = to(s); if (s_store == nullptr) { throw std::logic_error("Could not find buffer producer to inline"); } return computeInline(s_store->buf()); } -bool LoopNest::computeInline(Buf* b) { +bool LoopNest::computeInline(BufPtr b) { // If buf is used or defined in an ExternalCall, we cannot inline it auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); for (auto& use : buf_load_store_uses.at(b)) { - Stmt* s = use.s; - if (dynamic_cast(s)) { + StmtPtr s = use.s; + if (to(s)) { return false; } } // Find producers. - Store* relevant_store{nullptr}; + StorePtr relevant_store{nullptr}; auto stores = NodeFinder::find(root_stmt_); - for (auto* s : stores) { + for (auto s : stores) { if (s->buf() == b) { auto reductions = NodeFinder::find(s); if (!reductions.empty()) { @@ -738,7 +742,7 @@ bool LoopNest::computeInline(Buf* b) { // difficult synchronization logic across blocks. Inlining trivial reads does // not duplicate work void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { - std::unordered_set bufs_to_inline; + std::unordered_set bufs_to_inline; auto intermediate_bufs = getIntermediateBufs(); if (allow_duplicated_work) { @@ -757,15 +761,15 @@ void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { // tensors, always inline, bc we are not duplicating any work // and avoiding an intermediary buffer if (stores.size() == 1) { - if (auto store = dynamic_cast(stores[0].s)) { - auto input_as_load = dynamic_cast(store->value()); + if (auto store = to(stores[0].s)) { + auto input_as_load = to(store->value()); if (input_as_load && input_bufs.count(input_as_load->buf())) { bufs_to_inline.insert(buf); continue; } } else { // If S is not a store, it must be an ExternalCall. - TORCH_INTERNAL_ASSERT(dynamic_cast(stores[0].s)); + TORCH_INTERNAL_ASSERT(to(stores[0].s)); } } @@ -791,28 +795,29 @@ void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { // TODO: Unify with DepTracker class LoadOrStoreUseFinder : public IRVisitor { public: - std::unordered_map> findUses(Stmt* s) { + std::unordered_map> findUses( + StmtPtr s) { uses_.clear(); s->accept(this); return uses_; } private: - void visit(Store* v) override { + void visit(StorePtr v) override { if (stores_[v->buf()].insert(last_stmt_).second) { - uses_[v->buf()].push_back({(Stmt*)v, true}); + uses_[v->buf()].push_back({(StmtPtr)v, true}); } - last_stmt_ = (Stmt*)v; + last_stmt_ = (StmtPtr)v; IRVisitor::visit(v); } - void visit(ExternalCall* v) override { + void visit(ExternalCallPtr v) override { if (stores_[v->buf()].insert(last_stmt_).second) { - uses_[v->buf()].push_back({(Stmt*)v, true}); + uses_[v->buf()].push_back({(StmtPtr)v, true}); } - last_stmt_ = (Stmt*)v; + last_stmt_ = (StmtPtr)v; - for (Buf* input_buf : v->buf_args()) { + for (BufPtr input_buf : v->buf_args()) { if (loads_[input_buf].insert(last_stmt_).second) { uses_[input_buf].push_back({last_stmt_, false}); } @@ -821,23 +826,23 @@ class LoadOrStoreUseFinder : public IRVisitor { IRVisitor::visit(v); } - void visit(Load* v) override { + void visit(LoadPtr v) override { if (loads_[v->buf()].insert(last_stmt_).second) { uses_[v->buf()].push_back({last_stmt_, false}); } IRVisitor::visit(v); } - Stmt* last_stmt_ = nullptr; - std::unordered_map> uses_; + StmtPtr last_stmt_ = nullptr; + std::unordered_map> uses_; // Sets of loads and stores in order to keep the results unique - std::unordered_map> loads_; - std::unordered_map> stores_; + std::unordered_map> loads_; + std::unordered_map> stores_; }; -std::unordered_map> findLoadOrStoreUses( - Stmt* s) { +std::unordered_map> findLoadOrStoreUses( + StmtPtr s) { LoadOrStoreUseFinder uf; return uf.findUses(s); } @@ -845,46 +850,46 @@ std::unordered_map> findLoadOrStoreUses( class ContainedStmtsFinder : public IRVisitor { public: // Simply list all Stores and Block that are children of the given stmt - const std::unordered_set& findContainedStmts(Stmt* s) { + const std::unordered_set& findContainedStmts(StmtPtr s) { contained_.clear(); s->accept(this); return contained_; } private: - void visit(Store* v) override { - contained_.insert((Stmt*)v); + void visit(StorePtr v) override { + contained_.insert((StmtPtr)v); IRVisitor::visit(v); } - void visit(ExternalCall* v) override { - contained_.insert((Stmt*)v); + void visit(ExternalCallPtr v) override { + contained_.insert((StmtPtr)v); IRVisitor::visit(v); } - void visit(Block* v) override { - contained_.insert((Stmt*)v); + void visit(BlockPtr v) override { + contained_.insert((StmtPtr)v); IRVisitor::visit(v); } - std::unordered_set contained_; + std::unordered_set contained_; }; -bool containsAll(const std::vector& uses, Block* b) { - std::unordered_set not_found; +bool containsAll(const std::vector& uses, BlockPtr b) { + std::unordered_set not_found; for (auto use : uses) { not_found.insert(use.s); } ContainedStmtsFinder csf; - const std::unordered_set& contained = csf.findContainedStmts(b); + const std::unordered_set& contained = csf.findContainedStmts(b); for (auto s : contained) { not_found.erase(s); } return not_found.empty(); } -Block* findParentBlock(Stmt* s) { +BlockPtr findParentBlock(StmtPtr s) { while (s) { - if (auto b = dynamic_cast(s)) { + if (auto b = to(s)) { return b; } s = s->get_parent(); @@ -892,33 +897,33 @@ Block* findParentBlock(Stmt* s) { return nullptr; } -Block* findLowestContainingBlock(const std::vector& uses) { +BlockPtr findLowestContainingBlock(const std::vector& uses) { // TODO: we're not using the most efficient algorithm here for simplicity. // Replace with something more performant in case it becomes a bottleneck. - Block* b = findParentBlock(uses[0].s); + BlockPtr b = findParentBlock(uses[0].s); while (b && !containsAll(uses, b)) { b = findParentBlock(b->get_parent()); } return b; } -Stmt* LoopNest::insertAllocFree(Stmt* stmt) { +StmtPtr LoopNest::insertAllocFree(StmtPtr stmt) { auto intermediate_bufs = getIntermediateBufs(); if (intermediate_bufs.size() == 0ULL) { return stmt; } - Block* b = dynamic_cast(stmt); + BlockPtr b = to(stmt); if (!b) { - b = new Block({stmt}); + b = alloc(std::vector({stmt})); } - std::unordered_map> uses = + std::unordered_map> uses = findLoadOrStoreUses(stmt); // Insert allocations and frees for temporary buffers at global scope. - for (Buf* buf : intermediate_bufs) { - b->prepend_stmt(new Allocate(buf)); - b->append_stmt(new Free(buf)); + for (BufPtr buf : intermediate_bufs) { + b->prepend_stmt(alloc(buf)); + b->append_stmt(alloc(buf)); } return b; @@ -926,15 +931,15 @@ Stmt* LoopNest::insertAllocFree(Stmt* stmt) { class StmtDeleter : public IRMutator { public: - StmtDeleter(const std::unordered_set& targets) : targets_(targets) {} + StmtDeleter(const std::unordered_set& targets) : targets_(targets) {} private: - Stmt* mutate(Block* v) override { - std::vector stmts; + StmtPtr mutate(BlockPtr v) override { + std::vector stmts; - for (auto* s : v->stmts()) { + for (auto s : v->stmts()) { if (targets_.count(s) == 0) { - Stmt* ns = s->accept_mutator(this); + StmtPtr ns = s->accept_mutator(this); if (ns) { stmts.push_back(Stmt::clone(ns)); } @@ -944,7 +949,7 @@ class StmtDeleter : public IRMutator { return Block::make(stmts); } - const std::unordered_set& targets_; + const std::unordered_set& targets_; }; void LoopNest::eliminateDeadStores() { @@ -952,9 +957,9 @@ void LoopNest::eliminateDeadStores() { MemDependencyChecker checker(getInputBufs(), getOutputBufs()); root_stmt_->accept(&checker); - std::unordered_set deadStores; + std::unordered_set deadStores; std::vector> outputAccesses; - for (auto* o : getOutputBufs()) { + for (auto o : getOutputBufs()) { outputAccesses.push_back(checker.output(o)); } @@ -997,10 +1002,10 @@ namespace { // the rest of the IR nodes (the ones not touched directly) to be cloned. class IfThenElseReplacer : public IRCloner { public: - IfThenElseReplacer(IfThenElse* to_replace, Expr* new_expr) + IfThenElseReplacer(IfThenElsePtr to_replace, ExprPtr new_expr) : to_replace_(to_replace), new_expr_(new_expr) {} - Expr* mutate(IfThenElse* i) override { + ExprPtr mutate(IfThenElsePtr i) override { if (i == to_replace_) { return new_expr_; } @@ -1008,8 +1013,8 @@ class IfThenElseReplacer : public IRCloner { } private: - IfThenElse* to_replace_; - Expr* new_expr_; + IfThenElsePtr to_replace_; + ExprPtr new_expr_; }; // Check if the given condition is optimizable. @@ -1021,12 +1026,12 @@ class IfThenElseReplacer : public IRCloner { // * sets `compared_value` to `expr`, and // * returns true. bool isConditionOptimizable( - Expr* condition, - Var** cond_var, - Expr** compared_value) { - auto cs = dynamic_cast(condition); + ExprPtr condition, + VarPtr* cond_var, + ExprPtr* compared_value) { + auto cs = to(condition); if (cs && cs->compare_select_op() == kLT) { - auto var = dynamic_cast(cs->lhs()); + auto var = to(cs->lhs()); if (var) { *cond_var = var; *compared_value = cs->rhs(); @@ -1054,13 +1059,13 @@ bool isConditionOptimizable( // * sub_exprs to the list of sub-expressions that are the result of this // if-then-else expression. bool isConditionalFromCat( - IfThenElse* ite, - Var** cond_var, - std::vector* comp_values, - std::vector* sub_exprs) { - Var* var = nullptr; + IfThenElsePtr ite, + VarPtr* cond_var, + std::vector* comp_values, + std::vector* sub_exprs) { + VarPtr var = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* comp_value; + ExprPtr comp_value; if (isConditionOptimizable(ite->condition(), &var, &comp_value)) { if (*cond_var == nullptr) { *cond_var = var; @@ -1069,7 +1074,7 @@ bool isConditionalFromCat( // expressions. Can not optimize such cases. return false; } - auto true_ite = dynamic_cast(ite->true_value()); + auto true_ite = to(ite->true_value()); if (true_ite) { if (!isConditionalFromCat(true_ite, cond_var, comp_values, sub_exprs)) { return false; @@ -1077,7 +1082,7 @@ bool isConditionalFromCat( } else { sub_exprs->push_back(ite->true_value()); } - auto false_ite = dynamic_cast(ite->false_value()); + auto false_ite = to(ite->false_value()); if (false_ite) { return false; } @@ -1088,7 +1093,7 @@ bool isConditionalFromCat( return false; } -bool areConstantsAndSorted(const std::vector& comp_values) { +bool areConstantsAndSorted(const std::vector& comp_values) { std::vector comp_consts; comp_consts.reserve(comp_values.size()); for (auto c : comp_values) { @@ -1106,16 +1111,16 @@ bool LoopNest::optimizeConditionals() { // Consider every store in the root_stmt_ and try to optimize the // conditionals in that store. auto stores = NodeFinder::find(root_stmt_); - std::unordered_set split_fors; + std::unordered_set split_fors; for (auto store : stores) { - Var* cond_var = nullptr; + VarPtr cond_var = nullptr; // `comp_values` represent the list of compared values that will be // collected as we check for the expected pattern. Since that will // only include the RHS of the conditions in the if-then-else expressions // we need to start with `0` which is the initial bound, given that we // only handle normalized loops (check for this is done below). - std::vector comp_values = {new IntImm(0)}; - std::vector sub_exprs; + std::vector comp_values = {alloc(0)}; + std::vector sub_exprs; auto ifthenelse_exprs = NodeFinder::find(store); if (ifthenelse_exprs.empty()) { continue; @@ -1174,14 +1179,14 @@ bool LoopNest::optimizeConditionals() { // Remove all the if-then-else expressions from this store and create // one loop per sub-expression. - std::vector split_loops; + std::vector split_loops; auto cond_to_replace = ifthenelse_exprs.front(); for (size_t i = 0; i < sub_exprs.size(); ++i) { IfThenElseReplacer ifthenelseReplacer(cond_to_replace, sub_exprs[i]); auto new_store = store->accept_mutator(&ifthenelseReplacer); auto new_for_body = for_to_split->body()->clone_and_replace(store, new_store); - auto new_for = new For( + auto new_for = alloc( for_to_split->var(), comp_values[i], comp_values[i + 1], @@ -1189,30 +1194,30 @@ bool LoopNest::optimizeConditionals() { LoopNest::normalize(new_for); split_loops.push_back(new_for); } - auto par = dynamic_cast(for_to_split->get_parent()); - par->replace_stmt(for_to_split, new Block(split_loops)); + auto par = to(for_to_split->get_parent()); + par->replace_stmt(for_to_split, alloc(split_loops)); } root_stmt_ = IRSimplifier::simplify(root_stmt_); return true; } void LoopNest::vectorizeInnerLoops() { - std::vector innerLoops; - std::vector worklist; + std::vector innerLoops; + std::vector worklist; // Find outer-most For loops - if (For* rootF = dynamic_cast(root_stmt_)) { + if (ForPtr rootF = to(root_stmt_)) { worklist.push_back(rootF); - } else if (Block* body = dynamic_cast(root_stmt_)) { - std::vector blocks = {body}; + } else if (BlockPtr body = to(root_stmt_)) { + std::vector blocks = {body}; while (blocks.size()) { - Block* b = blocks.back(); + BlockPtr b = blocks.back(); blocks.pop_back(); - for (Stmt* s : *b) { - if (For* f = dynamic_cast(s)) { + for (StmtPtr s : *b) { + if (ForPtr f = to(s)) { worklist.push_back(f); - } else if (Block* b2 = dynamic_cast(s)) { + } else if (BlockPtr b2 = to(s)) { blocks.push_back(b2); } } @@ -1222,13 +1227,13 @@ void LoopNest::vectorizeInnerLoops() { // Traverse the For loop nest find inner-most loops, which are // vectorization candidates. while (worklist.size()) { - For* f = worklist.back(); + ForPtr f = worklist.back(); worklist.pop_back(); bool containsSubLoops = false; - if (Block* body = dynamic_cast(f->body())) { - for (Stmt* s2 : *body) { - if (For* f2 = dynamic_cast(s2)) { + if (BlockPtr body = to(f->body())) { + for (StmtPtr s2 : *body) { + if (ForPtr f2 = to(s2)) { containsSubLoops = true; worklist.push_back(f2); } @@ -1241,11 +1246,11 @@ void LoopNest::vectorizeInnerLoops() { } // vectorize inner loops. - for (For* loop : innerLoops) { + for (ForPtr loop : innerLoops) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* split1; + ForPtr split1; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail1; + ForPtr tail1; static const int kBodyVectorWidth = 8; splitWithTail(loop, kBodyVectorWidth, &split1, &tail1); @@ -1253,9 +1258,9 @@ void LoopNest::vectorizeInnerLoops() { if (tail1) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* split2; + ForPtr split2; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* tail2; + ForPtr tail2; static const int kTailVectorWidth = 4; splitWithTail(tail1, kTailVectorWidth, &split2, &tail2); vectorize(split2); @@ -1263,10 +1268,10 @@ void LoopNest::vectorizeInnerLoops() { } } -void LoopNest::sliceHead(For* f, int factor, For** head, For** tail) { - if (dynamic_cast(f->start()) && dynamic_cast(f->stop())) { - int start_val = dynamic_cast(f->start())->value(); - int stop_val = dynamic_cast(f->stop())->value(); +void LoopNest::sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { + if (to(f->start()) && to(f->stop())) { + int start_val = to(f->start())->value(); + int stop_val = to(f->stop())->value(); int size_val = stop_val - start_val; if (factor >= size_val) { *head = f; @@ -1279,15 +1284,15 @@ void LoopNest::sliceHead(For* f, int factor, For** head, For** tail) { throw malformed_input("sliceHead attempted on null loop", f); } - Block* p = dynamic_cast(f->get_parent()); + BlockPtr p = to(f->get_parent()); if (!p) { throw malformed_input("sliceHead attempted on loop with no parent", p); } - Expr* head_end = - new Min(new Add(f->start(), new IntImm(factor)), f->stop(), true); - *head = new For(f->var(), f->start(), head_end, Stmt::clone(f->body())); - *tail = new For( + ExprPtr head_end = alloc( + alloc(f->start(), alloc(factor)), f->stop(), true); + *head = alloc(f->var(), f->start(), head_end, Stmt::clone(f->body())); + *tail = alloc( f->var(), head_end, f->stop(), Stmt::clone(f->body()), f->loop_options()); p->replace_stmt(f, *head); @@ -1298,16 +1303,16 @@ void LoopNest::sliceHead(For* f, int factor, For** head, For** tail) { LoopNest::normalize(*tail); } } -void LoopNest::sliceHead(For* f, int factor) { +void LoopNest::sliceHead(ForPtr f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *head, *tail; + ForPtr head, tail; sliceHead(f, factor, &head, &tail); } -void LoopNest::sliceTail(For* f, int factor, For** head, For** tail) { - if (dynamic_cast(f->start()) && dynamic_cast(f->stop())) { - int start_val = dynamic_cast(f->start())->value(); - int stop_val = dynamic_cast(f->stop())->value(); +void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { + if (to(f->start()) && to(f->stop())) { + int start_val = to(f->start())->value(); + int stop_val = to(f->stop())->value(); int size_val = stop_val - start_val; if (factor >= size_val) { *head = nullptr; @@ -1320,20 +1325,20 @@ void LoopNest::sliceTail(For* f, int factor, For** head, For** tail) { throw malformed_input("sliceTail attempted on null loop", f); } - Block* p = dynamic_cast(f->get_parent()); + BlockPtr p = to(f->get_parent()); if (!p) { throw malformed_input("sliceTail attempted on loop with no parent", p); } - Expr* tail_start = - new Max(f->start(), new Sub(f->stop(), new IntImm(factor)), true); - *head = new For( + ExprPtr tail_start = alloc( + f->start(), alloc(f->stop(), alloc(factor)), true); + *head = alloc( f->var(), f->start(), tail_start, Stmt::clone(f->body()), f->loop_options()); - *tail = new For(f->var(), tail_start, f->stop(), Stmt::clone(f->body())); + *tail = alloc(f->var(), tail_start, f->stop(), Stmt::clone(f->body())); p->replace_stmt(f, *head); p->insert_stmt_after(*tail, *head); @@ -1343,32 +1348,36 @@ void LoopNest::sliceTail(For* f, int factor, For** head, For** tail) { LoopNest::normalize(*head); } } -void LoopNest::sliceTail(For* f, int factor) { +void LoopNest::sliceTail(ForPtr f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *head, *tail; + ForPtr head, tail; sliceTail(f, factor, &head, &tail); } -void LoopNest::splitWithTail(For* f, int factor) { +void LoopNest::splitWithTail(ForPtr f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *inner, *tail; + ForPtr inner, tail; splitWithTail(f, factor, &inner, &tail); } -void LoopNest::splitWithTail(For* f, int factor, For** inner, For** tail) { +void LoopNest::splitWithTail( + ForPtr f, + int factor, + ForPtr* inner, + ForPtr* tail) { if (!f) { throw malformed_input("splitWithTail attempted on null loop", f); } - Block* p = dynamic_cast(f->get_parent()); + BlockPtr p = to(f->get_parent()); if (!p) { throw malformed_input("splitWithTail attempted on loop with no parent", p); } bool tail_is_needed = true; - if (dynamic_cast(f->start()) && dynamic_cast(f->stop())) { - int start_val = dynamic_cast(f->start())->value(); - int stop_val = dynamic_cast(f->stop())->value(); + if (to(f->start()) && to(f->stop())) { + int start_val = to(f->start())->value(); + int stop_val = to(f->stop())->value(); int size_val = stop_val - start_val; int tail_size = size_val % factor; if (tail_size == 0) { @@ -1376,60 +1385,63 @@ void LoopNest::splitWithTail(For* f, int factor, For** inner, For** tail) { } } - IntImm* factor_expr = new IntImm(factor); - Expr* size = new Sub(f->stop(), f->start()); - Expr* split_count = new Div(size, factor_expr); - Expr* tail_size = new Mod(size, factor_expr); + IntImmPtr factor_expr = alloc(factor); + ExprPtr size = alloc(f->stop(), f->start()); + ExprPtr split_count = alloc
(size, factor_expr); + ExprPtr tail_size = alloc(size, factor_expr); const std::string& loop_var_name = f->var()->name_hint(); Dtype loop_var_dtype = f->var()->dtype(); - Var* i_inner = new Var(loop_var_name + "_inner", loop_var_dtype); - Var* i_outer = new Var(loop_var_name + "_outer", loop_var_dtype); + VarPtr i_inner = alloc(loop_var_name + "_inner", loop_var_dtype); + VarPtr i_outer = alloc(loop_var_name + "_outer", loop_var_dtype); // x -> x.outer * inner.size + x.inner - Expr* combined_index1 = new Add(new Mul(i_outer, factor_expr), i_inner); + ExprPtr combined_index1 = + alloc(alloc(i_outer, factor_expr), i_inner); if (tail_is_needed) { - Var* i_tail = new Var(loop_var_name + "_tail", loop_var_dtype); + VarPtr i_tail = alloc(loop_var_name + "_tail", loop_var_dtype); // x -> x.tail + outer.size * inner.size - Expr* combined_index2 = new Add(i_tail, new Mul(split_count, factor_expr)); + ExprPtr combined_index2 = + alloc(i_tail, alloc(split_count, factor_expr)); - Stmt* body_tail = + StmtPtr body_tail = SubstituteInClone(f->body(), {{f->var(), combined_index2}}); - *tail = new For(i_tail, new IntImm(0), tail_size, body_tail); + *tail = alloc(i_tail, alloc(0), tail_size, body_tail); p->insert_stmt_after(*tail, f); } else { *tail = nullptr; } - Stmt* body_inner = Substitute(f->removeBody(), {{f->var(), combined_index1}}); + StmtPtr body_inner = + Substitute(f->removeBody(), {{f->var(), combined_index1}}); - *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner); + *inner = alloc(i_inner, alloc(0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->set_var(i_outer); - f->set_start(new IntImm(0)); + f->set_start(alloc(0)); f->set_stop(split_count); f->set_body(*inner); } -void LoopNest::splitWithMask(For* f, int factor) { +void LoopNest::splitWithMask(ForPtr f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner; + ForPtr inner; splitWithMask(f, factor, &inner); } -void LoopNest::splitWithMask(For* f, int factor, For** inner) { - Block* p = dynamic_cast(f->get_parent()); +void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) { + BlockPtr p = to(f->get_parent()); if (!p) { std::cerr << "Parent is not a Block!\n"; return; } bool tail_is_needed = true; - Expr* start = IRSimplifier::simplify(f->start()); - Expr* stop = IRSimplifier::simplify(f->stop()); + ExprPtr start = IRSimplifier::simplify(f->start()); + ExprPtr stop = IRSimplifier::simplify(f->stop()); if (start->isConstant() && stop->isConstant()) { int start_val = immediateAs(start); int stop_val = immediateAs(stop); @@ -1440,69 +1452,70 @@ void LoopNest::splitWithMask(For* f, int factor, For** inner) { } } - IntImm* factor_expr = new IntImm(factor); - Expr* size = new Sub(f->stop(), f->start()); + IntImmPtr factor_expr = alloc(factor); + ExprPtr size = alloc(f->stop(), f->start()); // split_count = (size + factor - 1) / factor - Expr* split_count = - new Div(new Sub(new Add(size, factor_expr), new IntImm(1)), factor_expr); + ExprPtr split_count = alloc
( + alloc(alloc(size, factor_expr), alloc(1)), factor_expr); const std::string& loop_var_name = f->var()->name_hint(); Dtype loop_var_dtype = f->var()->dtype(); - Var* i_inner = new Var(loop_var_name + "_inner", loop_var_dtype); - Var* i_outer = new Var(loop_var_name + "_outer", loop_var_dtype); + VarPtr i_inner = alloc(loop_var_name + "_inner", loop_var_dtype); + VarPtr i_outer = alloc(loop_var_name + "_outer", loop_var_dtype); // x -> x.outer * inner.size + x.inner - Expr* combined_index = new Add(new Mul(i_outer, factor_expr), i_inner); + ExprPtr combined_index = + alloc(alloc(i_outer, factor_expr), i_inner); - Stmt* body_inner = f->removeBody(); + StmtPtr body_inner = f->removeBody(); // TODO: is it ok that we're doing it eagerly? In the other implementation we // are only materializing predicates at the last, lowering, step. if (tail_is_needed) { - IntImm* start = dynamic_cast(f->start()); + IntImmPtr start = to(f->start()); if (!start || start->value() != 0) { throw unimplemented_lowering(); } - Expr* predicate = + ExprPtr predicate = CompareSelect::make(ExprHandle(f->var()), ExprHandle(f->stop()), kLT) .node(); body_inner = Cond::make(ExprHandle(predicate), body_inner, nullptr); } body_inner = Substitute(body_inner, {{f->var(), combined_index}}); - *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner); + *inner = alloc(i_inner, alloc(0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->set_var(i_outer); - f->set_start(new IntImm(0)); + f->set_start(alloc(0)); f->set_stop(split_count); f->set_body(*inner); } -std::vector LoopNest::distributeLoop( - For* loop, - const std::unordered_set& pivots) { +std::vector LoopNest::distributeLoop( + ForPtr loop, + const std::unordered_set& pivots) { TORCH_INTERNAL_ASSERT(loop); auto root = loop->get_parent(); if (root == nullptr) { throw malformed_input("Loop without parent: ", loop); } - auto root_block = dynamic_cast(root); + auto root_block = to(root); if (root_block == nullptr) { throw malformed_input( "Loop's parent must be a Block, instead found ", root); } // Extract bodies for all the loops after distribution. - std::vector new_loop_bodies; - auto new_loop_body = new Block({}); + std::vector new_loop_bodies; + auto new_loop_body = alloc(std::vector({})); while (!loop->body()->empty()) { auto s = loop->body()->front(); loop->body()->remove_stmt(s); new_loop_body->append_stmt(s); if (pivots.count(s)) { new_loop_bodies.push_back(new_loop_body); - new_loop_body = new Block({}); + new_loop_body = alloc(std::vector({})); } } if (!new_loop_body->empty()) { @@ -1511,7 +1524,7 @@ std::vector LoopNest::distributeLoop( // The first loop body has to be in the original loop. loop->body()->splice(loop->body()->begin(), new_loop_bodies.front()); - std::vector new_loops = {loop}; + std::vector new_loops = {loop}; // Create loops for all the remaining blocks. // Add all the new loops to the parent block. @@ -1524,13 +1537,13 @@ std::vector LoopNest::distributeLoop( return new_loops; } -std::vector LoopNest::distributeLoop(For* loop) { - std::unordered_set stmtsInBlock( +std::vector LoopNest::distributeLoop(ForPtr loop) { + std::unordered_set stmtsInBlock( loop->body()->begin(), loop->body()->end()); return distributeLoop(loop, stmtsInBlock); } -std::vector LoopNest::distributeLoopAndParents(For* loop) { +std::vector LoopNest::distributeLoopAndParents(ForPtr loop) { auto parentLoop = getParentLoop(loop); auto result = distributeLoop(loop); if (parentLoop) { @@ -1539,13 +1552,14 @@ std::vector LoopNest::distributeLoopAndParents(For* loop) { return result; } -std::vector LoopNest::distributeLoopOverInnerLoops(For* loop) { +std::vector LoopNest::distributeLoopOverInnerLoops(ForPtr loop) { auto loops = NodeFinder::find(loop); - std::unordered_set loopsSet(loops.begin(), loops.end()); + std::unordered_set loopsSet(loops.begin(), loops.end()); return distributeLoop(loop, loopsSet); } -std::vector LoopNest::distributeLoopAndParentsOverInnerLoops(For* loop) { +std::vector LoopNest::distributeLoopAndParentsOverInnerLoops( + ForPtr loop) { auto parentLoop = getParentLoop(loop); auto result = distributeLoopOverInnerLoops(loop); if (parentLoop) { @@ -1554,13 +1568,15 @@ std::vector LoopNest::distributeLoopAndParentsOverInnerLoops(For* loop) { return result; } -bool areEqual(Expr* expr1, Expr* expr2) { - auto diff = IRSimplifier::simplify(new Sub(expr1, expr2)); +bool areEqual(ExprPtr expr1, ExprPtr expr2) { + auto diff = IRSimplifier::simplify(alloc(expr1, expr2)); return diff->isConstant() && (immediateAs(diff) == 0); }; -bool doesExprContainAnyVar(Expr* expr, const std::unordered_set& vars) { - for (auto* v : VarFinder::find(expr)) { +bool doesExprContainAnyVar( + ExprPtr expr, + const std::unordered_set& vars) { + for (auto v : VarFinder::find(expr)) { if (vars.count(v)) { return true; } @@ -1572,9 +1588,9 @@ bool doesExprContainAnyVar(Expr* expr, const std::unordered_set& vars) { // that are loop-independent w.r.t. the given list of outer loop // variables. bool areIndicesLoopIndependent( - const std::vector& expr_list1, - const std::vector& expr_list2, - const std::unordered_set& outer_loop_vars) { + const std::vector& expr_list1, + const std::vector& expr_list2, + const std::unordered_set& outer_loop_vars) { if (expr_list1.size() != expr_list2.size()) { return false; } @@ -1591,11 +1607,11 @@ bool areIndicesLoopIndependent( return true; } -bool LoopNest::hasLoopCarriedDependence(For* loop) { +bool LoopNest::hasLoopCarriedDependence(ForPtr loop) { analysis::MemDependencyChecker analyzer; loop->accept(&analyzer); - std::unordered_set outer_loop_vars = {loop->var()}; + std::unordered_set outer_loop_vars = {loop->var()}; auto outer_loops = LoopNest::getEnclosingLoopNest(loop); for (auto l : outer_loops) { outer_loop_vars.insert(l->var()); @@ -1687,7 +1703,9 @@ bool LoopNest::hasLoopCarriedDependence(For* loop) { return false; } -bool LoopNest::unsafeFuseLoops(const std::vector& loops, For** fused) { +bool LoopNest::unsafeFuseLoops( + const std::vector& loops, + ForPtr* fused) { if (loops.empty()) { return false; } @@ -1707,7 +1725,7 @@ bool LoopNest::unsafeFuseLoops(const std::vector& loops, For** fused) { return false; } } - auto root_block = dynamic_cast(root); + auto root_block = to(root); if (root_block == nullptr) { return false; } @@ -1735,7 +1753,7 @@ bool LoopNest::unsafeFuseLoops(const std::vector& loops, For** fused) { // onwards and moving them into the first loop's body. // This way the final fused loop will be the same as the first loop. for (size_t i = 1; i < loops.size(); ++i) { - auto body = dynamic_cast(SubstituteInClone( + auto body = to(SubstituteInClone( loops[i]->body(), {{loops[i]->var(), first_loop->var()}})); first_loop->body()->splice(first_loop->body()->end(), body); root_block->remove_stmt(loops[i]); @@ -1745,7 +1763,7 @@ bool LoopNest::unsafeFuseLoops(const std::vector& loops, For** fused) { return true; } -bool LoopNest::fuseLoops(const std::vector& loops, For** fused) { +bool LoopNest::fuseLoops(const std::vector& loops, ForPtr* fused) { if (loops.empty()) { return false; } @@ -1774,16 +1792,16 @@ bool LoopNest::fuseLoops(const std::vector& loops, For** fused) { // This check can be done only after the loops are fused into one. But if the // check is violated, we need to return the given loops in the original form. // So, we create a clone of all the loops, fuse them and check for this. - std::vector loops_copy; + std::vector loops_copy; loops_copy.reserve(loops.size()); - Block* parent = new Block({}); + BlockPtr parent = alloc(std::vector({})); for (auto& l : loops) { auto l_copy = Stmt::clone(l); - loops_copy.push_back(dynamic_cast(l_copy)); + loops_copy.push_back(to(l_copy)); parent->append_stmt(l_copy); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* fused_copy; + ForPtr fused_copy; bool ret = unsafeFuseLoops(loops_copy, &fused_copy); if (!ret || hasLoopCarriedDependence(fused_copy)) { return false; @@ -1793,8 +1811,8 @@ bool LoopNest::fuseLoops(const std::vector& loops, For** fused) { return unsafeFuseLoops(loops, fused); } -For* findOuterFor(For* a, For* b) { - Stmt* s = b; // guess b is the latter. +ForPtr findOuterFor(ForPtr a, ForPtr b) { + StmtPtr s = b; // guess b is the latter. while (s != nullptr) { if (s == a) { // yes, b is after a. @@ -1817,24 +1835,24 @@ For* findOuterFor(For* a, For* b) { return nullptr; } -void LoopNest::reorderAxis(For* a, For* b) { +void LoopNest::reorderAxis(ForPtr a, ForPtr b) { if (a == b) { // nothing to do. return; } // find inner and outer. - For* outer = findOuterFor(a, b); + ForPtr outer = findOuterFor(a, b); if (outer == nullptr) { throw std::runtime_error("Reordered a loop not in LoopNest"); } - For* inner = a == outer ? b : a; - std::deque internal_axes; + ForPtr inner = a == outer ? b : a; + std::deque internal_axes; // Find relevant axes, store reversed. - Stmt* s = inner; + StmtPtr s = inner; while (s != outer) { - if (For* f = dynamic_cast(s)) { + if (ForPtr f = to(s)) { internal_axes.push_back(f); } @@ -1844,26 +1862,27 @@ void LoopNest::reorderAxis(For* a, For* b) { internal_axes.push_back(outer); - Block* root = dynamic_cast(outer->get_parent()); + BlockPtr root = to(outer->get_parent()); CHECK(root); // Do a shallow copy of the inner blocks. - Block* body = new Block({}); + BlockPtr body = alloc(std::vector({})); body->splice(body->end(), inner->body()); - For* before{outer}; - For* after{nullptr}; - For* last = internal_axes.front(); - Stmt* newInner = body; + ForPtr before{outer}; + ForPtr after{nullptr}; + ForPtr last = internal_axes.front(); + StmtPtr newInner = body; s = inner; while (s != outer) { - if (auto cond = dynamic_cast(s->get_parent())) { + if (auto cond = to(s->get_parent())) { if (s == cond->true_stmt()) { newInner = cond->cloneWithNewBody(newInner); } else { // s is the false branch of Cond - newInner = cond->cloneWithNewBodies(new Block({}), newInner); + newInner = cond->cloneWithNewBodies( + alloc(std::vector({})), newInner); } } s = s->get_parent(); @@ -1883,7 +1902,7 @@ void LoopNest::reorderAxis(For* a, For* b) { // When reordering loop i and j we need to ensure that Statement A and C are // still both executed with the loop extents of i, and that the three // statements are not reordered (as much as possible). - for (auto* loop : internal_axes) { + for (auto loop : internal_axes) { // If the inner loop had a component after the loop we must wrap it in a For // loop matching this level of the tree. if (after != nullptr) { @@ -1894,7 +1913,7 @@ void LoopNest::reorderAxis(For* a, For* b) { bool hadBeforeStmts = false; for (auto I = loop->body()->begin(), E = loop->body()->end(); I != E;) { // Be careful not to invalidate the iterator. - Stmt* s = *(I++); + StmtPtr s = *(I++); if (s == last) { // This is the midpoint. loop->body()->remove_stmt(s); @@ -1925,7 +1944,7 @@ void LoopNest::reorderAxis(For* a, For* b) { std::swap(internal_axes.front(), internal_axes.back()); // Create the reordered internals: - for (auto* loop : internal_axes) { + for (auto loop : internal_axes) { newInner = loop->cloneWithNewBody(newInner); } @@ -1956,8 +1975,8 @@ bool isValidPermutation(std::vector permutation) { return isTrivialPermutation(permutation); } -std::vector LoopNest::reorder( - const std::vector& loops, +std::vector LoopNest::reorder( + const std::vector& loops, const std::vector& permutation) { if (loops.size() != permutation.size()) { throw malformed_input("invalid permutation size"); @@ -1975,13 +1994,13 @@ std::vector LoopNest::reorder( throw malformed_input("reorder is only allowed on perfectly nested loops"); } - auto parent = dynamic_cast(loops.front()->get_parent()); + auto parent = to(loops.front()->get_parent()); if (parent == nullptr) { throw malformed_input("parent of the loops must be a Block"); } // Reorder the loops according to the permutation. - std::vector result(loops.size()); + std::vector result(loops.size()); for (size_t i = 0; i < loops.size(); ++i) { result[i] = loops[permutation[i]]; } @@ -1991,10 +2010,10 @@ std::vector LoopNest::reorder( // We use an empty block statement to replace the outermost loop // so that we know the position where the outermost reordered loop // is to be inserted. - auto empty_block = new Block({}); + auto empty_block = alloc(std::vector({})); parent->replace_stmt(loops.front(), empty_block); for (size_t i = 1; i < loops.size(); ++i) { - auto block = dynamic_cast(loops[i]->get_parent()); + auto block = to(loops[i]->get_parent()); TORCH_INTERNAL_ASSERT(block); block->remove_stmt(loops[i]); } @@ -2008,7 +2027,7 @@ std::vector LoopNest::reorder( return result; } -For* LoopNest::getLoopAt(For* root, const std::vector& indices) const { +ForPtr LoopNest::getLoopAt(ForPtr root, const std::vector& indices) const { if (indices.empty()) { return root; } @@ -2016,14 +2035,14 @@ For* LoopNest::getLoopAt(For* root, const std::vector& indices) const { throw malformed_input("root loop is null"); } - For* curr = root; + ForPtr curr = root; for (auto i : indices) { if (i < 0 || curr->body()->nstmts() <= i) { return nullptr; } - std::list::iterator stmtp = curr->body()->begin(); + std::list::iterator stmtp = curr->body()->begin(); std::advance(stmtp, i); - curr = dynamic_cast(*stmtp); + curr = to(*stmtp); if (curr == nullptr) { return nullptr; } @@ -2032,8 +2051,8 @@ For* LoopNest::getLoopAt(For* root, const std::vector& indices) const { return curr; } -For* LoopNest::tile(For* x, For* y, int x_factor, int y_factor) { - auto parent = dynamic_cast(x->get_parent()); +ForPtr LoopNest::tile(ForPtr x, ForPtr y, int x_factor, int y_factor) { + auto parent = to(x->get_parent()); if (parent == nullptr) { throw malformed_input("parent of the loops must be a Block"); } @@ -2043,10 +2062,10 @@ For* LoopNest::tile(For* x, For* y, int x_factor, int y_factor) { // Split x, y axes by x_factor and y_factor // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *yi, *ytail; + ForPtr yi, ytail; splitWithTail(y, y_factor, &yi, &ytail); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *xi, *xtail; + ForPtr xi, xtail; splitWithTail(x, x_factor, &xi, &xtail); // Distribute xi over yo and ytail so we can manipulate the loop order of {xo, @@ -2055,14 +2074,14 @@ For* LoopNest::tile(For* x, For* y, int x_factor, int y_factor) { // For {xi, yo, yi}, reorder the axes to be yo, xi, yi xi = loops.front(); - For* yo = dynamic_cast(xi->body()->stmts().front()); + ForPtr yo = to(xi->body()->stmts().front()); CHECK(yo); reorder({xi, yo}, {1, 0}); // For {xi, ytail}, reorder the axes to be ytail, xi if (loops.size() == 2) { xi = loops.back(); - ytail = dynamic_cast(xi->body()->stmts().front()); + ytail = to(xi->body()->stmts().front()); CHECK(ytail); reorder({xi, ytail}, {1, 0}); } @@ -2070,7 +2089,7 @@ For* LoopNest::tile(For* x, For* y, int x_factor, int y_factor) { return xtail; } -bool LoopNest::areLoopsPerfectlyNested(const std::vector& loops) { +bool LoopNest::areLoopsPerfectlyNested(const std::vector& loops) { if (loops.size() < 2) { return true; } @@ -2083,8 +2102,8 @@ bool LoopNest::areLoopsPerfectlyNested(const std::vector& loops) { return true; } -void LoopNest::unroll(For* f, Stmt** unrolled) { - Block* p = dynamic_cast(f->get_parent()); +void LoopNest::unroll(ForPtr f, StmtPtr* unrolled) { + BlockPtr p = to(f->get_parent()); if (!f) { throw malformed_input("unroll attempted on null loop"); } else if (!p) { @@ -2100,7 +2119,7 @@ void LoopNest::unroll(For* f, Stmt** unrolled) { throw std::runtime_error("Can't unroll due to non-constant loop stop!"); } - std::vector unrolled_stmts; + std::vector unrolled_stmts; int start_val = immediateAs(start_expr); int stop_val = immediateAs(stop_expr); for (int current = start_val; current < stop_val; ++current) { @@ -2109,26 +2128,26 @@ void LoopNest::unroll(For* f, Stmt** unrolled) { stmt, {{f->var(), getImmediateByType(f->var()->dtype(), current)}})); } } - *unrolled = new Block(unrolled_stmts); + *unrolled = alloc(unrolled_stmts); *unrolled = IRSimplifier::simplify(*unrolled); p->replace_stmt(f, *unrolled); } -void LoopNest::unroll(For* f) { +void LoopNest::unroll(ForPtr f) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Stmt* unrolled; + StmtPtr unrolled; unroll(f, &unrolled); } -bool LoopNest::isNormalized(For* f) { +bool LoopNest::isNormalized(ForPtr f) { if (f->start()->isConstant()) { return immediateAs(f->start()) == 0; } return false; } -bool LoopNest::normalize(For* f) { +bool LoopNest::normalize(ForPtr f) { if (!f) { throw malformed_input("normalize attempted on null loop"); } @@ -2142,31 +2161,31 @@ bool LoopNest::normalize(For* f) { f->body(), {{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}}); f->set_body(IRSimplifier::simplify(for_body_normalized)); - f->set_stop(IRSimplifier::simplify(new Sub(f->stop(), f->start()))); - f->set_start(new IntImm(0)); + f->set_stop(IRSimplifier::simplify(alloc(f->stop(), f->start()))); + f->set_start(alloc(0)); return true; } // This function expects that there are 'num' loops perfectly nested within // and including 'f'. -std::vector LoopNest::getLoopStmtsInLoopNest(For* f, size_t num) { - std::vector loops(num); - For* curr_for = f; +std::vector LoopNest::getLoopStmtsInLoopNest(ForPtr f, size_t num) { + std::vector loops(num); + ForPtr curr_for = f; loops[0] = curr_for; for (size_t i = 1; i < num; ++i) { TORCH_INTERNAL_ASSERT(curr_for->body()->nstmts() == 1); - curr_for = dynamic_cast(curr_for->body()->front()); + curr_for = to(curr_for->body()->front()); TORCH_INTERNAL_ASSERT(curr_for); loops[i] = curr_for; } return loops; } -bool LoopNest::flatten(const std::vector& loops, For** flattened) { +bool LoopNest::flatten(const std::vector& loops, ForPtr* flattened) { if (loops.empty()) { throw malformed_input("flatten attempted on empty set of loops"); } - Block* p = dynamic_cast(loops[0]->get_parent()); + BlockPtr p = to(loops[0]->get_parent()); if (!p) { throw malformed_input("flatten attempted on loops with no parent"); } @@ -2204,37 +2223,37 @@ bool LoopNest::flatten(const std::vector& loops, For** flattened) { // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) auto normalized_loops = getLoopStmtsInLoopNest(loops.front(), loops.size()); - auto flat_var = new Var( + auto flat_var = alloc( normalized_loops[0]->var()->name_hint() + "_flat", normalized_loops[0]->var()->dtype()); VarMapping var_mapping; - Expr* stop = new IntImm(1); + ExprPtr stop = alloc(1); for (size_t i = 0; i < normalized_loops.size(); ++i) { size_t idx = normalized_loops.size() - i - 1; auto curr_loop = normalized_loops[idx]; - Expr* div = new Div(flat_var, stop); - Expr* sub_expr = idx == 0 ? div : new Mod(div, curr_loop->stop()); + ExprPtr div = alloc
(flat_var, stop); + ExprPtr sub_expr = idx == 0 ? div : alloc(div, curr_loop->stop()); var_mapping.push_back(std::make_pair(curr_loop->var(), sub_expr)); - stop = new Mul(curr_loop->stop(), stop); + stop = alloc(curr_loop->stop(), stop); } auto flattened_body = Substitute(normalized_loops.back()->removeBody(), var_mapping); normalized_loops.front()->set_var(flat_var); - normalized_loops.front()->set_start(new IntImm(0)); + normalized_loops.front()->set_start(alloc(0)); normalized_loops.front()->set_stop(stop); normalized_loops.front()->set_body(flattened_body); *flattened = normalized_loops.front(); return true; } -bool LoopNest::flatten(const std::vector& loops) { +bool LoopNest::flatten(const std::vector& loops) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* flattened; + ForPtr flattened; return flatten(loops, &flattened); } -void LoopNest::compressBuffer(Buf* buf, Stmt* stmt) { +void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) { // Loop iterations in NNC IR do not follow sequential semantics by default. // In other words, the iterations of the loops could be executed in any // random order without affecting correctness. This constraint in turn @@ -2265,7 +2284,7 @@ void LoopNest::compressBuffer(Buf* buf, Stmt* stmt) { auto reads = StmtsReadingBuf::find(stmt, buf); // Find the parent common to all the buffer accesses. - Block* parent = dynamic_cast(writes.front()->get_parent()); + BlockPtr parent = to(writes.front()->get_parent()); TORCH_INTERNAL_ASSERT(parent); for (auto w : writes) { parent = Block::getSharedParent(parent, w); @@ -2276,7 +2295,7 @@ void LoopNest::compressBuffer(Buf* buf, Stmt* stmt) { // Collect all the loops that are above the common parent. auto loops = LoopNest::getEnclosingLoopNest(parent); - std::unordered_set loop_vars; + std::unordered_set loop_vars; for (auto l : loops) { loop_vars.insert(l->var()); } @@ -2287,7 +2306,7 @@ void LoopNest::compressBuffer(Buf* buf, Stmt* stmt) { // Vector to indicate which dimensions could be compressed away. std::vector dims(buf->dims().size(), true); - auto check_indices = [&](const std::vector& indices) { + auto check_indices = [&](const std::vector& indices) { TORCH_INTERNAL_ASSERT(indices.size() == dims.size()); for (size_t i = 0; i < indices.size(); ++i) { auto index_vars = NodeFinder::find(indices[i]); @@ -2320,21 +2339,21 @@ void LoopNest::compressBuffer(Buf* buf, Stmt* stmt) { } // Compress buffer by removing the marked dims. - std::vector new_dims(buf->dims()); + std::vector new_dims(buf->dims()); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { - new_dims[i] = new IntImm(1); + new_dims[i] = alloc(1); } } buf->set_dims(new_dims); // Modify all access to reflect the removed dims. - auto get_new_indices = [&](const std::vector& indices) { + auto get_new_indices = [&](const std::vector& indices) { TORCH_INTERNAL_ASSERT(indices.size() == dims.size()); - std::vector new_indices(indices); + std::vector new_indices(indices); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { - new_indices[i] = new IntImm(0); + new_indices[i] = alloc(0); } } return new_indices; @@ -2351,27 +2370,27 @@ void LoopNest::compressBuffer(Buf* buf, Stmt* stmt) { } } -void LoopNest::compressAllBuffers(Stmt* stmt) { +void LoopNest::compressAllBuffers(StmtPtr stmt) { for (auto buf : BufFinder::find(stmt)) { - compressBuffer(const_cast(buf), stmt); + compressBuffer(const_cast(buf), stmt); } } -std::vector LoopNest::getLoopStmtsFor(Tensor* t) const { - Stmt* cur_stmt = getLoopBodyFor(t); +std::vector LoopNest::getLoopStmtsFor(Tensor* t) const { + StmtPtr cur_stmt = getLoopBodyFor(t); return getLoopStmtsFor(cur_stmt); } -std::vector LoopNest::getLoopStmtsFor(Buf* buf) const { - Stmt* cur_stmt = getLoopBodyFor(buf); +std::vector LoopNest::getLoopStmtsFor(BufPtr buf) const { + StmtPtr cur_stmt = getLoopBodyFor(buf); return getLoopStmtsFor(cur_stmt); } -std::vector LoopNest::getLoopStmtsFor(Stmt* s) const { - std::vector result; +std::vector LoopNest::getLoopStmtsFor(StmtPtr s) const { + std::vector result; while (s) { - if (auto* loop = dynamic_cast(s)) { + if (auto loop = to(s)) { result.push_back(loop); } s = s->get_parent(); @@ -2380,25 +2399,25 @@ std::vector LoopNest::getLoopStmtsFor(Stmt* s) const { return result; } -Stmt* LoopNest::getLoopBodyFor(Tensor* t) const { +StmtPtr LoopNest::getLoopBodyFor(Tensor* t) const { return getLoopBodyFor(t->buf()); } -Stmt* LoopNest::getLoopBodyFor(Buf* buf) const { +StmtPtr LoopNest::getLoopBodyFor(BufPtr buf) const { auto writes = WritesToBuf::find(root_stmt_, buf); // special case for reduction Tensors, ignore the initializer if it's the only // op: if (writes.size() == 2) { - if (Store* s = dynamic_cast(writes.back())) { - if (ReduceOp* r = dynamic_cast(s->value())) { - return (Stmt*)s; // NOLINT + if (StorePtr s = to(writes.back())) { + if (ReduceOpPtr r = to(s->value())) { + return (StmtPtr)s; // NOLINT } } } - Stmt* res = nullptr; - for (auto* s : writes) { + StmtPtr res = nullptr; + for (auto s : writes) { if (!res) { res = s; continue; @@ -2407,22 +2426,22 @@ Stmt* LoopNest::getLoopBodyFor(Buf* buf) const { res = Block::getSharedParent(res, s); } - return (Stmt*)res; // NOLINT + return (StmtPtr)res; // NOLINT } -For* LoopNest::getParentLoop(Stmt* st) { +ForPtr LoopNest::getParentLoop(StmtPtr st) { if (st == nullptr) { return nullptr; } auto par = st->get_parent(); - if (auto f = dynamic_cast(par)) { + if (auto f = to(par)) { return f; } return getParentLoop(par); } -std::vector LoopNest::getEnclosingLoopNest(Stmt* st) { - std::vector loops; +std::vector LoopNest::getEnclosingLoopNest(StmtPtr st) { + std::vector loops; auto f = getParentLoop(st); while (f) { loops.push_back(f); @@ -2432,13 +2451,14 @@ std::vector LoopNest::getEnclosingLoopNest(Stmt* st) { return loops; } -std::vector LoopNest::getAllWritesToBuf(Buf* buf) const { +std::vector LoopNest::getAllWritesToBuf(BufPtr buf) const { return WritesToBuf::find(root_stmt_, buf); } -std::vector LoopNest::getAllInnermostLoopsWritingToBuf(Buf* buf) const { +std::vector LoopNest::getAllInnermostLoopsWritingToBuf( + BufPtr buf) const { auto writes = getAllWritesToBuf(buf); - std::vector innermost_loops; + std::vector innermost_loops; innermost_loops.reserve(writes.size()); for (auto w : writes) { innermost_loops.push_back(LoopNest::getParentLoop(w)); @@ -2446,10 +2466,10 @@ std::vector LoopNest::getAllInnermostLoopsWritingToBuf(Buf* buf) const { return innermost_loops; } -std::vector> LoopNest::getAllLoopNestsWritingToBuf( - Buf* buf) const { +std::vector> LoopNest::getAllLoopNestsWritingToBuf( + BufPtr buf) const { auto writes = getAllWritesToBuf(buf); - std::vector> loopnests; + std::vector> loopnests; loopnests.reserve(writes.size()); for (auto w : writes) { loopnests.emplace_back(LoopNest::getEnclosingLoopNest(w)); @@ -2457,12 +2477,12 @@ std::vector> LoopNest::getAllLoopNestsWritingToBuf( return loopnests; } -Stmt* LoopNest::simplify() { +StmtPtr LoopNest::simplify() { root_stmt_ = IRSimplifier::simplify(root_stmt_); return root_stmt_; } -Stmt* FlattenIndexes(Stmt* s) { +StmtPtr FlattenIndexes(StmtPtr s) { IndexFlattener idx_flattener; return idx_flattener.flatten(s); } @@ -2471,34 +2491,37 @@ Stmt* FlattenIndexes(Stmt* s) { // LoopNest::computeAt for more details. class LoopComputeAtRewriter : public IRMutator { public: - LoopComputeAtRewriter(Buf* buf, Buf* new_buf, std::vector offsets) + LoopComputeAtRewriter( + BufPtr buf, + BufPtr new_buf, + std::vector offsets) : buf_(buf), new_buf_(new_buf), offsets_(std::move(offsets)) {} private: - Buf* buf_; - Buf* new_buf_; - std::vector offsets_; + BufPtr buf_; + BufPtr new_buf_; + std::vector offsets_; - Expr* mutate(Load* v) override { + ExprPtr mutate(LoadPtr v) override { if (v->buf() != buf_) { return v; } - std::vector new_indices(v->indices().size()); + std::vector new_indices(v->indices().size()); for (const auto i : c10::irange(v->indices().size())) { new_indices[i] = - IRSimplifier::simplify(new Sub(v->indices()[i], offsets_[i])); + IRSimplifier::simplify(alloc(v->indices()[i], offsets_[i])); } - return new Load(v->dtype(), new_buf_, new_indices); + return alloc(v->dtype(), new_buf_, new_indices); } }; -static Store* getStoreStmtOfProducer(Stmt* s) { - if (Store* st = dynamic_cast(s)) { +static StorePtr getStoreStmtOfProducer(StmtPtr s) { + if (StorePtr st = to(s)) { return st; } - if (Block* b = dynamic_cast(s)) { - for (Stmt* ss : *b) { - if (Store* st = dynamic_cast(ss)) { + if (BlockPtr b = to(s)) { + for (StmtPtr ss : *b) { + if (StorePtr st = to(ss)) { return st; } } @@ -2506,11 +2529,11 @@ static Store* getStoreStmtOfProducer(Stmt* s) { return nullptr; } -static std::vector getOuterLoopIndexes(Stmt* s) { - std::vector res; - Stmt* cur = s; +static std::vector getOuterLoopIndexes(StmtPtr s) { + std::vector res; + StmtPtr cur = s; while (cur) { - if (auto l = dynamic_cast(cur)) { + if (auto l = to(cur)) { res.push_back(l->var()); } cur = cur->get_parent(); @@ -2520,63 +2543,63 @@ static std::vector getOuterLoopIndexes(Stmt* s) { class CacheReplacer : public IRMutator { public: - CacheReplacer(Buf* buffer, Buf* cache, std::vector& offsets) + CacheReplacer(BufPtr buffer, BufPtr cache, std::vector& offsets) : buf_(buffer), cache_(cache), offsets_(offsets) {} private: - Expr* mutate(Load* v) override { - Buf* buf = v->buf(); + ExprPtr mutate(LoadPtr v) override { + BufPtr buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } // Map indices to call-parameters. - std::vector newIndices; + std::vector newIndices; TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); for (size_t i = 0; i < v->indices().size(); ++i) { - Expr* index = v->indices()[i]->accept_mutator(this); - Expr* offset = offsets_[i]; - Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + ExprPtr index = v->indices()[i]->accept_mutator(this); + ExprPtr offset = offsets_[i]; + ExprPtr sub = IRSimplifier::simplify(alloc(index, offset)); newIndices.push_back(sub); } - return new Load(cache_, newIndices); + return alloc(cache_, newIndices); } - Stmt* mutate(Store* v) override { - Buf* buf = v->buf(); + StmtPtr mutate(StorePtr v) override { + BufPtr buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } - Expr* newValue = v->value()->accept_mutator(this); + ExprPtr newValue = v->value()->accept_mutator(this); // Map indices to call-parameters. - std::vector newIndices; + std::vector newIndices; TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); for (size_t i = 0; i < v->indices().size(); ++i) { - Expr* index = v->indices()[i]->accept_mutator(this); - Expr* offset = offsets_[i]; - Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + ExprPtr index = v->indices()[i]->accept_mutator(this); + ExprPtr offset = offsets_[i]; + ExprPtr sub = IRSimplifier::simplify(alloc(index, offset)); newIndices.push_back(sub); } - return new Store(cache_, newIndices, newValue); + return alloc(cache_, newIndices, newValue); } - Buf* buf_; - Buf* cache_; - std::vector& offsets_; + BufPtr buf_; + BufPtr cache_; + std::vector& offsets_; }; LoopNest::AccessResult LoopNest::cacheAccesses( - Buf* producer, + BufPtr producer, const std::string& name, - Stmt* consumer) { - ReduceOp* reduceOp{nullptr}; + StmtPtr consumer) { + ReduceOpPtr reduceOp{nullptr}; auto stores = NodeFinder::find(consumer); - for (auto* store : stores) { - if (auto ro = dynamic_cast(store->value())) { + for (auto store : stores) { + if (auto ro = to(store->value())) { if (store->buf() != producer) { continue; } @@ -2605,45 +2628,46 @@ LoopNest::AccessResult LoopNest::cacheAccesses( bool hasWrites = info.kind == kStore || info.kind == kMutate; std::vector var_names = {"i", "j", "k", "l", "m", "n", "o", "p"}; - std::vector tmp_dims; - std::vector new_loop_vars; - std::vector new_loop_vars_expr; + std::vector tmp_dims; + std::vector new_loop_vars; + std::vector new_loop_vars_expr; // Determine the size of the cache, and create a loop var for each dimension. for (size_t i = 0; i < info.start.size(); ++i) { - Expr* dim = IRSimplifier::simplify( - new Add(new Sub(info.stop[i], info.start[i]), new IntImm(1))); + ExprPtr dim = IRSimplifier::simplify( + alloc(alloc(info.stop[i], info.start[i]), alloc(1))); tmp_dims.push_back(dim); - new_loop_vars.push_back(new Var(var_names[i % var_names.size()], kInt)); + new_loop_vars.push_back(alloc(var_names[i % var_names.size()], kInt)); new_loop_vars_expr.push_back(new_loop_vars[i]); } // Create the var. - Buf* tmp_buf = new Buf(new Var(name, kHandle), tmp_dims, producer->dtype()); + BufPtr tmp_buf = + alloc(alloc(name, kHandle), tmp_dims, producer->dtype()); // determine the offsets for calls into the cache based off the loop start of // each axis. - std::vector tmp_params; + std::vector tmp_params; for (size_t i = 0; i < new_loop_vars.size(); ++i) { - tmp_params.push_back(new Add(new_loop_vars[i], info.start[i])); + tmp_params.push_back(alloc(new_loop_vars[i], info.start[i])); } // Replace acceses to the producer in the consumer with the cache. CacheReplacer replacer(producer, tmp_buf, info.start); // TODO: Can we reuse 'consumer' below without cloning? - Stmt* new_consumer = + StmtPtr new_consumer = IRSimplifier::simplify(Stmt::clone(consumer)->accept_mutator(&replacer)); // replace the old consumer with the replaced consumer. - Block* consumer_block = nullptr; + BlockPtr consumer_block = nullptr; // if the consumer is a block, we should mutate it in place. - if ((consumer_block = dynamic_cast(consumer))) { + if ((consumer_block = to(consumer))) { consumer_block->clear(); consumer_block->append_stmt(new_consumer); } else { - consumer_block = dynamic_cast(consumer->get_parent()); + consumer_block = to(consumer->get_parent()); assert(consumer_block); consumer_block->replace_stmt(consumer, new_consumer); } @@ -2654,9 +2678,9 @@ LoopNest::AccessResult LoopNest::cacheAccesses( // Instead we need to create a new ReduceOp. bool on_reduce_axis = false; if (reduceOp) { - std::set reduce_args( + std::set reduce_args( reduceOp->reduce_args().begin(), reduceOp->reduce_args().end()); - std::set enclosing_vars; + std::set enclosing_vars; for (auto enclosing_for_stmt : NodeFinder::find(consumer)) { enclosing_vars.insert(enclosing_for_stmt->var()); } @@ -2670,29 +2694,29 @@ LoopNest::AccessResult LoopNest::cacheAccesses( // reduceOp means we had both loads and stores. // Init cache to 0. - Stmt* tmp_init = new Store( + StmtPtr tmp_init = alloc( tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_init = - new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_init); + alloc(new_loop_vars[i], alloc(0), tmp_dims[i], tmp_init); } consumer_block->insert_stmt_before(tmp_init, new_consumer); // Reduce back to the original buffer: - Stmt* tmp_store = new Store( + StmtPtr tmp_store = alloc( producer, tmp_params, reduceOp->reducer()( producer, - ExprHandle(new Load(tmp_buf, new_loop_vars_expr)), + ExprHandle(alloc(tmp_buf, new_loop_vars_expr)), tmp_params, {})); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { - tmp_store = - new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + tmp_store = alloc( + new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); } consumer_block->insert_stmt_after(tmp_store, new_consumer); @@ -2702,12 +2726,12 @@ LoopNest::AccessResult LoopNest::cacheAccesses( if (hasReads) { // Fill the cache with values from the consumer. - Stmt* tmp_store = - new Store(tmp_buf, new_loop_vars_expr, new Load(producer, tmp_params)); + StmtPtr tmp_store = alloc( + tmp_buf, new_loop_vars_expr, alloc(producer, tmp_params)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { - tmp_store = - new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + tmp_store = alloc( + new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); } consumer_block->insert_stmt_before(tmp_store, new_consumer); @@ -2715,12 +2739,12 @@ LoopNest::AccessResult LoopNest::cacheAccesses( if (hasWrites) { // sync the cache back to the producer buf. - Stmt* tmp_store = - new Store(producer, tmp_params, new Load(tmp_buf, new_loop_vars_expr)); + StmtPtr tmp_store = alloc( + producer, tmp_params, alloc(tmp_buf, new_loop_vars_expr)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { - tmp_store = - new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + tmp_store = alloc( + new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); } consumer_block->insert_stmt_after(tmp_store, new_consumer); @@ -2837,8 +2861,8 @@ LoopNest::AccessResult LoopNest::cacheAccesses( * `temp` instead of `producer`. The indices in the corresponding accesses * also need to be offset. */ -void LoopNest::computeAt(Stmt* s, For* f) { - Store* st = getStoreStmtOfProducer(s); +void LoopNest::computeAt(StmtPtr s, ForPtr f) { + StorePtr st = getStoreStmtOfProducer(s); if (!st) { return; } @@ -2855,16 +2879,16 @@ void LoopNest::computeAt(Stmt* s, For* f) { } // Compute dimensions of the temp buffer we would need to allocate - std::vector dims = getBoundExtents(bounds_it->second); + std::vector dims = getBoundExtents(bounds_it->second); // TODO: Use name-hint of the producer instead of "temp" - Buf* temp_buf = new Buf("temp", dims, st->value()->dtype()); + BufPtr temp_buf = alloc("temp", dims, st->value()->dtype()); // Generate index variables for 'temp' - std::vector temp_indices(dims.size()); + std::vector temp_indices(dims.size()); for (const auto i : c10::irange(dims.size())) { // TODO: Use name-hint of the producer indices instead of 'idx' - temp_indices[i] = new Var(std::string("idx") + c10::to_string(i), kInt); + temp_indices[i] = alloc(std::string("idx") + c10::to_string(i), kInt); } // Prepare substitute rules for constructing the temp statement from the prod @@ -2874,27 +2898,27 @@ void LoopNest::computeAt(Stmt* s, For* f) { // modified (e.g. split or merged) so that the loop indices no longer // correspond to the indices of the original expression and even their number // might be different. In that case, the loop below would crash. - std::vector prod_indices = getOuterLoopIndexes(s); - std::vector> rewrite_indices_map; - std::vector offsets; + std::vector prod_indices = getOuterLoopIndexes(s); + std::vector> rewrite_indices_map; + std::vector offsets; for (const TensorAccessBoundsInfo& p : bounds_it->second) { for (const auto i : c10::irange(p.start.size())) { if (offsets.size() <= i) { offsets.push_back(p.start[i]); } else { offsets[i] = - IRSimplifier::simplify(new Min(offsets[i], p.start[i], true)); + IRSimplifier::simplify(alloc(offsets[i], p.start[i], true)); } } } for (const auto i : c10::irange(prod_indices.size())) { rewrite_indices_map.push_back( - {prod_indices[i], new Add(temp_indices[i], offsets[i])}); + {prod_indices[i], alloc(temp_indices[i], offsets[i])}); } // Construct the temp statement - Stmt* bd = new Store( + StmtPtr bd = alloc( temp_buf, temp_indices, SubstituteInClone(st->value(), rewrite_indices_map)); @@ -2904,11 +2928,8 @@ void LoopNest::computeAt(Stmt* s, For* f) { // We're creating loops from innermost to outermost, so we need to access // dimensions in reversed order. size_t dim_idx = dims.size() - 1 - i; - bd = new For( - dynamic_cast(temp_indices[dim_idx]), - new IntImm(0), - dims[dim_idx], - bd); + bd = alloc( + to(temp_indices[dim_idx]), alloc(0), dims[dim_idx], bd); } // Add constructed stmts to the consumer loop @@ -2916,9 +2937,9 @@ void LoopNest::computeAt(Stmt* s, For* f) { // Rewrite accesses to producer in consumer with accesses to temp LoopComputeAtRewriter lr(st->buf(), temp_buf, offsets); - Stmt* new_f = f->accept_mutator(&lr); + StmtPtr new_f = f->accept_mutator(&lr); if (f != new_f) { - Block* bb = dynamic_cast(f->get_parent()); + BlockPtr bb = to(f->get_parent()); bb->replace_stmt(f, new_f); } } @@ -2926,10 +2947,10 @@ void LoopNest::computeAt(Stmt* s, For* f) { class RfactorStoreRewriter : public IRMutator { public: RfactorStoreRewriter( - Buf* old_buf, - const std::vector& old_indices, - Buf* new_buf, - Var* reduction_var) + BufPtr old_buf, + const std::vector& old_indices, + BufPtr new_buf, + VarPtr reduction_var) : old_buf_(old_buf), old_indices_(old_indices), new_buf_(new_buf), @@ -2938,7 +2959,7 @@ class RfactorStoreRewriter : public IRMutator { new_indices_.push_back(reduction_var_); } - Expr* mutate(Load* v) override { + ExprPtr mutate(LoadPtr v) override { if (v->buf() != old_buf_) { return IRMutator::mutate(v); } @@ -2956,23 +2977,23 @@ class RfactorStoreRewriter : public IRMutator { return IRMutator::mutate(v); } - return new Load(new_buf_, new_indices_); + return alloc(new_buf_, new_indices_); } - Expr* mutate(ReduceOp* v) override { - Expr* body_new = v->body()->accept_mutator(this); + ExprPtr mutate(ReduceOpPtr v) override { + ExprPtr body_new = v->body()->accept_mutator(this); - std::vector new_reduce_args; - for (auto* r : v->reduce_args()) { + std::vector new_reduce_args; + for (auto r : v->reduce_args()) { if (r != reduction_var_) { new_reduce_args.push_back(r); } } - return new ReduceOp(body_new, new_reduce_args, v->reducer()); + return alloc(body_new, new_reduce_args, v->reducer()); } - Stmt* mutate(Store* v) override { + StmtPtr mutate(StorePtr v) override { if (v->buf() != old_buf_) { return IRMutator::mutate(v); } @@ -2990,26 +3011,29 @@ class RfactorStoreRewriter : public IRMutator { return IRMutator::mutate(v); } - Expr* new_value = v->value()->accept_mutator(this); - return new Store(new_buf_, new_indices_, new_value); + ExprPtr new_value = v->value()->accept_mutator(this); + return alloc(new_buf_, new_indices_, new_value); } private: - Buf* old_buf_; - const std::vector& old_indices_; - Buf* new_buf_; - Var* reduction_var_; - std::vector new_indices_; + BufPtr old_buf_; + const std::vector& old_indices_; + BufPtr new_buf_; + VarPtr reduction_var_; + std::vector new_indices_; }; -bool LoopNest::rfactor(Stmt* st, For* target_for) { - Buf* tmp_buf = nullptr; +bool LoopNest::rfactor(StmtPtr st, ForPtr target_for) { + BufPtr tmp_buf = nullptr; return rfactor(st, target_for, &tmp_buf); } -bool LoopNest::rfactor(Stmt* st, For* outer_reduction_for, Buf** rfac_buf_ptr) { - Store* reduction_store = dynamic_cast(st); - ReduceOp* reduce_op = dynamic_cast(reduction_store->value()); +bool LoopNest::rfactor( + StmtPtr st, + ForPtr outer_reduction_for, + BufPtr* rfac_buf_ptr) { + StorePtr reduction_store = to(st); + ReduceOpPtr reduce_op = to(reduction_store->value()); if (!reduce_op) { // Not a reduction store return false; @@ -3017,9 +3041,9 @@ bool LoopNest::rfactor(Stmt* st, For* outer_reduction_for, Buf** rfac_buf_ptr) { auto orig_buf = reduction_store->buf(); auto orig_buf_indices = reduction_store->indices(); - Var* reduction_var = outer_reduction_for->var(); + VarPtr reduction_var = outer_reduction_for->var(); - std::set reduce_args = { + std::set reduce_args = { reduce_op->reduce_args().begin(), reduce_op->reduce_args().end()}; if (reduce_args.size() < 2) { @@ -3029,15 +3053,15 @@ bool LoopNest::rfactor(Stmt* st, For* outer_reduction_for, Buf** rfac_buf_ptr) { // Verify that outer_reduction_for is a perfect loop nest with all loops being // reductions - Stmt* cur = outer_reduction_for; - while (For* cur_for = dynamic_cast(cur)) { + StmtPtr cur = outer_reduction_for; + while (ForPtr cur_for = to(cur)) { if (!reduce_args.count(cur_for->var())) { // output axis inside outer_reduction_for are not allowed return false; } reduce_args.erase(cur_for->var()); - Block* b = cur_for->body(); + BlockPtr b = cur_for->body(); if (b->nstmts() != 1) { return false; } @@ -3056,41 +3080,41 @@ bool LoopNest::rfactor(Stmt* st, For* outer_reduction_for, Buf** rfac_buf_ptr) { // assert: reduce_axis match loop vars from outer_reduction_for and inside // assert: no other stmts in outer_reduction_for or its child loops - std::vector rfac_dims = orig_buf->dims(); - Expr* extra_dim = IRSimplifier::simplify( - new Sub(outer_reduction_for->stop(), outer_reduction_for->start())); + std::vector rfac_dims = orig_buf->dims(); + ExprPtr extra_dim = IRSimplifier::simplify( + alloc(outer_reduction_for->stop(), outer_reduction_for->start())); rfac_dims.push_back(extra_dim); - Expr* rfac_init = - new Cast(reduce_op->dtype(), reduce_op->reducer().initializer()); + ExprPtr rfac_init = + alloc(reduce_op->dtype(), reduce_op->reducer().initializer()); - *rfac_buf_ptr = new Buf( + *rfac_buf_ptr = alloc( orig_buf->name_hint() + "_rfac", rfac_dims, reduce_op->dtype(), rfac_init); - Buf* rfac_buf = *rfac_buf_ptr; + BufPtr rfac_buf = *rfac_buf_ptr; // Rewrite the original reduction store to use the temporary rfac buffer: // 1) X[*indexes] --> T[*indexes + {reduction_var}] // 2) reduce_axis -= {reduction_var} RfactorStoreRewriter rfac_rewriter( orig_buf, orig_buf_indices, rfac_buf, reduction_var); - dynamic_cast(st->get_parent()) + to(st->get_parent()) ->replace_stmt(st, st->accept_mutator(&rfac_rewriter)); // Insert a store for the final reduction over the temp buffer into the // original buffer: // X[*indexes] = ReduceOp(X[*indexes] + T[*indexes + {reduction_var}], // reduce_axis={reduction_var}) - Block* b = outer_reduction_for->body(); + BlockPtr b = outer_reduction_for->body(); TORCH_INTERNAL_ASSERT(b->nstmts() == 1); - Stmt* first_reduction_loop = b->stmts().front(); + StmtPtr first_reduction_loop = b->stmts().front(); auto rfac_buf_indices = orig_buf_indices; rfac_buf_indices.emplace_back(reduction_var); - Expr* final_reduce_load = new Load(rfac_buf, rfac_buf_indices); + ExprPtr final_reduce_load = alloc(rfac_buf, rfac_buf_indices); outer_reduction_for->body()->insert_stmt_after( - new Store( + alloc( orig_buf, orig_buf_indices, reduce_op->reducer()( @@ -3100,7 +3124,8 @@ bool LoopNest::rfactor(Stmt* st, For* outer_reduction_for, Buf** rfac_buf_ptr) { // Insert an initialization store for the temp buffer: // T[a,b,c] = init outer_reduction_for->body()->insert_stmt_before( - new Store(rfac_buf, rfac_buf_indices, rfac_init), first_reduction_loop); + alloc(rfac_buf, rfac_buf_indices, rfac_init), + first_reduction_loop); return true; } diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 3f467e1..c8cf2d8 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -6,6 +6,7 @@ #include #include +#include namespace torch { namespace jit { @@ -34,21 +35,21 @@ class TORCH_API LoopNest { // A constructor for building a LoopNest from an Stmt and a list of output // buffers. - LoopNest(Stmt* stmt, std::unordered_set output_bufs); + LoopNest(StmtPtr stmt, std::unordered_set output_bufs); // A constructor for building a LoopNest from another loopnest. It clones the // other loopnest's stmt. LoopNest(const LoopNest& other); - Stmt* root_stmt() const { + StmtPtr root_stmt() const { return root_stmt_; } - std::vector getLoopStmtsFor(Tensor*) const; - std::vector getLoopStmtsFor(Buf*) const; - std::vector getLoopStmtsFor(Stmt*) const; - Stmt* getLoopBodyFor(Tensor*) const; - Stmt* getLoopBodyFor(Buf*) const; + std::vector getLoopStmtsFor(Tensor*) const; + std::vector getLoopStmtsFor(BufPtr) const; + std::vector getLoopStmtsFor(StmtPtr) const; + StmtPtr getLoopBodyFor(Tensor*) const; + StmtPtr getLoopBodyFor(BufPtr) const; // Returns the For stmt indexed by 'indices' in the 'root' For stmt. //'indices' indicates the path to the returned loop from 'root' in AST, e.g., @@ -68,17 +69,17 @@ class TORCH_API LoopNest { // the path from 'root' to 'j_loop' is [0] // the path from 'root' to 'k1_loop' is [0, 0] // the path from 'root' to 'k2_loop' is [0, 2] - For* getLoopAt(For* root, const std::vector& indices) const; + ForPtr getLoopAt(ForPtr root, const std::vector& indices) const; // Returns the For stmt that is immediately enclosing the given stmt. - static For* getParentLoop(Stmt* st); + static ForPtr getParentLoop(StmtPtr st); // Returns the list of For stmts corresponding to the loopnest that is // enclosing the given stmt. - static std::vector getEnclosingLoopNest(Stmt* st); + static std::vector getEnclosingLoopNest(StmtPtr st); // Returns a list of all Stmts that write to the given buf. - std::vector getAllWritesToBuf(Buf*) const; + std::vector getAllWritesToBuf(BufPtr) const; // The following methods return the For loops that contain writes to // the given buf. @@ -98,18 +99,18 @@ class TORCH_API LoopNest { // to buf. // For the above example: // getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3} - std::vector getAllInnermostLoopsWritingToBuf(Buf*) const; + std::vector getAllInnermostLoopsWritingToBuf(BufPtr) const; // Returns a list of For loopnests which contain a Stmt that writes to // the given buf. Each loopnest here is a vector For loops. // For the above example: // getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}} - std::vector> getAllLoopNestsWritingToBuf(Buf*) const; + std::vector> getAllLoopNestsWritingToBuf(BufPtr) const; - Stmt* simplify(); + StmtPtr simplify(); - bool computeInline(Stmt* s); - bool computeInline(Buf* b); + bool computeInline(StmtPtr s); + bool computeInline(BufPtr b); void inlineIntermediateBufs(bool allow_duplicated_work); // Optimizes conditionals. @@ -156,10 +157,10 @@ class TORCH_API LoopNest { // So, the pointer to the input loop should be valid after splitting and // will point to the outer loop. The `inner` and `tail` parameters will be // set to point to the inner and tail loops that are generated. - static void splitWithTail(For* f, int factor, For** inner, For** tail); + static void splitWithTail(ForPtr f, int factor, ForPtr* inner, ForPtr* tail); // A convenience wrapper when the caller does not need to access the // split loops. - static void splitWithTail(For* f, int factor); + static void splitWithTail(ForPtr f, int factor); // Splits the given loop into 2 nested loops with the given factor as the // inner loop bound. If the factor does not evenly divide the loop bound, @@ -184,10 +185,10 @@ class TORCH_API LoopNest { // So, the pointer to the input loop should be valid after splitting and // will point to the outer loop. The `inner` parameter will be set to point // to the inner loop that is generated. - static void splitWithMask(For* f, int factor, For** inner); + static void splitWithMask(ForPtr f, int factor, ForPtr* inner); // A convenience wrapper when the caller does not need to access the // split loops. - static void splitWithMask(For* f, int factor); + static void splitWithMask(ForPtr f, int factor); // The following methods support loop distribution. // For example, consider the following code. This will be used to @@ -220,9 +221,9 @@ class TORCH_API LoopNest { // : for i // S6: for k // S7: B[i] = B[i] + - static std::vector distributeLoop( - For* loop, - const std::unordered_set& pivots); + static std::vector distributeLoop( + ForPtr loop, + const std::unordered_set& pivots); // This method distributes the given loop over every stmt in its body. // @@ -240,7 +241,7 @@ class TORCH_API LoopNest { // : for i // S6: for k // S7: B[i] = B[i] + - static std::vector distributeLoop(For* loop); + static std::vector distributeLoop(ForPtr loop); // Same as above, but also distribute parent loops. // Returns the result of distributing the outermost loop. // @@ -260,7 +261,7 @@ class TORCH_API LoopNest { // : for i // S6: for k // S7: B[i] = B[i] + - static std::vector distributeLoopAndParents(For* loop); + static std::vector distributeLoopAndParents(ForPtr loop); // This method distributes the given loop over its body by splitting // after every For stmt in its body. @@ -277,7 +278,7 @@ class TORCH_API LoopNest { // S5: B[i] = A[i] // S6: for k // S7: B[i] = B[i] + - static std::vector distributeLoopOverInnerLoops(For* loop); + static std::vector distributeLoopOverInnerLoops(ForPtr loop); // Same as above, but also distribute parent loops. // Returns the result of distributing the outermost loop. // @@ -294,7 +295,8 @@ class TORCH_API LoopNest { // S5: B[i] = A[i] // S6: for k // S7: B[i] = B[i] + - static std::vector distributeLoopAndParentsOverInnerLoops(For* loop); + static std::vector distributeLoopAndParentsOverInnerLoops( + ForPtr loop); // This method performs loop fusion. // For example, consider the following code. @@ -323,7 +325,7 @@ class TORCH_API LoopNest { // Below are the two requirements to apply unsafeFuseLoops: // * All the loops have the same parent. // * There are no statements between these loops in their parent body. - static bool unsafeFuseLoops(const std::vector& loops, For** fused); + static bool unsafeFuseLoops(const std::vector& loops, ForPtr* fused); // Loop fusion is done only when all the conditions below are satisfied. // * All the loops have the same parent. @@ -331,9 +333,9 @@ class TORCH_API LoopNest { // * The start bounds are the same for all loops. // * The stop bounds are the same for all loops. // * Fusing the loops does not violate or add any dependencies. - static bool fuseLoops(const std::vector& loops, For** fused); + static bool fuseLoops(const std::vector& loops, ForPtr* fused); - static void reorderAxis(For* a, For* b); + static void reorderAxis(ForPtr a, ForPtr b); // Reorder the given list of loops according to the permutation specified. // Here `permutation[i]` represents the position of the loop in the input @@ -353,8 +355,8 @@ class TORCH_API LoopNest { // for p // for q // A[p,q,r,s] = - static std::vector reorder( - const std::vector& loops, + static std::vector reorder( + const std::vector& loops, const std::vector& permutation); // Tile takes a 2d domain (x, y) and splits it into small rectangular blocks @@ -395,24 +397,24 @@ class TORCH_API LoopNest { // for k: (0, 32) // A[i_outer * 4 + i_inner, 7 * 9 + j_tail] = // B[i_outer * 4 + i_inner, k] + C[7 * 9 + j_tail, k] - For* tile(For* x, For* y, int x_factor, int y_factor); + ForPtr tile(ForPtr x, ForPtr y, int x_factor, int y_factor); // Returns true if the given loops are perfectly nested, i.e., every loop // (except the innermost) should have exactly one statement in its body // and that statement must be the next inner loop. - static bool areLoopsPerfectlyNested(const std::vector& loops); + static bool areLoopsPerfectlyNested(const std::vector& loops); // Returns true if the given loop has a loop-carried dependence. - static bool hasLoopCarriedDependence(For* loop); + static bool hasLoopCarriedDependence(ForPtr loop); - static void unroll(For* f, Stmt** unrolled); - static void unroll(For* f); + static void unroll(ForPtr f, StmtPtr* unrolled); + static void unroll(ForPtr f); - static bool normalize(For* f); - static bool isNormalized(For* f); + static bool normalize(ForPtr f); + static bool isNormalized(ForPtr f); - static bool flatten(const std::vector& f, For** flattened); - static bool flatten(const std::vector& f); + static bool flatten(const std::vector& f, ForPtr* flattened); + static bool flatten(const std::vector& f); // Compresses the given buffer based on its use in the given Stmts. // @@ -442,7 +444,7 @@ class TORCH_API LoopNest { // B[i,j] = A[0,j] + A[0, j+1] // } // } - static void compressBuffer(Buf* buf, Stmt* stmt); + static void compressBuffer(BufPtr buf, StmtPtr stmt); // Compresses all buffers in the given statement. // @@ -451,32 +453,32 @@ class TORCH_API LoopNest { // kernel statement to avoid incorrect buffer compressions. // // TODO: Add an IR verifier check to detect invalidly compressed buffers. - static void compressAllBuffers(Stmt* stmt); + static void compressAllBuffers(StmtPtr stmt); // Get 'num' loops from the loopnest starting at 'f'. - static std::vector getLoopStmtsInLoopNest(For* f, size_t num); + static std::vector getLoopStmtsInLoopNest(ForPtr f, size_t num); // LoopOptions are propagated to tail. - static void sliceHead(For* f, int factor, For** head, For** tail); - static void sliceHead(For* f, int factor); + static void sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail); + static void sliceHead(ForPtr f, int factor); // LoopOptions are propagated to head. - static void sliceTail(For* f, int factor, For** head, For** tail); - static void sliceTail(For* f, int factor); + static void sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail); + static void sliceTail(ForPtr f, int factor); - using AccessResult = std::pair; + using AccessResult = std::pair; // Insert a cache for the consumer's usages of the buffer produced in // consumer, and redirect reads and writes in the consumer to that cache. // Returns a pair of the new cache buffer, and the new rewritten consumer. static AccessResult cacheAccesses( - Buf* producer, + BufPtr producer, const std::string& name, - Stmt* consumer); + StmtPtr consumer); // Insert a temporary computation of statement S in the scope of loop AT. // S is assumed to be a Store or a Block containing a Store. Along with the // computation itself, this transformation inserts Alloc/Free statements for // the temporary buffer used in the computation. - static void computeAt(Stmt* s, For* at); + static void computeAt(StmtPtr s, ForPtr at); // Rfactor a reduction axis into a normal axis. // @@ -520,13 +522,16 @@ class TORCH_API LoopNest { // S4: for k # reduction axis // X_rfac[i,j] = ReduceOp(X_rfac[i,j] + Y[i,j,k], reduce_axis={k}) // X[i] = ReduceOp(X[i] + X_rfac[i,j], reduce_axis={j}) - static bool rfactor(Stmt* s, For* outer_reduction_for); - static bool rfactor(Stmt* s, For* outer_reduction_for, Buf** rfac_buf_ptr); + static bool rfactor(StmtPtr s, ForPtr outer_reduction_for); + static bool rfactor( + StmtPtr s, + ForPtr outer_reduction_for, + BufPtr* rfac_buf_ptr); // Vectorize the given loop. This method requires that the given loop // does not perform a reduction. // It returns true if vectorization is successful and false otherwise. - static bool vectorize(For*); + static bool vectorize(ForPtr); // Find the inner-most loops and vectorize them. Currently, this only works // for the LLVM backend, when no reductions are involved. @@ -535,8 +540,8 @@ class TORCH_API LoopNest { void eliminateDeadStores(); void prepareForCodegen(); - const std::unordered_set getInputBufs() const; - const std::unordered_set getOutputBufs() const { + const std::unordered_set getInputBufs() const; + const std::unordered_set getOutputBufs() const { return output_bufs_; } @@ -544,32 +549,32 @@ class TORCH_API LoopNest { void initialize( const std::vector& output_tensors, const std::vector& tensors_to_compute); - Stmt* insertAllocFree(Stmt* stmt); - const std::unordered_set getIntermediateBufs() const; + StmtPtr insertAllocFree(StmtPtr stmt); + const std::unordered_set getIntermediateBufs() const; - Stmt* root_stmt_; + StmtPtr root_stmt_; - std::unordered_set output_bufs_; + std::unordered_set output_bufs_; }; -TORCH_API Stmt* FlattenIndexes(Stmt* s); +TORCH_API StmtPtr FlattenIndexes(StmtPtr s); // TODO: Revisit this once we decide on how dependencies analysis should look // like. Maybe we would choose to use a different API and BufUse would be // removed, or if we decide to keep it we need to properly document its API. struct BufLoadOrStoreUse { - Stmt* s; + StmtPtr s; bool isStore; }; /* * Returns a map ( Buf -> uses of this Buf), uses are represented as vectors of - * BufUse elements, which are Stmt* and a bool isStore flag. The order of uses + * BufUse elements, which are StmtPtr and a bool isStore flag. The order of uses * in the vectors reflects the order in which the uses appear in the given * statement. */ -std::unordered_map> findLoadOrStoreUses( - Stmt* s); +std::unordered_map> findLoadOrStoreUses( + StmtPtr s); } // namespace tensorexpr } // namespace jit diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp index 626cd81..8f6f2b1 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp @@ -59,15 +59,15 @@ void getDependentsChain( // AccessInfo -std::vector AccessInfo::getIndices() const { - std::vector indices; +std::vector AccessInfo::getIndices() const { + std::vector indices; if (expr_) { - if (auto* load = dynamic_cast(expr_)) { + if (auto load = to(expr_)) { indices = load->indices(); } } else { - if (auto* store = dynamic_cast(stmt_)) { + if (auto store = to(stmt_)) { indices = store->indices(); } } @@ -184,12 +184,14 @@ void AccessInfo::dumpDOT(std::ostream& os) const { os << "label = \"" << AccessToString(type_) << "\\n " << *var_ << "["; if (bounds_.size() > 0) { for (size_t i = 0; i < bounds_.size() - 1; ++i) { - os << *IRSimplifier::simplify(new Add(bounds_[i].end, new IntImm(1))) + os << *IRSimplifier::simplify( + alloc(bounds_[i].end, alloc(1))) << ", "; } size_t i = bounds_.size() - 1; - os << *IRSimplifier::simplify(new Add(bounds_[i].end, new IntImm(1))); + os << *IRSimplifier::simplify( + alloc(bounds_[i].end, alloc(1))); os << "]\"\n "; } if (isWrite()) { @@ -255,12 +257,12 @@ MemDependencyChecker::MemDependencyChecker() { } MemDependencyChecker::MemDependencyChecker( - const std::unordered_set& inputs, - const std::unordered_set& outputs) { - for (auto* s : inputs) { + const std::unordered_set& inputs, + const std::unordered_set& outputs) { + for (auto s : inputs) { inputs_[s] = nullptr; } - for (auto* s : outputs) { + for (auto s : outputs) { outputs_[s] = nullptr; } @@ -320,15 +322,15 @@ DependencySet MemDependencyChecker::getAllWriteDependencies( return writes; } -bool MemDependencyChecker::dependsDirectly(Expr* A, Stmt* B) { +bool MemDependencyChecker::dependsDirectly(ExprPtr A, StmtPtr B) { return dependsDirectlyHelper(A, B); } -bool MemDependencyChecker::dependsDirectly(Stmt* A, Stmt* B) { +bool MemDependencyChecker::dependsDirectly(StmtPtr A, StmtPtr B) { return dependsDirectlyHelper(A, B); } -bool MemDependencyChecker::dependsDirectly(Buf* O, Stmt* B) { +bool MemDependencyChecker::dependsDirectly(BufPtr O, StmtPtr B) { auto outputAccess = output(O); auto bWrites = getAllWritesWithin(B); @@ -341,7 +343,7 @@ bool MemDependencyChecker::dependsDirectly(Buf* O, Stmt* B) { return false; } -bool MemDependencyChecker::dependsDirectly(Stmt* A, Buf* I) { +bool MemDependencyChecker::dependsDirectly(StmtPtr A, BufPtr I) { auto aReads = getAllReadsWithin(A); auto inputAccess = input(I); @@ -354,7 +356,7 @@ bool MemDependencyChecker::dependsDirectly(Stmt* A, Buf* I) { return false; } -bool MemDependencyChecker::dependsDirectly(Expr* A, Buf* I) { +bool MemDependencyChecker::dependsDirectly(ExprPtr A, BufPtr I) { auto aReads = getAllReadsWithin(A); auto inputAccess = input(I); @@ -373,15 +375,15 @@ bool MemDependencyChecker::dependsDirectly( return A->hasDependency(B) && B->isWrite(); } -bool MemDependencyChecker::dependsIndirectly(Expr* A, Stmt* B) { +bool MemDependencyChecker::dependsIndirectly(ExprPtr A, StmtPtr B) { return dependsIndirectlyHelper(A, B); } -bool MemDependencyChecker::dependsIndirectly(Stmt* A, Stmt* B) { +bool MemDependencyChecker::dependsIndirectly(StmtPtr A, StmtPtr B) { return dependsIndirectlyHelper(A, B); } -bool MemDependencyChecker::dependsIndirectly(Buf* O, Stmt* B) { +bool MemDependencyChecker::dependsIndirectly(BufPtr O, StmtPtr B) { auto outputAccess = output(O); DependencySet dependencies; @@ -397,7 +399,7 @@ bool MemDependencyChecker::dependsIndirectly(Buf* O, Stmt* B) { return false; } -bool MemDependencyChecker::dependsIndirectly(Stmt* A, Buf* I) { +bool MemDependencyChecker::dependsIndirectly(StmtPtr A, BufPtr I) { auto aReads = getAllReadsWithin(A); auto inputAccess = input(I); @@ -406,7 +408,7 @@ bool MemDependencyChecker::dependsIndirectly(Stmt* A, Buf* I) { return aDeps.count(inputAccess) != 0; } -bool MemDependencyChecker::dependsIndirectly(Expr* A, Buf* I) { +bool MemDependencyChecker::dependsIndirectly(ExprPtr A, BufPtr I) { auto aReads = getAllReadsWithin(A); auto inputAccess = input(I); @@ -415,7 +417,7 @@ bool MemDependencyChecker::dependsIndirectly(Expr* A, Buf* I) { return aDeps.count(inputAccess) != 0; } -bool MemDependencyChecker::dependsIndirectly(Buf* O, Buf* I) { +bool MemDependencyChecker::dependsIndirectly(BufPtr O, BufPtr I) { auto outputAccess = output(O); auto inputAccess = input(I); @@ -438,7 +440,7 @@ bool MemDependencyChecker::dependsIndirectly( return true; } -std::shared_ptr MemDependencyChecker::accessFor(Stmt* A) const { +std::shared_ptr MemDependencyChecker::accessFor(StmtPtr A) const { auto bound = stmtToAccess_.equal_range(A); for (auto it = bound.first; it != bound.second; ++it) { if (it->second->expr() == nullptr) { @@ -448,7 +450,7 @@ std::shared_ptr MemDependencyChecker::accessFor(Stmt* A) const { return nullptr; } -std::shared_ptr MemDependencyChecker::accessFor(Expr* A) const { +std::shared_ptr MemDependencyChecker::accessFor(ExprPtr A) const { // TODO exprs can have multiple accesses... we're returning the first but that // isn't great. Can't do much here. auto bound = exprToAccess_.equal_range(A); @@ -460,7 +462,7 @@ std::shared_ptr MemDependencyChecker::accessFor(Expr* A) const { } std::unordered_set> MemDependencyChecker:: - accessesWithin(Stmt* A) const { + accessesWithin(StmtPtr A) const { auto it = scopeToAccesses_.find(A); if (it != scopeToAccesses_.end()) { return std::unordered_set>( @@ -476,11 +478,11 @@ std::unordered_set> MemDependencyChecker:: } std::unordered_set> MemDependencyChecker:: - accessesWithin(Expr* A) const { + accessesWithin(ExprPtr A) const { return {accessFor(A)}; } -std::shared_ptr MemDependencyChecker::input(Buf* b) const { +std::shared_ptr MemDependencyChecker::input(BufPtr b) const { auto it = inputs_.find(b); if (it == inputs_.end()) { return nullptr; @@ -488,7 +490,7 @@ std::shared_ptr MemDependencyChecker::input(Buf* b) const { return it->second; } -std::shared_ptr MemDependencyChecker::output(Buf* b) const { +std::shared_ptr MemDependencyChecker::output(BufPtr b) const { auto it = outputs_.find(b); if (it == outputs_.end()) { return nullptr; @@ -498,18 +500,18 @@ std::shared_ptr MemDependencyChecker::output(Buf* b) const { // Node visitors: -void MemDependencyChecker::visit(Store* v) { - Stmt* last = lastStmt_; +void MemDependencyChecker::visit(StorePtr v) { + StmtPtr last = lastStmt_; lastStmt_ = v; v->value()->accept(this); - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { ind->accept(this); } lastStmt_ = last; // Create a new AccessInfo for the store. - Var* var = v->buf()->base_handle(); + VarPtr var = v->buf()->base_handle(); auto info = std::make_shared( nextAccess_++, AccessType::Store, v, var, getIndicesBounds(v->indices())); @@ -530,19 +532,19 @@ void MemDependencyChecker::visit(Store* v) { currentScope_->accesses_.push_back(info); } -void MemDependencyChecker::visit(Load* v) { +void MemDependencyChecker::visit(LoadPtr v) { // Create a temporary scope to hold any loads that occur within the indices of // this load. auto indicesScope = std::make_shared(currentScope_->block, currentScope_); currentScope_ = indicesScope; - for (Expr* ind : v->indices()) { + for (ExprPtr ind : v->indices()) { ind->accept(this); } // Create a new AccessInfo for the load. - Var* var = v->buf()->base_handle(); + VarPtr var = v->buf()->base_handle(); auto load = std::make_shared( nextAccess_++, AccessType::Load, @@ -582,32 +584,32 @@ void MemDependencyChecker::visit(Load* v) { bool executionSafetyCheck( const std::shared_ptr& info, const std::shared_ptr& other, - const std::vector& aStrides, - const std::vector& oStrides, + const std::vector& aStrides, + const std::vector& oStrides, bool parallelized) { if (aStrides.empty() || oStrides.empty()) { return false; } TORCH_INTERNAL_ASSERT(info->bounds().size() == other->bounds().size()); for (size_t b = 0; b < info->bounds().size(); ++b) { - Expr* aIndexStride = aStrides[b]; - Expr* oIndexStride = oStrides[b]; + ExprPtr aIndexStride = aStrides[b]; + ExprPtr oIndexStride = oStrides[b]; // can't be safe on this index if we can't determine stride. if (!aIndexStride->isConstant() || !oIndexStride->isConstant()) { continue; } - Expr* minStride = - IRSimplifier::simplify(new Min(aIndexStride, oIndexStride, true)); - Expr* maxStride = - IRSimplifier::simplify(new Max(aIndexStride, oIndexStride, true)); + ExprPtr minStride = + IRSimplifier::simplify(alloc(aIndexStride, oIndexStride, true)); + ExprPtr maxStride = + IRSimplifier::simplify(alloc(aIndexStride, oIndexStride, true)); // If the first access has no stride don't apply safety). if (immediateEquals(minStride, 0)) { continue; } - Expr* modCheck = IRSimplifier::simplify(new Mod(maxStride, minStride)); + ExprPtr modCheck = IRSimplifier::simplify(alloc(maxStride, minStride)); // if the strides can't have easily inferable distinct offsets, they're not // safe. @@ -621,33 +623,34 @@ bool executionSafetyCheck( // axis is the same sign as the common stride, then they will not // overlap. - Expr* startDiff = IRSimplifier::simplify( - new Sub(info->bounds()[b].start, other->bounds()[b].start)); + ExprPtr startDiff = IRSimplifier::simplify( + alloc(info->bounds()[b].start, other->bounds()[b].start)); bool diffNegative = immediateIsNegative(startDiff); bool strideNegative = immediateIsNegative(minStride); // Invert the startDiff so mod works. if (diffNegative != strideNegative) { - startDiff = IRSimplifier::simplify(new Sub(new IntImm(0), startDiff)); + startDiff = + IRSimplifier::simplify(alloc(alloc(0), startDiff)); } // If both accesses have the same stride, and the difference in start // element is smaller than this stride then the entire range is distinct. if (exprEquals(minStride, maxStride)) { - Expr* check1 = - IRSimplifier::simplify(new CompareSelect(startDiff, minStride, kLT)); + ExprPtr check1 = IRSimplifier::simplify( + alloc(startDiff, minStride, kLT)); if (check1->isConstant() && immediateEquals(check1, 1)) { return true; } } - startDiff = IRSimplifier::simplify(new Mod(startDiff, minStride)); + startDiff = IRSimplifier::simplify(alloc(startDiff, minStride)); CompareSelectOperation op = strideNegative ? kLT : kGT; - Expr* check = - IRSimplifier::simplify(new CompareSelect(startDiff, new IntImm(0), op)); + ExprPtr check = IRSimplifier::simplify( + alloc(startDiff, alloc(0), op)); // If the start difference modulo the minimum stride is offset from that // stride, then the ranges have distinct strides. @@ -667,10 +670,10 @@ bool executionSafetyCheck( return false; } -void MemDependencyChecker::visit(For* v) { - Var* var = v->var(); +void MemDependencyChecker::visit(ForPtr v) { + VarPtr var = v->var(); - Stmt* last = lastStmt_; + StmtPtr last = lastStmt_; lastStmt_ = v; v->var()->accept(this); @@ -713,22 +716,22 @@ void MemDependencyChecker::visit(For* v) { // access, which we do via substituting the loop var with (var+1) into the // indices expr. - std::vector> loopStrides; + std::vector> loopStrides; loopStrides.resize(currentScope_->accesses_.size()); for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) { auto& info = currentScope_->accesses_[a]; - std::vector indices = info->getIndices(); + std::vector indices = info->getIndices(); - std::vector& loopIndicesStride = loopStrides[a]; + std::vector& loopIndicesStride = loopStrides[a]; loopIndicesStride.resize(indices.size()); // index expr must depend on the loop var in some way to have a stride. for (const auto i : c10::irange(indices.size())) { VarFinder vf; if (vf.find(indices[i]).count(var) == 0) { - loopIndicesStride[i] = new IntImm(0); + loopIndicesStride[i] = alloc(0); } else { // If we've previously swapped the start and end of this bound, we // should apply the substitution to the reverse of the bounds. @@ -737,25 +740,25 @@ void MemDependencyChecker::visit(For* v) { SubstituteInClone(info->bounds()[i].end, {{var, v->start()}})); info->bounds()[i].start = IRSimplifier::simplify(SubstituteInClone( info->bounds()[i].start, - {{var, new Sub(v->stop(), new IntImm(1))}})); + {{var, alloc(v->stop(), alloc(1))}})); } else { info->bounds()[i].start = IRSimplifier::simplify( SubstituteInClone(info->bounds()[i].start, {{var, v->start()}})); info->bounds()[i].end = IRSimplifier::simplify(SubstituteInClone( info->bounds()[i].end, - {{var, new Sub(v->stop(), new IntImm(1))}})); + {{var, alloc(v->stop(), alloc(1))}})); } - Expr* zeroStep = indices[i]; - Expr* oneStep = - SubstituteInClone(indices[i], {{var, new Add(var, new IntImm(1))}}); + ExprPtr zeroStep = indices[i]; + ExprPtr oneStep = SubstituteInClone( + indices[i], {{var, alloc(var, alloc(1))}}); loopIndicesStride[i] = - IRSimplifier::simplify(new Sub(oneStep, zeroStep)); + IRSimplifier::simplify(alloc(oneStep, zeroStep)); // If the start < end then swap the order of the bound. - Expr* diff = IRSimplifier::simplify( - new Sub(info->bounds()[i].end, info->bounds()[i].start)); + ExprPtr diff = IRSimplifier::simplify( + alloc(info->bounds()[i].end, info->bounds()[i].start)); if (diff->isConstant() && immediateIsNegative(diff)) { info->bounds()[i].swap(); } @@ -782,10 +785,11 @@ void MemDependencyChecker::visit(For* v) { bound.start = IRSimplifier::simplify( SubstituteInClone(bound.start, {{var, v->start()}})); bound.end = IRSimplifier::simplify(SubstituteInClone( - bound.end, {{var, new Sub(v->stop(), new IntImm(1))}})); + bound.end, {{var, alloc(v->stop(), alloc(1))}})); // If the start < end then swap the order of the bound. - Expr* diff = IRSimplifier::simplify(new Sub(bound.end, bound.start)); + ExprPtr diff = + IRSimplifier::simplify(alloc(bound.end, bound.start)); if (diff->isConstant() && immediateIsNegative(diff)) { bound.swap(); } @@ -798,7 +802,7 @@ void MemDependencyChecker::visit(For* v) { v->loop_options().is_gpu_thread_index(); // Store buffers allocated at this scope. - std::unordered_set local_intermediates; + std::unordered_set local_intermediates; // Scanning from the top of the loop, we look for accesses which may depend // on a previous or parallel loop iteration. @@ -901,8 +905,8 @@ void MemDependencyChecker::visit(For* v) { currentScope_ = currentScope_->parent; } -void MemDependencyChecker::visit(Cond* v) { - Stmt* last = lastStmt_; +void MemDependencyChecker::visit(CondPtr v) { + StmtPtr last = lastStmt_; lastStmt_ = v; auto enclosingScope = @@ -911,8 +915,8 @@ void MemDependencyChecker::visit(Cond* v) { // condition is in enclosing scope. v->condition()->accept(this); - Block* true_stmt = v->true_stmt(); - Block* false_stmt = v->false_stmt(); + BlockPtr true_stmt = v->true_stmt(); + BlockPtr false_stmt = v->false_stmt(); // Create scopes so the Block visitor doesn't create and merge a new scope. auto trueScope = std::make_shared(true_stmt, enclosingScope); @@ -950,12 +954,12 @@ void MemDependencyChecker::visit(Cond* v) { lastStmt_ = last; } -void MemDependencyChecker::visit(IfThenElse* v) { +void MemDependencyChecker::visit(IfThenElsePtr v) { // condition is in enclosing scope. v->condition()->accept(this); - Expr* true_value = v->true_value(); - Expr* false_value = v->false_value(); + ExprPtr true_value = v->true_value(); + ExprPtr false_value = v->false_value(); auto enclosingScope = currentScope_; @@ -986,13 +990,13 @@ void MemDependencyChecker::visit(IfThenElse* v) { currentScope_ = enclosingScope; } -void MemDependencyChecker::visit(CompareSelect* v) { +void MemDependencyChecker::visit(CompareSelectPtr v) { // condition is in enclosing scope. v->lhs()->accept(this); v->rhs()->accept(this); - Expr* true_value = v->ret_val1(); - Expr* false_value = v->ret_val2(); + ExprPtr true_value = v->ret_val1(); + ExprPtr false_value = v->ret_val2(); auto enclosingScope = currentScope_; @@ -1025,15 +1029,16 @@ void MemDependencyChecker::visit(CompareSelect* v) { // Inserts accesses for a map of buffers (ie. for inputs and outputs). void MemDependencyChecker::insertBuffers( - std::unordered_map>& bufs, + std::unordered_map>& bufs, AccessType type) { for (auto& pair : bufs) { - Buf* b = pair.first; - Var* var = b->base_handle(); + BufPtr b = pair.first; + VarPtr var = b->base_handle(); IndexBounds bounds; - for (auto* d : b->dims()) { + for (auto d : b->dims()) { bounds.push_back( - {new IntImm(0), IRSimplifier::simplify(new Sub(d, new IntImm(1)))}); + {alloc(0), + IRSimplifier::simplify(alloc(d, alloc(1)))}); } auto info = std::make_shared(nextAccess_++, type, nullptr, var, bounds); @@ -1046,7 +1051,7 @@ void MemDependencyChecker::insertBuffers( } } -void MemDependencyChecker::visit(Block* v) { +void MemDependencyChecker::visit(BlockPtr v) { auto prev_scope = currentScope_; // handle kernel inputs. @@ -1055,14 +1060,14 @@ void MemDependencyChecker::visit(Block* v) { } if (currentScope_->block != v) { - currentScope_ = std::make_shared((Block*)v, prev_scope); + currentScope_ = std::make_shared((BlockPtr)v, prev_scope); } - for (auto* s : *v) { + for (auto s : *v) { s->accept(this); } - for (auto* v : currentScope_->localVars) { + for (auto v : currentScope_->localVars) { knownVarBounds_.erase(v); } for (auto& pair : currentScope_->shadowedVarBounds) { @@ -1082,15 +1087,15 @@ void MemDependencyChecker::visit(Block* v) { } } -void MemDependencyChecker::visit(Let* v) { - Stmt* last = lastStmt_; +void MemDependencyChecker::visit(LetPtr v) { + StmtPtr last = lastStmt_; lastStmt_ = v; IRVisitor::visit(v); lastStmt_ = last; - Var* var = v->var(); + VarPtr var = v->var(); if (knownVarBounds_.count(var) != 0) { currentScope_->shadowedVarBounds[var] = knownVarBounds_[var]; } @@ -1101,17 +1106,17 @@ void MemDependencyChecker::visit(Let* v) { // Don't support AtomicAdd yet, it's a bit more complex since it's both a read // and a write. It's only inserted during Cuda codegen so this should be okay. -void MemDependencyChecker::visit(AtomicAdd* v) { +void MemDependencyChecker::visit(AtomicAddPtr v) { throw std::runtime_error("MemDependencyChecker AtomicAdd unimplemented"); } -void MemDependencyChecker::visit(Allocate* v) { - Stmt* last = lastStmt_; +void MemDependencyChecker::visit(AllocatePtr v) { + StmtPtr last = lastStmt_; lastStmt_ = v; IRVisitor::visit(v); - Var* var = v->buffer_var(); + VarPtr var = v->buffer_var(); IndexBounds bounds; // TODO: remove the "buf_flat_size" process below and extend the buf bound // check to support N-d indices access and 1-d index access. @@ -1120,9 +1125,9 @@ void MemDependencyChecker::visit(Allocate* v) { // identify 1-d index access for N-d bufs. Thus we flatten N-d bufs here to // avoid failing the bound check. But this is not the correct approach and // should be fixed. - Expr* flat_size = buf_flat_size(v->buf()); - flat_size = IRSimplifier::simplify(new Sub(flat_size, new IntImm(1))); - bounds.push_back({new IntImm(0), flat_size}); + ExprPtr flat_size = buf_flat_size(v->buf()); + flat_size = IRSimplifier::simplify(alloc(flat_size, alloc(1))); + bounds.push_back({alloc(0), flat_size}); auto info = std::make_shared( nextAccess_++, AccessType::Alloc, nullptr, var, bounds); @@ -1136,13 +1141,13 @@ void MemDependencyChecker::visit(Allocate* v) { lastStmt_ = last; } -void MemDependencyChecker::visit(Free* v) { - Stmt* last = lastStmt_; +void MemDependencyChecker::visit(FreePtr v) { + StmtPtr last = lastStmt_; lastStmt_ = v; IRVisitor::visit(v); - Var* var = v->buffer_var(); + VarPtr var = v->buffer_var(); auto it = intermediates_.find(var); TORCH_INTERNAL_ASSERT(it != intermediates_.end()); @@ -1243,7 +1248,7 @@ void MemDependencyChecker::mergeScope( // Copy open writes up. for (auto& pair : child->openWrites_) { - Var* var = pair.first; + VarPtr var = pair.first; // Intentionally using operator[], we want it to be created if it does not // exist. @@ -1266,7 +1271,7 @@ class VarBoundBinder : public IRVisitor { public: VarBoundBinder(const VarBoundMap& vars) : vars_(vars) {} - Bound getBounds(Expr* e) { + Bound getBounds(ExprPtr e) { min_ = e; max_ = e; e->accept(this); @@ -1276,7 +1281,7 @@ class VarBoundBinder : public IRVisitor { } private: - void visit(Var* v) override { + void visit(VarPtr v) override { auto it = vars_.find(v); if (it == vars_.end()) { return; @@ -1286,17 +1291,17 @@ class VarBoundBinder : public IRVisitor { max_ = SubstituteInClone(max_, {{v, it->second.end}}); } - Expr* min_{nullptr}; - Expr* max_{nullptr}; + ExprPtr min_{nullptr}; + ExprPtr max_{nullptr}; const VarBoundMap& vars_; }; std::vector MemDependencyChecker::getIndicesBounds( - const std::vector& indices) { + const std::vector& indices) { std::vector bounds; bounds.reserve(indices.size()); VarBoundBinder binder(knownVarBounds_); - for (auto* s : indices) { + for (auto s : indices) { bounds.push_back(binder.getBounds(s)); } return bounds; diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h index 43a7d5d..5363d2f 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -40,8 +40,8 @@ class TORCH_API AccessInfo { AccessInfo( size_t id, AccessType type, - Stmt* stmt, - Var* var, + StmtPtr stmt, + VarPtr var, IndexBounds bounds) : id_(id), type_(type), @@ -53,9 +53,9 @@ class TORCH_API AccessInfo { AccessInfo( size_t id, AccessType type, - Expr* expr, - Stmt* stmt, - Var* var, + ExprPtr expr, + StmtPtr stmt, + VarPtr var, IndexBounds bounds) : id_(id), type_(type), @@ -77,18 +77,18 @@ class TORCH_API AccessInfo { // The enclosing Stmt this access represents. E.g. if this is a Store then // Stmt is the Store itself, while if the access is caused by an Expr, this is // the most immediate parent Stmt. - Stmt* stmt() const { + StmtPtr stmt() const { return stmt_; } // If the access is represented by an Expr (such as Load or Call) then this is // it, otherwise it's nullptr. - Expr* expr() const { + ExprPtr expr() const { return expr_; } // The Var representing the underlying Buffer. - Var* var() const { + VarPtr var() const { return var_; } @@ -114,7 +114,7 @@ class TORCH_API AccessInfo { } // Returns the symbolic expression of the indices of this access. - std::vector getIndices() const; + std::vector getIndices() const; // Establishes a dependency or dependent relationship with another access. void addDependency(const std::shared_ptr& write); @@ -149,9 +149,9 @@ class TORCH_API AccessInfo { private: size_t id_; AccessType type_; - Stmt* stmt_; - Expr* expr_; - Var* var_; + StmtPtr stmt_; + ExprPtr expr_; + VarPtr var_; IndexBounds bounds_; // Yes these should be sorted. @@ -159,7 +159,7 @@ class TORCH_API AccessInfo { std::map> dependents_; }; -using VarBoundMap = std::unordered_map; +using VarBoundMap = std::unordered_map; /* MemDepedencyChecker analyses a IR fragment and builds a dependency graph of * accesses contained within. @@ -176,8 +176,8 @@ class TORCH_API MemDependencyChecker : public IRVisitor { public: MemDependencyChecker(); MemDependencyChecker( - const std::unordered_set& inputs, - const std::unordered_set& outputs); + const std::unordered_set& inputs, + const std::unordered_set& outputs); MemDependencyChecker( const std::vector& inputs, const std::vector& outputs); @@ -193,15 +193,15 @@ class TORCH_API MemDependencyChecker : public IRVisitor { // about it. // Returns true if any read in A has a direct dependence on a write in B. - bool dependsDirectly(Stmt* A, Stmt* B); - bool dependsDirectly(Expr* A, Stmt* B); + bool dependsDirectly(StmtPtr A, StmtPtr B); + bool dependsDirectly(ExprPtr A, StmtPtr B); // Returns true of the output depends directly on a write contained in B. - bool dependsDirectly(Buf* output, Stmt* B); + bool dependsDirectly(BufPtr output, StmtPtr B); // Returns true if a read in A depends directly on the provided input. - bool dependsDirectly(Stmt* A, Buf* input); - bool dependsDirectly(Expr* A, Buf* input); + bool dependsDirectly(StmtPtr A, BufPtr input); + bool dependsDirectly(ExprPtr A, BufPtr input); // Outputs/inputs cannot depend directly. @@ -211,18 +211,18 @@ class TORCH_API MemDependencyChecker : public IRVisitor { const std::shared_ptr& B); // Returns true if any read in A has an ancestor write contained in B. - bool dependsIndirectly(Stmt* A, Stmt* B); - bool dependsIndirectly(Expr* A, Stmt* B); + bool dependsIndirectly(StmtPtr A, StmtPtr B); + bool dependsIndirectly(ExprPtr A, StmtPtr B); // Returns true of the output depends indirectly on a write contained in B. - bool dependsIndirectly(Buf* output, Stmt* B); + bool dependsIndirectly(BufPtr output, StmtPtr B); // Returns true if a read in A depends indirectly on the provided input. - bool dependsIndirectly(Stmt* A, Buf* input); - bool dependsIndirectly(Expr* A, Buf* input); + bool dependsIndirectly(StmtPtr A, BufPtr input); + bool dependsIndirectly(ExprPtr A, BufPtr input); // returns true if the output uses any load of the input. - bool dependsIndirectly(Buf* output, Buf* input); + bool dependsIndirectly(BufPtr output, BufPtr input); // Returns true if the access A has a dependency chain to access B. bool dependsIndirectly( @@ -230,19 +230,21 @@ class TORCH_API MemDependencyChecker : public IRVisitor { const std::shared_ptr& B); // Returns the AccessInfo - std::shared_ptr accessFor(Stmt* A) const; - std::shared_ptr accessFor(Expr* A) const; + std::shared_ptr accessFor(StmtPtr A) const; + std::shared_ptr accessFor(ExprPtr A) const; // Returns all AccessInfos. - std::unordered_set> accessesWithin(Stmt* A) const; + std::unordered_set> accessesWithin( + StmtPtr A) const; // TODO: this will return only the AccessInfo for A. It's included for // completeness but be aware it wont return accesses used in the computation // of A. - std::unordered_set> accessesWithin(Expr* A) const; + std::unordered_set> accessesWithin( + ExprPtr A) const; // Accesses relating to input and output buffers. - std::shared_ptr input(Buf* B) const; - std::shared_ptr output(Buf* B) const; + std::shared_ptr input(BufPtr B) const; + std::shared_ptr output(BufPtr B) const; // Returns the full history of reads and writes. const std::vector>& getHistory() const; @@ -252,49 +254,49 @@ class TORCH_API MemDependencyChecker : public IRVisitor { private: // Node visitors. - void visit(Store* v) override; - void visit(Load* v) override; - void visit(For* v) override; - void visit(Cond* v) override; - void visit(IfThenElse* v) override; - void visit(CompareSelect* v) override; - void visit(Block* v) override; - void visit(Let* v) override; - void visit(AtomicAdd* v) override; - void visit(Allocate* v) override; - void visit(Free* v) override; + void visit(StorePtr v) override; + void visit(LoadPtr v) override; + void visit(ForPtr v) override; + void visit(CondPtr v) override; + void visit(IfThenElsePtr v) override; + void visit(CompareSelectPtr v) override; + void visit(BlockPtr v) override; + void visit(LetPtr v) override; + void visit(AtomicAddPtr v) override; + void visit(AllocatePtr v) override; + void visit(FreePtr v) override; using BoundRelationship = std::pair>; // An internal struct holding the accesses found within a scope Block. struct Scope { - Scope(Block* b, std::shared_ptr p) + Scope(BlockPtr b, std::shared_ptr p) : block(b), parent(std::move(p)) {} - Block* block; + BlockPtr block; std::shared_ptr parent; - std::unordered_map shadowedVarBounds; - std::unordered_set localVars; + std::unordered_map shadowedVarBounds; + std::unordered_set localVars; std::vector> accesses_; - std::unordered_map> openWrites_; + std::unordered_map> openWrites_; }; std::shared_ptr currentScope_; bool allowExecutionOrderAnalysis_{false}; - std::unordered_multimap> stmtToAccess_; - std::unordered_multimap> exprToAccess_; - std::unordered_map>> + std::unordered_multimap> stmtToAccess_; + std::unordered_multimap> exprToAccess_; + std::unordered_map>> scopeToAccesses_; VarBoundMap knownVarBounds_; // Finds all accesses that are reads within the scope of v. - template - DependencySet getAllReadsWithin(StmtOrExpr* v) { + template + DependencySet getAllReadsWithin(StmtOrExprPtr v) { DependencySet reads; auto insertAllReads = [&](const auto& nodes) { for (auto* l : nodes) { @@ -317,7 +319,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor { // Finds all accesses that are writes within the scope of v. // Writes cannot occur in Exprs, so this is a little simpler. - DependencySet getAllWritesWithin(Stmt* v) { + DependencySet getAllWritesWithin(StmtPtr v) { DependencySet writes; // writes just Store currently. @@ -334,8 +336,8 @@ class TORCH_API MemDependencyChecker : public IRVisitor { } // Templated helpers to work on either Exprs or Stmts. - template - bool dependsDirectlyHelper(StmtOrExpr* A, Stmt* B) { + template + bool dependsDirectlyHelper(StmtOrExprPtr A, StmtPtr B) { auto aReads = getAllReadsWithin(A); auto bWrites = getAllWritesWithin(B); @@ -350,8 +352,8 @@ class TORCH_API MemDependencyChecker : public IRVisitor { return false; } - template - bool dependsIndirectlyHelper(StmtOrExpr* A, Stmt* B) { + template + bool dependsIndirectlyHelper(StmtOrExprPtr A, StmtPtr B) { auto aReads = getAllReadsWithin(A); auto bWrites = getAllWritesWithin(B); @@ -369,13 +371,13 @@ class TORCH_API MemDependencyChecker : public IRVisitor { DependencySet getAllWriteDependencies(const DependencySet& products); // Maps for inputs and outputs, since they aren't present directly in the IR. - std::unordered_map> inputs_; - std::unordered_map> outputs_; - std::unordered_map> intermediates_; + std::unordered_map> inputs_; + std::unordered_map> outputs_; + std::unordered_map> intermediates_; // Inserts accesses for Buf's: specifically for inputs and outputs. void insertBuffers( - std::unordered_map>& bufs, + std::unordered_map>& bufs, AccessType type); // Update the write history with a new write, adding dependencies and closing @@ -395,10 +397,10 @@ class TORCH_API MemDependencyChecker : public IRVisitor { bool closeOverlapped = true); // Binds symbolic vars in indices with the low and high bound for those vars. - std::vector getIndicesBounds(const std::vector& indices); + std::vector getIndicesBounds(const std::vector& indices); size_t nextAccess_{0}; - Stmt* lastStmt_{nullptr}; + StmtPtr lastStmt_{nullptr}; }; } // namespace analysis diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp index f83b34b..c4af83a 100644 --- a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp +++ b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp @@ -75,14 +75,14 @@ Tensor* conv2d_depthwise_static( constexpr int kLoopH = 2, kLoopW = 3; if (R == 3 && stride == 2 && pad == 1) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *head, *tail; + ForPtr head, tail; auto loops = nest.getLoopStmtsFor(conv); nest.sliceHead(loops[kLoopW], 2, &head, &tail); loops = nest.getLoopStmtsFor(conv); nest.sliceHead(loops[kLoopH], 2, &head, &tail); } else if (R == 3 && stride == 1 && pad == 1) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *main, *peeled; + ForPtr main, peeled; auto loops = nest.getAllLoopNestsWritingToBuf(conv->buf()); main = loops[1][kLoopW]; nest.sliceHead(main, 1, &peeled, &main); diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.cpp b/torch/csrc/jit/tensorexpr/operators/matmul.cpp index 3fdd086..23cb455 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.cpp +++ b/torch/csrc/jit/tensorexpr/operators/matmul.cpp @@ -21,11 +21,11 @@ Tensor* computeMatmul( auto size_b = b.dims(); // We currently only support rank 2 matmuls TORCH_INTERNAL_ASSERT(size_a.size() == 2 && size_b.size() == 2); - auto total_size = dynamic_cast( - IRSimplifier::simplify( - cast(size_a[0]) * cast(size_a[1]) * - cast(size_b[1])) - .node()); + auto total_size = + to(IRSimplifier::simplify( + cast(size_a[0]) * cast(size_a[1]) * + cast(size_b[1])) + .node()); // For small sizes, where N*M*K < 1000, lower matmul to a naive 3-level // loopnest. The number is not tuned very carefully, and in future we should diff --git a/torch/csrc/jit/tensorexpr/operators/softmax.cpp b/torch/csrc/jit/tensorexpr/operators/softmax.cpp index edb911e..d6cb6c0 100644 --- a/torch/csrc/jit/tensorexpr/operators/softmax.cpp +++ b/torch/csrc/jit/tensorexpr/operators/softmax.cpp @@ -129,8 +129,8 @@ Tensor* computeSoftmax( }); return new Tensor( result->buf(), - new tensorexpr::Block( - {max->stmt(), e->stmt(), sum->stmt(), result->stmt()})); + alloc(std::vector( + {max->stmt(), e->stmt(), sum->stmt(), result->stmt()}))); } auto log_sum = Compute( @@ -147,12 +147,12 @@ Tensor* computeSoftmax( }); return new Tensor( result->buf(), - new tensorexpr::Block( + alloc(std::vector( {max->stmt(), e->stmt(), sum->stmt(), log_sum->stmt(), - result->stmt()})); + result->stmt()}))); } } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/reduction.cpp b/torch/csrc/jit/tensorexpr/reduction.cpp index 2245860..1727482 100644 --- a/torch/csrc/jit/tensorexpr/reduction.cpp +++ b/torch/csrc/jit/tensorexpr/reduction.cpp @@ -6,21 +6,21 @@ namespace torch { namespace jit { namespace tensorexpr { -ReduceOp* Reducer::operator()( - Buf* result_buf, +ReduceOpPtr Reducer::operator()( + BufPtr result_buf, ExprHandle body, - const std::vector& output, - const std::vector& inner) const { - return new ReduceOp( + const std::vector& output, + const std::vector& inner) const { + return alloc( complete(result_buf, interaction_, body, output, inner), inner, *this); } -ReduceOp* Reducer::operator()( - Buf* result_buf, - Expr* body, - const std::vector& output, - const std::vector& inner) const { - return new ReduceOp( +ReduceOpPtr Reducer::operator()( + BufPtr result_buf, + ExprPtr body, + const std::vector& output, + const std::vector& inner) const { + return alloc( complete(result_buf, interaction_, ExprHandle(body), output, inner), inner, *this); diff --git a/torch/csrc/jit/tensorexpr/reduction.h b/torch/csrc/jit/tensorexpr/reduction.h index e482191..08aef01 100644 --- a/torch/csrc/jit/tensorexpr/reduction.h +++ b/torch/csrc/jit/tensorexpr/reduction.h @@ -35,21 +35,21 @@ class TORCH_API Reducer { } virtual ~Reducer() = default; - Expr* initializer() const { + ExprPtr initializer() const { return init_; } - ReduceOp* operator()( - Buf* result_buf, + ReduceOpPtr operator()( + BufPtr result_buf, ExprHandle body, - const std::vector& output, - const std::vector& inner) const; + const std::vector& output, + const std::vector& inner) const; - ReduceOp* operator()( - Buf* result_buf, - Expr* body, - const std::vector& output, - const std::vector& inner) const; + ReduceOpPtr operator()( + BufPtr result_buf, + ExprPtr body, + const std::vector& output, + const std::vector& inner) const; // Polymorphic handling of Body functions with a variety of parameters. static ExprHandle getReduceBody( @@ -103,20 +103,20 @@ class TORCH_API Reducer { // Completes the reduction operator by applying the interaction function to // the accumulation and the body expression. - static Expr* complete( - Buf* accumulator, + static ExprPtr complete( + BufPtr accumulator, ReduceInteraction interaction, ExprHandle body, - const std::vector& output_args, - const std::vector& reduce_args) { + const std::vector& output_args, + const std::vector& reduce_args) { ExprHandle accum = - ExprHandle(new Load(body.dtype(), accumulator, output_args)); + ExprHandle(alloc(body.dtype(), accumulator, output_args)); auto e = interaction(accum, body); return e.node(); } private: - Expr* init_; + ExprPtr init_; ReduceInteraction interaction_; }; @@ -128,14 +128,17 @@ class TORCH_API Reducer { class TORCH_API ReduceOp : public ExprNode { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - ReduceOp(Expr* body, std::vector reduce_args, const Reducer& reducer) + ReduceOp( + ExprPtr body, + std::vector reduce_args, + const Reducer& reducer) : ExprNodeBase(body->dtype()), body_(body), reduce_args_(std::move(reduce_args)), reducer_(reducer) {} // return the body expression which obtains the value to be reduced. - Expr* body() const { + ExprPtr body() const { return body_; } @@ -145,13 +148,13 @@ class TORCH_API ReduceOp : public ExprNode { } // returns variables associated with the axes of reduction. - const std::vector& reduce_args() const { + const std::vector& reduce_args() const { return reduce_args_; } private: - Expr* body_; - std::vector reduce_args_; + ExprPtr body_; + std::vector reduce_args_; const Reducer reducer_; }; @@ -216,11 +219,11 @@ class Minimum : public Reducer { class ReductionExpander : public IRMutator { public: - Stmt* expand(Stmt* s) { + StmtPtr expand(StmtPtr s) { return s->accept_mutator(this); } - Expr* mutate(ReduceOp* v) override { + ExprPtr mutate(ReduceOpPtr v) override { return v->body(); } }; diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 1036311..07aee20 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -7,7 +7,7 @@ namespace registerizer { // AccessInfo -void AccessInfo::addStore(Store* store, const std::shared_ptr& scope) { +void AccessInfo::addStore(StorePtr store, const std::shared_ptr& scope) { block_ = block_ ? Block::getSharedParent(block_, scope->block()) : scope->block(); @@ -17,7 +17,8 @@ void AccessInfo::addStore(Store* store, const std::shared_ptr& scope) { first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : store; last_usage_ = store; - store_cost_ = IRSimplifier::simplify(new Add(store_cost_, new IntImm(1))); + store_cost_ = + IRSimplifier::simplify(alloc(store_cost_, alloc(1))); stores_.push_back(store); conditionId_ = scope->conditionId(); @@ -25,15 +26,15 @@ void AccessInfo::addStore(Store* store, const std::shared_ptr& scope) { } void AccessInfo::addLoad( - Load* load, + LoadPtr load, const std::shared_ptr& scope, - Stmt* usage) { + StmtPtr usage) { block_ = block_ ? Block::getSharedParent(block_, scope->block()) : scope->block(); first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage; last_usage_ = usage; - load_cost_ = IRSimplifier::simplify(new Add(load_cost_, new IntImm(1))); + load_cost_ = IRSimplifier::simplify(alloc(load_cost_, alloc(1))); loads_.push_back(load); conditionId_ = scope->conditionId(); @@ -45,16 +46,17 @@ void AccessInfo::merge(const std::shared_ptr& other) { TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size()); last_usage_ = other->last_usage(); - for (auto* s : other->stores()) { + for (auto s : other->stores()) { stores_.push_back(s); } - for (auto* l : other->loads()) { + for (auto l : other->loads()) { loads_.push_back(l); } store_cost_ = - IRSimplifier::simplify(new Add(store_cost_, other->store_cost())); - load_cost_ = IRSimplifier::simplify(new Add(load_cost_, other->load_cost())); + IRSimplifier::simplify(alloc(store_cost_, other->store_cost())); + load_cost_ = + IRSimplifier::simplify(alloc(load_cost_, other->load_cost())); block_ = Block::getSharedParent(block_, other->block()); // update first and last usage to be in the parent Block. @@ -73,7 +75,7 @@ bool AccessInfo::overlaps(const std::shared_ptr& other) { // dimension. bool overlap = true; for (size_t i = 0; i < indices_.size(); ++i) { - Expr* diff = new Sub(indices_[i], other_indices[i]); + ExprPtr diff = alloc(indices_[i], other_indices[i]); diff = IRSimplifier::simplify(diff); if (diff->isConstant() && !immediateEquals(diff, 0)) { @@ -85,9 +87,9 @@ bool AccessInfo::overlaps(const std::shared_ptr& other) { return overlap; } -bool AccessInfo::dependsOnVar(Var* v) { +bool AccessInfo::dependsOnVar(VarPtr v) { VarFinder vf; - for (auto* i : indices_) { + for (auto i : indices_) { i->accept(&vf); } @@ -105,10 +107,10 @@ std::shared_ptr AccessInfo::cloneWithHiddenInfo( newInfo->firstUsageOverlapped_ = orig->firstUsageOverlapped_; newInfo->store_cost_ = orig->store_cost_; newInfo->load_cost_ = orig->load_cost_; - for (auto* s : orig->stores_) { + for (auto s : orig->stores_) { newInfo->stores_.push_back(s); } - for (auto* s : orig->loads_) { + for (auto s : orig->loads_) { newInfo->loads_.push_back(s); } @@ -119,7 +121,7 @@ std::shared_ptr AccessInfo::cloneWithHiddenInfo( void AccessInfo::print() const { std::cout << "Access: " << *buf_ << "{"; - for (auto* i : indices_) { + for (auto i : indices_) { std::cout << *i << " "; } std::cout << "} stores: " << stores_.size() << " (" << *store_cost_ << ") -"; @@ -137,7 +139,7 @@ void Scope::closeAccess(const std::shared_ptr& info) { closedAccesses_.push_back(info); } -AccessHashMap& Scope::getAccessMapByBuf(Buf* b) { +AccessHashMap& Scope::getAccessMapByBuf(BufPtr b) { auto it = openAccesses_.find(b); if (it == openAccesses_.end()) { // create and return @@ -177,7 +179,7 @@ void RegisterizerAnalysis::closeAccessIntoScope( scope->closeAccess(info); } -void RegisterizerAnalysis::visit(For* v) { +void RegisterizerAnalysis::visit(ForPtr v) { if (v->loop_options().is_gpu_block_index() || v->loop_options().is_gpu_thread_index()) { throw malformed_input( @@ -193,7 +195,8 @@ void RegisterizerAnalysis::visit(For* v) { v->body()->accept(this); stmtStack_.pop_front(); - Expr* loopExtent = IRSimplifier::simplify(new Sub(v->stop(), v->start())); + ExprPtr loopExtent = + IRSimplifier::simplify(alloc(v->stop(), v->start())); // now we need to see which accesses we can hoist out of the for loop, their // costs should be multiplied by the loop extent. @@ -224,7 +227,7 @@ void RegisterizerAnalysis::visit(For* v) { bool closed = false; // If this access depends on a locally scoped variable, it cannot be // hosted out of the loop. - for (auto* v : currentScope_->localVars()) { + for (auto v : currentScope_->localVars()) { if (candidate->dependsOnVar(v)) { closeAccessIntoScope(candidate, currentScope_); closed = true; @@ -260,10 +263,10 @@ void RegisterizerAnalysis::visit(For* v) { mergeCurrentScopeIntoParent(); }; -void RegisterizerAnalysis::visit(Cond* v) { - Expr* condition = v->condition(); - Block* true_stmt = v->true_stmt(); - Block* false_stmt = v->false_stmt(); +void RegisterizerAnalysis::visit(CondPtr v) { + ExprPtr condition = v->condition(); + BlockPtr true_stmt = v->true_stmt(); + BlockPtr false_stmt = v->false_stmt(); stmtStack_.push_front(v); @@ -300,10 +303,10 @@ void RegisterizerAnalysis::visit(Cond* v) { // IfThenElses are just like Conds except they are not Stmts, which means no // registerization can occur internally. However, the first reference to an // access can occur within one if its visible outside the condition. -void RegisterizerAnalysis::visit(IfThenElse* v) { - Expr* condition = v->condition(); - Expr* true_value = v->true_value(); - Expr* false_value = v->false_value(); +void RegisterizerAnalysis::visit(IfThenElsePtr v) { + ExprPtr condition = v->condition(); + ExprPtr true_value = v->true_value(); + ExprPtr false_value = v->false_value(); // condition is in enclosing scope. condition->accept(this); @@ -335,7 +338,7 @@ void RegisterizerAnalysis::visit(IfThenElse* v) { } } -void RegisterizerAnalysis::visit(Let* v) { +void RegisterizerAnalysis::visit(LetPtr v) { currentScope_->addLocalVar(v->var()); stmtStack_.push_front(v); @@ -343,7 +346,7 @@ void RegisterizerAnalysis::visit(Let* v) { stmtStack_.pop_front(); } -void RegisterizerAnalysis::visit(Block* v) { +void RegisterizerAnalysis::visit(BlockPtr v) { auto prev_scope = currentScope_; if (currentScope_->block() != v) { currentScope_ = std::make_shared(v, prev_scope); @@ -351,7 +354,7 @@ void RegisterizerAnalysis::visit(Block* v) { stmtStack_.push_front(v); - for (auto* s : *v) { + for (auto s : *v) { s->accept(this); if (currentScope_->block() != v) { // merge the inner block's accesses into this Block's accesses. @@ -371,7 +374,7 @@ void RegisterizerAnalysis::visit(Block* v) { } } -void RegisterizerAnalysis::visit(Store* v) { +void RegisterizerAnalysis::visit(StorePtr v) { stmtStack_.push_front(v); v->value()->accept(this); stmtStack_.pop_front(); @@ -383,7 +386,7 @@ void RegisterizerAnalysis::visit(Store* v) { // hash the Store: SimplifierHashType accessHash = hasher_.hash(v->buf()); - for (auto* i : v->indices()) { + for (auto i : v->indices()) { accessHash = hasher_.hash_combine(accessHash, i); } @@ -425,14 +428,14 @@ void RegisterizerAnalysis::visit(Store* v) { } } -void RegisterizerAnalysis::visit(Load* v) { +void RegisterizerAnalysis::visit(LoadPtr v) { if (v->indices().empty()) { // already a scalar. return; } // hash the Load: SimplifierHashType accessHash = hasher_.hash(v->buf()); - for (auto* i : v->indices()) { + for (auto i : v->indices()) { accessHash = hasher_.hash_combine(accessHash, i); } @@ -560,7 +563,7 @@ void RegisterizerAnalysis::mergeCurrentScopeIntoParent() { // copy across current open accesses, merging as necessary. // for each Buf with an open access: for (auto& pair : currentScope_->openAccesses()) { - Buf* buf = pair.first; + BufPtr buf = pair.first; if (pair.second.empty()) { continue; } @@ -604,7 +607,7 @@ void RegisterizerAnalysis::mergeCurrentScopeIntoParent() { // If this access depends on a locally scoped variable, it cannot be // lifted out of the loop. - for (auto* v : currentScope_->localVars()) { + for (auto v : currentScope_->localVars()) { if (candidate->dependsOnVar(v)) { closeAccessIntoScope(candidate, parent); handled = true; @@ -637,7 +640,7 @@ std::vector> RegisterizerAnalysis::getCandidates() { // RegisterizerReplacer -Expr* RegisterizerReplacer::mutate(Load* v) { +ExprPtr RegisterizerReplacer::mutate(LoadPtr v) { auto it = loadToAccess_.find(v); if (it == loadToAccess_.end()) { // This access cannot be registerized. @@ -649,7 +652,7 @@ Expr* RegisterizerReplacer::mutate(Load* v) { return info->replacement().var; } -Stmt* RegisterizerReplacer::mutate(Store* v) { +StmtPtr RegisterizerReplacer::mutate(StorePtr v) { if (eliminatedIntializers_.count(v) != 0) { // This store is the intializer for a scalar var that is already inserted. return nullptr; @@ -663,22 +666,23 @@ Stmt* RegisterizerReplacer::mutate(Store* v) { auto& info = it->second; - Expr* new_val = v->value()->accept_mutator(this); + ExprPtr new_val = v->value()->accept_mutator(this); - return new Store(info->replacement().var_wrapper, {}, new_val); + return alloc( + info->replacement().var_wrapper, std::vector({}), new_val); } -Stmt* RegisterizerReplacer::mutate(Block* v) { +StmtPtr RegisterizerReplacer::mutate(BlockPtr v) { auto& scope = parentToAccesses_[v]; - std::vector stmts; - for (Stmt* stmt : v->stmts()) { + std::vector stmts; + for (StmtPtr stmt : v->stmts()) { { // Insert the initializer for any Scalars scoped to this block. auto it = scope.initializerPoints_.find(stmt); if (it != scope.initializerPoints_.end()) { for (auto& info : it->second) { - Stmt* initializer = + StmtPtr initializer = info->replacement().initializer->accept_mutator(this); stmts.push_back(initializer); } @@ -686,7 +690,7 @@ Stmt* RegisterizerReplacer::mutate(Block* v) { } } - Stmt* stmt_new = stmt->accept_mutator(this); + StmtPtr stmt_new = stmt->accept_mutator(this); if (stmt_new) { if (stmt_new->get_parent()) { stmt_new = Stmt::clone(stmt_new); @@ -699,8 +703,8 @@ Stmt* RegisterizerReplacer::mutate(Block* v) { auto it = scope.finalizePoints_.find(stmt); if (it != scope.finalizePoints_.end()) { for (auto& info : it->second) { - Store* finalizer = - new Store(info->buf(), info->indices(), info->replacement().var); + StorePtr finalizer = alloc( + info->buf(), info->indices(), info->replacement().var); stmts.push_back(finalizer); } scope.finalizePoints_.erase(it); @@ -708,14 +712,14 @@ Stmt* RegisterizerReplacer::mutate(Block* v) { } } - return new Block(stmts); + return alloc(stmts); } void RegisterizerReplacer::buildReplacements() { // Traverse the list of replacements, creating vars and updating our local // maps. for (auto& info : infoSet_) { - Var* v = new Var( + VarPtr v = alloc( info->buf()->name_hint() + "_" + c10::to_string(getBufferAccessCount(info->buf())), info->buf()->dtype()); @@ -723,12 +727,13 @@ void RegisterizerReplacer::buildReplacements() { info->replacement().var = v; // we need to wrap the Var in a Buf so we can Load or Store it. - info->replacement().var_wrapper = new Buf(v, {}, info->buf()->dtype()); + info->replacement().var_wrapper = + alloc(v, std::vector({}), info->buf()->dtype()); bool first = true; - for (auto* s : info->stores()) { + for (auto s : info->stores()) { if (first && info->first_usage() == s && !info->firstUsageOverlapped()) { - info->replacement().initializer = new Let(v, s->value()); + info->replacement().initializer = alloc(v, s->value()); eliminatedIntializers_.insert(s); } else { storeToAccess_[s] = info; @@ -737,7 +742,7 @@ void RegisterizerReplacer::buildReplacements() { first = false; } - for (auto* s : info->loads()) { + for (auto s : info->loads()) { loadToAccess_[s] = info; } @@ -752,8 +757,8 @@ void RegisterizerReplacer::buildReplacements() { // create a default initializer by reading the access. if (info->replacement().initializer == nullptr) { - info->replacement().initializer = new Let( - v, new Load(info->buf()->dtype(), info->buf(), info->indices())); + info->replacement().initializer = alloc( + v, alloc(info->buf()->dtype(), info->buf(), info->indices())); } } } @@ -761,13 +766,13 @@ void RegisterizerReplacer::buildReplacements() { } // namespace registerizer // Apply scalar replacement to all accesses in s. -Stmt* registerize(Stmt* s) { +StmtPtr registerize(StmtPtr s) { s = IRSimplifier::simplify(s); // The outermost node must be a Block so we have somewhere to put outer scope // scalars. - if (!dynamic_cast(s)) { - s = new Block({s}); + if (!to(s)) { + s = alloc(std::vector({s})); } registerizer::RegisterizerAnalysis analysis; s->accept(&analysis); diff --git a/torch/csrc/jit/tensorexpr/registerizer.h b/torch/csrc/jit/tensorexpr/registerizer.h index d44cab6..75eface 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.h +++ b/torch/csrc/jit/tensorexpr/registerizer.h @@ -54,22 +54,25 @@ class AccessInfo { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AccessInfo( SimplifierHashType h, - Buf* b, - std::vector i, + BufPtr b, + std::vector i, size_t accessOrder) : hash_(h), buf_(b), indices_(std::move(i)), - store_cost_(new IntImm(0)), - load_cost_(new IntImm(0)), + store_cost_(alloc(0)), + load_cost_(alloc(0)), accessOrder_(accessOrder) {} // Adds a Store to this access, which is in the provided scope. - void addStore(Store* store, const std::shared_ptr& scope); + void addStore(StorePtr store, const std::shared_ptr& scope); // Adds a Load to this access, which occurs in the usage Stmt in the provided // scope. - void addLoad(Load* load, const std::shared_ptr& scope, Stmt* usage); + void addLoad( + LoadPtr load, + const std::shared_ptr& scope, + StmtPtr usage); // Merge another AccessInfo into this one. void merge(const std::shared_ptr& other); @@ -78,7 +81,7 @@ class AccessInfo { bool overlaps(const std::shared_ptr& other); // Returns true if the indices of this access depend on the provided Var. - bool dependsOnVar(Var* v); + bool dependsOnVar(VarPtr v); // Clone this AccessInfo, and set this as the new accesses' hiddenAccess. static std::shared_ptr cloneWithHiddenInfo( @@ -91,30 +94,30 @@ class AccessInfo { return hash_; } - Buf* buf() const { + BufPtr buf() const { return buf_; } - const std::vector& indices() const { + const std::vector& indices() const { return indices_; } - Block* block() const { + BlockPtr block() const { return block_; } - void setEnclosingBlock(Block* b) { + void setEnclosingBlock(BlockPtr b) { block_ = b; } - Stmt* first_usage() const { + StmtPtr first_usage() const { return first_usage_; } - Stmt* last_usage() const { + StmtPtr last_usage() const { return last_usage_; } - void setUsageMarks(Stmt* first, Stmt* last) { + void setUsageMarks(StmtPtr first, StmtPtr last) { first_usage_ = first; last_usage_ = last; } @@ -123,25 +126,25 @@ class AccessInfo { return firstUsageOverlapped_; } - Expr* store_cost() const { + ExprPtr store_cost() const { return store_cost_; } - Expr* load_cost() const { + ExprPtr load_cost() const { return load_cost_; } - const std::vector& stores() const { + const std::vector& stores() const { return stores_; } - const std::vector& loads() const { + const std::vector& loads() const { return loads_; } - void hoistCosts(Expr* extent) { - store_cost_ = IRSimplifier::simplify(new Mul(store_cost_, extent)); - load_cost_ = IRSimplifier::simplify(new Mul(load_cost_, extent)); + void hoistCosts(ExprPtr extent) { + store_cost_ = IRSimplifier::simplify(alloc(store_cost_, extent)); + load_cost_ = IRSimplifier::simplify(alloc(load_cost_, extent)); } size_t conditionId() const { @@ -163,9 +166,9 @@ class AccessInfo { // Holds state relating to the scalar variable we will insert to replace some // number of loads and stores. struct ScalarReplacement { - Var* var{nullptr}; - Buf* var_wrapper{nullptr}; - Let* initializer{nullptr}; + VarPtr var{nullptr}; + BufPtr var_wrapper{nullptr}; + LetPtr initializer{nullptr}; }; ScalarReplacement& replacement() { @@ -174,12 +177,12 @@ class AccessInfo { private: SimplifierHashType hash_; - Buf* buf_; - std::vector indices_; - Block* block_{nullptr}; + BufPtr buf_; + std::vector indices_; + BlockPtr block_{nullptr}; - Stmt* first_usage_{nullptr}; - Stmt* last_usage_{nullptr}; + StmtPtr first_usage_{nullptr}; + StmtPtr last_usage_{nullptr}; // Whether or not this access is overlapped in the first Stmt it appears. This // means we cannot use it's first Store as the initializer. @@ -187,13 +190,13 @@ class AccessInfo { // The cost in real ops that this access represents, to enable // filtering accesses that wont save any loads or stores. - Expr* store_cost_; - Expr* load_cost_; + ExprPtr store_cost_; + ExprPtr load_cost_; // The actual Stores and Loads which represent this access. // Be careful with these, any mutator will invalidate these pointers. - std::vector stores_; - std::vector loads_; + std::vector stores_; + std::vector loads_; // An identifier representing the conditional block, if any, this access // depends on. @@ -219,12 +222,12 @@ using AccessHashMap = class Scope { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Scope(Block* b, std::shared_ptr parent, size_t conditionId = 0) + Scope(BlockPtr b, std::shared_ptr parent, size_t conditionId = 0) : block_(b), parent_(std::move(parent)), conditionId_(conditionId) {} - AccessHashMap& getAccessMapByBuf(Buf* b); + AccessHashMap& getAccessMapByBuf(BufPtr b); - std::unordered_map& openAccesses() { + std::unordered_map& openAccesses() { return openAccesses_; } @@ -232,7 +235,7 @@ class Scope { return closedAccesses_; } - Block* block() const { + BlockPtr block() const { return block_; } @@ -244,10 +247,10 @@ class Scope { return conditionId_; } - const std::unordered_set& localVars() const { + const std::unordered_set& localVars() const { return localVars_; } - void addLocalVar(Var* v) { + void addLocalVar(VarPtr v) { localVars_.insert(v); } @@ -261,11 +264,11 @@ class Scope { // overlap with other accesses to the same buf. Buf -> // Hash -> // Access - std::unordered_map openAccesses_; + std::unordered_map openAccesses_; std::vector> closedAccesses_; // The Block object this scope represents. - Block* block_; + BlockPtr block_; // The enclosing scope object. std::shared_ptr parent_; @@ -274,7 +277,7 @@ class Scope { size_t conditionId_; // A set of variables local to this scope (e.g. loop vars). - std::unordered_set localVars_; + std::unordered_set localVars_; }; /* Analyzes the graph and collects accesses to the same symbolic tensor element @@ -320,25 +323,25 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor { : currentScope_(std::make_shared(nullptr, nullptr, 0)) {} ~RegisterizerAnalysis() override = default; - void visit(For* v) override; + void visit(ForPtr v) override; - void visit(Cond* v) override; + void visit(CondPtr v) override; - void visit(Block* v) override; + void visit(BlockPtr v) override; - void visit(Store* v) override; + void visit(StorePtr v) override; - void visit(Load* v) override; + void visit(LoadPtr v) override; - void visit(IfThenElse* v) override; + void visit(IfThenElsePtr v) override; - void visit(Let* v) override; + void visit(LetPtr v) override; -#define STMT_ON_STACK(Op) \ - void visit(Op* v) override { \ - stmtStack_.push_front(v); \ - IRVisitor::visit(v); \ - stmtStack_.pop_front(); \ +#define STMT_ON_STACK(Op) \ + void visit(Op##Ptr v) override { \ + stmtStack_.push_front(v); \ + IRVisitor::visit(v); \ + stmtStack_.pop_front(); \ } STMT_ON_STACK(AtomicAdd); @@ -359,7 +362,7 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor { std::unordered_set exprConditionals_; // A stack of enclosing Stmts for tracking the usage Stmt of Loads. - std::deque stmtStack_; + std::deque stmtStack_; // The current scope being analyzed. std::shared_ptr currentScope_; @@ -381,17 +384,17 @@ class TORCH_API RegisterizerReplacer : public IRMutator { buildReplacements(); } - Expr* mutate(Load* v) override; + ExprPtr mutate(LoadPtr v) override; - Stmt* mutate(Store* v) override; + StmtPtr mutate(StorePtr v) override; - Stmt* mutate(Block* v) override; + StmtPtr mutate(BlockPtr v) override; private: struct ReplacerScope { - std::unordered_map>> + std::unordered_map>> initializerPoints_; - std::unordered_map>> + std::unordered_map>> finalizePoints_; }; @@ -400,18 +403,18 @@ class TORCH_API RegisterizerReplacer : public IRMutator { // State relating to the accesses yet to be replaced. std::vector>& infoSet_; - std::unordered_map> storeToAccess_; - std::unordered_map> loadToAccess_; - std::unordered_map parentToAccesses_; + std::unordered_map> storeToAccess_; + std::unordered_map> loadToAccess_; + std::unordered_map parentToAccesses_; // Holds the set of Stores that should be pulled into an initializer, so they // can be eliminated. - std::set eliminatedIntializers_; + std::set eliminatedIntializers_; // Tracks the number of times we've seen each buffer, so we can name the // scalar Vars appropriately. - std::unordered_map bufferAccessCounts_; - unsigned int getBufferAccessCount(Buf* b) { + std::unordered_map bufferAccessCounts_; + unsigned int getBufferAccessCount(BufPtr b) { return ++bufferAccessCounts_[b]; } }; @@ -420,7 +423,7 @@ class TORCH_API RegisterizerReplacer : public IRMutator { // Apply scalar replacement to all accesses in s. // To produce safe code, this must occur after handling parallelized axes and // atomics. -TORCH_API Stmt* registerize(Stmt* s); +TORCH_API StmtPtr registerize(StmtPtr s); } // namespace tensorexpr } // namespace jit diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index 07eef0c..0b4a2e4 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -18,9 +18,9 @@ class TORCH_API Stmt : public KernelScopedObject { public: Stmt() = default; virtual void accept(IRVisitor* visitor) = 0; - virtual Stmt* accept_mutator(IRMutator* mutator) = 0; + virtual StmtPtr accept_mutator(IRMutator* mutator) = 0; - Stmt* get_parent() const { + StmtPtr get_parent() const { return parent_; } @@ -31,15 +31,15 @@ class TORCH_API Stmt : public KernelScopedObject { * cloned. Note that the variables are not deep-copied since they are * immutable. */ - static Stmt* clone(Stmt* s); + static StmtPtr clone(StmtPtr s); protected: - static void set_parent(Stmt* s, Stmt* new_parent) { + static void set_parent(StmtPtr s, StmtPtr new_parent) { s->parent_ = new_parent; } private: - Stmt* parent_ = nullptr; + StmtPtr parent_ = nullptr; }; template @@ -47,25 +47,23 @@ class StmtNode : public Stmt { public: using StmtNodeBase = StmtNode; void accept(IRVisitor* visitor) override { - visitor->visit(static_cast(this)); + visitor->visit(static_to(this)); } - Stmt* accept_mutator(IRMutator* mutator) override; + StmtPtr accept_mutator(IRMutator* mutator) override; StmtNode() = default; }; template -Stmt* StmtNode::accept_mutator(IRMutator* mutator) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - StmtNode* this_mutable = const_cast(this); - return mutator->mutate(static_cast(this_mutable)); +StmtPtr StmtNode::accept_mutator(IRMutator* mutator) { + return mutator->mutate(static_to(this)); } // Concrete Stmt classes class TORCH_API Block : public StmtNode { public: - static Block* make(const std::vector& stmts) { + static BlockPtr make(const std::vector& stmts) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector valid_stmts; + std::vector valid_stmts; for (auto& stmt : stmts) { if (!stmt) { continue; @@ -75,7 +73,7 @@ class TORCH_API Block : public StmtNode { if (valid_stmts.empty()) { return nullptr; } - return new Block(valid_stmts); + return alloc(valid_stmts); } int nstmts() const { @@ -85,7 +83,7 @@ class TORCH_API Block : public StmtNode { return stmts_.empty(); } - void prepend_stmt(Stmt* s) { + void prepend_stmt(StmtPtr s) { if (s->get_parent()) { throw malformed_input("Block prepend Stmt with existing parent", s); } @@ -93,7 +91,7 @@ class TORCH_API Block : public StmtNode { stmts_.push_front(s); set_parent(s, this); } - void append_stmt(Stmt* s) { + void append_stmt(StmtPtr s) { if (s->get_parent()) { throw malformed_input("Block append Stmt with existing parent", s); } @@ -102,7 +100,7 @@ class TORCH_API Block : public StmtNode { set_parent(s, this); } - void insert_stmt_before(Stmt* s, Stmt* before) { + void insert_stmt_before(StmtPtr s, StmtPtr before) { if (s->get_parent()) { throw malformed_input("Block append Stmt with existing parent", s); } @@ -117,7 +115,7 @@ class TORCH_API Block : public StmtNode { set_parent(s, this); } - void insert_stmt_after(Stmt* s, Stmt* after) { + void insert_stmt_after(StmtPtr s, StmtPtr after) { if (s->get_parent()) { throw malformed_input("Block append Stmt with existing parent", s); } @@ -134,7 +132,7 @@ class TORCH_API Block : public StmtNode { set_parent(s, this); } - bool replace_stmt(Stmt* old_stmt, Stmt* new_stmt) { + bool replace_stmt(StmtPtr old_stmt, StmtPtr new_stmt) { if (new_stmt->get_parent()) { throw malformed_input( "Block replace Stmt with existing parent", new_stmt); @@ -154,16 +152,16 @@ class TORCH_API Block : public StmtNode { // Creates a new block by cloning `this` block and replacing the given // statement with a new statement. Note that `old_stmt` refers to a statement // in `this` block. If the `old_stmt` is not found, it will return `nullptr`. - Block* clone_and_replace(Stmt* old_stmt, Stmt* new_stmt) { + BlockPtr clone_and_replace(StmtPtr old_stmt, StmtPtr new_stmt) { if (new_stmt->get_parent()) { throw malformed_input( "Block replace Stmt with existing parent", new_stmt); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector stmts(stmts_.begin(), stmts_.end()); + std::vector stmts(stmts_.begin(), stmts_.end()); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector cloned_stmts(stmts.size()); + std::vector cloned_stmts(stmts.size()); bool found = false; for (int i = 0; i < static_cast(stmts.size()); ++i) { if (stmts[i] == old_stmt) { @@ -176,10 +174,10 @@ class TORCH_API Block : public StmtNode { if (!found) { return nullptr; } - return new Block(cloned_stmts); + return alloc(cloned_stmts); } - bool remove_stmt(Stmt* stmt) { + bool remove_stmt(StmtPtr stmt) { auto pos = std::find(stmts_.begin(), stmts_.end(), stmt); if (pos == stmts_.end()) { return false; @@ -190,7 +188,7 @@ class TORCH_API Block : public StmtNode { return true; } - std::list stmts() const { + std::list stmts() const { return stmts_; } @@ -201,18 +199,18 @@ class TORCH_API Block : public StmtNode { stmts_.clear(); } - void set_stmts(const std::vector& stmts) { + void set_stmts(const std::vector& stmts) { clear(); init(stmts); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit Block(const std::vector& stmts) { + explicit Block(const std::vector& stmts) { init(stmts); } - typedef std::list::iterator iterator; - typedef std::list::const_iterator const_iterator; + typedef std::list::iterator iterator; + typedef std::list::const_iterator const_iterator; iterator begin() { return stmts_.begin(); @@ -230,37 +228,37 @@ class TORCH_API Block : public StmtNode { return stmts_.end(); } - Stmt* front() { + StmtPtr front() { return stmts_.front(); } - Stmt* front() const { + StmtPtr front() const { return stmts_.front(); } - Stmt* back() { + StmtPtr back() { return stmts_.back(); } - Stmt* back() const { + StmtPtr back() const { return stmts_.back(); } - void splice(Block::iterator it, Block* other) { - for (Stmt* s : *other) { + void splice(Block::iterator it, BlockPtr other) { + for (StmtPtr s : *other) { set_parent(s, this); } stmts_.splice(it, other->stmts_); } - static Block* getSharedParent(Stmt* p1, Stmt* p2) { + static BlockPtr getSharedParent(StmtPtr p1, StmtPtr p2) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::unordered_set enclosing; + std::unordered_set enclosing; - Stmt* p1_p = p1; + StmtPtr p1_p = p1; while (p1_p) { - if (Block* b = dynamic_cast(p1_p)) { + if (BlockPtr b = to(p1_p)) { if (b) { enclosing.insert(b); } @@ -268,9 +266,9 @@ class TORCH_API Block : public StmtNode { p1_p = p1_p->get_parent(); } - Stmt* p2_p = p2; + StmtPtr p2_p = p2; while (p2_p) { - if (Block* b = dynamic_cast(p2_p)) { + if (BlockPtr b = to(p2_p)) { if (enclosing.count(b) != 0) { return b; } @@ -282,7 +280,7 @@ class TORCH_API Block : public StmtNode { } // returns the immediate child containing statement s. - Stmt* getEnclosedRoot(Stmt* s) const { + StmtPtr getEnclosedRoot(StmtPtr s) const { while (s && s->get_parent() != this) { s = s->get_parent(); } @@ -290,10 +288,10 @@ class TORCH_API Block : public StmtNode { } private: - std::list stmts_; + std::list stmts_; - void init(const std::vector& stmts) { - for (Stmt* s : stmts) { + void init(const std::vector& stmts) { + for (StmtPtr s : stmts) { if (!s) { continue; } @@ -310,46 +308,46 @@ class TORCH_API Block : public StmtNode { class TORCH_API Store : public StmtNode { public: - Var* base_handle() const { + VarPtr base_handle() const { return buf_->base_handle(); } - std::vector indices() const { + std::vector indices() const { return indices_; } - Expr* flat_index() const { + ExprPtr flat_index() const { TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened."); return indices_[0]; } - Expr* value() const { + ExprPtr value() const { return value_; } - Buf* buf() const { + BufPtr buf() const { return buf_; } - void set_buf(Buf* buf) { + void set_buf(BufPtr buf) { buf_ = buf; } - void set_indices(std::vector indices) { + void set_indices(std::vector indices) { indices_ = std::move(indices); } - void set_value(Expr* value) { + void set_value(ExprPtr value) { value_ = value; } - static Store* make( + static StorePtr make( const BufHandle& buf, const std::vector& indices, const ExprHandle& value); - Store(Buf* buf, std::vector indices, Expr* value); + Store(BufPtr buf, std::vector indices, ExprPtr value); private: - Buf* buf_; - std::vector indices_; - Expr* value_; + BufPtr buf_; + std::vector indices_; + ExprPtr value_; }; // Allocate a buffer of given shapes and dtypes and bind it with the given @@ -357,11 +355,11 @@ class TORCH_API Store : public StmtNode { // explicitly freed. An unfreed memory is likely considered an error. class TORCH_API Allocate : public StmtNode { public: - static Allocate* make(const BufHandle& buf_handle) { - return new Allocate(buf_handle.node()); + static AllocatePtr make(const BufHandle& buf_handle) { + return alloc(buf_handle.node()); } - Var* buffer_var() const { + VarPtr buffer_var() const { return buf_->base_handle(); } @@ -369,149 +367,149 @@ class TORCH_API Allocate : public StmtNode { return buf_->dtype(); } - const std::vector dims() const { + const std::vector dims() const { return buf_->dims(); } - Buf* buf() const { + BufPtr buf() const { return buf_; } - void set_buf(Buf* buf) { + void set_buf(BufPtr buf) { buf_ = buf; } - explicit Allocate(Buf* buf) : buf_(buf) {} + explicit Allocate(BufPtr buf) : buf_(buf) {} private: - Buf* buf_; + BufPtr buf_; // TODO: add memory types. }; // Free the specific buffer. It is an error. class TORCH_API Free : public StmtNode { public: - static Free* make(const BufHandle& buf_handle) { - return new Free(buf_handle.node()); + static FreePtr make(const BufHandle& buf_handle) { + return alloc(buf_handle.node()); } - Var* buffer_var() const { + VarPtr buffer_var() const { return buf_->base_handle(); } - Buf* buf() const { + BufPtr buf() const { return buf_; } - void set_buf(Buf* buf) { + void set_buf(BufPtr buf) { buf_ = buf; } - explicit Free(Buf* buf) : buf_(buf) {} + explicit Free(BufPtr buf) : buf_(buf) {} private: - Buf* buf_; + BufPtr buf_; }; class TORCH_API Let : public StmtNode { public: - static Let* make(const VarHandle& var, const ExprHandle& val) { - return new Let(var.node(), val.node()); + static LetPtr make(const VarHandle& var, const ExprHandle& val) { + return alloc(var.node(), val.node()); } - Let(Var* var, Expr* val) : dtype_(var->dtype()), var_(var), val_(val) {} + Let(VarPtr var, ExprPtr val) : dtype_(var->dtype()), var_(var), val_(val) {} Dtype dtype() const { return dtype_; } - Var* var() const { + VarPtr var() const { return var_; } - Expr* value() const { + ExprPtr value() const { return val_; } - void set_var(Var* var) { + void set_var(VarPtr var) { var_ = var; } - void set_val(Expr* val) { + void set_val(ExprPtr val) { val_ = val; } private: Dtype dtype_; - Var* var_; - Expr* val_; + VarPtr var_; + ExprPtr val_; }; class TORCH_API Cond : public StmtNode { public: - static Cond* make( + static CondPtr make( const ExprHandle& condition, - Stmt* true_stmt, - Stmt* false_stmt) { - return new Cond(condition.node(), true_stmt, false_stmt); + StmtPtr true_stmt, + StmtPtr false_stmt) { + return alloc(condition.node(), true_stmt, false_stmt); } - Expr* condition() const { + ExprPtr condition() const { return condition_; } - Block* true_stmt() const { + BlockPtr true_stmt() const { return true_stmt_; } - Block* false_stmt() const { + BlockPtr false_stmt() const { return false_stmt_; } - void set_condition(Expr* condition) { + void set_condition(ExprPtr condition) { condition_ = condition; } - void set_true_stmt(Stmt* true_stmt) { + void set_true_stmt(StmtPtr true_stmt) { if (true_stmt) { - Block* b = dynamic_cast(true_stmt); + BlockPtr b = to(true_stmt); if (!b) { - b = new Block({true_stmt}); + b = alloc(std::vector({true_stmt})); } true_stmt_ = b; set_parent(true_stmt_, this); } } - void set_false_stmt(Stmt* false_stmt) { + void set_false_stmt(StmtPtr false_stmt) { if (false_stmt) { - Block* b = dynamic_cast(false_stmt); + BlockPtr b = to(false_stmt); if (!b) { - b = new Block({false_stmt}); + b = alloc(std::vector({false_stmt})); } false_stmt_ = b; set_parent(false_stmt_, this); } } - Cond(Expr* condition, Stmt* true_stmt, Stmt* false_stmt) + Cond(ExprPtr condition, StmtPtr true_stmt, StmtPtr false_stmt) : condition_(condition) { set_true_stmt(true_stmt); set_false_stmt(false_stmt); } - Cond* cloneWithNewBodies(Stmt* true_stmt, Stmt* false_stmt) { - return new Cond(condition_, true_stmt, false_stmt); + CondPtr cloneWithNewBodies(StmtPtr true_stmt, StmtPtr false_stmt) { + return alloc(condition_, true_stmt, false_stmt); } - Cond* cloneWithNewBody(Stmt* true_stmt) { - return new Cond(condition_, true_stmt, nullptr); + CondPtr cloneWithNewBody(StmtPtr true_stmt) { + return alloc(condition_, true_stmt, nullptr); } private: - Expr* condition_; - Block* true_stmt_ = nullptr; - Block* false_stmt_ = nullptr; + ExprPtr condition_; + BlockPtr true_stmt_ = nullptr; + BlockPtr false_stmt_ = nullptr; }; class TORCH_API LoopOptions { @@ -630,11 +628,11 @@ class TORCH_API LoopOptions { !is_parallel_; } - void set_buffer_mapping(const std::unordered_map& map) { + void set_buffer_mapping(const std::unordered_map& map) { map_input_to_tensor_bufs_ = map; } - std::unordered_map get_buffer_mapping() const { + std::unordered_map get_buffer_mapping() const { return map_input_to_tensor_bufs_; } @@ -642,59 +640,64 @@ class TORCH_API LoopOptions { int gpu_block_index_{IDX_UNSET}; int gpu_thread_index_{IDX_UNSET}; bool is_parallel_{false}; - std::unordered_map map_input_to_tensor_bufs_; + std::unordered_map map_input_to_tensor_bufs_; }; class TORCH_API For : public StmtNode { public: - Var* var() const { + VarPtr var() const { return var_; } - Expr* start() const { + ExprPtr start() const { return start_; } - Expr* stop() const { + ExprPtr stop() const { return stop_; } - Block* body() const { + BlockPtr body() const { return body_; } - static For* make( + static ForPtr make( const VarHandle& var, const ExprHandle& start, const ExprHandle& stop, - Stmt* body) { + StmtPtr body) { if (!body) { return nullptr; } - return new For(var.node(), start.node(), stop.node(), body); + return alloc(var.node(), start.node(), stop.node(), body); } - static For* make( + static ForPtr make( const VarHandle& var, const ExprHandle& start, const ExprHandle& stop, - Stmt* body, + StmtPtr body, const LoopOptions& loop_options) { if (!body) { return nullptr; } - return new For(var.node(), start.node(), stop.node(), body, loop_options); + return alloc( + var.node(), start.node(), stop.node(), body, loop_options); } const LoopOptions loop_options() const { return loop_options_; } - For(Var* var, Expr* start, Expr* stop, Stmt* body) + For(VarPtr var, ExprPtr start, ExprPtr stop, StmtPtr body) : var_(var), start_(start), stop_(stop) { - Block* b = dynamic_cast(body); + BlockPtr b = to(body); if (!b) { - b = new Block({body}); + b = alloc(std::vector({body})); } body_ = b; set_parent(body_, this); } - For(Var* var, Expr* start, Expr* stop, Stmt* body, LoopOptions loop_options) + For(VarPtr var, + ExprPtr start, + ExprPtr stop, + StmtPtr body, + LoopOptions loop_options) : var_(var), start_(start), stop_(stop), @@ -709,9 +712,9 @@ class TORCH_API For : public StmtNode { throw malformed_input("invalid Body in For loop", body); } - Block* b = dynamic_cast(body); + BlockPtr b = to(body); if (!b) { - b = new Block({body}); + b = alloc(std::vector({body})); } body_ = b; set_parent(body_, this); @@ -733,47 +736,47 @@ class TORCH_API For : public StmtNode { return loop_options_.is_parallel(); } - void set_buffer_map(const std::unordered_map& map) { + void set_buffer_map(const std::unordered_map& map) { loop_options_.set_buffer_mapping(map); } - For* cloneWithNewBody(Stmt* body) const { - return new For(var_, start_, stop_, body, loop_options_); + ForPtr cloneWithNewBody(StmtPtr body) const { + return alloc(var_, start_, stop_, body, loop_options_); } - Block* removeBody() { + BlockPtr removeBody() { auto res = body_; set_parent(res, nullptr); body_ = nullptr; return res; } - void set_body(Stmt* body) { - Block* b = dynamic_cast(body); + void set_body(StmtPtr body) { + BlockPtr b = to(body); if (!b) { - b = new Block({body}); + b = alloc(std::vector({body})); } body_ = b; set_parent(body_, this); } - void set_start(Expr* start) { + void set_start(ExprPtr start) { start_ = start; } - void set_stop(Expr* stop) { + void set_stop(ExprPtr stop) { stop_ = stop; } - void set_var(Var* var) { + void set_var(VarPtr var) { var_ = var; } private: - Var* var_; - Expr* start_; - Expr* stop_; - Block* body_; + VarPtr var_; + ExprPtr start_; + ExprPtr stop_; + BlockPtr body_; LoopOptions loop_options_; }; @@ -784,46 +787,46 @@ class TORCH_API For : public StmtNode { class TORCH_API AtomicAdd : public StmtNode { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - AtomicAdd(Buf* buf, std::vector indices, Expr* value) + AtomicAdd(BufPtr buf, std::vector indices, ExprPtr value) : buf_(buf), indices_(std::move(indices)), value_(value) {} - Var* base_handle() const { + VarPtr base_handle() const { return buf_->base_handle(); } - Buf* buf() const { + BufPtr buf() const { return buf_; } - Expr* flat_index() const { + ExprPtr flat_index() const { TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened."); return indices_[0]; } - Expr* value() const { + ExprPtr value() const { return value_; } - const std::vector& indices() const { + const std::vector& indices() const { return indices_; } - void set_buf(Buf* buf) { + void set_buf(BufPtr buf) { buf_ = buf; } - void set_indices(std::vector indices) { + void set_indices(std::vector indices) { indices_ = std::move(indices); } - void set_value(Expr* value) { + void set_value(ExprPtr value) { value_ = value; } private: - Buf* buf_; - std::vector indices_; - Expr* value_; + BufPtr buf_; + std::vector indices_; + ExprPtr value_; }; class TORCH_API SyncThreads : public StmtNode { @@ -852,13 +855,13 @@ class TORCH_API SyncThreads : public StmtNode { */ class TORCH_API ExternalCall : public StmtNode { public: - static ExternalCall* make( + static ExternalCallPtr make( BufHandle buf, const std::string& func_name, const std::vector& buf_args, const std::vector& args); - Buf* buf() const { + BufPtr buf() const { return buf_; } @@ -866,42 +869,42 @@ class TORCH_API ExternalCall : public StmtNode { return func_name_; } - std::vector buf_args() const { + std::vector buf_args() const { return buf_args_; } - std::vector args() const { + std::vector args() const { return args_; } - void set_buf(Buf* buf) { + void set_buf(BufPtr buf) { buf_ = buf; } - void set_buf_args(std::vector buf_args) { + void set_buf_args(std::vector buf_args) { buf_args_ = std::move(buf_args); } - void set_args(std::vector args) { + void set_args(std::vector args) { args_ = std::move(args); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ExternalCall( - Buf* buf, + BufPtr buf, std::string func_name, - std::vector buf_args, - std::vector args) + std::vector buf_args, + std::vector args) : buf_(buf), func_name_(std::move(func_name)), buf_args_(std::move(buf_args)), args_(std::move(args)) {} private: - Buf* buf_; + BufPtr buf_; std::string func_name_; - std::vector buf_args_; - std::vector args_; + std::vector buf_args_; + std::vector args_; }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index 3c62926..9df70f8 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -9,14 +9,14 @@ namespace torch { namespace jit { namespace tensorexpr { -Stmt* Tensor::constructStmt( - const std::vector& args, - Expr* body, - const std::vector& reduce_dims, - const std::vector& reduce_args) const { - std::vector indices(args.begin(), args.end()); +StmtPtr Tensor::constructStmt( + const std::vector& args, + ExprPtr body, + const std::vector& reduce_dims, + const std::vector& reduce_args) const { + std::vector indices(args.begin(), args.end()); - Stmt* s = new Store(buf_, indices, body); + StmtPtr s = alloc(buf_, indices, body); size_t ndim = buf()->ndim(); size_t reduce_ndim = reduce_dims.size(); @@ -25,25 +25,25 @@ Stmt* Tensor::constructStmt( return s; } - Expr* init_expr = buf()->initializer(); + ExprPtr init_expr = buf()->initializer(); if (reduce_ndim > 0) { for (const auto i : c10::irange(reduce_ndim)) { // Going in reverse order: from innermost loop to the outermost size_t dim_index = reduce_ndim - i - 1; - s = new For( - reduce_args[dim_index], new IntImm(0), reduce_dims[dim_index], s); + s = alloc( + reduce_args[dim_index], alloc(0), reduce_dims[dim_index], s); } if (init_expr) { - Store* init_stmt = new Store(buf(), indices, init_expr); - s = new Block({init_stmt, s}); + StorePtr init_stmt = alloc(buf(), indices, init_expr); + s = alloc(std::vector({init_stmt, s})); } } for (const auto i : c10::irange(ndim)) { // Going in reverse order: from innermost loop to the outermost size_t dim_index = ndim - i - 1; - s = new For(args[dim_index], new IntImm(0), buf()->dim(dim_index), s); + s = alloc(args[dim_index], alloc(0), buf()->dim(dim_index), s); } return s; } @@ -52,11 +52,11 @@ Tensor* Compute( const std::string& name, const std::vector& dim_args, const std::function&)>& body_func) { - std::vector dims; - std::vector args; + std::vector dims; + std::vector args; unpack_dim_args(dim_args, &dims, &args); - Expr* body = body_func(VarVectorToVarHandleVector(args)).node(); - Buf* buf = new Buf(name, dims, body->dtype()); + ExprPtr body = body_func(VarVectorToVarHandleVector(args)).node(); + BufPtr buf = alloc(name, dims, body->dtype()); return new Tensor(buf, args, body); } @@ -68,11 +68,11 @@ Tensor* Compute( throw malformed_input("mismatch between body and arg size (1)"); } - std::vector dims; - std::vector args; + std::vector dims; + std::vector args; unpack_dim_args(dim_args, &dims, &args); - Expr* body = body_func(VarHandle(args[0])).node(); - Buf* buf = new Buf(name, dims, body->dtype()); + ExprPtr body = body_func(VarHandle(args[0])).node(); + BufPtr buf = alloc(name, dims, body->dtype()); return new Tensor(buf, args, body); } @@ -84,11 +84,11 @@ Tensor* Compute( if (dim_args.size() != 2) { throw malformed_input("mismatch between body and arg size (2)"); } - std::vector dims; - std::vector args; + std::vector dims; + std::vector args; unpack_dim_args(dim_args, &dims, &args); - Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node(); - Buf* buf = new Buf(name, dims, body->dtype()); + ExprPtr body = body_func(VarHandle(args[0]), VarHandle(args[1])).node(); + BufPtr buf = alloc(name, dims, body->dtype()); return new Tensor(buf, args, body); } @@ -101,13 +101,13 @@ Tensor* Compute( if (dim_args.size() != 3) { throw malformed_input("mismatch between body and arg size (3)"); } - std::vector dims; - std::vector args; + std::vector dims; + std::vector args; unpack_dim_args(dim_args, &dims, &args); - Expr* body = + ExprPtr body = body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2])) .node(); - Buf* buf = new Buf(name, dims, body->dtype()); + BufPtr buf = alloc(name, dims, body->dtype()); return new Tensor(buf, args, body); } @@ -122,16 +122,16 @@ Tensor* Compute( if (dim_args.size() != 4) { throw malformed_input("mismatch between body and arg size (4)"); } - std::vector dims; - std::vector args; + std::vector dims; + std::vector args; unpack_dim_args(dim_args, &dims, &args); - Expr* body = body_func( - VarHandle(args[0]), - VarHandle(args[1]), - VarHandle(args[2]), - VarHandle(args[3])) - .node(); - Buf* buf = new Buf(name, dims, body->dtype()); + ExprPtr body = body_func( + VarHandle(args[0]), + VarHandle(args[1]), + VarHandle(args[2]), + VarHandle(args[3])) + .node(); + BufPtr buf = alloc(name, dims, body->dtype()); return new Tensor(buf, args, body); } diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index a2f1b91..3eb02c6 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -15,28 +15,29 @@ namespace tensorexpr { class TORCH_API Tensor : KernelScopedObject { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - Tensor(Buf* buf, const std::vector& args, Expr* body) : buf_(buf) { + Tensor(BufPtr buf, const std::vector& args, ExprPtr body) + : buf_(buf) { stmt_ = constructStmt(args, body, {}, {}); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) Tensor( - Buf* buf, - const std::vector& args, - const std::vector& reduce_dims, - const std::vector& reduce_args, - Expr* body) + BufPtr buf, + const std::vector& args, + const std::vector& reduce_dims, + const std::vector& reduce_args, + ExprPtr body) : buf_(buf) { stmt_ = constructStmt(args, body, reduce_dims, reduce_args); } - Tensor(Buf* buf, Stmt* stmt) : buf_(buf), stmt_(stmt) {} + Tensor(BufPtr buf, StmtPtr stmt) : buf_(buf), stmt_(stmt) {} - Buf* buf() const { + BufPtr buf() const { return buf_; } - Stmt* stmt() const { + StmtPtr stmt() const { return stmt_; } @@ -46,14 +47,14 @@ class TORCH_API Tensor : KernelScopedObject { inline ExprHandle load(const Ts&... ts); private: - Stmt* constructStmt( - const std::vector& args, - Expr* body, - const std::vector& reduce_dims, - const std::vector& reduce_args) const; - - Buf* buf_; - Stmt* stmt_; + StmtPtr constructStmt( + const std::vector& args, + ExprPtr body, + const std::vector& reduce_dims, + const std::vector& reduce_args) const; + + BufPtr buf_; + StmtPtr stmt_; }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) @@ -95,7 +96,7 @@ class Placeholder { explicit Placeholder(const std::vector& dims) : Placeholder(BufHandle("_", dims, kFloat)) {} - Buf* data() const { + BufPtr data() const { return data_; } BufHandle handle() const { @@ -107,10 +108,10 @@ class Placeholder { int ndim() const { return data_->ndim(); } - Expr* dim(int index) const { + ExprPtr dim(int index) const { return data_->dim(index); } - std::vector dims() const { + std::vector dims() const { return data_->dims(); } @@ -122,15 +123,15 @@ class Placeholder { inline ExprHandle load(const std::vector& args) const; - inline Store* store( + inline StorePtr store( const std::vector& args, const ExprHandle& val) const { - return new Store(data(), ExprHandleVectorToExprVector(args), val.node()); + return alloc(data(), ExprHandleVectorToExprVector(args), val.node()); } private: - Buf* data_; - std::vector strides_; + BufPtr data_; + std::vector strides_; }; TORCH_API Tensor* Compute( @@ -163,14 +164,14 @@ TORCH_API Tensor* Compute( inline void unpack_dim_args( const std::vector& dim_args, - std::vector* dims, - std::vector* vars) { + std::vector* dims, + std::vector* vars) { dims->clear(); vars->clear(); for (const DimArg& dim_arg : dim_args) { - Expr* expr = dim_arg.dim().node(); + ExprPtr expr = dim_arg.dim().node(); dims->push_back(expr); - vars->push_back(new Var( + vars->push_back(alloc( dim_arg.name_hint(), expr->dtype().scalar_type() == ScalarType::Long ? kLong : kInt)); } @@ -186,45 +187,45 @@ Tensor* Reduce( const BodyFunc& body_func, const std::vector& reduce_args) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector dims; + std::vector dims; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector vars; + std::vector vars; unpack_dim_args(dim_args, &dims, &vars); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector reduce_dims; + std::vector reduce_dims; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector reduce_vars; + std::vector reduce_vars; unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars); // If reduce_vars is empty, then it's not a reduction, but rather a simple // copy if (reduce_vars.empty()) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* body = + ExprPtr body = Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(vars)) .node(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Buf* func_result = new Buf(func_name, dims, body->dtype()); + BufPtr func_result = alloc(func_name, dims, body->dtype()); return new Tensor(func_result, vars, body); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector all_vars; + std::vector all_vars; all_vars.insert(all_vars.end(), vars.begin(), vars.end()); all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end()); ExprHandle body = Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars)); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector output_args(vars.begin(), vars.end()); + std::vector output_args(vars.begin(), vars.end()); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Expr* init_expr = new Cast( + ExprPtr init_expr = alloc( body.dtype(), init_func(VarVectorToVarHandleVector(vars)).node()); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Buf* func_result = new Buf(func_name, dims, body.dtype(), init_expr); + BufPtr func_result = alloc(func_name, dims, body.dtype(), init_expr); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - ReduceOp* reduce_op = reducer(func_result, body, output_args, reduce_vars); + ReduceOpPtr reduce_op = reducer(func_result, body, output_args, reduce_vars); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Tensor* t = new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op); @@ -300,14 +301,14 @@ template inline ExprHandle Placeholder::load(const Ts&... ts) const { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector params({ExprHandle(ts)...}); - return ExprHandle(new Load(data(), ExprHandleVectorToExprVector(params))); + return ExprHandle(alloc(data(), ExprHandleVectorToExprVector(params))); } template inline ExprHandle Placeholder::load(const std::vector& args) const { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector params(args.begin(), args.end()); - return ExprHandle(new Load(data(), ExprHandleVectorToExprVector(params))); + return ExprHandle(alloc(data(), ExprHandleVectorToExprVector(params))); } inline ExprHandle Placeholder::load(const std::vector& args) const { diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index 5b28716..304a317 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -190,7 +190,7 @@ void initTensorExprBindings(PyObject* module) { py::return_value_policy::reference); py::class_>(te, "Tensor") .def(py::init( - [](BufHandle& b, Stmt* s) { return new Tensor(b.node(), s); })) + [](BufHandle& b, StmtPtr s) { return new Tensor(b.node(), s); })) .def( "load", [](Tensor& self, const std::vector& v) { @@ -322,7 +322,7 @@ void initTensorExprBindings(PyObject* module) { py::return_value_policy::reference); py::class_>(te, "Stmt") - .def(py::init([](const std::vector& stmts) { + .def(py::init([](const std::vector& stmts) { return tensorexpr::Block::make(stmts); })) .def("__str__", [](Stmt& self) { @@ -362,14 +362,16 @@ void initTensorExprBindings(PyObject* module) { [](const VarHandle& var, const ExprHandle& start, const ExprHandle& stop, - Stmt* body) { return For::make(var, start, stop, body); }, + StmtPtr body) { return For::make(var, start, stop, body); }, py::return_value_policy::reference); py::class_>(te, "Cond") .def_static( "make", - [](const ExprHandle& condition, Stmt* true_stmt, Stmt* false_stmt) { - return new Cond(condition.node(), true_stmt, false_stmt); + [](const ExprHandle& condition, + StmtPtr true_stmt, + StmtPtr false_stmt) { + return alloc(condition.node(), true_stmt, false_stmt); }, py::return_value_policy::reference) .def("true_stmt", &Cond::true_stmt, py::return_value_policy::reference) @@ -379,7 +381,7 @@ void initTensorExprBindings(PyObject* module) { tensorexpr::Block, Stmt, std::unique_ptr>(te, "Block") - .def(py::init([](const std::vector& stmts) { + .def(py::init([](const std::vector& stmts) { return tensorexpr::Block::make(stmts); })) .def( @@ -392,8 +394,8 @@ void initTensorExprBindings(PyObject* module) { py::class_(te, "LoopNest") .def(py::init&>()) - .def(py::init([](Stmt* s, const std::vector& bufs) { - std::unordered_set buf_nodes; + .def(py::init([](StmtPtr s, const std::vector& bufs) { + std::unordered_set buf_nodes; for (auto& buf : bufs) { buf_nodes.insert(buf.node()); } @@ -427,7 +429,7 @@ void initTensorExprBindings(PyObject* module) { py::return_value_policy::reference) .def( "get_enclosing_loopnest", - [](const LoopNest& self, Stmt* s) { + [](const LoopNest& self, StmtPtr s) { return self.getEnclosingLoopNest(s); }, py::return_value_policy::reference) @@ -445,117 +447,119 @@ void initTensorExprBindings(PyObject* module) { py::return_value_policy::reference) .def( "get_loop_at", - [](const LoopNest& self, For* root, const std::vector& indices) { + [](const LoopNest& self, + ForPtr root, + const std::vector& indices) { return self.getLoopAt(root, indices); }, py::return_value_policy::reference) .def( "get_parent_loop", - [](const LoopNest& self, Stmt* s) { return self.getParentLoop(s); }, + [](const LoopNest& self, StmtPtr s) { return self.getParentLoop(s); }, py::return_value_policy::reference) .def_static( "get_loop_stmts_in_loopnest", - [](For* f, size_t num) { + [](ForPtr f, size_t num) { return LoopNest::getLoopStmtsInLoopNest(f, num); }, py::return_value_policy::reference) .def( "split_with_tail", - [](For* f, int factor) { - For *inner = nullptr, *tail = nullptr; + [](ForPtr f, int factor) { + ForPtr inner = nullptr, tail = nullptr; LoopNest::splitWithTail(f, factor, &inner, &tail); return std::make_tuple(inner, tail); }, py::return_value_policy::reference) .def( "split_with_mask", - [](For* f, int factor) { - For* inner = nullptr; + [](ForPtr f, int factor) { + ForPtr inner = nullptr; LoopNest::splitWithMask(f, factor, &inner); return inner; }, py::return_value_policy::reference) .def( "slice_head", - [](For* f, int factor) { - For *head = nullptr, *tail = nullptr; + [](ForPtr f, int factor) { + ForPtr head = nullptr, tail = nullptr; LoopNest::sliceHead(f, factor, &head, &tail); return std::make_tuple(head, tail); }, py::return_value_policy::reference) .def( "slice_tail", - [](For* f, int factor) { - For *head = nullptr, *tail = nullptr; + [](ForPtr f, int factor) { + ForPtr head = nullptr, tail = nullptr; LoopNest::sliceTail(f, factor, &head, &tail); return std::make_tuple(head, tail); }, py::return_value_policy::reference) .def_static( "normalize", - [](For* f) { + [](ForPtr f) { LoopNest::normalize(f); return f; }, py::return_value_policy::reference) .def( "tile", - [](LoopNest& self, For* x, For* y, int x_factor, int y_factor) { + [](LoopNest& self, ForPtr x, ForPtr y, int x_factor, int y_factor) { return self.tile(x, y, x_factor, y_factor); }, py::return_value_policy::reference) .def_static( "distribute_loop", - [](For* f) { return LoopNest::distributeLoop(f); }, + [](ForPtr f) { return LoopNest::distributeLoop(f); }, py::return_value_policy::reference) .def_static( "distribute_loop", - [](For* f, const std::unordered_set& pivots) { + [](ForPtr f, const std::unordered_set& pivots) { return LoopNest::distributeLoop(f, pivots); }, py::return_value_policy::reference) .def_static( "distribute_loop_over_inner_loops", - [](For* f) { return LoopNest::distributeLoopOverInnerLoops(f); }, + [](ForPtr f) { return LoopNest::distributeLoopOverInnerLoops(f); }, py::return_value_policy::reference) .def_static( "unsafe_fuse_loops", - [](const std::vector& loops) { - For* fused_loop = nullptr; + [](const std::vector& loops) { + ForPtr fused_loop = nullptr; LoopNest::unsafeFuseLoops(loops, &fused_loop); return fused_loop; }, py::return_value_policy::reference) .def_static( "fuse_loops", - [](const std::vector& loops) { - For* fused_loop = nullptr; + [](const std::vector& loops) { + ForPtr fused_loop = nullptr; LoopNest::fuseLoops(loops, &fused_loop); return fused_loop; }, py::return_value_policy::reference) .def_static( "reorder", - [](const std::vector& loops, + [](const std::vector& loops, const std::vector& permutation) { return LoopNest::reorder(loops, permutation); }, py::return_value_policy::reference) .def( "unroll", - [](const LoopNest& self, For* f) { - Stmt* unrolled = nullptr; + [](const LoopNest& self, ForPtr f) { + StmtPtr unrolled = nullptr; self.unroll(f, &unrolled); return unrolled; }, py::return_value_policy::reference) .def( "vectorize", - [](For* f) { LoopNest::vectorize(f); }, + [](ForPtr f) { LoopNest::vectorize(f); }, py::return_value_policy::reference) .def_static( "compress_buffer", - [](BufHandle& buf, Stmt* stmt) { + [](BufHandle& buf, StmtPtr stmt) { return LoopNest::compressBuffer(buf.node(), stmt); }, py::return_value_policy::reference) @@ -563,16 +567,18 @@ void initTensorExprBindings(PyObject* module) { "cache_accesses", [](const BufHandle& producer, const std::string& name, - Stmt* consumer) { - std::pair ret = + StmtPtr consumer) { + std::pair ret = LoopNest::cacheAccesses(producer.node(), name, consumer); return std::make_pair(BufHandle(ret.first), ret.second); }, py::return_value_policy::reference) - .def("compute_at", [](Stmt* s, For* at) { LoopNest::computeAt(s, at); }) + .def( + "compute_at", + [](StmtPtr s, ForPtr at) { LoopNest::computeAt(s, at); }) .def( "compute_inline", - [](LoopNest& self, Stmt* s) { self.computeInline(s); }, + [](LoopNest& self, StmtPtr s) { self.computeInline(s); }, py::return_value_policy::reference) .def( "compute_inline", @@ -582,16 +588,16 @@ void initTensorExprBindings(PyObject* module) { py::return_value_policy::reference) .def( "rfactor", - [](Stmt* s, For* target_for) { - Buf* rfac_buf = nullptr; + [](StmtPtr s, ForPtr target_for) { + BufPtr rfac_buf = nullptr; LoopNest::rfactor(s, target_for, &rfac_buf); return BufHandle(rfac_buf); }, py::return_value_policy::reference) .def( "flatten", - [](const std::vector& loops) { - For* flattened = nullptr; + [](const std::vector& loops) { + ForPtr flattened = nullptr; LoopNest::flatten(loops, &flattened); return flattened; }, @@ -623,7 +629,7 @@ void initTensorExprBindings(PyObject* module) { te.def( "simplify", - [](Stmt* stmt) { return IRSimplifier::simplify(stmt); }, + [](StmtPtr stmt) { return IRSimplifier::simplify(stmt); }, py::return_value_policy::reference); te.def( @@ -779,7 +785,7 @@ void initTensorExprBindings(PyObject* module) { te.def( "construct_codegen", [](const std::string& name, - Stmt* stmt, + StmtPtr stmt, const std::vector& args) { CodeGen* cg = nullptr; if (name == "llvm") { diff --git a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp index 3ed82bb..bbd9fd7 100644 --- a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp +++ b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp @@ -8,7 +8,7 @@ namespace torch { namespace jit { namespace tensorexpr { -const std::string& UniqueNameManager::get_unique_name(Var* v) { +const std::string& UniqueNameManager::get_unique_name(VarPtr v) { // Find if we have already encountered this variable. auto iter = unique_name_mapping_.find(v); if (iter != unique_name_mapping_.end()) { diff --git a/torch/csrc/jit/tensorexpr/unique_name_manager.h b/torch/csrc/jit/tensorexpr/unique_name_manager.h index 1e945ab..a52e759 100644 --- a/torch/csrc/jit/tensorexpr/unique_name_manager.h +++ b/torch/csrc/jit/tensorexpr/unique_name_manager.h @@ -5,6 +5,7 @@ #include #include +#include namespace torch { namespace jit { @@ -13,7 +14,7 @@ namespace tensorexpr { class VarHandle; class Var; -using VarNameMap = std::unordered_map; +using VarNameMap = std::unordered_map; // A manager to get unique names from vars. // It starts with the name hints of the var and append "_" + $counter until it @@ -23,7 +24,7 @@ class TORCH_API UniqueNameManager { public: const std::string& get_unique_name(const VarHandle& v); - const std::string& get_unique_name(Var* v); + const std::string& get_unique_name(VarPtr v); private: friend class ScopedVarName; diff --git a/torch/csrc/jit/tensorexpr/var_substitutor.h b/torch/csrc/jit/tensorexpr/var_substitutor.h index f5a6227..e036114 100644 --- a/torch/csrc/jit/tensorexpr/var_substitutor.h +++ b/torch/csrc/jit/tensorexpr/var_substitutor.h @@ -13,15 +13,15 @@ namespace torch { namespace jit { namespace tensorexpr { -using VarMapping = std::vector>; +using VarMapping = std::vector>; class VarSubMutator : public IRMutator { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) VarSubMutator(const VarMapping& var_mapping) { for (auto& entry : var_mapping) { - Var* key_var = entry.first; - Expr* value = entry.second; + VarPtr key_var = entry.first; + ExprPtr value = entry.second; if (!key_var) { throw malformed_input("missing key in VarSubMutator"); } @@ -29,7 +29,7 @@ class VarSubMutator : public IRMutator { } } - Expr* mutate(Var* var) override { + ExprPtr mutate(VarPtr var) override { auto iter = var_mapping_.find(var); if (iter == var_mapping_.end()) { return var; @@ -37,14 +37,14 @@ class VarSubMutator : public IRMutator { return iter->second; } - Expr* mutate(ReduceOp* var) override { + ExprPtr mutate(ReduceOpPtr var) override { auto body = var->body()->accept_mutator(this); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector new_inner; + std::vector new_inner; - for (auto* v : var->reduce_args()) { - Expr* e = v->accept_mutator(this); - if (Var* new_var = dynamic_cast(e)) { + for (auto v : var->reduce_args()) { + ExprPtr e = v->accept_mutator(this); + if (VarPtr new_var = to(e)) { new_inner.push_back(new_var); } else { VarFinder varFinder; @@ -54,11 +54,11 @@ class VarSubMutator : public IRMutator { } } - return new ReduceOp(body, new_inner, var->reducer()); + return alloc(body, new_inner, var->reducer()); } private: - std::unordered_map var_mapping_; + std::unordered_map var_mapping_; }; } // namespace tensorexpr -- 2.7.4