}
}
+TEST_F(Kernel, ParallelStrided) {
+ KernelScope kernel_scope;
+
+ const auto graph_string = R"IR(
+ graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
+ %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
+ %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
+ %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
+ return (%3))IR";
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+
+ auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
+ .index(
+ {Slice(None, None, 2),
+ Slice(None, None, 2),
+ Slice(None, None, 2)});
+ auto ref = a * (a * b);
+ auto o = at::zeros_like(ref);
+ TensorExprKernel k(graph);
+ std::vector<at::Tensor> inputs = {a, b};
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ o = stack[0].toTensor();
+ for (size_t i = 0; i < 5 * 3; i++) {
+ CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+ }
+}
+
TEST_F(Kernel, DISABLED_Shape_Inference) {
// disabled: doesn't do stride propagation, and isn't being used currently
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <ATen/ExpandUtils.h>
+#include <ATen/Parallel.h>
#include <ATen/TensorGeometry.h>
#include <c10/util/irange.h>
#include <c10/util/string_utils.h>
}
}
+// Compute the trip count of a loop if it is a constant.
+c10::optional<int64_t> tripCount(ForPtr loop) {
+ auto tc = IRSimplifier::simplify(
+ cast<int64_t>(ExprHandle(loop->stop()) - ExprHandle(loop->start())));
+ if (auto val = to<LongImm>(tc.node())) {
+ return val->value();
+ }
+ return c10::nullopt;
+}
+
+// Prune innermost loops until iterations satisfies a minimum grain size.
+static void pruneByGrainSize(std::vector<ForPtr>& loops) {
+ constexpr int64_t minGrainSize = 32768;
+ int64_t grainSize = 1;
+ for (int64_t i = loops.size(); i > 0; i--) {
+ auto tc = tripCount(loops[i - 1]);
+ if (!tc) {
+ break;
+ }
+ grainSize *= *tc;
+ if (grainSize < minGrainSize) {
+ loops.pop_back();
+ }
+ }
+}
+
+// Retain enough outermost loops to fill the number of threads.
+static void pruneByThreadCount(std::vector<ForPtr>& loops) {
+ int64_t trips = 1;
+ auto threads = at::get_num_threads();
+ auto it = loops.begin();
+ for (; it != loops.end(); it++) {
+ if (trips >= threads) {
+ break;
+ }
+ auto tc = tripCount(*it);
+ if (!tc) {
+ break;
+ }
+ trips *= *tc;
+ }
+ loops.erase(it, loops.end());
+}
+
+// Flatten and parallelize outer loops, subject to a minimum number of elements
+// in the inner loop, and a maximum level of thread-level parallelism in the
+// outer loops.
+template <typename Bufs>
+static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
+ for (auto const& buf : bufs) {
+ auto loops = l.getLoopStmtsFor(buf);
+ pruneByGrainSize(loops);
+ pruneByThreadCount(loops);
+
+ // There are no loops to parallelize; give up.
+ if (loops.size() == 0) {
+ continue;
+ }
+ // The loop nest contains a reduction; give up.
+ auto reductions = NodeFinder<ReduceOp>::find(loops[0]);
+ if (reductions.size() > 0) {
+ continue;
+ }
+ // The loop nest has loop carried dependences; give up.
+ if (LoopNest::hasLoopCarriedDependence(loops[0])) {
+ continue;
+ }
+ // Try to flatten the outer loops and parallelize them if successful.
+ ForPtr flattened = nullptr;
+ if (loops.size() == 1) {
+ flattened = loops[0];
+ } else {
+ LoopNest::flatten(loops, &flattened);
+ }
+ if (flattened) {
+ flattened->set_parallel();
+ }
+ }
+}
+
StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
torch::jit::tensorexpr::LoopNest l(st, bufOutputs_);
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
if (backendType == kLLVMCodeGen) {
fuseAllLoops(l.root_stmt());
GRAPH_DEBUG("after fuse", *l.root_stmt());
+ parallelizeOuterLoops(l, bufOutputs_);
+ GRAPH_DEBUG("after parallelize", *l.root_stmt());
}
if (backendType == kCudaCodeGen) {
}
l.prepareForCodegen();
+ GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
+ l.simplify();
+ GRAPH_DEBUG("after simplification", *l.root_stmt());
if (backendType == kLLVMCodeGen && !hasReduction) {
l.vectorizeInnerLoops();
+ GRAPH_DEBUG("after vectorization", *l.root_stmt());
}
StmtPtr stmt = l.root_stmt();
}
};
+extern "C" {
typedef void (*ParallelCallee)(int index, int8_t* packed_data);
-void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) {
+void DispatchParallel(
+ int8_t* func,
+ int start,
+ int stop,
+ int8_t* packed_data) noexcept {
// TODO: preserve the func type.
- ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
- at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
- for (int index = f_begin; index < f_end; index++) {
- callee(index, packed_data);
- }
- });
+ try {
+ ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
+ at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
+ for (int index = f_begin; index < f_end; index++) {
+ callee(index, packed_data);
+ }
+ });
+ } catch (...) {
+ }
+}
}
} // namespace tensorexpr
module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype);
llvm::Function* dispatcher =
llvm::cast<llvm::Function>(dispatcher_callee.getCallee());
+ dispatcher->addFnAttr(llvm::Attribute::NoUnwind);
irb_.CreateCall(
dispatcher, {func_value, start, stop, packed_caller_args_ptr});
value_ = llvm::ConstantInt::get(IntTy_, 0);
namespace jit {
namespace tensorexpr {
-void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data);
+extern "C" {
+void DispatchParallel(
+ int8_t* func,
+ int start,
+ int stop,
+ int8_t* packed_data) noexcept;
+}
inline std::string formatError(llvm::Error&& err, const char* msg) {
static constexpr char* defaultErrorMsg = "Unexpected failure in LLVM JIT";
});
}
+ ExprPtr mutate(ModPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
+ return try_vectorize(v, inputs, [&]() {
+ return ExprHandle(inputs[0]) % ExprHandle(inputs[1]);
+ });
+ }
+
ExprPtr mutate(AndPtr v) override {
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {