TensorExprKernel k(graph);
std::ostringstream oss;
oss << *k.getCodeGenStmt();
- const std::string& verification_pattern = "# CHECK: 4000000000";
+ // The 4000000000 iterations loop will be split into 500000000 x 8 and the
+ // outer loop will be parallel. If LLVM is not present, it will not be split,
+ // and to cover both of these cases we're looking for 00000000ll; in the
+ // output.
+ const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
const std::string& verification_pattern =
R"IR(
# CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat
# CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat
# CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat)IR";
+# CHECK: aten_cat
+# CHECK: for
+# CHECK: for
+# CHECK: aten_cat
+# CHECK: for
+# CHECK: for
+# CHECK: aten_cat)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
}
loopsToFuse.push_back(loop);
}
+ if (loopsToFuse.empty()) {
+ return;
+ }
if (!loopBoundsAllEqual(loopsToFuse)) {
return;
}
auto root_stmt = l.root_stmt();
root_stmt->accept(block_analysis.get());
}
+ l.simplify();
+ GRAPH_DEBUG("after simplify", *l.root_stmt());
// Inlining output & intermediate buffers can duplicate computation.
// Duplicating work can slow down the program if it's not ameliorated in some
// cur_idx = absolute // stride
// absolute = absolute % stride
+ auto zero = LongImm::make(0);
return Compute(
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
reverse_sort_indices(strides);
std::vector<ExprHandle> new_axes(sorted_stride_indices.size());
for (size_t stride_index : sorted_stride_indices) {
- auto stride = strides[stride_index];
auto size = sizes[stride_index];
- auto index = absolute_position /
- ExprHandle(immLike(absolute_position, stride));
+ auto index = zero;
if (size != 1) {
+ auto stride = strides[stride_index];
+ index = absolute_position /
+ ExprHandle(immLike(absolute_position, stride));
absolute_position = absolute_position %
ExprHandle(immLike(absolute_position, stride));
}
new_axes[stride_index] = index;
}
- // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return BufHandle(buf).load(new_axes);
});
}
bool LoopNest::computeInline(BufPtr b) {
// If buf is used or defined in an ExternalCall, we cannot inline it
auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
+ if (!buf_load_store_uses.count(b)) {
+ return false;
+ }
for (auto& use : buf_load_store_uses.at(b)) {
StmtPtr s = use.s;
if (to<ExternalCall>(s)) {