[TensorExpr] Simplify TE IR before applying any transformations. (#64717)
authorMikhail Zolotukhin <mvz@fb.com>
Fri, 10 Sep 2021 01:48:17 +0000 (18:48 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 01:50:51 +0000 (18:50 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64717

This also exposed several bugs, which are fixed in this PR.

Differential Revision:
D30826408
D30826408

Test Plan: Imported from OSS

Reviewed By: navahgar

Pulled By: ZolotukhinM

fbshipit-source-id: a67ec5739aceed9ffdf0d24f77eb3787cefe4560

test/cpp/tensorexpr/test_kernel.cpp
torch/csrc/jit/tensorexpr/ir.h
torch/csrc/jit/tensorexpr/kernel.cpp
torch/csrc/jit/tensorexpr/loopnest.cpp

index 1a6d086..0bcef07 100644 (file)
@@ -248,7 +248,11 @@ TEST_F(Kernel, Huge) {
   TensorExprKernel k(graph);
   std::ostringstream oss;
   oss << *k.getCodeGenStmt();
-  const std::string& verification_pattern = "# CHECK: 4000000000";
+  // The 4000000000 iterations loop will be split into 500000000 x 8 and the
+  // outer loop will be parallel. If LLVM is not present, it will not be split,
+  // and to cover both of these cases we're looking for 00000000ll; in the
+  // output.
+  const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR";
   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 }
 
@@ -629,17 +633,15 @@ TEST_F(Kernel, CatWoConditionals) {
   const std::string& verification_pattern =
       R"IR(
 # CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat
 # CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat
 # CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat)IR";
+# CHECK: aten_cat
+# CHECK: for
+# CHECK: for
+# CHECK: aten_cat
+# CHECK: for
+# CHECK: for
+# CHECK: aten_cat)IR";
   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 
   auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
index 65a362e..4448ea4 100644 (file)
@@ -902,11 +902,6 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
   IntrinsicsOp op_type_;
 };
 
-class Polynomial;
-class Term;
-class MaxTerm;
-class MinTerm;
-
 TORCH_API std::vector<ExprPtr> ExprHandleVectorToExprVector(
     const std::vector<ExprHandle>&);
 TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
index 17e4c96..cfb27ee 100644 (file)
@@ -2551,6 +2551,9 @@ void fuseAllLoops(StmtPtr st) {
       }
       loopsToFuse.push_back(loop);
     }
+    if (loopsToFuse.empty()) {
+      return;
+    }
     if (!loopBoundsAllEqual(loopsToFuse)) {
       return;
     }
@@ -2658,6 +2661,8 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
     auto root_stmt = l.root_stmt();
     root_stmt->accept(block_analysis.get());
   }
+  l.simplify();
+  GRAPH_DEBUG("after simplify", *l.root_stmt());
 
   // Inlining output & intermediate buffers can duplicate computation.
   // Duplicating work can slow down the program if it's not ameliorated in some
@@ -3030,6 +3035,7 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
   //     cur_idx = absolute // stride
   //     absolute = absolute % stride
 
+  auto zero = LongImm::make(0);
   return Compute(
       "output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
         std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
@@ -3042,17 +3048,17 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
             reverse_sort_indices(strides);
         std::vector<ExprHandle> new_axes(sorted_stride_indices.size());
         for (size_t stride_index : sorted_stride_indices) {
-          auto stride = strides[stride_index];
           auto size = sizes[stride_index];
-          auto index = absolute_position /
-              ExprHandle(immLike(absolute_position, stride));
+          auto index = zero;
           if (size != 1) {
+            auto stride = strides[stride_index];
+            index = absolute_position /
+                ExprHandle(immLike(absolute_position, stride));
             absolute_position = absolute_position %
                 ExprHandle(immLike(absolute_position, stride));
           }
           new_axes[stride_index] = index;
         }
-        // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
         return BufHandle(buf).load(new_axes);
       });
 }
index 0750b34..3742b93 100644 (file)
@@ -747,6 +747,9 @@ bool LoopNest::computeInline(StmtPtr s) {
 bool LoopNest::computeInline(BufPtr b) {
   // If buf is used or defined in an ExternalCall, we cannot inline it
   auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
+  if (!buf_load_store_uses.count(b)) {
+    return false;
+  }
   for (auto& use : buf_load_store_uses.at(b)) {
     StmtPtr s = use.s;
     if (to<ExternalCall>(s)) {