From e3840419ecffec315a08e58a48bc710ca2c2e9b4 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Fri, 30 Nov 2018 11:10:25 -0800 Subject: [PATCH] Move cuda copy to aten (#13348) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13348 Move cross device, cpu to device, device to cpu copies to aten. Most of it is a direct port, main difference is that we dispatch from a single _copy_ function for copies. Reviewed By: ezyang Differential Revision: D12850690 fbshipit-source-id: c2e3f336796b4ae38be6027d2ec131a274a6aa8c --- aten/src/ATen/copy_wrapper.py | 18 +- aten/src/ATen/native/Copy.cpp | 7 +- aten/src/ATen/native/Copy.h | 2 - aten/src/ATen/native/cuda/Copy.cu | 248 +++++++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 11 ++ aten/src/TH/generic/THTensor.cpp | 5 +- aten/src/TH/generic/THTensorEvenMoreMath.cpp | 5 +- aten/src/TH/generic/THTensorLapack.cpp | 12 +- aten/src/TH/generic/THTensorMath.cpp | 15 +- aten/src/TH/generic/THTensorMoreMath.cpp | 13 +- aten/src/THC/THCTensorCopy.cu | 189 -------------------- aten/src/THC/generic/THCStorageCopy.cpp | 4 +- aten/src/THC/generic/THCStorageCopy.cu | 2 +- aten/src/THC/generic/THCTensorCopy.cpp | 122 ------------- aten/src/THC/generic/THCTensorCopy.cu | 34 ++-- aten/src/THC/generic/THCTensorCopy.h | 30 ---- aten/src/THC/generic/THCTensorMasked.cu | 10 +- aten/src/THCUNN/generic/SparseLinear.cu | 12 +- 18 files changed, 312 insertions(+), 427 deletions(-) create mode 100644 aten/src/ATen/native/cuda/Copy.cu diff --git a/aten/src/ATen/copy_wrapper.py b/aten/src/ATen/copy_wrapper.py index 38c18ff..0a9bbb7 100644 --- a/aten/src/ATen/copy_wrapper.py +++ b/aten/src/ATen/copy_wrapper.py @@ -31,14 +31,8 @@ CUDA_INCLUDES = """\ # in both cases, we unconditionally cast both tensors (and rely # on the surrounding code to establish the necessary invariants.) -COPY_CPU = CodeTemplate("""\ -_copy_(dst, src); -""") - COPY = CodeTemplate("""\ -${THTensor}_copy${cuda}${src_scalar_name}(${state,}\ -dst.unsafeGetTensorImpl(), \ -src.unsafeGetTensorImpl()); +_copy_(dst, src); """) COPY_ASYNC_CPU = CodeTemplate("""\ @@ -142,10 +136,7 @@ def create_one_copy(dst_type, all_types): if dst_type['ScalarType'] == src_type['ScalarType']: if dst_type['Backend'] == 'CUDA' and src_type['Backend'] == 'CPU': copies.append(COPY_ASYNC_CPU.substitute(body_env)) - if dst_type['Backend'] == 'CPU' and src_type['Backend'] == 'CPU': - copies.append(COPY_CPU.substitute()) - else: - copies.append(COPY.substitute(body_env)) + copies.append(COPY.substitute()) copy_body.append(CASE.substitute(body_env, copies=copies)) @@ -211,10 +202,7 @@ def create_one_copy_from(src_type, all_types): # function if dst_type['Backend'] == 'CPU' and src_type['Backend'] == 'CUDA': copies.append(COPY_ASYNC_CUDA.substitute(body_env)) - if dst_type['Backend'] == 'CPU' and src_type['Backend'] == 'CPU': - copies.append(COPY_CPU.substitute()) - else: - copies.append(COPY.substitute(body_env)) + copies.append(COPY.substitute()) copy_body.append(CASE.substitute(body_env, copies=copies)) diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 96d6e5a..336a764 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -19,6 +19,7 @@ void _copy__cpu(at::Tensor& self, const at::Tensor& src) { template void _copy__cpu(at::Tensor& self, const at::Tensor& src) { + AT_CHECK(self.numel() == src.numel(), "sizes do not match"); AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cpu", [&]() { _copy__cpu(self, src); }); @@ -37,6 +38,10 @@ namespace at { namespace native { Tensor& _copy__cpu(Tensor& self, const Tensor& src) { + if (src.is_cuda()) { + _copy_from(src, self); + return self; + } AT_DISPATCH_ALL_TYPES_AND_HALF( self.type(), "_copy__cpu", [&]() { ::_copy__cpu(self, src); }); return self; @@ -95,7 +100,7 @@ void _copy_same_type_transpose_(Tensor& self, const Tensor& src) { }); } -void _copy_same_type_(Tensor& self, const Tensor& src) { +void _copy_same_type__cpu(Tensor& self, const Tensor& src) { if (self.is_same(src)) { return; } diff --git a/aten/src/ATen/native/Copy.h b/aten/src/ATen/native/Copy.h index 6bd27be..6382bce 100644 --- a/aten/src/ATen/native/Copy.h +++ b/aten/src/ATen/native/Copy.h @@ -43,7 +43,5 @@ struct inter_copy_type { template using inter_copy_type_t = typename inter_copy_type::type; -void _copy_same_type_(Tensor& self, const Tensor& src); - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu new file mode 100644 index 0000000..241f107 --- /dev/null +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -0,0 +1,248 @@ +#include "ATen/ATen.h" +#include "ATen/Context.h" +#include "ATen/Dispatch.h" +#include "ATen/NativeFunctions.h" +#include "ATen/cuda/CUDAApplyUtils.cuh" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/CUDAEvent.h" +#include "ATen/cuda/CUDAStream.h" +#include "ATen/native/Copy.h" + +namespace { + +using namespace at; +using namespace at::cuda; + +// Copy operator for the pointwise apply kernel +template +struct CopyOp { + static void apply(Tensor& dst, const Tensor& src) { + CUDA_tensor_apply2( + dst, src, [] __device__(dst_T & dst_val, const src_T& src_val) { +#if __CUDA_ARCH__ >= 350 + dst_val = static_cast( + static_cast>(__ldg(&src_val))); +#else + dst_val = static_cast(static_cast>(src_val)); +#endif + }); + } +}; + +// device-to-device copy, does type conversion +template +void copy_device_to_device(Tensor& dst, const Tensor& src) { + auto numel = dst.numel(); + if (dst.is_same(src) || numel == 0) { + return; + } + + // We can memcpy the memory if: + // -both tensors are contiguous; or, + // -there is only one element to copy; or, + // -FIXME: if both tensors have matching size and stride arrays, and no + // holes within (in other words, there is some permutation that can be applied + // to the size/strides such that the resulting tensor is + // contiguous). + // -AND: both tensors have the same type. + bool same_type = std::is_same::value; + bool memcpy_eligible = + ((src.is_contiguous() && dst.is_contiguous()) || (numel == 1)) && + same_type; + + Device src_device = src.device(); + Device dst_device = dst.device(); + + CUDAGuard device_guard(src_device); + + // Try to enable p2p access. This also handles the case src_device == + // dst_device. + bool p2pEnabled = THCState_getPeerToPeerAccess( + globalContext().getTHCState(), src_device.index(), dst_device.index()); + + // We always perform the copy on the source device, using the + // current stream on the source device. + // If the copy is on the default stream, then we fully synchronize + // both src and dst's default streams for completion of the + // copy. We have to explicitly do this for non-contig copies. + // This mimics the behavior of cross-device cudaMemcpyAsync on + // the default stream. + // If the copy is not on the default stream, then it is up to the + // user to add needed synchronization on the dst device, since the + // stream on the dst device that wishes to synchronize may not be + // the same index as the one on the src device. + CUDAStream copy_stream = getCurrentCUDAStream(src_device.index()); + if (src_device != dst_device && copy_stream == nullptr) { + // This is a cross-device copy on the default stream. We perform a + // two-way barrier between both devices' default streams before + // the copy. This ensures that any write-after-write and + // write-after-read dependencies on the destination side are + // handled, so that no one is operating on the dst memory when + // we perform the copy. + // src waits on dst barrier (src already waits on src) + CUDAEvent dst_ready; + device_guard.set_device(dst_device); + dst_ready.record(getDefaultCUDAStream(dst_device.index())); + + device_guard.set_device(src_device); + dst_ready.block(copy_stream); + } + + if (memcpy_eligible) { + // Perform the copy + AT_CUDA_CHECK(cudaMemcpyAsync( + dst.data(), + src.data(), + numel * sizeof(dst_T), + cudaMemcpyDeviceToDevice, + copy_stream)); + } else { + // Non-contiguous copy or a type-conversion copy + + // We avoid creating temporary memory copies if possible. + // If both src and dst are on the same device, or if they are on + // different devices and p2p access is enabled, perform the copy + // by a pointwise copy kernel. + // Otherwise, we'll have to make contiguous (which will in fact + // invoke copy() again), and then perform the copy. + // FIXME: might want to consider only running the pointwise kernel + // if both src and dst innermost dimensions are contiguous. If + // they are not, then taking the hit of the memory allocation/free + // might be worth it to avoid non-coalesced reads or writes. + if (p2pEnabled) { + CopyOp::apply(dst, src); + } else { + // GPUs can't access each other directly, but the tensors + // involved are non-contiguous and/or are different types. + + // Make sure the src is contiguous and in the same type as dst + Tensor src_contig; + if (same_type) { + src_contig = src.contiguous(); + } else { + // Types are different + // Copy into the new format, contiguous, on the source device + src_contig = at::empty_like(dst, src.options().dtype(dst.dtype())); + + CopyOp::apply(src_contig, src); + } + + // Make sure the dst is contiguous + device_guard.set_device(dst_device); + Tensor dst_contig = dst.contiguous(); + + // Now, we are ready for a cross-device memcpy of contiguous + // data, of the same layout and type + device_guard.set_device(src_device); + + AT_CUDA_CHECK(cudaMemcpyAsync( + dst_contig.data(), + src_contig.data(), + numel * sizeof(dst_T), + cudaMemcpyDeviceToDevice, + copy_stream)); + + if (!dst.is_contiguous()) { + copy_device_to_device(dst, dst_contig); + } + } + } + + if (src_device != dst_device && copy_stream == nullptr) { + // dst waits on src barrier (dst already waits on dst). We cannot + // operate on dst's copy until the copy is complete. + + // Still on src_device, record default stream event + CUDAEvent src_ready; + src_ready.record(copy_stream); + + device_guard.set_device(dst_device); + src_ready.block(getDefaultCUDAStream(dst_device.index())); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void copy_from_cpu(Tensor& dst, const Tensor& src) { + Tensor dst_contig = dst.contiguous(); + Tensor src_contig = src.contiguous(); + + CUDAStream stream = getCurrentCUDAStream(); + + AT_CUDA_CHECK(cudaMemcpyAsync( + dst_contig.data_ptr(), + src_contig.data_ptr(), + src.numel() * src.dtype().itemsize(), + cudaMemcpyHostToDevice, + stream)); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_from_cpu", [&]() { + copy_device_to_device(dst, dst_contig); + }); +} + +void copy_to_cpu(Tensor& dst, const Tensor& src) { + Tensor dst_contig = dst.contiguous(); + Tensor src_contig = src.contiguous(); + + CUDAGuard device_guard(src.device()); + CUDAStream stream = getCurrentCUDAStream(); + + AT_CUDA_CHECK(cudaMemcpyAsync( + dst_contig.data_ptr(), + src_contig.data_ptr(), + src.numel() * src.dtype().itemsize(), + cudaMemcpyDeviceToHost, + stream)); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + _copy_same_type_(dst, dst_contig); +} + +template +void _copy__cuda(Tensor& dst, const Tensor& src) { + AT_CHECK(dst.numel() == src.numel(), "sizes do not match"); + AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cuda", [&]() { + if (dst.is_cuda() && src.is_cuda()) { + copy_device_to_device(dst, src); + } else if (dst.is_cuda()) { + if (std::is_same::value) { + copy_from_cpu(dst, src); + } else { + // Do a dtype converting copy on the CPU, then copy to device + Tensor srcf = at::empty_like(src, src.options().dtype(dst.dtype())); + _copy_(srcf, src); + copy_from_cpu(dst, srcf); + } + } else { + if (std::is_same::value) { + copy_to_cpu(dst, src); + } else { + // Copy to CPU as the same dtype, then do a dtype converting copy + Tensor srcf = at::empty_like(src, dst.options().dtype(src.dtype())); + copy_to_cpu(srcf, src); + _copy_(dst, srcf); + } + } + }); +} + +} // namespace + +namespace at { +namespace native { + +Tensor& _copy__cuda(Tensor& self, const Tensor& src) { + AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy__cuda", [&]() { + ::_copy__cuda(self, src); + }); + return self; +} + +Tensor _copy_from_cuda(const Tensor& self, const Tensor& dst) { + Tensor dst_ = dst; + _copy__cuda(dst_, self); + return dst; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f2bf9ca..656962b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -443,6 +443,17 @@ cpu_half: True dispatch: CPU: _copy__cpu + CUDA: _copy__cuda + +- func: _copy_from(Tensor self, Tensor dst) -> Tensor + cpu_half: True + dispatch: + CUDA: _copy_from_cuda + +- func: _copy_same_type_(Tensor self, Tensor src) -> void + cpu_half: True + dispatch: + CPU: _copy_same_type__cpu - func: cos(Tensor self) -> Tensor variants: function, method diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp index 498fa75..d853464 100644 --- a/aten/src/TH/generic/THTensor.cpp +++ b/aten/src/TH/generic/THTensor.cpp @@ -3,7 +3,6 @@ #else #include -#include #include /**** access methods ****/ @@ -158,7 +157,7 @@ THTensor *THTensor_(newClone)(THTensor *self) THTensor_(resizeAs)(tensor, self); at::Tensor tensor_wrap = THTensor_wrap(tensor); at::Tensor self_wrap = THTensor_wrap(self); - at::native::_copy_same_type_(tensor_wrap, self_wrap); + at::_copy_same_type_(tensor_wrap, self_wrap); return tensor; } @@ -583,7 +582,7 @@ void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst) if(self != dst) { at::Tensor dst_wrap = THTensor_wrap(dst); at::Tensor self_wrap = THTensor_wrap(self); - at::native::_copy_same_type_(dst_wrap, self_wrap); + at::_copy_same_type_(dst_wrap, self_wrap); } THTensor_(free)(self); diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index bf33a2b..cc7b2cc 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -3,7 +3,6 @@ #else #include -#include void THTensor_(fill)(THTensor *r_, scalar_t value) { @@ -222,7 +221,7 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens THTensor_(select)(sSlice, src, dim, index_data[i] - TH_INDEX_BASE); at::Tensor tSlice_wrap = THTensor_wrap(tSlice); at::Tensor sSlice_wrap = THTensor_wrap(sSlice); - at::native::_copy_same_type_(tSlice_wrap, sSlice_wrap); + at::_copy_same_type_(tSlice_wrap, sSlice_wrap); c10::raw::intrusive_ptr::decref(tSlice); c10::raw::intrusive_ptr::decref(sSlice); } @@ -255,7 +254,7 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens THTensor_(select)(sSlice, src, dim, i); at::Tensor tSlice_wrap = THTensor_wrap(tSlice); at::Tensor sSlice_wrap = THTensor_wrap(sSlice); - at::native::_copy_same_type_(tSlice_wrap, sSlice_wrap); + at::_copy_same_type_(tSlice_wrap, sSlice_wrap); } c10::raw::intrusive_ptr::decref(tSlice); diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index 302ced6..cc538b5 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -2,8 +2,6 @@ #define TH_GENERIC_FILE "generic/THTensorLapack.cpp" #else -#include - /* Check if self is transpose of a contiguous matrix */ @@ -85,14 +83,14 @@ static THTensor *THTensor_(cloneColumnMajorNrows)(THTensor *self, THTensor *src, if (src->size(0) == nrows) { at::Tensor result_wrap = THTensor_wrap(result); at::Tensor src_wrap = THTensor_wrap(src); - at::native::_copy_same_type_(result_wrap, src_wrap); + at::_copy_same_type_(result_wrap, src_wrap); } else { view = THTensor_(newNarrow)(result, 0, 0, src->size(0)); at::Tensor view_wrap = THTensor_wrap(view); at::Tensor src_wrap = THTensor_wrap(src); - at::native::_copy_same_type_(view_wrap, src_wrap); + at::_copy_same_type_(view_wrap, src_wrap); c10::raw::intrusive_ptr::decref(view); } return result; @@ -538,7 +536,7 @@ void THTensor_(gesdd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra THTensor_(resizeAs)(rv_, rvf_); at::Tensor rv__wrap = THTensor_wrap(rv_); at::Tensor rvf__wrap = THTensor_wrap(rvf_); - at::native::_copy_same_type_(rv__wrap, rvf__wrap); + at::_copy_same_type_(rv__wrap, rvf__wrap); c10::raw::intrusive_ptr::decref(rvf_); } else { THTensor_(zero)(ru_); @@ -1018,7 +1016,7 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf THTensor_(resizeAs)(ra_, a); at::Tensor ra__wrap = THTensor_wrap(ra_); at::Tensor a_wrap = THTensor_wrap(a); - at::native::_copy_same_type_(ra__wrap, a_wrap); + at::_copy_same_type_(ra__wrap, a_wrap); } int m = a->size(1); @@ -1101,7 +1099,7 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor THTensor_(resizeAs)(rb_, b); at::Tensor rb__wrap = THTensor_wrap(rb_); at::Tensor b_wrap = THTensor_wrap(b); - at::native::_copy_same_type_(rb__wrap, b_wrap); + at::_copy_same_type_(rb__wrap, b_wrap); } int64_t num_batches = atf->size(0); diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp index bb6e9da..8348c3a 100644 --- a/aten/src/TH/generic/THTensorMath.cpp +++ b/aten/src/TH/generic/THTensorMath.cpp @@ -3,7 +3,6 @@ #else #include -#include // HEY YOU! // @@ -217,7 +216,7 @@ void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value) if(value == 1) { at::Tensor r__wrap = THTensor_wrap(r_); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(r__wrap, t_wrap); + at::_copy_same_type_(r__wrap, t_wrap); } else if(value == 2){ THTensor_(cmul)(r_, t, t); @@ -741,7 +740,7 @@ void THTensor_(addcmul)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src THTensor_(resizeAs)(r_, t); at::Tensor r__wrap = THTensor_wrap(r_); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(r__wrap, t_wrap); + at::_copy_same_type_(r__wrap, t_wrap); } int64_t r_Size = THTensor_(nElement)(r_); int64_t src1Size = THTensor_(nElement)(src1); @@ -779,7 +778,7 @@ void THTensor_(addcdiv)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src THTensor_(resizeAs)(r_, t); at::Tensor r__wrap = THTensor_wrap(r_); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(r__wrap, t_wrap); + at::_copy_same_type_(r__wrap, t_wrap); } int64_t r_Size = THTensor_(nElement)(r_); int64_t src1Size = THTensor_(nElement)(src1); @@ -836,7 +835,7 @@ void THTensor_(addmv)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor_(resizeAs)(r_, t); at::Tensor r__wrap = THTensor_wrap(r_); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(r__wrap, t_wrap); + at::_copy_same_type_(r__wrap, t_wrap); } auto r_stride = THTensor_strideLegacyNoScalars(r_, 0); @@ -957,7 +956,7 @@ void THTensor_(addmm)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, if (beta != 0.0) { at::Tensor r__wrap = THTensor_wrap(r_); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(r__wrap, t_wrap); + at::_copy_same_type_(r__wrap, t_wrap); } } @@ -1095,7 +1094,7 @@ void THTensor_(addr)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, T THTensor_(resizeAs)(r_, t); at::Tensor r__wrap = THTensor_wrap(r_); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(r__wrap, t_wrap); + at::_copy_same_type_(r__wrap, t_wrap); } if(beta == 0) { @@ -1160,7 +1159,7 @@ void THTensor_(addbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t al if (beta != 0.0) { at::Tensor result_wrap = THTensor_wrap(result); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(result_wrap, t_wrap); + at::_copy_same_type_(result_wrap, t_wrap); } } diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index f728a27..bbbd4ad 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -4,7 +4,6 @@ #include #include -#include void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2) { @@ -32,7 +31,7 @@ void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t a if (beta != 0.0) { at::Tensor result_wrap = THTensor_wrap(result); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(result_wrap, t_wrap); + at::_copy_same_type_(result_wrap, t_wrap); } } @@ -117,7 +116,7 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int THTensor *t0 = THTensor_(newSelect)(t, dimension, 0); at::Tensor values__wrap = THTensor_wrap(values_); at::Tensor t0_wrap = THTensor_wrap(t0); - at::native::_copy_same_type_(values__wrap, t0_wrap); + at::_copy_same_type_(values__wrap, t0_wrap); c10::raw::intrusive_ptr::decref(t0); } else { THTensor_(fill)(values_, THTensor_(get1d)(t, 0)); @@ -200,7 +199,7 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int THTensor *t0 = THTensor_(newSelect)(t, dimension, 0); at::Tensor values__wrap = THTensor_wrap(values_); at::Tensor t0_wrap = THTensor_wrap(t0); - at::native::_copy_same_type_(values__wrap, t0_wrap); + at::_copy_same_type_(values__wrap, t0_wrap); c10::raw::intrusive_ptr::decref(t0); } else { THTensor_(fill)(values_, THTensor_(get1d)(t, 0)); @@ -906,7 +905,7 @@ void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimensio THTensor_(resizeAs)(rt_, t); at::Tensor rt__wrap = THTensor_wrap(rt_); at::Tensor t_wrap = THTensor_wrap(t); - at::native::_copy_same_type_(rt__wrap, t_wrap); + at::_copy_same_type_(rt__wrap, t_wrap); THLongTensor_resize(ri_, t->sizes(), {}); if(descendingOrder) @@ -1422,7 +1421,7 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int THTensor_(narrow)(nt, NULL, dimension, offset, dimSize); at::Tensor nt__wrap = THTensor_wrap(nt); at::Tensor inputs_wrap = THTensor_wrap(inputs[j]); - at::native::_copy_same_type_(nt__wrap, inputs_wrap); + at::_copy_same_type_(nt__wrap, inputs_wrap); c10::raw::intrusive_ptr::decref(nt); offset += dimSize; } @@ -2078,7 +2077,7 @@ void THTensor_(renorm)(THTensor *res, THTensor *src, scalar_t value, int dimensi { at::Tensor rowR_wrap = THTensor_wrap(rowR); at::Tensor rowS_wrap = THTensor_wrap(rowS); - at::native::_copy_same_type_(rowR_wrap, rowS_wrap); + at::_copy_same_type_(rowR_wrap, rowS_wrap); } } diff --git a/aten/src/THC/THCTensorCopy.cu b/aten/src/THC/THCTensorCopy.cu index 844539c..a0757c1 100644 --- a/aten/src/THC/THCTensorCopy.cu +++ b/aten/src/THC/THCTensorCopy.cu @@ -4,12 +4,6 @@ #include "THCTensorCopy.hpp" #include -inline int curGPU() { - int curDev; - THCudaCheck(cudaGetDevice(&curDev)); - return curDev; -} - // Copy operator for the pointwise apply kernel template struct CopyOp { @@ -22,188 +16,5 @@ struct CopyOp { } }; -// Copy for the same type to the same type -template -void THC_copyTensor(THCState* state, THCTensor* dst, THCTensor* src) { - - ptrdiff_t totalElements = THCTensor_nElement(state, dst); - - THArgCheck(totalElements == - THCTensor_nElement(state, src), - 2, "sizes do not match"); - - if (THCTensor_nDimensionLegacyAll(state, dst) == 0) { - // Zero-dim tensor; copy nothing - return; - } - - // We can memcpy the memory if: - // -both tensors are contiguous; or, - // -there is only one element to copy; or, - // -FIXME: if both tensors have matching size and stride arrays, and no - // holes within (in other words, there is some permutation that can be applied - // to the size/strides such that the resulting tensor is - // contiguous). - // -AND: both tensors have the same type. - bool sameType = std::is_same::value; - bool srcContig = src->is_contiguous(); - bool dstContig = dst->is_contiguous(); - bool memcpyEligible = - ((srcContig && dstContig) || (totalElements == 1)) && sameType; - - int srcDev = THCTensor_getDevice(state, src); - int dstDev = THCTensor_getDevice(state, dst); - int oldDev = curGPU(); - - // Try to enable p2p access. This also handles the case srcDev == dstDev. - bool p2pEnabled = THCState_getPeerToPeerAccess(state, srcDev, dstDev); - - // We always perform the copy on the source device, using the - // current stream on the source device. - // If the copy is on the default stream, then we fully synchronize - // both src and dst's default streams for completion of the - // copy. We have to explicitly do this for non-contig copies. - // This mimics the behavior of cross-device cudaMemcpyAsync on - // the default stream. - // If the copy is not on the default stream, then it is up to the - // user to add needed synchronization on the dst device, since the - // stream on the dst device that wishes to synchronize may not be - // the same index as the one on the src device. - cudaStream_t copyStream = THCState_getCurrentStreamOnDevice(state, srcDev); - if (srcDev != dstDev && copyStream == NULL) { - // This is a cross-device copy on the default stream. We perform a - // two-way barrier between both devices' default streams before - // the copy. This ensures that any write-after-write and - // write-after-read dependencies on the destination side are - // handled, so that no one is operating on the dst memory when - // we perform the copy. - // src waits on dst barrier (src already waits on src) - cudaEvent_t dstReady; - THCudaCheck(cudaSetDevice(dstDev)); - THCudaCheck(cudaEventCreateWithFlags(&dstReady, cudaEventDisableTiming)); - THCudaCheck(cudaEventRecord(dstReady, NULL)); - - THCudaCheck(cudaSetDevice(srcDev)); - THCudaCheck(cudaStreamWaitEvent(NULL, dstReady, 0)); - THCudaCheck(cudaEventDestroy(dstReady)); - } else if (srcDev != oldDev) { - THCudaCheck(cudaSetDevice(srcDev)); - } - - // We are now on srcDev - if (memcpyEligible) { - // Perform the copy - THCudaCheck(cudaMemcpyAsync( - dst->template data(), - src->template data(), - totalElements * - sizeof(ScalarTypeDst), - cudaMemcpyDeviceToDevice, - copyStream)); - } else { - // Non-contiguous copy or a type-conversion copy - - // We avoid creating temporary memory copies if possible. - // If both src and dst are on the same device, or if they are on - // different devices and p2p access is enabled, perform the copy - // by a pointwise copy kernel. - // Otherwise, we'll have to make contiguous (which will in fact - // invoke copy() again), and then perform the copy. - // FIXME: might want to consider only running the pointwise kernel - // if both src and dst innermost dimensions are contiguous. If - // they are not, then taking the hit of the memory allocation/free - // might be worth it to avoid non-coalesced reads or writes. - if (p2pEnabled) { - bool succ = - THC_pointwiseApply2( - state, dst, src, - CopyOp()); - - THArgCheck(succ, 2, CUTORCH_DIM_WARNING); - } else { - // GPUs can't access each other directly, but the tensors - // involved are non-contiguous and/or are different types. - - // Make sure the src is contiguous and in the same type as dst - THCudaCheck(cudaSetDevice(srcDev)); - THCTensor* srcContig = NULL; - - if (sameType) { - srcContig = THCTensor_newContiguous(state, src); - - } else { - // Types are different - // Copy into the new format, contiguous, on the source device - srcContig = THCTensor_new(state, caffe2::TypeMeta::Make()); - THCTensor_resizeAs(state, srcContig, dst); - - bool succ = - THC_pointwiseApply2( - state, srcContig, src, - CopyOp()); - - THArgCheck(succ, 2, CUTORCH_DIM_WARNING); - } - - // Make sure the dst is contiguous - THCudaCheck(cudaSetDevice(dstDev)); - THCTensor* dstContig = THCTensor_newContiguous(state, dst); - - // Now, we are ready for a cross-device memcpy of contiguous - // data, of the same layout and type - THCudaCheck(cudaSetDevice(srcDev)); - - THCudaCheck(cudaMemcpyAsync( - dstContig->template data(), - srcContig->template data(), - totalElements * - sizeof(ScalarTypeDst), - cudaMemcpyDeviceToDevice, - copyStream)); - - // We are done with the src - THCTensor_free(state, srcContig); - - if (dst != dstContig) { - THCTensor_freeCopyTo(state, dstContig, dst); - } else { - THCTensor_free(state, dstContig); - } - - // We're still on srcDev at this point - } - } - - if (srcDev != dstDev && copyStream == NULL) { - // dst waits on src barrier (dst already waits on dst). We cannot - // operate on dst's copy until the copy is complete. - - // Still on srcDev, record default stream event - cudaEvent_t srcReady; - THCudaCheck(cudaEventCreateWithFlags(&srcReady, cudaEventDisableTiming)); - THCudaCheck(cudaEventRecord(srcReady, NULL)); - - THCudaCheck(cudaSetDevice(dstDev)); - THCudaCheck(cudaStreamWaitEvent(NULL, srcReady, 0)); - THCudaCheck(cudaEventDestroy(srcReady)); - - // We are now on dstDev (right above). Restore prior device from dst - if (dstDev != oldDev) { - THCudaCheck(cudaSetDevice(oldDev)); - } - } else { - // We are still on srcDev. Restore prior device from src - if (srcDev != oldDev) { - THCudaCheck(cudaSetDevice(oldDev)); - } - } - - THCudaCheck(cudaGetLastError()); -} - #include "generic/THCTensorCopy.cu" #include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/generic/THCStorageCopy.cpp b/aten/src/THC/generic/THCStorageCopy.cpp index 397cb7c..a8f236a 100644 --- a/aten/src/THC/generic/THCStorageCopy.cpp +++ b/aten/src/THC/generic/THCStorageCopy.cpp @@ -21,7 +21,7 @@ void THCStorage_(copy##TYPEC)(THCState *state, THCStorage *self, struct TH##TYPE THCTensor_(newWithStorage1d)(state, self, 0, self->numel(), 1); \ struct TH##TYPEC##Tensor* srcTensor = \ TH##TYPEC##Tensor_newWithStorage1d(src, 0, src->numel(), 1); \ - THCTensor_(copy##TYPEC)(state, selfTensor, srcTensor); \ + THCTensor_(copy)(state, selfTensor, srcTensor); \ TH##TYPEC##Tensor_free(srcTensor); \ THCTensor_(free)(state, selfTensor); \ } @@ -53,7 +53,7 @@ void TH_CONCAT_4(TH,TYPEC,Storage_copyCuda,Real)(THCState *state, TH##TYPEC##Sto TH##TYPEC##Tensor_newWithStorage1d(self, 0, self->numel(), 1); \ struct THCTensor* srcTensor = \ THCTensor_(newWithStorage1d)(state, src, 0, src->numel(), 1); \ - TH_CONCAT_4(TH,TYPEC,Tensor_copyCuda,Real)(state, selfTensor, srcTensor); \ + THCTensor_(copy)(state, selfTensor, srcTensor); \ THCTensor_(free)(state, srcTensor); \ TH##TYPEC##Tensor_free(selfTensor); \ } diff --git a/aten/src/THC/generic/THCStorageCopy.cu b/aten/src/THC/generic/THCStorageCopy.cu index 74ff1e1..ba5a477 100644 --- a/aten/src/THC/generic/THCStorageCopy.cu +++ b/aten/src/THC/generic/THCStorageCopy.cu @@ -15,7 +15,7 @@ void THCStorage_(copyCuda##TYPEC)(THCState *state, THCStorage *self, struct THCu THCTensor* selfTensor = THCTensor_(newWithStorage1d)(state, self, 0, self->numel(), 1); \ struct THCuda##TYPECUDA##Tensor* srcTensor = \ THCuda##TYPECUDA##Tensor_newWithStorage1d(state, src, 0, src->numel(), 1); \ - THCTensor_(copyCuda##TYPEC)(state, selfTensor, srcTensor); \ + THCTensor_(copy)(state, selfTensor, srcTensor); \ THCuda##TYPECUDA##Tensor_free(state, srcTensor); \ THCTensor_(free)(state, selfTensor); \ } diff --git a/aten/src/THC/generic/THCTensorCopy.cpp b/aten/src/THC/generic/THCTensorCopy.cpp index 0d3e867..2ad8769 100644 --- a/aten/src/THC/generic/THCTensorCopy.cpp +++ b/aten/src/THC/generic/THCTensorCopy.cpp @@ -2,128 +2,6 @@ #define THC_GENERIC_FILE "generic/THCTensorCopy.cpp" #else -/* specific methods */ - -void THCTensor_(copyCPU)(THCState *state, THCTensor *self, struct THTensor *src) -{ - THArgCheck(THCTensor_(nElement)(state, self) == THTensor_(nElement)(src), 2, "sizes do not match"); - - { - THCTensor *selfc = THCTensor_(newContiguous)(state, self); - src = THTensor_(newContiguous)(src); - - cudaStream_t stream = THCState_getCurrentStream(state); - THCudaCheck(cudaMemcpyAsync(THCTensor_(data)(state,selfc), - src->data(), - THTensor_(nElement)(src) * sizeof(scalar_t), - cudaMemcpyHostToDevice, - stream)); - THCudaCheck(cudaStreamSynchronize(stream)); - - c10::raw::intrusive_ptr::decref(src); - THCTensor_(freeCopyTo)(state, selfc, self); - } -} - -#define IMPLEMENT_TH_CUDA_TENSOR_COPY(TYPEC) \ - void THCTensor_(copy##TYPEC)( \ - THCState * state, THCTensor * self, struct TH##TYPEC##Tensor * src) { \ - THArgCheck( \ - THCTensor_(nElement)(state, self) == TH##TYPEC##Tensor_nElement(src), \ - 2, \ - "sizes do not match"); \ - if (THCTypeIdx_(Real) == THCTypeIdx_(TYPEC)) { \ - THCTensor_(copyCPU)( \ - state, self, (THTensor*)src); /* cast just removes warnings */ \ - } else { \ - at::Tensor srcf_wrap = \ - at::empty(src->sizes(), caffe2::TypeMeta::Make()); \ - at::Tensor src_wrap = THTensor_wrap(src); \ - \ - at::_copy_(srcf_wrap, src_wrap); \ - THCTensor_(copyCPU)(state, self, srcf_wrap.unsafeGetTensorImpl()); \ - } \ - } - -IMPLEMENT_TH_CUDA_TENSOR_COPY(Byte) -IMPLEMENT_TH_CUDA_TENSOR_COPY(Char) -IMPLEMENT_TH_CUDA_TENSOR_COPY(Short) -IMPLEMENT_TH_CUDA_TENSOR_COPY(Int) -IMPLEMENT_TH_CUDA_TENSOR_COPY(Long) -IMPLEMENT_TH_CUDA_TENSOR_COPY(Float) -IMPLEMENT_TH_CUDA_TENSOR_COPY(Double) -IMPLEMENT_TH_CUDA_TENSOR_COPY(Half) - -/* copyCuda */ - -void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src) -{ - THArgCheck(THTensor_(nElement)(self) == THCTensor_(nElement)(state, src), 2, "sizes do not match"); - - { - THTensor *selfc = THTensor_(newContiguous)(self); - int tensorDevice = THCTensor_(getDevice)(state, src); - int currentDevice; - THCudaCheck(cudaGetDevice(¤tDevice)); - - if (currentDevice != tensorDevice) { - THCudaCheck(cudaSetDevice(tensorDevice)); - } - src = THCTensor_(newContiguous)(state, src); - - cudaStream_t stream = THCState_getCurrentStream(state); - THCudaCheck(cudaMemcpyAsync(selfc->data(), - THCTensor_(data)(state, src), - THCTensor_(nElement)(state, src) * sizeof(scalar_t), - cudaMemcpyDeviceToHost, - stream)); - THCudaCheck(cudaStreamSynchronize(stream)); - - if (currentDevice != tensorDevice) { - THCudaCheck(cudaSetDevice(currentDevice)); - } - - THCTensor_(free)(state, src); - THTensor_(freeCopyTo)(selfc, self); - } -} - -#define IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(TYPEC) \ - void TH_CONCAT_4(TH, TYPEC, Tensor_copyCuda, Real)( \ - THCState * state, TH##TYPEC##Tensor * self, struct THCTensor * src) { \ - THArgCheck( \ - TH##TYPEC##Tensor_nElement(self) == THCTensor_(nElement)(state, src), \ - 2, \ - "sizes do not match"); \ - if (THCTypeIdx_(Real) == THCTypeIdx_(TYPEC)) { \ - THTensor_(copyCuda)( \ - state, \ - (THTensor*)self, \ - src); /* cast just removes compiler warning */ \ - } else { \ - at::Tensor srcf_wrap = \ - at::empty(src->sizes(), caffe2::TypeMeta::Make()); \ - at::Tensor self_wrap = THTensor_wrap(self); \ - \ - THTensor_(copyCuda)(state, srcf_wrap.unsafeGetTensorImpl(), src); \ - at::_copy_(self_wrap, srcf_wrap); \ - } \ - } - -IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Byte) -IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Char) -IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Short) -IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Int) -IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Long) -IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Float) -IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Double) -IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Half) - -void THCTensor_(copyCuda)(THCState *state, THCTensor *self, THCTensor *src) -{ - THCTensor_(copy)(state, self, src); -} - void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, struct THTensor *src) { THArgCheck(THCTensor_(nElement)(state, self) == THTensor_(nElement)(src), 2, "sizes do not match"); diff --git a/aten/src/THC/generic/THCTensorCopy.cu b/aten/src/THC/generic/THCTensorCopy.cu index 2c05c74..f5d80d2 100644 --- a/aten/src/THC/generic/THCTensorCopy.cu +++ b/aten/src/THC/generic/THCTensorCopy.cu @@ -4,7 +4,9 @@ void THCTensor_(copy)(THCState* state, THCTensor* dst, THCTensor* src) { if (dst == src) return; - THC_copyTensor(state, dst, src); + at::Tensor dst_wrap = THTensor_wrap(dst); + at::Tensor src_wrap = THTensor_wrap(src); + at::_copy_(dst_wrap, src_wrap); } template <> @@ -12,7 +14,9 @@ THCTensor *THCTensor_newClone(THCState *state, THCTensor *self) { THCTensor* tensor = THCTensor_new(state, THTensor_getStoragePtr(self)->dtype()); THCTensor_resizeAs(state, tensor, self); - THC_copyTensor(state, tensor, self); + at::Tensor tensor_wrap = THTensor_wrap(tensor); + at::Tensor self_wrap = THTensor_wrap(self); + at::_copy_(tensor_wrap, self_wrap); return tensor; } @@ -30,8 +34,11 @@ THCTensor *THCTensor_newContiguous(THCState *state, THCTensor *self) template <> void THCTensor_freeCopyTo(THCState *state, THCTensor *self, THCTensor *dst) { - if(self != dst) - THC_copyTensor(state, dst, self); + if(self != dst) { + at::Tensor dst_wrap = THTensor_wrap(dst); + at::Tensor self_wrap = THTensor_wrap(self); + at::_copy_(dst_wrap, self_wrap); + } THCTensor_free(state, self); } @@ -54,23 +61,4 @@ void THCTensor_(copyIgnoringOverlaps)(THCState* state, THCTensor* dst, THCTensor THCTensor_copyIgnoringOverlaps(state, dst, src); } -#define IMPLEMENT_THC_CUDA_TENSOR_COPY(TYPEC, TYPECUDA, SCALARC) \ - void THCTensor_(copyCuda##TYPEC)(THCState *state, \ - THCTensor *self, \ - THCuda##TYPECUDA##Tensor *src) { \ - THC_copyTensor(state, self, src); \ - } - -IMPLEMENT_THC_CUDA_TENSOR_COPY(Byte, Byte, uint8_t) -IMPLEMENT_THC_CUDA_TENSOR_COPY(Char, Char, int8_t) -IMPLEMENT_THC_CUDA_TENSOR_COPY(Short, Short, int16_t) -IMPLEMENT_THC_CUDA_TENSOR_COPY(Int, Int, int32_t) -IMPLEMENT_THC_CUDA_TENSOR_COPY(Long, Long, int64_t) -// THCudaTensor aka the non-existent THCudaFloatTensor -IMPLEMENT_THC_CUDA_TENSOR_COPY(Float, , float) -IMPLEMENT_THC_CUDA_TENSOR_COPY(Double, Double, double) -IMPLEMENT_THC_CUDA_TENSOR_COPY(Half, Half, at::Half) - -#undef IMPLEMENT_THC_CUDA_TENSOR_COPY - #endif diff --git a/aten/src/THC/generic/THCTensorCopy.h b/aten/src/THC/generic/THCTensorCopy.h index 3004359..e793322 100644 --- a/aten/src/THC/generic/THCTensorCopy.h +++ b/aten/src/THC/generic/THCTensorCopy.h @@ -4,36 +4,6 @@ THC_API void THCTensor_(copy)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(copyIgnoringOverlaps)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(copyByte)(THCState *state, THCTensor *self, THByteTensor *src); -THC_API void THCTensor_(copyChar)(THCState *state, THCTensor *self, THCharTensor *src); -THC_API void THCTensor_(copyShort)(THCState *state, THCTensor *self, THShortTensor *src); -THC_API void THCTensor_(copyInt)(THCState *state, THCTensor *self, THIntTensor *src); -THC_API void THCTensor_(copyLong)(THCState *state, THCTensor *self, THLongTensor *src); -THC_API void THCTensor_(copyFloat)(THCState *state, THCTensor *self, THFloatTensor *src); -THC_API void THCTensor_(copyDouble)(THCState *state, THCTensor *self, THDoubleTensor *src); -THC_API void THCTensor_(copyHalf)(THCState *state, THCTensor *self, struct THHalfTensor *src); - -THC_API void THCTensor_(copyCudaByte)(THCState *state, THCTensor *dst, struct THCudaByteTensor *src); -THC_API void THCTensor_(copyCudaChar)(THCState *state, THCTensor *dst, struct THCudaCharTensor *src); -THC_API void THCTensor_(copyCudaShort)(THCState *state, THCTensor *dst, struct THCudaShortTensor *src); -THC_API void THCTensor_(copyCudaInt)(THCState *state, THCTensor *dst, struct THCudaIntTensor *src); -THC_API void THCTensor_(copyCudaLong)(THCState *state, THCTensor *dst, struct THCudaLongTensor *src); -THC_API void THCTensor_(copyCudaFloat)(THCState *state, THCTensor *dst, struct THCudaTensor *src); -THC_API void THCTensor_(copyCudaDouble)(THCState *state, THCTensor *dst, struct THCudaDoubleTensor *src); -THC_API void THCTensor_(copyCudaHalf)(THCState *state, THCTensor *dst, struct THCudaHalfTensor *src); - -THC_API void TH_CONCAT_2(THByteTensor_copyCuda , Real) (THCState *state, THByteTensor *self, THCTensor *src); -THC_API void TH_CONCAT_2(THCharTensor_copyCuda , Real) (THCState *state, THCharTensor *self, THCTensor *src); -THC_API void TH_CONCAT_2(THShortTensor_copyCuda , Real) (THCState *state, THShortTensor *self, THCTensor *src); -THC_API void TH_CONCAT_2(THIntTensor_copyCuda , Real) (THCState *state, THIntTensor *self, THCTensor *src); -THC_API void TH_CONCAT_2(THLongTensor_copyCuda , Real) (THCState *state, THLongTensor *self, THCTensor *src); -THC_API void TH_CONCAT_2(THFloatTensor_copyCuda , Real) (THCState *state, THFloatTensor *self, THCTensor *src); -THC_API void TH_CONCAT_2(THDoubleTensor_copyCuda, Real) (THCState *state, THDoubleTensor *self, THCTensor *src); -THC_API void TH_CONCAT_2(THHalfTensor_copyCuda, Real) (THCState *state, THHalfTensor *self, THCTensor *src); -THC_API void THCTensor_(copyCuda) (THCState *state, THCTensor *self, THCTensor *src); - -THC_API void THTensor_(copyCuda) (THCState *state, THTensor *self, THCTensor *src); -THC_API void THCTensor_(copyCPU) (THCState *state, THCTensor *self, THTensor *src); THC_API void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, THTensor *src); THC_API void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, THCTensor *src); diff --git a/aten/src/THC/generic/THCTensorMasked.cu b/aten/src/THC/generic/THCTensorMasked.cu index 684ce31..86647c0 100644 --- a/aten/src/THC/generic/THCTensorMasked.cu +++ b/aten/src/THC/generic/THCTensorMasked.cu @@ -24,7 +24,7 @@ void THCTensor_(maskedFillByte)(THCState* state, { THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, tensor)); THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, mask->sizes(), {}); - THCudaByteTensor_copyByte(state, maskCuda, mask); + THCTensor_(copy)(state, maskCuda, mask); THCTensor_(maskedFill)(state, tensor, maskCuda, value); THCudaByteTensor_free(state, maskCuda); } @@ -56,7 +56,7 @@ void THCTensor_(maskedCopy)(THCState* state, THCudaLongTensor* maskLong = THCudaLongTensor_new(state); at::IntList maskSizes = mask->sizes(); THCudaLongTensor_resize(state, maskLong, maskSizes, {}); - THCudaLongTensor_copyCudaByte(state, maskLong, mask); + THCTensor_(copy)(state, maskLong, mask); // Use a prefix sum to determine the output locations of the masked elements THCudaLongTensor* maskPrefixSum = THCudaLongTensor_new(state); @@ -99,7 +99,7 @@ void THCTensor_(maskedCopyByte)(THCState* state, THCTensor *tensor, THByteTensor *mask, THCTensor *src) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src)); THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, mask->sizes(), {}); - THCudaByteTensor_copyByte(state, maskCuda, mask); + THCTensor_(copy)(state, maskCuda, mask); THCTensor_(maskedCopy)(state, tensor, maskCuda, src); THCudaByteTensor_free(state, maskCuda); } @@ -126,7 +126,7 @@ void THCTensor_(maskedSelect)(THCState* state, THCudaLongTensor* maskLong = THCudaLongTensor_new(state); at::IntList maskSizes = mask->sizes(); THCudaLongTensor_resize(state, maskLong, maskSizes, {}); - THCudaLongTensor_copyCudaByte(state, maskLong, mask); + THCTensor_(copy)(state, maskLong, mask); // Use a prefix sum to determine the output locations of the masked elements THCudaLongTensor* maskPrefixSum = THCudaLongTensor_new(state); @@ -171,7 +171,7 @@ void THCTensor_(maskedSelectByte)(THCState* state, { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src)); THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, mask->sizes(), {}); - THCudaByteTensor_copyByte(state, maskCuda, mask); + THCTensor_(copy)(state, maskCuda, mask); THCTensor_(maskedSelect)(state, tensor, src, maskCuda); THCudaByteTensor_free(state, maskCuda); } diff --git a/aten/src/THCUNN/generic/SparseLinear.cu b/aten/src/THCUNN/generic/SparseLinear.cu index 04025b9..f04d693 100644 --- a/aten/src/THCUNN/generic/SparseLinear.cu +++ b/aten/src/THCUNN/generic/SparseLinear.cu @@ -18,13 +18,7 @@ static bool THNN_(checkSize1D)(THCTensor* t, int64_t size0) } static inline void THNN_(copyCudaFloatingType)(THCState *state, THCudaIntTensor *buf, THCTensor *t) { - #ifdef THC_REAL_IS_FLOAT - THCudaIntTensor_copyCudaFloat(state, buf, t); - #elif defined(THC_REAL_IS_DOUBLE) - THCudaIntTensor_copyCudaDouble(state, buf, t); - #elif defined(THC_REAL_IS_HALF) - THCudaIntTensor_copyCudaHalf(state, buf, t); - #endif + THCTensor_(copy)(state, buf, t); } void THNN_(SparseLinear_updateOutput)( @@ -71,7 +65,7 @@ void THNN_(SparseLinear_updateOutput)( THCTensor_(select)(state, sel, input, 1, 1); THNN_(copyCudaFloatingType)(state, colInds, sel); THCTensor_(select)(state, sel, input, 1, 2); - THCTensor_(copyCuda)(state, values, sel); + THCTensor_(copy)(state, values, sel); init_cusparse(); cusparseXcoo2csr(cusparse_handle, @@ -171,7 +165,7 @@ void THNN_(SparseLinear_accGradParameters)( THCTensor_(select)(state, sel, buf, 1, 1); THNN_(copyCudaFloatingType)(state, colbuf, sel); THCTensor_(select)(state, sel, buf, 1, 2); - THCTensor_(copyCuda)(state, values, sel); + THCTensor_(copy)(state, values, sel); init_cusparse(); // Secretly coo2csc -- 2.7.4