}
}
+TEST_F(Kernel, ParallelStrided) {
+ const auto graph_string = R"IR(
+ graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
+ %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
+ %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
+ %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
+ return (%3))IR";
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+
+ auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
+ .index(
+ {Slice(None, None, 2),
+ Slice(None, None, 2),
+ Slice(None, None, 2)});
+ auto ref = a * (a * b);
+ auto o = at::zeros_like(ref);
+ TensorExprKernel k(graph);
+ std::vector<at::Tensor> inputs = {a, b};
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ o = stack[0].toTensor();
+ for (size_t i = 0; i < 5 * 3; i++) {
+ CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+ }
+}
+
TEST_F(Kernel, DISABLED_Shape_Inference) {
// disabled: doesn't do stride propagation, and isn't being used currently
using namespace torch::jit::tensorexpr;
struct WithCPUFuser {
- WithCPUFuser(bool val = true)
- : cpuFuserEnabled(canFuseOnCPU()), parallel(texprParallelCPUEnabled()) {
+ WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
overrideCanFuseOnCPU(val);
- setTexprParallelCPUEnabled(true);
}
~WithCPUFuser() {
overrideCanFuseOnCPU(cpuFuserEnabled);
- setTexprParallelCPUEnabled(parallel);
}
bool cpuFuserEnabled;
- bool parallel;
};
TEST(TEFuserPass, FuserPass_1) {
torch._C._debug_set_fusion_group_inlining(False)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
def tearDown(self):
torch._C._jit_set_profiling_executor(self.prev_exec)
torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def test_tensor_type_not_determined_by_inputs(self):
@torch.jit.script
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- # TODO: CPU fuser currently is disabled when multithreading.
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
-
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
self.int_dtypes = [
torch.int8,
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def assertLastGraphAllFused(self):
self.assertAllFused(torch.jit.last_executed_optimized_graph())
torch._C._debug_set_fusion_group_inlining(False)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- # TODO: CPU fuser currently is disabled when multithreading.
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def assertLastGraphAllFused(self):
self.assertAllFused(torch.jit.last_executed_optimized_graph())
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
-#include <ATen/Parallel.h>
#include <ATen/core/interned_strings.h>
#include <ATen/record_function.h>
#include <c10/util/FunctionRef.h>
} // namespace tensorexpr
static bool texpr_fuser_enabled_ = true;
-static bool texpr_parallel_cpu_enabled = false;
-
-bool texprParallelCPUEnabled() {
- return texpr_parallel_cpu_enabled;
-}
-
-void setTexprParallelCPUEnabled(bool val) {
- texpr_parallel_cpu_enabled = val;
-}
void setTensorExprFuserEnabled(bool val) {
texpr_fuser_enabled_ = val;
return false;
}
if (device->is_cpu()) {
- // CPU fusion is only supported for single-thread.
- if (!canFuseOnCPU()) {
- return false;
- }
- if (at::get_num_threads() == 1 || texprParallelCPUEnabled()) {
- return true;
- }
- return false;
+ return canFuseOnCPU();
} else if (device->is_cuda()) {
return canFuseOnGPU();
} else if (device->is_xpu()) {
TORCH_API bool tensorExprFuserEnabled();
TORCH_API bool setTexprReductionsEnabled(bool value);
TORCH_API bool texprReductionsEnabled();
-TORCH_API bool texprParallelCPUEnabled();
-TORCH_API void setTexprParallelCPUEnabled(bool val);
TORCH_API void RemoveProfileNodesAndSpecializeTypes(
std::shared_ptr<Graph>& graph);
.def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
.def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
.def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
- .def("_jit_set_texpr_parallel_cpu_enabled", &setTexprParallelCPUEnabled)
- .def("_jit_texpr_parallel_cpu_enabled", &texprParallelCPUEnabled)
.def(
"_jit_set_te_generate_block_code",
[](bool gen_block_code) {
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <ATen/ExpandUtils.h>
+#include <ATen/Parallel.h>
#include <ATen/TensorGeometry.h>
#include <c10/util/irange.h>
#include <c10/util/string_utils.h>
}
}
+// Compute the trip count of a loop if it is a constant.
+c10::optional<int64_t> tripCount(ForPtr loop) {
+ auto tc = IRSimplifier::simplify(
+ cast<int64_t>(ExprHandle(loop->stop()) - ExprHandle(loop->start())));
+ if (auto val = to<LongImm>(tc.node())) {
+ return val->value();
+ }
+ return c10::nullopt;
+}
+
+// Prune innermost loops until iterations satisfies a minimum grain size.
+static void pruneByGrainSize(std::vector<ForPtr>& loops) {
+ constexpr int64_t minGrainSize = 32768;
+ int64_t grainSize = 1;
+ for (int64_t i = loops.size(); i > 0; i--) {
+ auto tc = tripCount(loops[i - 1]);
+ if (!tc) {
+ break;
+ }
+ grainSize *= *tc;
+ if (grainSize < minGrainSize) {
+ loops.pop_back();
+ }
+ }
+}
+
+// Retain enough outermost loops to fill the number of threads.
+static void pruneByThreadCount(std::vector<ForPtr>& loops) {
+ int64_t trips = 1;
+ auto threads = at::get_num_threads();
+ auto it = loops.begin();
+ for (; it != loops.end(); it++) {
+ if (trips >= threads) {
+ break;
+ }
+ auto tc = tripCount(*it);
+ if (!tc) {
+ break;
+ }
+ trips *= *tc;
+ }
+ loops.erase(it, loops.end());
+}
+
+// Flatten and parallelize outer loops, subject to a minimum number of elements
+// in the inner loop, and a maximum level of thread-level parallelism in the
+// outer loops.
+template <typename Bufs>
+static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
+ for (auto const& buf : bufs) {
+ auto loops = l.getLoopStmtsFor(buf);
+ pruneByGrainSize(loops);
+ pruneByThreadCount(loops);
+
+ // There are no loops to parallelize; give up.
+ if (loops.size() == 0) {
+ continue;
+ }
+ // The loop nest contains a reduction; give up.
+ auto reductions = NodeFinder<ReduceOp>::find(loops[0]);
+ if (reductions.size() > 0) {
+ continue;
+ }
+ // The loop nest has loop carried dependences; give up.
+ if (LoopNest::hasLoopCarriedDependence(loops[0])) {
+ continue;
+ }
+ // Try to flatten the outer loops and parallelize them if successful.
+ ForPtr flattened = nullptr;
+ if (loops.size() == 1) {
+ flattened = loops[0];
+ } else {
+ LoopNest::flatten(loops, &flattened);
+ }
+ if (flattened) {
+ flattened->set_parallel();
+ }
+ }
+}
+
StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
torch::jit::tensorexpr::LoopNest l(st, bufOutputs_);
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
if (backendType == kLLVMCodeGen) {
fuseAllLoops(l.root_stmt());
GRAPH_DEBUG("after fuse", *l.root_stmt());
+ parallelizeOuterLoops(l, bufOutputs_);
+ GRAPH_DEBUG("after parallelize", *l.root_stmt());
}
if (backendType == kCudaCodeGen) {
}
l.prepareForCodegen();
+ GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
+ l.simplify();
+ GRAPH_DEBUG("after simplification", *l.root_stmt());
if (backendType == kLLVMCodeGen && !hasReduction) {
l.vectorizeInnerLoops();
+ GRAPH_DEBUG("after vectorization", *l.root_stmt());
}
StmtPtr stmt = l.root_stmt();
}
};
+extern "C" {
typedef void (*ParallelCallee)(int index, int8_t* packed_data);
-void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) {
+void DispatchParallel(
+ int8_t* func,
+ int start,
+ int stop,
+ int8_t* packed_data) noexcept {
// TODO: preserve the func type.
- ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
- at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
- for (int index = f_begin; index < f_end; index++) {
- callee(index, packed_data);
- }
- });
+ try {
+ ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
+ at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
+ for (int index = f_begin; index < f_end; index++) {
+ callee(index, packed_data);
+ }
+ });
+ } catch (...) {
+ }
+}
}
} // namespace tensorexpr
module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype);
llvm::Function* dispatcher =
llvm::cast<llvm::Function>(dispatcher_callee.getCallee());
+ dispatcher->addFnAttr(llvm::Attribute::NoUnwind);
irb_.CreateCall(
dispatcher, {func_value, start, stop, packed_caller_args_ptr});
value_ = llvm::ConstantInt::get(IntTy_, 0);
namespace jit {
namespace tensorexpr {
-void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data);
+extern "C" {
+void DispatchParallel(
+ int8_t* func,
+ int start,
+ int stop,
+ int8_t* packed_data) noexcept;
+}
inline std::string formatError(llvm::Error&& err, const char* msg) {
static constexpr char* defaultErrorMsg = "Unexpected failure in LLVM JIT";
});
}
+ ExprPtr mutate(ModPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
+ return try_vectorize(v, inputs, [&]() {
+ return ExprHandle(inputs[0]) % ExprHandle(inputs[1]);
+ });
+ }
+
ExprPtr mutate(AndPtr v) override {
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {