Split LegacyDeviceTypeInit from LegacyTypeDispatch. (#14723)
authorGregory Chanan <gchanan@fb.com>
Wed, 5 Dec 2018 01:48:25 +0000 (17:48 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 01:51:37 +0000 (17:51 -0800)
Summary:
The goal here is to have LegacyTHDispatch call into this as well, so LegacyTypeDispatch and LegacyTHDispatch don't have cross dependencies.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14723

Reviewed By: ezyang

Differential Revision: D13314017

Pulled By: gchanan

fbshipit-source-id: 8761cb4af2b2269d2e755203e073bfdba535b8c0

aten/src/ATen/Context.cpp
aten/src/ATen/core/LegacyDeviceTypeInit.cpp [new file with mode: 0644]
aten/src/ATen/core/LegacyDeviceTypeInit.h [new file with mode: 0644]
aten/src/ATen/core/LegacyTypeDispatch.cpp
aten/src/ATen/core/LegacyTypeDispatch.h

index c94336e..1d5be47 100644 (file)
@@ -126,8 +126,8 @@ Allocator* getCPUAllocator() {
   return getTHDefaultAllocator();
 }
 
-struct LegacyTypeInit : public LegacyTypeInitInterface {
-  LegacyTypeInit(LegacyTypeInitArgs) {}
+struct LegacyDeviceTypeInit : public LegacyDeviceTypeInitInterface {
+  LegacyDeviceTypeInit(LegacyDeviceTypeInitArgs) {}
   void initCPU() const override {
     globalContext();
   }
@@ -141,6 +141,6 @@ struct LegacyTypeInit : public LegacyTypeInitInterface {
     globalContext().lazyInitComplex();
   }
 };
-REGISTER_LEGACY_TYPE_INIT(LegacyTypeInit);
+REGISTER_LEGACY_TYPE_INIT(LegacyDeviceTypeInit);
 
 }
diff --git a/aten/src/ATen/core/LegacyDeviceTypeInit.cpp b/aten/src/ATen/core/LegacyDeviceTypeInit.cpp
new file mode 100644 (file)
index 0000000..37b7905
--- /dev/null
@@ -0,0 +1,23 @@
+#include <ATen/core/LegacyDeviceTypeInit.h>
+
+namespace at {
+
+C10_DEFINE_REGISTRY(
+    LegacyDeviceTypeInitRegistry,
+    LegacyDeviceTypeInitInterface,
+    LegacyDeviceTypeInitArgs)
+
+const LegacyDeviceTypeInitInterface& getLegacyDeviceTypeInit() {
+  static std::unique_ptr<LegacyDeviceTypeInitInterface> legacy_device_type_init;
+  static std::once_flag once;
+  std::call_once(once, [] {
+    legacy_device_type_init = LegacyDeviceTypeInitRegistry()->Create("LegacyDeviceTypeInit", LegacyDeviceTypeInitArgs{});
+    if (!legacy_device_type_init) {
+      legacy_device_type_init =
+          std::unique_ptr<LegacyDeviceTypeInitInterface>(new LegacyDeviceTypeInitInterface());
+    }
+  });
+  return *legacy_device_type_init;
+}
+
+}
diff --git a/aten/src/ATen/core/LegacyDeviceTypeInit.h b/aten/src/ATen/core/LegacyDeviceTypeInit.h
new file mode 100644 (file)
index 0000000..3fcc3bb
--- /dev/null
@@ -0,0 +1,40 @@
+#pragma once
+
+// The legacy mechanism for initializing device types; this is used by
+// both LegacyTypeDispatch and LegacyTHDispatch.
+
+#include <c10/DeviceType.h>
+#include <c10/macros/Macros.h>
+#include <c10/util/Registry.h>
+#include <ATen/core/ScalarType.h>
+
+namespace at {
+
+struct CAFFE2_API LegacyDeviceTypeInitInterface {
+  virtual ~LegacyDeviceTypeInitInterface() {}
+  virtual void initCPU() const {
+    AT_ERROR("cannot use CPU without ATen library");
+  }
+  virtual void initCUDA() const {
+    AT_ERROR("cannot use CUDA without ATen CUDA library");
+  }
+  virtual void initHIP() const {
+    AT_ERROR("cannot use HIP without ATen HIP library");
+  }
+  virtual void initComplex() const {
+    AT_ERROR("cannot use complex without ATen Complex library");
+  }
+};
+
+struct CAFFE2_API LegacyDeviceTypeInitArgs {};
+
+C10_DECLARE_REGISTRY(
+    LegacyDeviceTypeInitRegistry,
+    LegacyDeviceTypeInitInterface,
+    LegacyDeviceTypeInitArgs);
+#define REGISTER_LEGACY_TYPE_INIT(clsname) \
+  C10_REGISTER_CLASS(LegacyDeviceTypeInitRegistry, clsname, clsname)
+
+CAFFE2_API const LegacyDeviceTypeInitInterface& getLegacyDeviceTypeInit();
+
+} // namespace at
index 56c19cd..f20062d 100644 (file)
@@ -9,22 +9,4 @@ LegacyTypeDispatch & globalLegacyTypeDispatch() {
   return singleton;
 }
 
-C10_DEFINE_REGISTRY(
-    LegacyTypeInitRegistry,
-    LegacyTypeInitInterface,
-    LegacyTypeInitArgs)
-
-const LegacyTypeInitInterface& getLegacyTypeInit() {
-  static std::unique_ptr<LegacyTypeInitInterface> legacy_type_init;
-  static std::once_flag once;
-  std::call_once(once, [] {
-    legacy_type_init = LegacyTypeInitRegistry()->Create("LegacyTypeInit", LegacyTypeInitArgs{});
-    if (!legacy_type_init) {
-      legacy_type_init =
-          std::unique_ptr<LegacyTypeInitInterface>(new LegacyTypeInitInterface());
-    }
-  });
-  return *legacy_type_init;
-}
-
 }
index 53950f1..2b88cb7 100644 (file)
 #include <c10/core/ScalarType.h>
 #include <ATen/core/VariableHooksInterface.h>
 #include <c10/util/Exception.h>
+#include <ATen/core/LegacyDeviceTypeInit.h>
 #include <ATen/core/TensorImpl.h>
 
 namespace at {
 
-struct CAFFE2_API LegacyTypeInitInterface {
-  virtual ~LegacyTypeInitInterface() {}
-  virtual void initCPU() const {
-    AT_ERROR("cannot use CPU without ATen library");
-  }
-  virtual void initCUDA() const {
-    AT_ERROR("cannot use CUDA without ATen CUDA library");
-  }
-  virtual void initHIP() const {
-    AT_ERROR("cannot use HIP without ATen HIP library");
-  }
-  virtual void initComplex() const {
-    AT_ERROR("cannot use complex without ATen Complex library");
-  }
-};
-struct CAFFE2_API LegacyTypeInitArgs {};
-C10_DECLARE_REGISTRY(
-    LegacyTypeInitRegistry,
-    LegacyTypeInitInterface,
-    LegacyTypeInitArgs);
-#define REGISTER_LEGACY_TYPE_INIT(clsname) \
-  C10_REGISTER_CLASS(LegacyTypeInitRegistry, clsname, clsname)
-
-CAFFE2_API const LegacyTypeInitInterface& getLegacyTypeInit();
-
 struct Type;
 
 struct CAFFE2_API LegacyTypeDeleter {
@@ -133,15 +109,15 @@ private:
     static std::once_flag cuda_once;
     if (p == DeviceType::CPU) {
       std::call_once(cpu_once, [] {
-        getLegacyTypeInit().initCPU();
+        getLegacyDeviceTypeInit().initCPU();
       });
     } else if (p == DeviceType::CUDA) {
       std::call_once(cuda_once, [] {
-        getLegacyTypeInit().initCUDA();
+        getLegacyDeviceTypeInit().initCUDA();
       });
     } else if (p == DeviceType::HIP) {
       std::call_once(cuda_once, [] {
-        getLegacyTypeInit().initHIP();
+        getLegacyDeviceTypeInit().initHIP();
       });
     }
   }
@@ -150,7 +126,7 @@ private:
     // Only complex may need initialization
     if (isComplexType(s)) {
       std::call_once(once, [] {
-        getLegacyTypeInit().initComplex();
+        getLegacyDeviceTypeInit().initComplex();
       });
     }
   }