From: Mikhail Zolotukhin Date: Fri, 10 Sep 2021 01:48:17 +0000 (-0700) Subject: [TensorExpr] Simplify TE IR before applying any transformations. (#64717) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~318 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a17d6c7f80e429a1c21413329a2d73cadb11a6ef;p=platform%2Fupstream%2Fpytorch.git [TensorExpr] Simplify TE IR before applying any transformations. (#64717) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64717 This also exposed several bugs, which are fixed in this PR. Differential Revision: D30826408 D30826408 Test Plan: Imported from OSS Reviewed By: navahgar Pulled By: ZolotukhinM fbshipit-source-id: a67ec5739aceed9ffdf0d24f77eb3787cefe4560 --- diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 1a6d086..0bcef07 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -248,7 +248,11 @@ TEST_F(Kernel, Huge) { TensorExprKernel k(graph); std::ostringstream oss; oss << *k.getCodeGenStmt(); - const std::string& verification_pattern = "# CHECK: 4000000000"; + // The 4000000000 iterations loop will be split into 500000000 x 8 and the + // outer loop will be parallel. If LLVM is not present, it will not be split, + // and to cover both of these cases we're looking for 00000000ll; in the + // output. + const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -629,17 +633,15 @@ TEST_F(Kernel, CatWoConditionals) { const std::string& verification_pattern = R"IR( # CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat # CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat # CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat)IR"; +# CHECK: aten_cat +# CHECK: for +# CHECK: for +# CHECK: aten_cat +# CHECK: for +# CHECK: for +# CHECK: aten_cat)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 65a362e..4448ea4 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -902,11 +902,6 @@ class TORCH_API Intrinsics : public ExprNode { IntrinsicsOp op_type_; }; -class Polynomial; -class Term; -class MaxTerm; -class MinTerm; - TORCH_API std::vector ExprHandleVectorToExprVector( const std::vector&); TORCH_API std::vector ExprVectorToExprHandleVector( diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 17e4c96..cfb27ee 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -2551,6 +2551,9 @@ void fuseAllLoops(StmtPtr st) { } loopsToFuse.push_back(loop); } + if (loopsToFuse.empty()) { + return; + } if (!loopBoundsAllEqual(loopsToFuse)) { return; } @@ -2658,6 +2661,8 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { auto root_stmt = l.root_stmt(); root_stmt->accept(block_analysis.get()); } + l.simplify(); + GRAPH_DEBUG("after simplify", *l.root_stmt()); // Inlining output & intermediate buffers can duplicate computation. // Duplicating work can slow down the program if it's not ameliorated in some @@ -3030,6 +3035,7 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { // cur_idx = absolute // stride // absolute = absolute % stride + auto zero = LongImm::make(0); return Compute( "output_1", dims, [&](const std::vector& axes_input) { std::vector axes(axes_input.begin(), axes_input.end()); @@ -3042,17 +3048,17 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { reverse_sort_indices(strides); std::vector new_axes(sorted_stride_indices.size()); for (size_t stride_index : sorted_stride_indices) { - auto stride = strides[stride_index]; auto size = sizes[stride_index]; - auto index = absolute_position / - ExprHandle(immLike(absolute_position, stride)); + auto index = zero; if (size != 1) { + auto stride = strides[stride_index]; + index = absolute_position / + ExprHandle(immLike(absolute_position, stride)); absolute_position = absolute_position % ExprHandle(immLike(absolute_position, stride)); } new_axes[stride_index] = index; } - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return BufHandle(buf).load(new_axes); }); } diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 0750b34..3742b93 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -747,6 +747,9 @@ bool LoopNest::computeInline(StmtPtr s) { bool LoopNest::computeInline(BufPtr b) { // If buf is used or defined in an ExternalCall, we cannot inline it auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); + if (!buf_load_store_uses.count(b)) { + return false; + } for (auto& use : buf_load_store_uses.at(b)) { StmtPtr s = use.s; if (to(s)) {