Remove TensorImpl -> context_base dependency (#14658)
authorSebastian Messmer <messmer@fb.com>
Sat, 8 Dec 2018 00:18:20 +0000 (16:18 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 8 Dec 2018 00:23:46 +0000 (16:23 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14658

Remove this dependency by moving at::CopyBytes to c10.
The implementations for at::CopyBytes will have to live in aten/caffe2 for now because they're not unified for CUDA yet.
They'll be moved into c10/backend/xxx later.

Reviewed By: dzhulgakov

Differential Revision: D13288655

fbshipit-source-id: 1c92379345308b3cd39a402779d7b7999613fc0d

aten/src/ATen/core/TensorImpl.h
aten/src/ATen/core/context_base.cpp
aten/src/ATen/core/context_base.h
c10/core/CopyBytes.cpp [new file with mode: 0644]
c10/core/CopyBytes.h [new file with mode: 0644]

index 7642354..eaaf3f7 100644 (file)
@@ -9,7 +9,7 @@
 #include <c10/core/TensorOptions.h>
 #include <c10/core/TensorTypeId.h>
 #include <c10/core/TensorTypeIdRegistration.h>
-#include <ATen/core/context_base.h>
+#include <c10/core/CopyBytes.h>
 
 #include <c10/util/Exception.h>
 #include <c10/util/Optional.h>
index 0e1f8b4..f917640 100644 (file)
@@ -11,49 +11,6 @@ C10_DEFINE_TYPED_REGISTRY(
     std::unique_ptr,
     at::Device);
 
-// First dimension of the array is `bool async`: 0 is sync,
-// 1 is async (non-blocking)
-static CopyBytesFunction g_copy_bytes[2][COMPILE_TIME_MAX_DEVICE_TYPES]
-                                     [COMPILE_TIME_MAX_DEVICE_TYPES];
-
-_CopyBytesFunctionRegisterer::_CopyBytesFunctionRegisterer(
-    DeviceType fromType,
-    DeviceType toType,
-    CopyBytesFunction func_sync,
-    CopyBytesFunction func_async) {
-  auto from = static_cast<int>(fromType);
-  auto to = static_cast<int>(toType);
-  if (!func_async) {
-    // default to the sync function
-    func_async = func_sync;
-  }
-  CHECK(
-      g_copy_bytes[0][from][to] == nullptr &&
-      g_copy_bytes[1][from][to] == nullptr)
-      << "Duplicate registration for device type pair "
-      << c10::DeviceTypeName(fromType) << ", " << c10::DeviceTypeName(toType);
-  g_copy_bytes[0][from][to] = func_sync;
-  g_copy_bytes[1][from][to] = func_async;
-}
-
-void CopyBytes(
-    size_t nbytes,
-    const void* src,
-    Device src_device,
-    void* dst,
-    Device dst_device,
-    bool async) {
-  auto ptr = g_copy_bytes[async ? 1 : 0][static_cast<int>(src_device.type())]
-                         [static_cast<int>(dst_device.type())];
-  CAFFE_ENFORCE(
-      ptr,
-      "No function found for copying from ",
-      c10::DeviceTypeName(src_device.type()),
-      " to ",
-      c10::DeviceTypeName(dst_device.type()));
-  ptr(nbytes, src, src_device, dst, dst_device);
-}
-
 } // namespace at
 
 namespace caffe2 {
index f27071b..4ea5138 100644 (file)
@@ -11,6 +11,7 @@
 #include <c10/util/typeid.h>
 #include <c10/util/Exception.h>
 #include <c10/util/Registry.h>
+#include <c10/core/CopyBytes.h>
 
 namespace caffe2 {
 class Event;
@@ -156,39 +157,6 @@ inline std::unique_ptr<at::BaseContext> CreateContext(
 
 } // namespace at
 
-// TODO: move it to a separate file in c10 if possible
-namespace at {
-
-using CopyBytesFunction = void (*)(
-    size_t nbytes,
-    const void* src,
-    Device src_device,
-    void* dst,
-    Device dst_device);
-
-struct CAFFE2_API _CopyBytesFunctionRegisterer {
-  _CopyBytesFunctionRegisterer(
-      DeviceType from,
-      DeviceType to,
-      CopyBytesFunction func_sync,
-      CopyBytesFunction func_async = nullptr);
-};
-
-#define REGISTER_COPY_BYTES_FUNCTION(from, to, ...)           \
-  namespace {                                                 \
-  static _CopyBytesFunctionRegisterer C10_ANONYMOUS_VARIABLE( \
-      g_copy_function)(from, to, __VA_ARGS__);                \
-  }
-
-CAFFE2_API void CopyBytes(
-    size_t nbytes,
-    const void* src,
-    Device src_device,
-    void* dst,
-    Device dst_device,
-    bool async);
-} // namespace at
-
 namespace caffe2 {
 
 using at::BaseContext;
diff --git a/c10/core/CopyBytes.cpp b/c10/core/CopyBytes.cpp
new file mode 100644 (file)
index 0000000..83e0885
--- /dev/null
@@ -0,0 +1,49 @@
+#include <c10/core/CopyBytes.h>
+#include <c10/util/Logging.h>
+
+namespace c10 {
+
+// First dimension of the array is `bool async`: 0 is sync,
+// 1 is async (non-blocking)
+static CopyBytesFunction g_copy_bytes[2][COMPILE_TIME_MAX_DEVICE_TYPES]
+                                     [COMPILE_TIME_MAX_DEVICE_TYPES];
+
+_CopyBytesFunctionRegisterer::_CopyBytesFunctionRegisterer(
+    DeviceType fromType,
+    DeviceType toType,
+    CopyBytesFunction func_sync,
+    CopyBytesFunction func_async) {
+  auto from = static_cast<int>(fromType);
+  auto to = static_cast<int>(toType);
+  if (!func_async) {
+    // default to the sync function
+    func_async = func_sync;
+  }
+  CHECK(
+      g_copy_bytes[0][from][to] == nullptr &&
+      g_copy_bytes[1][from][to] == nullptr)
+      << "Duplicate registration for device type pair "
+      << c10::DeviceTypeName(fromType) << ", " << c10::DeviceTypeName(toType);
+  g_copy_bytes[0][from][to] = func_sync;
+  g_copy_bytes[1][from][to] = func_async;
+}
+
+void CopyBytes(
+    size_t nbytes,
+    const void* src,
+    Device src_device,
+    void* dst,
+    Device dst_device,
+    bool async) {
+  auto ptr = g_copy_bytes[async ? 1 : 0][static_cast<int>(src_device.type())]
+                         [static_cast<int>(dst_device.type())];
+  CAFFE_ENFORCE(
+      ptr,
+      "No function found for copying from ",
+      c10::DeviceTypeName(src_device.type()),
+      " to ",
+      c10::DeviceTypeName(dst_device.type()));
+  ptr(nbytes, src, src_device, dst, dst_device);
+}
+
+}
diff --git a/c10/core/CopyBytes.h b/c10/core/CopyBytes.h
new file mode 100644 (file)
index 0000000..eae7139
--- /dev/null
@@ -0,0 +1,44 @@
+#pragma once
+
+#include <c10/Device.h>
+
+namespace c10 {
+
+using CopyBytesFunction = void (*)(
+    size_t nbytes,
+    const void* src,
+    Device src_device,
+    void* dst,
+    Device dst_device);
+
+struct C10_API _CopyBytesFunctionRegisterer {
+  _CopyBytesFunctionRegisterer(
+      DeviceType from,
+      DeviceType to,
+      CopyBytesFunction func_sync,
+      CopyBytesFunction func_async = nullptr);
+};
+
+#define REGISTER_COPY_BYTES_FUNCTION(from, to, ...)           \
+  namespace {                                                 \
+  static _CopyBytesFunctionRegisterer C10_ANONYMOUS_VARIABLE( \
+      g_copy_function)(from, to, __VA_ARGS__);                \
+  }
+
+/*
+ * WARNING: Implementations for this function are currently registered from
+ * ATen and caffe2, not yet from c10. Don't use this if not either ATen
+ * or caffe2 is present as well.
+ * We can't move them yet, because the CUDA implementations aren't unified yet
+ * between ATen and caffe2.
+ * We're planning to move the implementations into c10/backend/xxx
+ * to make c10 self contained again.
+ */
+C10_API void CopyBytes(
+    size_t nbytes,
+    const void* src,
+    Device src_device,
+    void* dst,
+    Device dst_device,
+    bool async);
+} // namespace c10