# CHECK: y[m2, n2, k2] = (k2 * m2) * n2 + m2;)IR");
}
+TEST(LoopNest, ScheduleInlineBufferIndicesWithCast) {
+ // Input IR:
+ // for (int64_t i = 0; i < 100; i++) {
+ // A[0ll,i] = i * 500ll;
+ // }
+ // for (int64_t j = 0; j < 100; j++) {
+ // B[0ll,j] = A[(int64_t)0, j] + j * 100ll;
+ // }
+ BufHandle a_buf("A", {20, 100}, kLong);
+ BufHandle b_buf("B", {20, 100}, kLong);
+ VarHandle i("i", kLong);
+ VarHandle j("j", kLong);
+ auto forI = For::make(
+ i,
+ 0,
+ 100,
+ Store::make(
+ a_buf,
+ {static_cast<int64_t>(0), i},
+ Mul::make(i, static_cast<int64_t>(500))));
+ auto forJ = For::make(
+ j,
+ 0,
+ 100,
+ Store::make(
+ b_buf,
+ {static_cast<int64_t>(0), j},
+ Add::make(
+ Load::make(a_buf, {0, j}),
+ Mul::make(j, static_cast<int64_t>(100)))));
+ auto par = Block::make({forI, forJ});
+
+ LoopNest l(par, {b_buf.node()});
+ l.computeInline(a_buf.node());
+
+ checkIR(l.root_stmt(), R"IR(
+ # CHECK: for (int64_t j = 0; j < 100; j++) {
+ # CHECK: B[0ll, j] = j * 500ll + j * 100ll;
+ # CHECK: })IR");
+}
+
TEST(LoopNest, ScheduleFuserStyle) {
const int kVectorSize = 8;
const int kVectorCount = 128;
VarPtr func_callee_arg = producer_index_vars_.at(i);
ExprPtr func_caller_param = dims.at(i);
if (func_callee_arg == nullptr) {
+ auto param_val = evalInt(func_caller_param);
TORCH_INTERNAL_ASSERT(
- intValue(func_caller_param) && *intValue(func_caller_param) == 0,
+ param_val && *param_val == 0,
buildErrorMessage(
"We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0"));
continue;
}
- if (func_callee_arg == nullptr)
- continue;
auto iter = inline_mapping_.find(func_callee_arg);
if (iter != inline_mapping_.end()) {
throw std::logic_error(