From: Raghavan Raman Date: Thu, 9 Sep 2021 15:26:16 +0000 (-0700) Subject: [nnc] Handled cast in index expression during inlining (#64716) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~337 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b7c86365d11f6f9ec280c636e6abfec63b33104d;p=platform%2Fupstream%2Fpytorch.git [nnc] Handled cast in index expression during inlining (#64716) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64716 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30826388 Pulled By: navahgar fbshipit-source-id: 7e446602f650527e0d954e437f0370602019e040 --- diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index b1d59a1..feae1ca 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -1663,6 +1663,47 @@ TEST(LoopNest, ScheduleInlineOutputTensors) { # CHECK: y[m2, n2, k2] = (k2 * m2) * n2 + m2;)IR"); } +TEST(LoopNest, ScheduleInlineBufferIndicesWithCast) { + // Input IR: + // for (int64_t i = 0; i < 100; i++) { + // A[0ll,i] = i * 500ll; + // } + // for (int64_t j = 0; j < 100; j++) { + // B[0ll,j] = A[(int64_t)0, j] + j * 100ll; + // } + BufHandle a_buf("A", {20, 100}, kLong); + BufHandle b_buf("B", {20, 100}, kLong); + VarHandle i("i", kLong); + VarHandle j("j", kLong); + auto forI = For::make( + i, + 0, + 100, + Store::make( + a_buf, + {static_cast(0), i}, + Mul::make(i, static_cast(500)))); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + b_buf, + {static_cast(0), j}, + Add::make( + Load::make(a_buf, {0, j}), + Mul::make(j, static_cast(100))))); + auto par = Block::make({forI, forJ}); + + LoopNest l(par, {b_buf.node()}); + l.computeInline(a_buf.node()); + + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int64_t j = 0; j < 100; j++) { + # CHECK: B[0ll, j] = j * 500ll + j * 100ll; + # CHECK: })IR"); +} + TEST(LoopNest, ScheduleFuserStyle) { const int kVectorSize = 8; const int kVectorCount = 128; diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index 4582433..6baa8ff 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -1062,6 +1062,16 @@ void SimpleIREvaluator::bindVar(VarPtr v, ExprPtr e) { Value SimpleIREvaluator::value() const { return impl_->value(); } + +c10::optional evalInt(ExprPtr e) { + try { + return ExprEval(cast(ExprHandle(e))) + .value(); + } catch (std::runtime_error& err) { + return c10::nullopt; + } +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index e11bb16..9da1d38 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -267,6 +267,10 @@ class ExprEval { Value ret_value_; }; +// Evaluates the given expression and returns an int64_t value if the result of +// the given expression is int64_t. +c10::optional evalInt(ExprPtr e); + // Substitutes the given vars with their corresponding expressions in the input // expression. inline ExprPtr Substitute(ExprPtr expr, const VarMapping& var_mapping) { diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index a957b29..0750b34 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -573,14 +573,13 @@ class FunctionInliner : public IRMutator { VarPtr func_callee_arg = producer_index_vars_.at(i); ExprPtr func_caller_param = dims.at(i); if (func_callee_arg == nullptr) { + auto param_val = evalInt(func_caller_param); TORCH_INTERNAL_ASSERT( - intValue(func_caller_param) && *intValue(func_caller_param) == 0, + param_val && *param_val == 0, buildErrorMessage( "We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0")); continue; } - if (func_callee_arg == nullptr) - continue; auto iter = inline_mapping_.find(func_callee_arg); if (iter != inline_mapping_.end()) { throw std::logic_error(