[nnc] Handled cast in index expression during inlining (#64716)
authorRaghavan Raman <raghavanr@fb.com>
Thu, 9 Sep 2021 15:26:16 +0000 (08:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 15:30:52 +0000 (08:30 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64716

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30826388

Pulled By: navahgar

fbshipit-source-id: 7e446602f650527e0d954e437f0370602019e040

test/cpp/tensorexpr/test_loopnest.cpp
torch/csrc/jit/tensorexpr/eval.cpp
torch/csrc/jit/tensorexpr/eval.h
torch/csrc/jit/tensorexpr/loopnest.cpp

index b1d59a1..feae1ca 100644 (file)
@@ -1663,6 +1663,47 @@ TEST(LoopNest, ScheduleInlineOutputTensors) {
 # CHECK:       y[m2, n2, k2] = (k2 * m2) * n2 + m2;)IR");
 }
 
+TEST(LoopNest, ScheduleInlineBufferIndicesWithCast) {
+  // Input IR:
+  //     for (int64_t i = 0; i < 100; i++) {
+  //       A[0ll,i] = i * 500ll;
+  //     }
+  //     for (int64_t j = 0; j < 100; j++) {
+  //       B[0ll,j] = A[(int64_t)0, j] + j * 100ll;
+  //     }
+  BufHandle a_buf("A", {20, 100}, kLong);
+  BufHandle b_buf("B", {20, 100}, kLong);
+  VarHandle i("i", kLong);
+  VarHandle j("j", kLong);
+  auto forI = For::make(
+      i,
+      0,
+      100,
+      Store::make(
+          a_buf,
+          {static_cast<int64_t>(0), i},
+          Mul::make(i, static_cast<int64_t>(500))));
+  auto forJ = For::make(
+      j,
+      0,
+      100,
+      Store::make(
+          b_buf,
+          {static_cast<int64_t>(0), j},
+          Add::make(
+              Load::make(a_buf, {0, j}),
+              Mul::make(j, static_cast<int64_t>(100)))));
+  auto par = Block::make({forI, forJ});
+
+  LoopNest l(par, {b_buf.node()});
+  l.computeInline(a_buf.node());
+
+  checkIR(l.root_stmt(), R"IR(
+    # CHECK: for (int64_t j = 0; j < 100; j++) {
+    # CHECK:   B[0ll, j] = j * 500ll + j * 100ll;
+    # CHECK: })IR");
+}
+
 TEST(LoopNest, ScheduleFuserStyle) {
   const int kVectorSize = 8;
   const int kVectorCount = 128;
index 4582433..6baa8ff 100644 (file)
@@ -1062,6 +1062,16 @@ void SimpleIREvaluator::bindVar(VarPtr v, ExprPtr e) {
 Value SimpleIREvaluator::value() const {
   return impl_->value();
 }
+
+c10::optional<int64_t> evalInt(ExprPtr e) {
+  try {
+    return ExprEval<SimpleIREvaluator>(cast<int64_t>(ExprHandle(e)))
+        .value<int64_t>();
+  } catch (std::runtime_error& err) {
+    return c10::nullopt;
+  }
+}
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
index e11bb16..9da1d38 100644 (file)
@@ -267,6 +267,10 @@ class ExprEval {
   Value ret_value_;
 };
 
+// Evaluates the given expression and returns an int64_t value if the result of
+// the given expression is int64_t.
+c10::optional<int64_t> evalInt(ExprPtr e);
+
 // Substitutes the given vars with their corresponding expressions in the input
 // expression.
 inline ExprPtr Substitute(ExprPtr expr, const VarMapping& var_mapping) {
index a957b29..0750b34 100644 (file)
@@ -573,14 +573,13 @@ class FunctionInliner : public IRMutator {
       VarPtr func_callee_arg = producer_index_vars_.at(i);
       ExprPtr func_caller_param = dims.at(i);
       if (func_callee_arg == nullptr) {
+        auto param_val = evalInt(func_caller_param);
         TORCH_INTERNAL_ASSERT(
-            intValue(func_caller_param) && *intValue(func_caller_param) == 0,
+            param_val && *param_val == 0,
             buildErrorMessage(
                 "We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0"));
         continue;
       }
-      if (func_callee_arg == nullptr)
-        continue;
       auto iter = inline_mapping_.find(func_callee_arg);
       if (iter != inline_mapping_.end()) {
         throw std::logic_error(