From 9e1f4ba1244c196d3e1d94c3142c17c4906fdda5 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Tue, 4 Dec 2018 14:41:03 -0800 Subject: [PATCH] Introduce LegacyTHDispatcher for dispatching to TH functions. (#14754) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14754 This isn't hooked up to anything yet, this is just putting the skeleton in place. The idea here is that the functions generated via Declarations.cwrap and nn.yaml are not actually operators, they are implementation details of operators, and thus don't need to participate in VariableType, JIT dispatch generation. So, we will split these functions out from the usual Type/operator hierarchy; for now the dispatch will be done by a Type-like class called LegacyTHDispatcher. Once this is done this probably means we can collapse Type to be backend-specific, not Type/ScalarType specific, because all the ScalarType specific code will live in the LegacyTHDispatcher. Reviewed By: ezyang Differential Revision: D13321605 fbshipit-source-id: 25d1bbc9827a42d6ab5d69aabbad3eac72bf364c --- aten/src/ATen/Context.cpp | 11 ++++ aten/src/ATen/Context.h | 13 ++++ aten/src/ATen/LegacyTHDispatch.cpp | 12 ++++ aten/src/ATen/LegacyTHDispatch.h | 91 ++++++++++++++++++++++++++++ aten/src/ATen/gen.py | 5 ++ aten/src/ATen/templates/LegacyTHDispatcher.h | 18 ++++++ 6 files changed, 150 insertions(+) create mode 100644 aten/src/ATen/LegacyTHDispatch.cpp create mode 100644 aten/src/ATen/LegacyTHDispatch.h create mode 100644 aten/src/ATen/templates/LegacyTHDispatcher.h diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 296dc23..c94336e 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -111,6 +111,17 @@ TypeExtendedInterface& getType(const Tensor& t) { return getType(t.unsafeGetTensorImpl()); } +LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options) { + return globalContext().getLegacyTHDispatcher( + options.backend(), typeMetaToScalarType(options.dtype())); +} + +LegacyTHDispatcher& getLegacyTHDispatcher(const TensorImpl* impl) { + Backend backend = tensorTypeIdToBackend(impl->type_id()); + return globalContext().getLegacyTHDispatcher( + backend, typeMetaToScalarType(impl->dtype())); +} + Allocator* getCPUAllocator() { return getTHDefaultAllocator(); } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 01e408f..5456a85 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -4,6 +4,8 @@ #include "ATen/Type.h" #include "ATen/TypeExtendedInterface.h" #include "ATen/Utils.h" +#include "ATen/LegacyTHDispatch.h" +#include "ATen/LegacyTHDispatcher.h" #include "ATen/core/ATenGeneral.h" #include "ATen/core/Generator.h" #include "ATen/core/LegacyTypeDispatch.h" @@ -39,6 +41,9 @@ class CAFFE2_API Context { TypeExtendedInterface & getType(Backend p, ScalarType s, bool is_variable) { return static_cast(globalLegacyTypeDispatch().getType(p, s, is_variable)); } + LegacyTHDispatcher& getLegacyTHDispatcher(Backend p, ScalarType s) { + return globalLegacyTHDispatch().getLegacyTHDispatcher(p, s); + } // The passed in Type must be delete'able // TODO: Just make it take a unique_ptr void registerType(Backend b, ScalarType s, Type* t) { @@ -46,6 +51,11 @@ class CAFFE2_API Context { LegacyTypeDispatch::TypeUniquePtr{t, LegacyTypeDeleter([](Type* p) { delete p; }) }); } + void registerLegacyTHDispatcher(Backend b, ScalarType s, LegacyTHDispatcher* t) { + globalLegacyTHDispatch().registerDispatcher(b, s, + LegacyTHDispatch::LegacyTHDispatcherUniquePtr{t, LegacyTHDispatcherDeleter([](LegacyTHDispatcher* p) { delete p; }) }); + } + Generator & defaultGenerator(DeviceType device_type) { initCUDAIfNeeded(device_type); initHIPIfNeeded(device_type); @@ -182,6 +192,9 @@ static inline TypeExtendedInterface& HIP(ScalarType s) { return getNonVariableType(Backend::HIP, s); } +CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options); +CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(const Tensor&); + static inline bool hasCUDA() { return globalContext().hasCUDA(); } diff --git a/aten/src/ATen/LegacyTHDispatch.cpp b/aten/src/ATen/LegacyTHDispatch.cpp new file mode 100644 index 0000000..6223874 --- /dev/null +++ b/aten/src/ATen/LegacyTHDispatch.cpp @@ -0,0 +1,12 @@ +#include + +namespace at { + +// TODO: This could be bad juju if someone calls globalContext() in the +// destructor of an object with static lifetime. +LegacyTHDispatch & globalLegacyTHDispatch() { + static LegacyTHDispatch singleton; + return singleton; +} + +} diff --git a/aten/src/ATen/LegacyTHDispatch.h b/aten/src/ATen/LegacyTHDispatch.h new file mode 100644 index 0000000..b21184b --- /dev/null +++ b/aten/src/ATen/LegacyTHDispatch.h @@ -0,0 +1,91 @@ +#pragma once + +// LegacyTHDispatcher is the legacy mechanism for dispatching directly +// to TH/THNN/THC/THCUNN functions in ATen, which is essentially a giant virtual +// dispatch table for every TH function we support dynamically dispatching over. +// +// NB: We do not actually dispatch to *operators* here, the usual pattern is for +// ATen operators to call this mechanism for their implementation, but the +// operator itself is declared separately (e.g. as a native function "wrapper"). +// +// Q: Why don't we just use LegacyTypeDispatch here? +// A: Mainly separation of concerns: +// 1) Type is for implementation of operators, which requires codegen of +// Variables, JIT, etc. That is handled by the native function "wrappers"; +// just calling into TH does not require that. +// 2) Type does not require scalar-specific dispatch, whereas calling into TH +// does. Thus, this separation allows us to evolve operator dispatch +// separately (i.e. to use the C10 dispatcher) from details of how to +// call TH functionality. +// +// The implmentation here is very similar to the LegacyTypeDispatch design, with +// the following simplications: +// 1) This is not required for a mobile build, so does not have to live in /core. +// 2) Because these only contain function implementations, we do not have to +// handle the Variable/Tensor split; that is handled at the native function +// "wrapper" level. +// 3) Because an operator must have been previously dispatched via the Type +// mechanism, we do need to handle device initialization. This means it is +// WRONG to call directly into these functions without first going through +// Type dispatch (i.e. the usual operator -> Type -> LegacyTHDispatch pattern). +// 4) Because an operator must have been previously dispatched via the Type +// mechanism, we do not need to handle undefined Tensors. +// +// NB: We don't use Registry for this, because we don't want to +// pay for a hash table lookup every time we do an operation. +// +// NB: we can delete this when we don't call into any TH implementations. + +#include +#include +#include + +namespace at { + +struct Type; + +struct CAFFE2_API LegacyTHDispatcherDeleter { + using LegacyTHDispatcherDeleterFun = void(LegacyTHDispatcher*); + LegacyTHDispatcherDeleterFun *fn_ = nullptr; + LegacyTHDispatcherDeleter() {} + /* implicit */ LegacyTHDispatcherDeleter(LegacyTHDispatcherDeleterFun *fn) : fn_(fn) {} + void operator()(LegacyTHDispatcher * ptr) { + if (fn_) { + (*fn_)(ptr); + } + } +}; + +class CAFFE2_API LegacyTHDispatch { + public: + using LegacyTHDispatcherUniquePtr = std::unique_ptr; + // WARNING: This function has the precondition that you have + // initialized the type you want to call. This initialization + // step is generally done by Context, or assumed because you + // have a Tensor and thus the Type of that Tensor must already + // be initialized. + + void registerDispatcher(Backend b, ScalarType s, LegacyTHDispatcherUniquePtr&& t) { + dispatcher_registry[static_cast(b)][static_cast(s)] = std::move(t); + } + + LegacyTHDispatcher* getLegacyTHDispatcherRaw(Backend p, ScalarType s) { + return dispatcher_registry[static_cast(p)][static_cast(s)].get(); + } + + LegacyTHDispatcher & getLegacyTHDispatcher(Backend p, ScalarType s) { + auto* type = getLegacyTHDispatcherRaw(p, s); + if (!type) AT_ERROR(toString(p), toString(s), "THDispatcher is not enabled."); + return *type; + } +private: + // NB: dispatcher_registry has nullptr for all CUDA backends until + // CUDA initialization has occurred + LegacyTHDispatcherUniquePtr dispatcher_registry + [static_cast(Backend::NumOptions)] + [static_cast(ScalarType::NumOptions)]; +}; + +CAFFE2_API LegacyTHDispatch& globalLegacyTHDispatch(); + +} // namespace at diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index 2898d2e..9f463f6 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -119,6 +119,8 @@ TYPE_EXTENDED_INTERFACE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtende TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h") TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp") +LEGACY_TH_DISPATCHER_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.h") + REGISTER_CPU_H = CodeTemplate.from_file(TEMPLATE_PATH + "/RegisterCPU.h") REGISTER_CPU_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/RegisterCPU.cpp") @@ -344,6 +346,7 @@ def declare_outputs(): for f in core_files: core_file_manager.will_write(f) files = ['Declarations.yaml', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h', + 'LegacyTHDispatcher.h', 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h'] for f in files: file_manager.will_write(f) @@ -423,6 +426,8 @@ def generate_outputs(): file_manager.write('TypeDefault.h', TYPE_DEFAULT_H, top_env) file_manager.write('TypeDefault.cpp', TYPE_DEFAULT_CPP, top_env) + file_manager.write('LegacyTHDispatcher.h', LEGACY_TH_DISPATCHER_H, top_env) + file_manager.write('RegisterCPU.h', REGISTER_CPU_H, top_env) file_manager.write('RegisterCPU.cpp', REGISTER_CPU_CPP, top_env) diff --git a/aten/src/ATen/templates/LegacyTHDispatcher.h b/aten/src/ATen/templates/LegacyTHDispatcher.h new file mode 100644 index 0000000..73f1f49 --- /dev/null +++ b/aten/src/ATen/templates/LegacyTHDispatcher.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +namespace at { + +struct CAFFE2_API LegacyTHDispatcher { + explicit LegacyTHDispatcher(TensorTypeId type_id, bool is_undefined) + : type_id_(type_id) {} + + virtual ~LegacyTHDispatcher() {} + +protected: + TensorTypeId type_id_; +}; + +} // namespace th + -- 2.7.4