From 9f35185b565a6dcd4f749ecc531abce35f0c1f0a Mon Sep 17 00:00:00 2001 From: Ilia Cherniavskii Date: Tue, 16 Apr 2019 18:21:36 -0700 Subject: [PATCH] Initialize intra-op threads in JIT thread pool (#19058) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19058 ghimport-source-id: 53e87df8d93459259854a17d4de3348e463622dc Differential Revision: D14849624 Pulled By: ilia-cher fbshipit-source-id: 5043a1d4330e38857c8e04c547526a3ba5b30fa9 --- aten/src/ATen/Parallel.cpp | 32 ++++++++++++++++++++++++++++++++ aten/src/ATen/Parallel.h | 17 ++++++++++++++--- c10/core/thread_pool.cpp | 16 ---------------- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/Parallel.cpp b/aten/src/ATen/Parallel.cpp index 3345e09..b79bb99 100644 --- a/aten/src/ATen/Parallel.cpp +++ b/aten/src/ATen/Parallel.cpp @@ -60,4 +60,36 @@ size_t get_num_threads() { #endif } +PTThreadPool::PTThreadPool( + std::size_t pool_size, + int numa_node_id) + : c10::ThreadPool(pool_size, numa_node_id) {} + +void PTThreadPool::init_thread() { + c10::setThreadName("PTThreadPool"); + at::init_num_threads(); +} + +namespace { + +std::shared_ptr createC10ThreadPool( + int device_id, + int pool_size, + bool create_new) { + static std::shared_ptr pool = + std::make_shared(pool_size); + // For now, the only accepted device id is 0 + // for the JIT inter-op pool (CPU), + AT_ASSERT(device_id == 0); + // we use the shared thread pool + AT_ASSERT(!create_new); + // and the size does not change + AT_ASSERT(pool->size() == pool_size); + return pool; } + +} // namespace + +C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool); + +} // namespace at diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 0d00a9e..d7a3b89 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -1,5 +1,7 @@ #pragma once #include +#include + #include #include #include @@ -23,13 +25,13 @@ inline int64_t divup(int64_t x, int64_t y) { } // Called during new thread initialization -C10_API void init_num_threads(); +CAFFE2_API void init_num_threads(); // Sets the number of threads to be used in parallel region -C10_API void set_num_threads(size_t); +CAFFE2_API void set_num_threads(size_t); // Returns the number of threads used in parallel region -C10_API size_t get_num_threads(); +CAFFE2_API size_t get_num_threads(); // Returns the current thread number (starting from 0) // in the current parallel region, or 0 in the sequential region @@ -141,4 +143,13 @@ inline scalar_t parallel_reduce( } } +class CAFFE2_API PTThreadPool : public c10::ThreadPool { + public: + explicit PTThreadPool( + std::size_t pool_size, + int numa_node_id = -1); + + void init_thread() override; +}; + } // namespace at diff --git a/c10/core/thread_pool.cpp b/c10/core/thread_pool.cpp index d1afaa1..cc13566 100644 --- a/c10/core/thread_pool.cpp +++ b/c10/core/thread_pool.cpp @@ -140,20 +140,4 @@ C10_DEFINE_SHARED_REGISTRY( int, int, bool); - -namespace { - -std::shared_ptr createC10ThreadPool( - int device_id, - int pool_size, - bool create_new) { - static std::shared_ptr pool = - std::make_shared(pool_size); - return pool; -} - -} // namespace - -C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool); - } // namespace c10 -- 2.7.4