return getType(t.unsafeGetTensorImpl());
}
+LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options) {
+ return globalContext().getLegacyTHDispatcher(
+ options.backend(), typeMetaToScalarType(options.dtype()));
+}
+
+LegacyTHDispatcher& getLegacyTHDispatcher(const TensorImpl* impl) {
+ Backend backend = tensorTypeIdToBackend(impl->type_id());
+ return globalContext().getLegacyTHDispatcher(
+ backend, typeMetaToScalarType(impl->dtype()));
+}
+
Allocator* getCPUAllocator() {
return getTHDefaultAllocator();
}
#include "ATen/Type.h"
#include "ATen/TypeExtendedInterface.h"
#include "ATen/Utils.h"
+#include "ATen/LegacyTHDispatch.h"
+#include "ATen/LegacyTHDispatcher.h"
#include "ATen/core/ATenGeneral.h"
#include "ATen/core/Generator.h"
#include "ATen/core/LegacyTypeDispatch.h"
TypeExtendedInterface & getType(Backend p, ScalarType s, bool is_variable) {
return static_cast<TypeExtendedInterface&>(globalLegacyTypeDispatch().getType(p, s, is_variable));
}
+ LegacyTHDispatcher& getLegacyTHDispatcher(Backend p, ScalarType s) {
+ return globalLegacyTHDispatch().getLegacyTHDispatcher(p, s);
+ }
// The passed in Type must be delete'able
// TODO: Just make it take a unique_ptr
void registerType(Backend b, ScalarType s, Type* t) {
LegacyTypeDispatch::TypeUniquePtr{t, LegacyTypeDeleter([](Type* p) { delete p; }) });
}
+ void registerLegacyTHDispatcher(Backend b, ScalarType s, LegacyTHDispatcher* t) {
+ globalLegacyTHDispatch().registerDispatcher(b, s,
+ LegacyTHDispatch::LegacyTHDispatcherUniquePtr{t, LegacyTHDispatcherDeleter([](LegacyTHDispatcher* p) { delete p; }) });
+ }
+
Generator & defaultGenerator(DeviceType device_type) {
initCUDAIfNeeded(device_type);
initHIPIfNeeded(device_type);
return getNonVariableType(Backend::HIP, s);
}
+CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options);
+CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(const Tensor&);
+
static inline bool hasCUDA() {
return globalContext().hasCUDA();
}
--- /dev/null
+#include <ATen/LegacyTHDispatch.h>
+
+namespace at {
+
+// TODO: This could be bad juju if someone calls globalContext() in the
+// destructor of an object with static lifetime.
+LegacyTHDispatch & globalLegacyTHDispatch() {
+ static LegacyTHDispatch singleton;
+ return singleton;
+}
+
+}
--- /dev/null
+#pragma once
+
+// LegacyTHDispatcher is the legacy mechanism for dispatching directly
+// to TH/THNN/THC/THCUNN functions in ATen, which is essentially a giant virtual
+// dispatch table for every TH function we support dynamically dispatching over.
+//
+// NB: We do not actually dispatch to *operators* here, the usual pattern is for
+// ATen operators to call this mechanism for their implementation, but the
+// operator itself is declared separately (e.g. as a native function "wrapper").
+//
+// Q: Why don't we just use LegacyTypeDispatch here?
+// A: Mainly separation of concerns:
+// 1) Type is for implementation of operators, which requires codegen of
+// Variables, JIT, etc. That is handled by the native function "wrappers";
+// just calling into TH does not require that.
+// 2) Type does not require scalar-specific dispatch, whereas calling into TH
+// does. Thus, this separation allows us to evolve operator dispatch
+// separately (i.e. to use the C10 dispatcher) from details of how to
+// call TH functionality.
+//
+// The implmentation here is very similar to the LegacyTypeDispatch design, with
+// the following simplications:
+// 1) This is not required for a mobile build, so does not have to live in /core.
+// 2) Because these only contain function implementations, we do not have to
+// handle the Variable/Tensor split; that is handled at the native function
+// "wrapper" level.
+// 3) Because an operator must have been previously dispatched via the Type
+// mechanism, we do need to handle device initialization. This means it is
+// WRONG to call directly into these functions without first going through
+// Type dispatch (i.e. the usual operator -> Type -> LegacyTHDispatch pattern).
+// 4) Because an operator must have been previously dispatched via the Type
+// mechanism, we do not need to handle undefined Tensors.
+//
+// NB: We don't use Registry for this, because we don't want to
+// pay for a hash table lookup every time we do an operation.
+//
+// NB: we can delete this when we don't call into any TH implementations.
+
+#include <c10/core/Backend.h>
+#include <c10/core/ScalarType.h>
+#include <ATen/LegacyTHDispatcher.h>
+
+namespace at {
+
+struct Type;
+
+struct CAFFE2_API LegacyTHDispatcherDeleter {
+ using LegacyTHDispatcherDeleterFun = void(LegacyTHDispatcher*);
+ LegacyTHDispatcherDeleterFun *fn_ = nullptr;
+ LegacyTHDispatcherDeleter() {}
+ /* implicit */ LegacyTHDispatcherDeleter(LegacyTHDispatcherDeleterFun *fn) : fn_(fn) {}
+ void operator()(LegacyTHDispatcher * ptr) {
+ if (fn_) {
+ (*fn_)(ptr);
+ }
+ }
+};
+
+class CAFFE2_API LegacyTHDispatch {
+ public:
+ using LegacyTHDispatcherUniquePtr = std::unique_ptr<LegacyTHDispatcher, LegacyTHDispatcherDeleter>;
+ // WARNING: This function has the precondition that you have
+ // initialized the type you want to call. This initialization
+ // step is generally done by Context, or assumed because you
+ // have a Tensor and thus the Type of that Tensor must already
+ // be initialized.
+
+ void registerDispatcher(Backend b, ScalarType s, LegacyTHDispatcherUniquePtr&& t) {
+ dispatcher_registry[static_cast<int>(b)][static_cast<int>(s)] = std::move(t);
+ }
+
+ LegacyTHDispatcher* getLegacyTHDispatcherRaw(Backend p, ScalarType s) {
+ return dispatcher_registry[static_cast<int>(p)][static_cast<int>(s)].get();
+ }
+
+ LegacyTHDispatcher & getLegacyTHDispatcher(Backend p, ScalarType s) {
+ auto* type = getLegacyTHDispatcherRaw(p, s);
+ if (!type) AT_ERROR(toString(p), toString(s), "THDispatcher is not enabled.");
+ return *type;
+ }
+private:
+ // NB: dispatcher_registry has nullptr for all CUDA backends until
+ // CUDA initialization has occurred
+ LegacyTHDispatcherUniquePtr dispatcher_registry
+ [static_cast<int>(Backend::NumOptions)]
+ [static_cast<int>(ScalarType::NumOptions)];
+};
+
+CAFFE2_API LegacyTHDispatch& globalLegacyTHDispatch();
+
+} // namespace at
TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h")
TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp")
+LEGACY_TH_DISPATCHER_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.h")
+
REGISTER_CPU_H = CodeTemplate.from_file(TEMPLATE_PATH + "/RegisterCPU.h")
REGISTER_CPU_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/RegisterCPU.cpp")
for f in core_files:
core_file_manager.will_write(f)
files = ['Declarations.yaml', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h',
+ 'LegacyTHDispatcher.h',
'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h']
for f in files:
file_manager.will_write(f)
file_manager.write('TypeDefault.h', TYPE_DEFAULT_H, top_env)
file_manager.write('TypeDefault.cpp', TYPE_DEFAULT_CPP, top_env)
+ file_manager.write('LegacyTHDispatcher.h', LEGACY_TH_DISPATCHER_H, top_env)
+
file_manager.write('RegisterCPU.h', REGISTER_CPU_H, top_env)
file_manager.write('RegisterCPU.cpp', REGISTER_CPU_CPP, top_env)
--- /dev/null
+#pragma once
+
+#include <c10/core/TensorTypeIdRegistration.h>
+
+namespace at {
+
+struct CAFFE2_API LegacyTHDispatcher {
+ explicit LegacyTHDispatcher(TensorTypeId type_id, bool is_undefined)
+ : type_id_(type_id) {}
+
+ virtual ~LegacyTHDispatcher() {}
+
+protected:
+ TensorTypeId type_id_;
+};
+
+} // namespace th
+