[nnc] Updated inlining to handle cases when producer indices are constants after...
authorRaghavan Raman <raghavanr@fb.com>
Fri, 17 Sep 2021 17:50:43 +0000 (10:50 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 17 Sep 2021 18:28:48 +0000 (11:28 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65044

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30954655

Pulled By: navahgar

fbshipit-source-id: dfaedb5af710b2625ceec3a443a6c4e34158ab16

test/cpp/tensorexpr/test_loopnest.cpp
torch/csrc/jit/tensorexpr/loopnest.cpp

index d1c5da5..f8e73e9 100644 (file)
@@ -1704,7 +1704,7 @@ TEST(LoopNest, ScheduleInlineWithCompoundIndices) {
     # CHECK-NEXT:   B[)IR");
 }
 
-TEST(LoopNest, ScheduleInlineBufferIndicesWithCast) {
+TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) {
   // Input IR:
   //     for (int64_t i = 0; i < 100; i++) {
   //       A[0ll,i] = i * 500ll;
@@ -1737,7 +1737,45 @@ TEST(LoopNest, ScheduleInlineBufferIndicesWithCast) {
   auto par = Block::make({forI, forJ});
 
   LoopNest l(par, {b_buf.node()});
-  l.computeInline(a_buf.node());
+  ASSERT_TRUE(l.computeInline(a_buf.node()));
+
+  checkIR(l.root_stmt(), R"IR(
+    # CHECK: for (int64_t j = 0; j < 100; j++) {
+    # CHECK:   B[0ll, j] = j * 500ll + j * 100ll;
+    # CHECK: })IR");
+}
+
+TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) {
+  // Input IR:
+  //     for (int64_t i = 0; i < 100; i++) {
+  //       A[(int64_t)0,i] = i * 500ll;
+  //     }
+  //     for (int64_t j = 0; j < 100; j++) {
+  //       B[0ll,j] = A[0ll, j] + j * 100ll;
+  //     }
+  BufHandle a_buf("A", {20, 100}, kLong);
+  BufHandle b_buf("B", {20, 100}, kLong);
+  VarHandle i("i", kLong);
+  VarHandle j("j", kLong);
+  auto forI = For::make(
+      i,
+      0,
+      100,
+      Store::make(a_buf, {0, i}, Mul::make(i, static_cast<int64_t>(500))));
+  auto forJ = For::make(
+      j,
+      0,
+      100,
+      Store::make(
+          b_buf,
+          {static_cast<int64_t>(0), j},
+          Add::make(
+              Load::make(a_buf, {static_cast<int64_t>(0), j}),
+              Mul::make(j, static_cast<int64_t>(100)))));
+  auto par = Block::make({forI, forJ});
+
+  LoopNest l(par, {b_buf.node()});
+  ASSERT_TRUE(l.computeInline(a_buf.node()));
 
   checkIR(l.root_stmt(), R"IR(
     # CHECK: for (int64_t j = 0; j < 100; j++) {
index 755af43..af80439 100644 (file)
@@ -655,16 +655,15 @@ class FunctionInliner : public IRMutator {
       if (auto index_var = to<Var>(i)) {
         index_vars_.insert(index_var);
         producer_index_vars_.push_back(index_var);
-      } else if (intValue(i)) {
+      } else {
         // If the index can be a constant, then that dimension must have size 1
         // (since we don't support in-place writes). Resolves issue 52581.
-        if (*intValue(i) != 0) {
+        auto index_val = evalInt(i);
+        if (!index_val || *index_val != 0) {
           success_ = false;
+          break;
         }
         producer_index_vars_.push_back(nullptr);
-      } else {
-        // Cannot inline Buf with compound indices
-        success_ = false;
       }
     }
   }