return getTHDefaultAllocator();
}
-struct LegacyTypeInit : public LegacyTypeInitInterface {
- LegacyTypeInit(LegacyTypeInitArgs) {}
+struct LegacyDeviceTypeInit : public LegacyDeviceTypeInitInterface {
+ LegacyDeviceTypeInit(LegacyDeviceTypeInitArgs) {}
void initCPU() const override {
globalContext();
}
globalContext().lazyInitComplex();
}
};
-REGISTER_LEGACY_TYPE_INIT(LegacyTypeInit);
+REGISTER_LEGACY_TYPE_INIT(LegacyDeviceTypeInit);
}
--- /dev/null
+#include <ATen/core/LegacyDeviceTypeInit.h>
+
+namespace at {
+
+C10_DEFINE_REGISTRY(
+ LegacyDeviceTypeInitRegistry,
+ LegacyDeviceTypeInitInterface,
+ LegacyDeviceTypeInitArgs)
+
+const LegacyDeviceTypeInitInterface& getLegacyDeviceTypeInit() {
+ static std::unique_ptr<LegacyDeviceTypeInitInterface> legacy_device_type_init;
+ static std::once_flag once;
+ std::call_once(once, [] {
+ legacy_device_type_init = LegacyDeviceTypeInitRegistry()->Create("LegacyDeviceTypeInit", LegacyDeviceTypeInitArgs{});
+ if (!legacy_device_type_init) {
+ legacy_device_type_init =
+ std::unique_ptr<LegacyDeviceTypeInitInterface>(new LegacyDeviceTypeInitInterface());
+ }
+ });
+ return *legacy_device_type_init;
+}
+
+}
--- /dev/null
+#pragma once
+
+// The legacy mechanism for initializing device types; this is used by
+// both LegacyTypeDispatch and LegacyTHDispatch.
+
+#include <c10/DeviceType.h>
+#include <c10/macros/Macros.h>
+#include <c10/util/Registry.h>
+#include <ATen/core/ScalarType.h>
+
+namespace at {
+
+struct CAFFE2_API LegacyDeviceTypeInitInterface {
+ virtual ~LegacyDeviceTypeInitInterface() {}
+ virtual void initCPU() const {
+ AT_ERROR("cannot use CPU without ATen library");
+ }
+ virtual void initCUDA() const {
+ AT_ERROR("cannot use CUDA without ATen CUDA library");
+ }
+ virtual void initHIP() const {
+ AT_ERROR("cannot use HIP without ATen HIP library");
+ }
+ virtual void initComplex() const {
+ AT_ERROR("cannot use complex without ATen Complex library");
+ }
+};
+
+struct CAFFE2_API LegacyDeviceTypeInitArgs {};
+
+C10_DECLARE_REGISTRY(
+ LegacyDeviceTypeInitRegistry,
+ LegacyDeviceTypeInitInterface,
+ LegacyDeviceTypeInitArgs);
+#define REGISTER_LEGACY_TYPE_INIT(clsname) \
+ C10_REGISTER_CLASS(LegacyDeviceTypeInitRegistry, clsname, clsname)
+
+CAFFE2_API const LegacyDeviceTypeInitInterface& getLegacyDeviceTypeInit();
+
+} // namespace at
return singleton;
}
-C10_DEFINE_REGISTRY(
- LegacyTypeInitRegistry,
- LegacyTypeInitInterface,
- LegacyTypeInitArgs)
-
-const LegacyTypeInitInterface& getLegacyTypeInit() {
- static std::unique_ptr<LegacyTypeInitInterface> legacy_type_init;
- static std::once_flag once;
- std::call_once(once, [] {
- legacy_type_init = LegacyTypeInitRegistry()->Create("LegacyTypeInit", LegacyTypeInitArgs{});
- if (!legacy_type_init) {
- legacy_type_init =
- std::unique_ptr<LegacyTypeInitInterface>(new LegacyTypeInitInterface());
- }
- });
- return *legacy_type_init;
-}
-
}
#include <c10/core/ScalarType.h>
#include <ATen/core/VariableHooksInterface.h>
#include <c10/util/Exception.h>
+#include <ATen/core/LegacyDeviceTypeInit.h>
#include <ATen/core/TensorImpl.h>
namespace at {
-struct CAFFE2_API LegacyTypeInitInterface {
- virtual ~LegacyTypeInitInterface() {}
- virtual void initCPU() const {
- AT_ERROR("cannot use CPU without ATen library");
- }
- virtual void initCUDA() const {
- AT_ERROR("cannot use CUDA without ATen CUDA library");
- }
- virtual void initHIP() const {
- AT_ERROR("cannot use HIP without ATen HIP library");
- }
- virtual void initComplex() const {
- AT_ERROR("cannot use complex without ATen Complex library");
- }
-};
-struct CAFFE2_API LegacyTypeInitArgs {};
-C10_DECLARE_REGISTRY(
- LegacyTypeInitRegistry,
- LegacyTypeInitInterface,
- LegacyTypeInitArgs);
-#define REGISTER_LEGACY_TYPE_INIT(clsname) \
- C10_REGISTER_CLASS(LegacyTypeInitRegistry, clsname, clsname)
-
-CAFFE2_API const LegacyTypeInitInterface& getLegacyTypeInit();
-
struct Type;
struct CAFFE2_API LegacyTypeDeleter {
static std::once_flag cuda_once;
if (p == DeviceType::CPU) {
std::call_once(cpu_once, [] {
- getLegacyTypeInit().initCPU();
+ getLegacyDeviceTypeInit().initCPU();
});
} else if (p == DeviceType::CUDA) {
std::call_once(cuda_once, [] {
- getLegacyTypeInit().initCUDA();
+ getLegacyDeviceTypeInit().initCUDA();
});
} else if (p == DeviceType::HIP) {
std::call_once(cuda_once, [] {
- getLegacyTypeInit().initHIP();
+ getLegacyDeviceTypeInit().initHIP();
});
}
}
// Only complex may need initialization
if (isComplexType(s)) {
std::call_once(once, [] {
- getLegacyTypeInit().initComplex();
+ getLegacyDeviceTypeInit().initComplex();
});
}
}