Initialize intra-op threads in JIT thread pool (#19058)
authorIlia Cherniavskii <iliacher@fb.com>
Wed, 17 Apr 2019 01:21:36 +0000 (18:21 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 17 Apr 2019 01:27:22 +0000 (18:27 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19058
ghimport-source-id: 53e87df8d93459259854a17d4de3348e463622dc

Differential Revision: D14849624

Pulled By: ilia-cher

fbshipit-source-id: 5043a1d4330e38857c8e04c547526a3ba5b30fa9

aten/src/ATen/Parallel.cpp
aten/src/ATen/Parallel.h
c10/core/thread_pool.cpp

index 3345e09..b79bb99 100644 (file)
@@ -60,4 +60,36 @@ size_t get_num_threads() {
 #endif
 }
 
+PTThreadPool::PTThreadPool(
+    std::size_t pool_size,
+    int numa_node_id)
+    : c10::ThreadPool(pool_size, numa_node_id) {}
+
+void PTThreadPool::init_thread() {
+  c10::setThreadName("PTThreadPool");
+  at::init_num_threads();
+}
+
+namespace {
+
+std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
+    int device_id,
+    int pool_size,
+    bool create_new) {
+  static std::shared_ptr<TaskThreadPoolBase> pool =
+      std::make_shared<PTThreadPool>(pool_size);
+  // For now, the only accepted device id is 0
+  // for the JIT inter-op pool (CPU),
+  AT_ASSERT(device_id == 0);
+  // we use the shared thread pool
+  AT_ASSERT(!create_new);
+  // and the size does not change
+  AT_ASSERT(pool->size() == pool_size);
+  return pool;
 }
+
+} // namespace
+
+C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
+
+} // namespace at
index 0d00a9e..d7a3b89 100644 (file)
@@ -1,5 +1,7 @@
 #pragma once
 #include <ATen/ATen.h>
+#include <c10/core/thread_pool.h>
+
 #include <atomic>
 #include <cstddef>
 #include <exception>
@@ -23,13 +25,13 @@ inline int64_t divup(int64_t x, int64_t y) {
 }
 
 // Called during new thread initialization
-C10_API void init_num_threads();
+CAFFE2_API void init_num_threads();
 
 // Sets the number of threads to be used in parallel region
-C10_API void set_num_threads(size_t);
+CAFFE2_API void set_num_threads(size_t);
 
 // Returns the number of threads used in parallel region
-C10_API size_t get_num_threads();
+CAFFE2_API size_t get_num_threads();
 
 // Returns the current thread number (starting from 0)
 // in the current parallel region, or 0 in the sequential region
@@ -141,4 +143,13 @@ inline scalar_t parallel_reduce(
   }
 }
 
+class CAFFE2_API PTThreadPool : public c10::ThreadPool {
+ public:
+  explicit PTThreadPool(
+      std::size_t pool_size,
+      int numa_node_id = -1);
+
+  void init_thread() override;
+};
+
 } // namespace at
index d1afaa1..cc13566 100644 (file)
@@ -140,20 +140,4 @@ C10_DEFINE_SHARED_REGISTRY(
     int,
     int,
     bool);
-
-namespace {
-
-std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
-    int device_id,
-    int pool_size,
-    bool create_new) {
-  static std::shared_ptr<TaskThreadPoolBase> pool =
-      std::make_shared<ThreadPool>(pool_size);
-  return pool;
-}
-
-} // namespace
-
-C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
-
 } // namespace c10