using namespace torch::jit::tensorexpr;
struct WithCPUFuser {
- WithCPUFuser(bool val = true)
- : cpuFuserEnabled(canFuseOnCPU()), parallel(texprParallelCPUEnabled()) {
+ WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
overrideCanFuseOnCPU(val);
- setTexprParallelCPUEnabled(true);
}
~WithCPUFuser() {
overrideCanFuseOnCPU(cpuFuserEnabled);
- setTexprParallelCPUEnabled(parallel);
}
bool cpuFuserEnabled;
- bool parallel;
};
TEST(TEFuserPass, FuserPass_1) {
torch._C._debug_set_fusion_group_inlining(False)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
def tearDown(self):
torch._C._jit_set_profiling_executor(self.prev_exec)
torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def test_tensor_type_not_determined_by_inputs(self):
@torch.jit.script
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- # TODO: CPU fuser currently is disabled when multithreading.
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
-
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
self.int_dtypes = [
torch.int8,
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def assertLastGraphAllFused(self):
self.assertAllFused(torch.jit.last_executed_optimized_graph())
torch._C._debug_set_fusion_group_inlining(False)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- # TODO: CPU fuser currently is disabled when multithreading.
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def assertLastGraphAllFused(self):
self.assertAllFused(torch.jit.last_executed_optimized_graph())
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
-#include <ATen/Parallel.h>
#include <ATen/core/interned_strings.h>
#include <ATen/record_function.h>
#include <c10/util/FunctionRef.h>
} // namespace tensorexpr
static bool texpr_fuser_enabled_ = true;
-static bool texpr_parallel_cpu_enabled = false;
-
-bool texprParallelCPUEnabled() {
- return texpr_parallel_cpu_enabled;
-}
-
-void setTexprParallelCPUEnabled(bool val) {
- texpr_parallel_cpu_enabled = val;
-}
void setTensorExprFuserEnabled(bool val) {
texpr_fuser_enabled_ = val;
return false;
}
if (device->is_cpu()) {
- // CPU fusion is only supported for single-thread.
- if (!canFuseOnCPU()) {
- return false;
- }
- if (at::get_num_threads() == 1 || texprParallelCPUEnabled()) {
- return true;
- }
- return false;
+ return canFuseOnCPU();
} else if (device->is_cuda()) {
return canFuseOnGPU();
} else if (device->is_xpu()) {
TORCH_API bool tensorExprFuserEnabled();
TORCH_API bool setTexprReductionsEnabled(bool value);
TORCH_API bool texprReductionsEnabled();
-TORCH_API bool texprParallelCPUEnabled();
-TORCH_API void setTexprParallelCPUEnabled(bool val);
TORCH_API void RemoveProfileNodesAndSpecializeTypes(
std::shared_ptr<Graph>& graph);
.def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
.def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
.def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
- .def("_jit_set_texpr_parallel_cpu_enabled", &setTexprParallelCPUEnabled)
- .def("_jit_texpr_parallel_cpu_enabled", &texprParallelCPUEnabled)
.def(
"_jit_set_te_generate_block_code",
[](bool gen_block_code) {