#endif
}
+PTThreadPool::PTThreadPool(
+ std::size_t pool_size,
+ int numa_node_id)
+ : c10::ThreadPool(pool_size, numa_node_id) {}
+
+void PTThreadPool::init_thread() {
+ c10::setThreadName("PTThreadPool");
+ at::init_num_threads();
+}
+
+namespace {
+
+std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
+ int device_id,
+ int pool_size,
+ bool create_new) {
+ static std::shared_ptr<TaskThreadPoolBase> pool =
+ std::make_shared<PTThreadPool>(pool_size);
+ // For now, the only accepted device id is 0
+ // for the JIT inter-op pool (CPU),
+ AT_ASSERT(device_id == 0);
+ // we use the shared thread pool
+ AT_ASSERT(!create_new);
+ // and the size does not change
+ AT_ASSERT(pool->size() == pool_size);
+ return pool;
}
+
+} // namespace
+
+C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
+
+} // namespace at
#pragma once
#include <ATen/ATen.h>
+#include <c10/core/thread_pool.h>
+
#include <atomic>
#include <cstddef>
#include <exception>
}
// Called during new thread initialization
-C10_API void init_num_threads();
+CAFFE2_API void init_num_threads();
// Sets the number of threads to be used in parallel region
-C10_API void set_num_threads(size_t);
+CAFFE2_API void set_num_threads(size_t);
// Returns the number of threads used in parallel region
-C10_API size_t get_num_threads();
+CAFFE2_API size_t get_num_threads();
// Returns the current thread number (starting from 0)
// in the current parallel region, or 0 in the sequential region
}
}
+class CAFFE2_API PTThreadPool : public c10::ThreadPool {
+ public:
+ explicit PTThreadPool(
+ std::size_t pool_size,
+ int numa_node_id = -1);
+
+ void init_thread() override;
+};
+
} // namespace at
int,
int,
bool);
-
-namespace {
-
-std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
- int device_id,
- int pool_size,
- bool create_new) {
- static std::shared_ptr<TaskThreadPoolBase> pool =
- std::make_shared<ThreadPool>(pool_size);
- return pool;
-}
-
-} // namespace
-
-C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
-
} // namespace c10