Move THTensor_(copy) to aten (#13603)
authorRoy Li <royboy@fb.com>
Fri, 30 Nov 2018 19:10:25 +0000 (11:10 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 30 Nov 2018 19:12:54 +0000 (11:12 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13603
P
Moved vectorized CPU copy to aten. Notable changes mainly in _copy_same_type_.

Reviewed By: ezyang

Differential Revision: D12936031

fbshipit-source-id: 00d28813e3160595e73d104f76685e13154971c1

21 files changed:
aten/src/ATen/cpu/vec256/vec256_base.h
aten/src/ATen/cpu/vec256/vec256_double.h
aten/src/ATen/cpu/vec256/vec256_float.h
aten/src/ATen/native/Copy.cpp
aten/src/ATen/native/Copy.h
aten/src/ATen/native/cpu/CopyKernel.cpp [new file with mode: 0644]
aten/src/ATen/native/cpu/CopyKernel.h [new file with mode: 0644]
aten/src/TH/CMakeLists.txt
aten/src/TH/THTensor.h
aten/src/TH/THTensorCopy.cpp [deleted file]
aten/src/TH/generic/THTensor.cpp
aten/src/TH/generic/THTensorCopy.cpp [deleted file]
aten/src/TH/generic/THTensorCopy.h [deleted file]
aten/src/TH/generic/THTensorEvenMoreMath.cpp
aten/src/TH/generic/THTensorLapack.cpp
aten/src/TH/generic/THTensorMath.cpp
aten/src/TH/generic/THTensorMoreMath.cpp
aten/src/TH/generic/THVector.h
aten/src/TH/generic/THVectorDispatch.cpp
aten/src/TH/vector/AVX.cpp
test/test_torch.py

index 0bbb275..69faaec 100644 (file)
@@ -7,6 +7,7 @@
 #include <bitset>
 
 #include "ATen/Utils.h"
+#include "ATen/native/Copy.h"
 #include <c10/util/C++17.h>
 
 #if defined(__GNUC__)
@@ -476,7 +477,8 @@ template <typename src_T, typename dst_T>
 void convert(const src_T *src, dst_T *dst, int64_t n) {
 #pragma unroll
   for (int64_t i = 0; i < n; i++) {
-    *dst = static_cast<dst_T>(*src);
+    *dst = static_cast<dst_T>(
+        static_cast<at::native::inter_copy_type_t<dst_T>>(*src));
     src++;
     dst++;
   }
index b19f04d..fcdca43 100644 (file)
@@ -248,6 +248,19 @@ Vec256<double> inline operator^(const Vec256<double>& a, const Vec256<double>& b
   return _mm256_xor_pd(a, b);
 }
 
+template <>
+void convert(const double* src, double* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vec256<double>::size); i += Vec256<double>::size) {
+    _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
 #ifdef __AVX2__
 template <>
 Vec256<double> inline fmadd(const Vec256<double>& a, const Vec256<double>& b, const Vec256<double>& c) {
index b477aa5..7d6cc16 100644 (file)
@@ -256,6 +256,19 @@ Vec256<float> inline operator^(const Vec256<float>& a, const Vec256<float>& b) {
   return _mm256_xor_ps(a, b);
 }
 
+template <>
+void convert(const float* src, float* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vec256<float>::size); i += Vec256<float>::size) {
+    _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i));
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
 #ifdef __AVX2__
 template <>
 Vec256<float> inline fmadd(const Vec256<float>& a, const Vec256<float>& b, const Vec256<float>& c) {
index 978ef2c..96d6e5a 100644 (file)
@@ -1,9 +1,10 @@
-#include "Copy.h"
+#include "ATen/native/Copy.h"
 
 #include "ATen/ATen.h"
 #include "ATen/CPUApplyUtils.h"
 #include "ATen/Dispatch.h"
 #include "ATen/NativeFunctions.h"
+#include "ATen/native/cpu/CopyKernel.h"
 
 namespace {
 
@@ -18,8 +19,16 @@ void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
 
 template <typename self_T>
 void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(
-      src.type(), "_copy__cpu", [&]() { _copy__cpu<self_T, scalar_t>(self, src); });
+  AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cpu", [&]() {
+    _copy__cpu<self_T, scalar_t>(self, src);
+  });
+}
+
+bool copy_transpose_valid(const at::Tensor& self, const at::Tensor& src) {
+  const int MIN_SZ = 60 * 60;
+  return self.is_contiguous() && src.numel() != 0 && src.dim() == 2 &&
+      src.stride(0) == 1 && src.stride(1) == src.size(0) &&
+      self.numel() >= MIN_SZ;
 }
 
 } // namespace
@@ -33,5 +42,101 @@ Tensor& _copy__cpu(Tensor& self, const Tensor& src) {
   return self;
 }
 
+// special case copy where tensor is contiguous and src is a transposed matrix
+// This can be generalized to most copies, but it's tricker
+void _copy_same_type_transpose_(Tensor& self, const Tensor& src) {
+  int64_t BLOCK_SZ;
+  if (self.scalar_type() == kByte) {
+    BLOCK_SZ = 120;
+  } else {
+    BLOCK_SZ = 60;
+  }
+  Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options());
+
+  AT_DISPATCH_ALL_TYPES_AND_HALF(
+      self.type(), "_copy_same_type_transpose_", [&]() {
+        scalar_t* sp = src.data<scalar_t>();
+        scalar_t* rp = self.data<scalar_t>();
+        scalar_t* bp = buf.data<scalar_t>();
+
+        int64_t NR = src.size(0);
+        int64_t NC = src.size(1);
+        for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
+          for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
+            scalar_t* spo = sp + R + C * NR;
+            scalar_t* rpo = rp + C + R * NC;
+
+            int nr = std::min(NR - R, BLOCK_SZ);
+            int nc = std::min(NC - C, BLOCK_SZ);
+
+            // 1. copy columns from src to buf
+            for (int c = 0; c < nc; c++) {
+              memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t));
+            }
+
+            // 2. transpose buf in place
+            int rc_max = std::max(nr, nc);
+            int rc_min = std::min(nr, nc);
+            for (int r = 0; r < rc_max; r++) {
+              int end = std::min(r, rc_min);
+              for (int c = 0; c < end; c++) {
+                scalar_t tmp = bp[r + BLOCK_SZ * c];
+                bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
+                bp[r * BLOCK_SZ + c] = tmp;
+              }
+            }
+
+            // 3. copy rows from buf to dst
+            for (int r = 0; r < nr; r++) {
+              memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(scalar_t));
+            }
+          }
+        }
+      });
+}
+
+void _copy_same_type_(Tensor& self, const Tensor& src) {
+  if (self.is_same(src)) {
+    return;
+  }
+
+  bool serial_path = false;
+  if (self.numel() == src.numel()) {
+    if (self.is_contiguous() && src.is_contiguous()) {
+      copy_kernel(kCPU, self, src);
+    } else if (copy_transpose_valid(self, src)) {
+      _copy_same_type_transpose_(self, src);
+    } else {
+#ifdef _OPENMP
+      if (!in_parallel_region()) {
+        AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() {
+          at::CPU_tensor_parallel_apply2<scalar_t, scalar_t>(
+              self, src, [](scalar_t& self_val, const scalar_t& src_val) {
+                self_val = src_val;
+              });
+        });
+      } else {
+        serial_path = true;
+      }
+#else
+      serial_path = true;
+#endif
+    }
+  } else {
+    serial_path = true;
+  }
+
+  if (serial_path) {
+    AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() {
+      at::CPU_tensor_apply2<scalar_t, scalar_t>(
+          self, src, [](scalar_t& self_val, const scalar_t& src_val) {
+            self_val = src_val;
+          });
+    });
+  }
+}
+
+DEFINE_DISPATCH(copy_kernel);
+
 } // namespace native
 } // namespace at
index 6382bce..6bd27be 100644 (file)
@@ -43,5 +43,7 @@ struct inter_copy_type<uint8_t> {
 template <typename T>
 using inter_copy_type_t = typename inter_copy_type<T>::type;
 
+void _copy_same_type_(Tensor& self, const Tensor& src);
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp
new file mode 100644 (file)
index 0000000..caf0364
--- /dev/null
@@ -0,0 +1,37 @@
+#include <ATen/native/cpu/CopyKernel.h>
+
+#include <ATen/ATen.h>
+#include <ATen/CPUApplyUtils.h>
+#include <ATen/Dispatch.h>
+#include <ATen/cpu/vec256/vec256.h>
+#include <ATen/native/Copy.h>
+
+namespace at {
+namespace native {
+namespace {
+
+// TODO: this number was copied from TH, test to see if it's the right number
+constexpr int64_t COPY_GRAIN_SIZE = 20000;
+
+static void copy_kernel_impl(Tensor& dst, const Tensor& src) {
+  AT_DISPATCH_ALL_TYPES_AND_HALF(dst.type(), "copy_kernel_impl", [&]() {
+    scalar_t* self_ptr = dst.data<scalar_t>();
+    scalar_t* src_ptr = src.data<scalar_t>();
+
+    auto sample = [&](int64_t begin, int64_t end) {
+      int64_t len = end - begin;
+      scalar_t* self_seg = self_ptr + begin;
+      scalar_t* src_seg = src_ptr + begin;
+      at::vec256::convert<scalar_t, scalar_t>(src_seg, self_seg, len);
+    };
+
+    parallel_for(0, dst.numel(), COPY_GRAIN_SIZE, sample);
+  });
+}
+
+} // anonymous namespace
+
+REGISTER_DISPATCH(copy_kernel, &copy_kernel_impl);
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/cpu/CopyKernel.h b/aten/src/ATen/native/cpu/CopyKernel.h
new file mode 100644 (file)
index 0000000..4b29d69
--- /dev/null
@@ -0,0 +1,14 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/native/DispatchStub.h>
+
+namespace at {
+namespace native {
+
+using forward_fn = void (*)(Tensor&, const Tensor&);
+
+DECLARE_DISPATCH(forward_fn, copy_kernel);
+
+} // namespace native
+} // namespace at
index 7ba65d2..463792b 100644 (file)
@@ -19,7 +19,6 @@ set(ATen_TH_SRCS
   ${CMAKE_CURRENT_SOURCE_DIR}/THSize.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THStorageFunctions.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THTensor.cpp
-  ${CMAKE_CURRENT_SOURCE_DIR}/THTensorCopy.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THTensorRandom.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THTensorMath.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THTensorMoreMath.cpp
@@ -112,8 +111,6 @@ INSTALL(FILES
   generic/THTensor.hpp
   generic/THTensorConv.cpp
   generic/THTensorConv.h
-  generic/THTensorCopy.cpp
-  generic/THTensorCopy.h
   generic/THTensorLapack.cpp
   generic/THTensorLapack.h
   generic/THTensorMath.cpp
index 3335a6f..1b10ab5 100644 (file)
 #include "generic/THTensor.h"
 #include "THGenerateHalfType.h"
 
-#include "generic/THTensorCopy.h"
-#include "THGenerateAllTypes.h"
-
-#include "generic/THTensorCopy.h"
-#include "THGenerateHalfType.h"
-
 /* random numbers */
 #include "THRandom.h"
 #include "generic/THTensorRandom.h"
diff --git a/aten/src/TH/THTensorCopy.cpp b/aten/src/TH/THTensorCopy.cpp
deleted file mode 100644 (file)
index 482a7b9..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-#include "THTensor.hpp"
-#include "THVector.h"
-
-#include <algorithm>
-
-#include "generic/THTensorCopy.cpp"
-#include "THGenerateAllTypes.h"
-
-#include "generic/THTensorCopy.cpp"
-#include "THGenerateHalfType.h"
index 7b4f949..498fa75 100644 (file)
@@ -3,6 +3,7 @@
 #else
 
 #include <ATen/InferSize.h>
+#include <ATen/native/Copy.h>
 #include <new>
 
 /**** access methods ****/
@@ -155,7 +156,9 @@ THTensor *THTensor_(newClone)(THTensor *self)
 {
   THTensor *tensor = THTensor_(new)();
   THTensor_(resizeAs)(tensor, self);
-  THTensor_(copy)(tensor, self);
+  at::Tensor tensor_wrap = THTensor_wrap(tensor);
+  at::Tensor self_wrap = THTensor_wrap(self);
+  at::native::_copy_same_type_(tensor_wrap, self_wrap);
   return tensor;
 }
 
@@ -577,8 +580,11 @@ void THTensor_(free)(THTensor *self)
 
 void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst)
 {
-  if(self != dst)
-    THTensor_(copy)(dst, self);
+  if(self != dst) {
+    at::Tensor dst_wrap = THTensor_wrap(dst);
+    at::Tensor self_wrap = THTensor_wrap(self);
+    at::native::_copy_same_type_(dst_wrap, self_wrap);
+  }
 
   THTensor_(free)(self);
 }
diff --git a/aten/src/TH/generic/THTensorCopy.cpp b/aten/src/TH/generic/THTensorCopy.cpp
deleted file mode 100644 (file)
index 13d3052..0000000
+++ /dev/null
@@ -1,153 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "generic/THTensorCopy.cpp"
-#else
-
-#ifndef _WIN32
-#define PRAGMA(P) _Pragma(#P)
-#else
-#define PRAGMA(P) __pragma(P)
-#endif
-
-#ifdef _OPENMP
-#define TH_OMP_OVERHEAD_THRESHOLD_COPY 20000
-#include <omp.h>
-#endif
-
-int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
-  const int MIN_SZ = 60 * 60;
-  return THTensor_(isContiguous)(tensor) &&
-         !src->is_empty() &&
-         THTensor_(nDimensionLegacyNoScalars)(src) == 2 &&
-         THTensor_(stride)(src, 0) == 1 &&
-         THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
-         THTensor_(nElement)(tensor) >= MIN_SZ;
-}
-
-// special case copy where tensor is contiguous and src is a transposed matrix
-// This can be generalized to most copies, but it's tricker
-void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
-
-#ifdef TH_REAL_IS_BYTE
-  const int64_t BLOCK_SZ = 120;
-#else
-  const int64_t BLOCK_SZ = 60;
-#endif
-
-  THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ);
-  scalar_t *sp = src->data<scalar_t>();
-  scalar_t *rp = tensor->data<scalar_t>();
-  scalar_t *bp = buf->data<scalar_t>();
-
-
-  int64_t NR = THTensor_(size)(src, 0);
-  int64_t NC = THTensor_(size)(src, 1);
-  for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
-    for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
-      scalar_t *spo = sp + R + C * NR;
-      scalar_t *rpo = rp + C + R * NC;
-
-      int nr = std::min(NR - R, BLOCK_SZ);
-      int nc = std::min(NC - C, BLOCK_SZ);
-
-      // 1. copy columns from src to buf
-      for (int c = 0; c < nc; c++) {
-        memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t));
-      }
-
-      // 2. transpose buf in place
-      int rc_max = std::max(nr, nc);
-      int rc_min = std::min(nr, nc);
-      for (int r = 0; r < rc_max; r++) {
-        int end = std::min(r, rc_min);
-        for (int c = 0; c < end; c++) {
-          scalar_t tmp = bp[r + BLOCK_SZ * c];
-          bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
-          bp[r * BLOCK_SZ + c] = tmp;
-        }
-      }
-
-      // 3. copy rows from buf to dst
-      for (int r = 0; r < nr; r++) {
-        memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(scalar_t));
-      }
-    }
-  }
-  c10::raw::intrusive_ptr::decref(buf);
-}
-
-void THTensor_(copy)(THTensor *tensor, THTensor *src)
-{
-  if (tensor == src) return;
-  ptrdiff_t tensorSize = THTensor_(nElement)(tensor);
-  ptrdiff_t srcSize = THTensor_(nElement)(src);
-  int tensorContig = THTensor_(isContiguous)(tensor);
-  int srcContig = THTensor_(isContiguous)(src);
-
-  int serial_path = 0;
-#ifdef _OPENMP
-  int inOMP = omp_in_parallel();
-#endif
-  if (tensorSize == srcSize) {
-    if ( tensorContig && srcContig) {
-      scalar_t *sp = src->data<scalar_t>();
-      scalar_t *rp = tensor->data<scalar_t>();
-#ifndef TH_REAL_IS_HALF
-#ifdef _OPENMP
-      #pragma omp parallel if ( (tensorSize > TH_OMP_OVERHEAD_THRESHOLD_COPY) && (!inOMP) )
-      {
-        size_t num_threads = omp_get_num_threads();
-        size_t tid = omp_get_thread_num();
-        ptrdiff_t offset = tid * (tensorSize / num_threads);
-        ptrdiff_t end = (tid == num_threads - 1) ? tensorSize : offset + tensorSize / num_threads;
-        ptrdiff_t len = end - offset;
-        scalar_t *tensorData = rp + offset;
-        scalar_t *srcData = sp + offset;
-        THVector_(copy)(tensorData, srcData, len);
-      }
-#else
-        THVector_(copy)(rp, sp, srcSize);
-#endif
-
-#else
-
-#ifdef _OPENMP
-      if ((srcSize > TH_OMP_OVERHEAD_THRESHOLD_COPY) && (!inOMP)) {
-        ptrdiff_t i;
-        #pragma omp parallel for private (i)
-        for(i=0; i<srcSize; i++){
-          rp[i] = sp[i];
-        }
-      } else {
-        memcpy(rp, sp, srcSize * sizeof(scalar_t));
-      }
-#else
-      memcpy(rp, sp, srcSize * sizeof(scalar_t));
-#endif
-
-#endif
-
-#ifndef TH_REAL_IS_HALF
-    } else if (THTensor_(copyTransposeValid)(tensor, src)) {
-      THTensor_(copyTranspose)(tensor, src);
-#endif
-    } else {
-#ifdef _OPENMP
-      if (inOMP) {
-        serial_path = 1;
-      } else {
-        TH_TENSOR_APPLY2_OMP(srcSize, tensorContig, srcContig, scalar_t, tensor, scalar_t, src, *tensor_data = *src_data;, TH_OMP_OVERHEAD_THRESHOLD_COPY)
-      }
-#else
-      serial_path = 1;
-#endif
-    }
-  } else {
-    serial_path = 1;
-  }
-
-  if (serial_path) {
-    TH_TENSOR_APPLY2(scalar_t, tensor, scalar_t, src, *tensor_data = *src_data;)
-  }
-}
-
-#endif
diff --git a/aten/src/TH/generic/THTensorCopy.h b/aten/src/TH/generic/THTensorCopy.h
deleted file mode 100644 (file)
index a326f36..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "generic/THTensorCopy.h"
-#else
-
-/* Support for copy between different Tensor types */
-
-TH_API void THTensor_(copy)(THTensor *tensor, THTensor *src);
-
-#endif
index 7dee7a3..bf33a2b 100644 (file)
@@ -3,6 +3,7 @@
 #else
 
 #include <TH/generic/THTensorApply.hpp>
+#include <ATen/native/Copy.h>
 
 void THTensor_(fill)(THTensor *r_, scalar_t value)
 {
@@ -219,7 +220,9 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
       sSlice = THTensor_(new)();
       THTensor_(select)(tSlice, tensor, dim, i);
       THTensor_(select)(sSlice, src, dim, index_data[i] - TH_INDEX_BASE);
-      THTensor_(copy)(tSlice, sSlice);
+      at::Tensor tSlice_wrap = THTensor_wrap(tSlice);
+      at::Tensor sSlice_wrap = THTensor_wrap(sSlice);
+      at::native::_copy_same_type_(tSlice_wrap, sSlice_wrap);
       c10::raw::intrusive_ptr::decref(tSlice);
       c10::raw::intrusive_ptr::decref(sSlice);
     }
@@ -250,7 +253,9 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens
     {
       THTensor_(select)(tSlice, tensor, dim, index_data[i] - TH_INDEX_BASE);
       THTensor_(select)(sSlice, src, dim, i);
-      THTensor_(copy)(tSlice, sSlice);
+      at::Tensor tSlice_wrap = THTensor_wrap(tSlice);
+      at::Tensor sSlice_wrap = THTensor_wrap(sSlice);
+      at::native::_copy_same_type_(tSlice_wrap, sSlice_wrap);
     }
 
     c10::raw::intrusive_ptr::decref(tSlice);
index 98fbcf9..302ced6 100644 (file)
@@ -2,6 +2,8 @@
 #define TH_GENERIC_FILE "generic/THTensorLapack.cpp"
 #else
 
+#include <ATen/native/Copy.h>
+
 /*
 Check if self is transpose of a contiguous matrix
 */
@@ -80,12 +82,17 @@ static THTensor *THTensor_(cloneColumnMajorNrows)(THTensor *self, THTensor *src,
   THTensor_(resize2d)(result, src->size(1), nrows);
   THTensor_(checkTransposed)(result);
 
-  if (src->size(0) == nrows)
-    THTensor_(copy)(result, src);
+  if (src->size(0) == nrows) {
+    at::Tensor result_wrap = THTensor_wrap(result);
+    at::Tensor src_wrap = THTensor_wrap(src);
+    at::native::_copy_same_type_(result_wrap, src_wrap);
+  }
   else
   {
     view = THTensor_(newNarrow)(result, 0, 0, src->size(0));
-    THTensor_(copy)(view, src);
+    at::Tensor view_wrap = THTensor_wrap(view);
+    at::Tensor src_wrap = THTensor_wrap(src);
+    at::native::_copy_same_type_(view_wrap, src_wrap);
     c10::raw::intrusive_ptr::decref(view);
   }
   return result;
@@ -529,7 +536,9 @@ void THTensor_(gesdd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra
       THTensor_(narrow)(rvf_,NULL,1,0,k);
 
     THTensor_(resizeAs)(rv_, rvf_);
-    THTensor_(copy)(rv_, rvf_);
+    at::Tensor rv__wrap = THTensor_wrap(rv_);
+    at::Tensor rvf__wrap =  THTensor_wrap(rvf_);
+    at::native::_copy_same_type_(rv__wrap, rvf__wrap);
     c10::raw::intrusive_ptr::decref(rvf_);
   } else {
     THTensor_(zero)(ru_);
@@ -1007,7 +1016,9 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf
 
   if (ra_ != a) {
     THTensor_(resizeAs)(ra_, a);
-    THTensor_(copy)(ra_, a);
+    at::Tensor ra__wrap = THTensor_wrap(ra_);
+    at::Tensor a_wrap = THTensor_wrap(a);
+    at::native::_copy_same_type_(ra__wrap, a_wrap);
   }
 
   int m = a->size(1);
@@ -1088,7 +1099,9 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
 
   if (rb_ != b) {
     THTensor_(resizeAs)(rb_, b);
-    THTensor_(copy)(rb_, b);
+    at::Tensor rb__wrap = THTensor_wrap(rb_);
+    at::Tensor b_wrap = THTensor_wrap(b);
+    at::native::_copy_same_type_(rb__wrap, b_wrap);
   }
 
   int64_t num_batches = atf->size(0);
index 1a71e97..bb6e9da 100644 (file)
@@ -3,6 +3,7 @@
 #else
 
 #include <TH/generic/THTensorApply.hpp>
+#include <ATen/native/Copy.h>
 
 // HEY YOU!
 //
@@ -213,8 +214,10 @@ void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src)
 void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value)
 {
   THTensor_(resizeAs)(r_, t);
-  if(value == 1){
-    THTensor_(copy)(r_, t);
+  if(value == 1) {
+    at::Tensor r__wrap = THTensor_wrap(r_);
+    at::Tensor t_wrap = THTensor_wrap(t);
+    at::native::_copy_same_type_(r__wrap, t_wrap);
   }
   else if(value == 2){
     THTensor_(cmul)(r_, t, t);
@@ -736,7 +739,9 @@ void THTensor_(addcmul)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src
   if(r_ != t)
   {
     THTensor_(resizeAs)(r_, t);
-    THTensor_(copy)(r_, t);
+    at::Tensor r__wrap = THTensor_wrap(r_);
+    at::Tensor t_wrap = THTensor_wrap(t);
+    at::native::_copy_same_type_(r__wrap, t_wrap);
   }
   int64_t r_Size = THTensor_(nElement)(r_);
   int64_t src1Size = THTensor_(nElement)(src1);
@@ -772,7 +777,9 @@ void THTensor_(addcdiv)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src
   if(r_ != t)
   {
     THTensor_(resizeAs)(r_, t);
-    THTensor_(copy)(r_, t);
+    at::Tensor r__wrap = THTensor_wrap(r_);
+    at::Tensor t_wrap = THTensor_wrap(t);
+    at::native::_copy_same_type_(r__wrap, t_wrap);
   }
   int64_t r_Size = THTensor_(nElement)(r_);
   int64_t src1Size = THTensor_(nElement)(src1);
@@ -827,7 +834,9 @@ void THTensor_(addmv)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha,
   if(r_ != t)
   {
     THTensor_(resizeAs)(r_, t);
-    THTensor_(copy)(r_, t);
+    at::Tensor r__wrap = THTensor_wrap(r_);
+    at::Tensor t_wrap = THTensor_wrap(t);
+    at::native::_copy_same_type_(r__wrap, t_wrap);
   }
 
   auto r_stride = THTensor_strideLegacyNoScalars(r_, 0);
@@ -946,7 +955,9 @@ void THTensor_(addmm)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha,
   {
     THTensor_(resizeAs)(r_, t);
     if (beta != 0.0) {
-      THTensor_(copy)(r_, t);
+      at::Tensor r__wrap = THTensor_wrap(r_);
+      at::Tensor t_wrap = THTensor_wrap(t);
+      at::native::_copy_same_type_(r__wrap, t_wrap);
     }
   }
 
@@ -1082,7 +1093,9 @@ void THTensor_(addr)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, T
   if(r_ != t)
   {
     THTensor_(resizeAs)(r_, t);
-    THTensor_(copy)(r_, t);
+    at::Tensor r__wrap = THTensor_wrap(r_);
+    at::Tensor t_wrap = THTensor_wrap(t);
+    at::native::_copy_same_type_(r__wrap, t_wrap);
   }
 
   if(beta == 0) {
@@ -1145,7 +1158,9 @@ void THTensor_(addbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t al
   if (t != result) {
     THTensor_(resizeAs)(result, t);
     if (beta != 0.0) {
-      THTensor_(copy)(result, t);
+      at::Tensor result_wrap = THTensor_wrap(result);
+      at::Tensor t_wrap = THTensor_wrap(t);
+      at::native::_copy_same_type_(result_wrap, t_wrap);
     }
   }
 
index 311d04d..f728a27 100644 (file)
@@ -4,6 +4,7 @@
 
 #include <TH/generic/THTensorApply.hpp>
 #include <TH/THGenerator.hpp>
+#include <ATen/native/Copy.h>
 
 void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2)
 {
@@ -29,7 +30,9 @@ void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t a
   if (t != result) {
     THTensor_(resizeAs)(result, t);
     if (beta != 0.0) {
-      THTensor_(copy)(result, t);
+      at::Tensor result_wrap = THTensor_wrap(result);
+      at::Tensor t_wrap = THTensor_wrap(t);
+      at::native::_copy_same_type_(result_wrap, t_wrap);
     }
   }
 
@@ -112,7 +115,9 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
   } else {
     if (THTensor_(nDimensionLegacyAll)(t) > 1) {
       THTensor *t0 = THTensor_(newSelect)(t, dimension, 0);
-      THTensor_(copy)(values_, t0);
+      at::Tensor values__wrap = THTensor_wrap(values_);
+      at::Tensor t0_wrap = THTensor_wrap(t0);
+      at::native::_copy_same_type_(values__wrap, t0_wrap);
       c10::raw::intrusive_ptr::decref(t0);
     } else {
       THTensor_(fill)(values_, THTensor_(get1d)(t, 0));
@@ -193,7 +198,9 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
   } else {
     if (THTensor_(nDimensionLegacyAll)(t) > 1) {
       THTensor *t0 = THTensor_(newSelect)(t, dimension, 0);
-      THTensor_(copy)(values_, t0);
+      at::Tensor values__wrap = THTensor_wrap(values_);
+      at::Tensor t0_wrap = THTensor_wrap(t0);
+      at::native::_copy_same_type_(values__wrap, t0_wrap);
       c10::raw::intrusive_ptr::decref(t0);
     } else {
       THTensor_(fill)(values_, THTensor_(get1d)(t, 0));
@@ -897,7 +904,9 @@ void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimensio
       dimension + TH_INDEX_BASE);
 
   THTensor_(resizeAs)(rt_, t);
-  THTensor_(copy)(rt_, t);
+  at::Tensor rt__wrap = THTensor_wrap(rt_);
+  at::Tensor t_wrap = THTensor_wrap(t);
+  at::native::_copy_same_type_(rt__wrap, t_wrap);
   THLongTensor_resize(ri_, t->sizes(), {});
 
   if(descendingOrder)
@@ -1411,7 +1420,9 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int
         int64_t dimSize = inputs[j]->size(dimension);
         THTensor *nt = THTensor_(newWithTensor)(result);
         THTensor_(narrow)(nt, NULL, dimension, offset, dimSize);
-        THTensor_(copy)(nt, inputs[j]);
+        at::Tensor nt__wrap = THTensor_wrap(nt);
+        at::Tensor inputs_wrap = THTensor_wrap(inputs[j]);
+        at::native::_copy_same_type_(nt__wrap, inputs_wrap);
         c10::raw::intrusive_ptr::decref(nt);
         offset += dimSize;
       }
@@ -2064,7 +2075,11 @@ void THTensor_(renorm)(THTensor *res, THTensor *src, scalar_t value, int dimensi
       )
     }
     else
-      THTensor_(copy)(rowR, rowS);
+    {
+      at::Tensor rowR_wrap = THTensor_wrap(rowR);
+      at::Tensor rowS_wrap = THTensor_wrap(rowS);
+      at::native::_copy_same_type_(rowR_wrap, rowS_wrap);
+    }
   }
 
   c10::raw::intrusive_ptr::decref(rowR);
index 32dfec0..bab2b94 100644 (file)
@@ -12,7 +12,6 @@ TH_API void THVector_(cmul)(scalar_t *z, const scalar_t *x, const scalar_t *y, c
 TH_API void THVector_(muls)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n);
 TH_API void THVector_(cdiv)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n);
 TH_API void THVector_(divs)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n);
-TH_API void THVector_(copy)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
 TH_API void THVector_(neg)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
 TH_API void THVector_(normal_fill)(scalar_t *data,
                                                                   const int64_t size,
index a65a378..e8f8b45 100644 (file)
@@ -177,20 +177,6 @@ void THVector_(divs)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptr
   THVector_(divs_DISPATCHPTR)(y, x, c, n);
 }
 
-static void (*THVector_(copy_DISPATCHPTR))(scalar_t *, const scalar_t *, const ptrdiff_t) = &THVector_(copy_DEFAULT);
-static FunctionDescription THVector_(copy_DISPATCHTABLE)[] = {
-  #if defined(USE_AVX)
-    #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
-      FUNCTION_IMPL(THVector_(copy_AVX), SIMDExtension_AVX),
-    #endif
-  #endif
-
-  FUNCTION_IMPL(THVector_(copy_DEFAULT), SIMDExtension_DEFAULT)
-};
-void THVector_(copy)(scalar_t *y, const scalar_t *x, const ptrdiff_t n) {
-  THVector_(copy_DISPATCHPTR)(y, x, n);
-}
-
 
 static void (*THVector_(normal_fill_DISPATCHPTR))(scalar_t *, const int64_t, THGenerator *, const scalar_t, const scalar_t) = &THVector_(normal_fill_DEFAULT);
 static FunctionDescription THVector_(normal_fill_DISPATCHTABLE)[] = {
@@ -240,7 +226,6 @@ struct THVector_(startup) {
     INIT_DISPATCH_PTR(muls);
     INIT_DISPATCH_PTR(cdiv);
     INIT_DISPATCH_PTR(divs);
-    INIT_DISPATCH_PTR(copy);
     INIT_DISPATCH_PTR(normal_fill);
 
 #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
index e5af660..96f8d1f 100644 (file)
@@ -8,19 +8,6 @@
 #include "AVX.h"
 #include <TH/THGeneral.h>
 
-void THDoubleVector_copy_AVX(double *y, const double *x, const ptrdiff_t n) {
-  ptrdiff_t i;
-  ptrdiff_t off;
-  for (i=0; i<=((n)-8); i+=8) {
-    _mm256_storeu_pd(y+i, _mm256_loadu_pd(x+i));
-    _mm256_storeu_pd(y+i+4, _mm256_loadu_pd(x+i+4));
-  }
-  off = (n) - ((n)%8);
-  for (i=0; i<((n)%8); i++) {
-    y[off+i] = x[off+i];
-  }
-}
-
 void THDoubleVector_fill_AVX(double *x, const double c, const ptrdiff_t n) {
   ptrdiff_t i;
   ptrdiff_t off;
@@ -140,19 +127,6 @@ void THDoubleVector_adds_AVX(double *y, const double *x, const double c, const p
   }
 }
 
-void THFloatVector_copy_AVX(float *y, const float *x, const ptrdiff_t n) {
-  ptrdiff_t i;
-  ptrdiff_t off;
-  for (i=0; i<=((n)-16); i+=16) {
-    _mm256_storeu_ps(y+i, _mm256_loadu_ps(x+i));
-    _mm256_storeu_ps(y+i+8, _mm256_loadu_ps(x+i+8));
-  }
-  off = (n) - ((n)%16);
-  for (i=0; i<((n)%16); i++) {
-    y[off+i] = x[off+i];
-  }
-}
-
 void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n) {
   ptrdiff_t i;
   ptrdiff_t off;
index 33e0fde..d3d5d4c 100644 (file)
@@ -9093,7 +9093,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t())
         self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_())
 
-    # unit test for THTensor_(copyTranspose)
+    # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
     @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
     def test_big_transpose(self):
         t = torch.rand(456, 789)