#include <bitset>
#include "ATen/Utils.h"
+#include "ATen/native/Copy.h"
#include <c10/util/C++17.h>
#if defined(__GNUC__)
void convert(const src_T *src, dst_T *dst, int64_t n) {
#pragma unroll
for (int64_t i = 0; i < n; i++) {
- *dst = static_cast<dst_T>(*src);
+ *dst = static_cast<dst_T>(
+ static_cast<at::native::inter_copy_type_t<dst_T>>(*src));
src++;
dst++;
}
return _mm256_xor_pd(a, b);
}
+template <>
+void convert(const double* src, double* dst, int64_t n) {
+ int64_t i;
+#pragma unroll
+ for (i = 0; i <= (n - Vec256<double>::size); i += Vec256<double>::size) {
+ _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
+ }
+#pragma unroll
+ for (; i < n; i++) {
+ dst[i] = src[i];
+ }
+}
+
#ifdef __AVX2__
template <>
Vec256<double> inline fmadd(const Vec256<double>& a, const Vec256<double>& b, const Vec256<double>& c) {
return _mm256_xor_ps(a, b);
}
+template <>
+void convert(const float* src, float* dst, int64_t n) {
+ int64_t i;
+#pragma unroll
+ for (i = 0; i <= (n - Vec256<float>::size); i += Vec256<float>::size) {
+ _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i));
+ }
+#pragma unroll
+ for (; i < n; i++) {
+ dst[i] = src[i];
+ }
+}
+
#ifdef __AVX2__
template <>
Vec256<float> inline fmadd(const Vec256<float>& a, const Vec256<float>& b, const Vec256<float>& c) {
-#include "Copy.h"
+#include "ATen/native/Copy.h"
#include "ATen/ATen.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/NativeFunctions.h"
+#include "ATen/native/cpu/CopyKernel.h"
namespace {
template <typename self_T>
void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- src.type(), "_copy__cpu", [&]() { _copy__cpu<self_T, scalar_t>(self, src); });
+ AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cpu", [&]() {
+ _copy__cpu<self_T, scalar_t>(self, src);
+ });
+}
+
+bool copy_transpose_valid(const at::Tensor& self, const at::Tensor& src) {
+ const int MIN_SZ = 60 * 60;
+ return self.is_contiguous() && src.numel() != 0 && src.dim() == 2 &&
+ src.stride(0) == 1 && src.stride(1) == src.size(0) &&
+ self.numel() >= MIN_SZ;
}
} // namespace
return self;
}
+// special case copy where tensor is contiguous and src is a transposed matrix
+// This can be generalized to most copies, but it's tricker
+void _copy_same_type_transpose_(Tensor& self, const Tensor& src) {
+ int64_t BLOCK_SZ;
+ if (self.scalar_type() == kByte) {
+ BLOCK_SZ = 120;
+ } else {
+ BLOCK_SZ = 60;
+ }
+ Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options());
+
+ AT_DISPATCH_ALL_TYPES_AND_HALF(
+ self.type(), "_copy_same_type_transpose_", [&]() {
+ scalar_t* sp = src.data<scalar_t>();
+ scalar_t* rp = self.data<scalar_t>();
+ scalar_t* bp = buf.data<scalar_t>();
+
+ int64_t NR = src.size(0);
+ int64_t NC = src.size(1);
+ for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
+ for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
+ scalar_t* spo = sp + R + C * NR;
+ scalar_t* rpo = rp + C + R * NC;
+
+ int nr = std::min(NR - R, BLOCK_SZ);
+ int nc = std::min(NC - C, BLOCK_SZ);
+
+ // 1. copy columns from src to buf
+ for (int c = 0; c < nc; c++) {
+ memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t));
+ }
+
+ // 2. transpose buf in place
+ int rc_max = std::max(nr, nc);
+ int rc_min = std::min(nr, nc);
+ for (int r = 0; r < rc_max; r++) {
+ int end = std::min(r, rc_min);
+ for (int c = 0; c < end; c++) {
+ scalar_t tmp = bp[r + BLOCK_SZ * c];
+ bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
+ bp[r * BLOCK_SZ + c] = tmp;
+ }
+ }
+
+ // 3. copy rows from buf to dst
+ for (int r = 0; r < nr; r++) {
+ memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(scalar_t));
+ }
+ }
+ }
+ });
+}
+
+void _copy_same_type_(Tensor& self, const Tensor& src) {
+ if (self.is_same(src)) {
+ return;
+ }
+
+ bool serial_path = false;
+ if (self.numel() == src.numel()) {
+ if (self.is_contiguous() && src.is_contiguous()) {
+ copy_kernel(kCPU, self, src);
+ } else if (copy_transpose_valid(self, src)) {
+ _copy_same_type_transpose_(self, src);
+ } else {
+#ifdef _OPENMP
+ if (!in_parallel_region()) {
+ AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() {
+ at::CPU_tensor_parallel_apply2<scalar_t, scalar_t>(
+ self, src, [](scalar_t& self_val, const scalar_t& src_val) {
+ self_val = src_val;
+ });
+ });
+ } else {
+ serial_path = true;
+ }
+#else
+ serial_path = true;
+#endif
+ }
+ } else {
+ serial_path = true;
+ }
+
+ if (serial_path) {
+ AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() {
+ at::CPU_tensor_apply2<scalar_t, scalar_t>(
+ self, src, [](scalar_t& self_val, const scalar_t& src_val) {
+ self_val = src_val;
+ });
+ });
+ }
+}
+
+DEFINE_DISPATCH(copy_kernel);
+
} // namespace native
} // namespace at
template <typename T>
using inter_copy_type_t = typename inter_copy_type<T>::type;
+void _copy_same_type_(Tensor& self, const Tensor& src);
+
} // namespace native
} // namespace at
--- /dev/null
+#include <ATen/native/cpu/CopyKernel.h>
+
+#include <ATen/ATen.h>
+#include <ATen/CPUApplyUtils.h>
+#include <ATen/Dispatch.h>
+#include <ATen/cpu/vec256/vec256.h>
+#include <ATen/native/Copy.h>
+
+namespace at {
+namespace native {
+namespace {
+
+// TODO: this number was copied from TH, test to see if it's the right number
+constexpr int64_t COPY_GRAIN_SIZE = 20000;
+
+static void copy_kernel_impl(Tensor& dst, const Tensor& src) {
+ AT_DISPATCH_ALL_TYPES_AND_HALF(dst.type(), "copy_kernel_impl", [&]() {
+ scalar_t* self_ptr = dst.data<scalar_t>();
+ scalar_t* src_ptr = src.data<scalar_t>();
+
+ auto sample = [&](int64_t begin, int64_t end) {
+ int64_t len = end - begin;
+ scalar_t* self_seg = self_ptr + begin;
+ scalar_t* src_seg = src_ptr + begin;
+ at::vec256::convert<scalar_t, scalar_t>(src_seg, self_seg, len);
+ };
+
+ parallel_for(0, dst.numel(), COPY_GRAIN_SIZE, sample);
+ });
+}
+
+} // anonymous namespace
+
+REGISTER_DISPATCH(copy_kernel, ©_kernel_impl);
+
+} // namespace native
+} // namespace at
--- /dev/null
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/native/DispatchStub.h>
+
+namespace at {
+namespace native {
+
+using forward_fn = void (*)(Tensor&, const Tensor&);
+
+DECLARE_DISPATCH(forward_fn, copy_kernel);
+
+} // namespace native
+} // namespace at
${CMAKE_CURRENT_SOURCE_DIR}/THSize.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THStorageFunctions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THTensor.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/THTensorCopy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THTensorRandom.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THTensorMath.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THTensorMoreMath.cpp
generic/THTensor.hpp
generic/THTensorConv.cpp
generic/THTensorConv.h
- generic/THTensorCopy.cpp
- generic/THTensorCopy.h
generic/THTensorLapack.cpp
generic/THTensorLapack.h
generic/THTensorMath.cpp
#include "generic/THTensor.h"
#include "THGenerateHalfType.h"
-#include "generic/THTensorCopy.h"
-#include "THGenerateAllTypes.h"
-
-#include "generic/THTensorCopy.h"
-#include "THGenerateHalfType.h"
-
/* random numbers */
#include "THRandom.h"
#include "generic/THTensorRandom.h"
+++ /dev/null
-#include "THTensor.hpp"
-#include "THVector.h"
-
-#include <algorithm>
-
-#include "generic/THTensorCopy.cpp"
-#include "THGenerateAllTypes.h"
-
-#include "generic/THTensorCopy.cpp"
-#include "THGenerateHalfType.h"
#else
#include <ATen/InferSize.h>
+#include <ATen/native/Copy.h>
#include <new>
/**** access methods ****/
{
THTensor *tensor = THTensor_(new)();
THTensor_(resizeAs)(tensor, self);
- THTensor_(copy)(tensor, self);
+ at::Tensor tensor_wrap = THTensor_wrap(tensor);
+ at::Tensor self_wrap = THTensor_wrap(self);
+ at::native::_copy_same_type_(tensor_wrap, self_wrap);
return tensor;
}
void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst)
{
- if(self != dst)
- THTensor_(copy)(dst, self);
+ if(self != dst) {
+ at::Tensor dst_wrap = THTensor_wrap(dst);
+ at::Tensor self_wrap = THTensor_wrap(self);
+ at::native::_copy_same_type_(dst_wrap, self_wrap);
+ }
THTensor_(free)(self);
}
+++ /dev/null
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "generic/THTensorCopy.cpp"
-#else
-
-#ifndef _WIN32
-#define PRAGMA(P) _Pragma(#P)
-#else
-#define PRAGMA(P) __pragma(P)
-#endif
-
-#ifdef _OPENMP
-#define TH_OMP_OVERHEAD_THRESHOLD_COPY 20000
-#include <omp.h>
-#endif
-
-int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
- const int MIN_SZ = 60 * 60;
- return THTensor_(isContiguous)(tensor) &&
- !src->is_empty() &&
- THTensor_(nDimensionLegacyNoScalars)(src) == 2 &&
- THTensor_(stride)(src, 0) == 1 &&
- THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
- THTensor_(nElement)(tensor) >= MIN_SZ;
-}
-
-// special case copy where tensor is contiguous and src is a transposed matrix
-// This can be generalized to most copies, but it's tricker
-void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
-
-#ifdef TH_REAL_IS_BYTE
- const int64_t BLOCK_SZ = 120;
-#else
- const int64_t BLOCK_SZ = 60;
-#endif
-
- THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ);
- scalar_t *sp = src->data<scalar_t>();
- scalar_t *rp = tensor->data<scalar_t>();
- scalar_t *bp = buf->data<scalar_t>();
-
-
- int64_t NR = THTensor_(size)(src, 0);
- int64_t NC = THTensor_(size)(src, 1);
- for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
- for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
- scalar_t *spo = sp + R + C * NR;
- scalar_t *rpo = rp + C + R * NC;
-
- int nr = std::min(NR - R, BLOCK_SZ);
- int nc = std::min(NC - C, BLOCK_SZ);
-
- // 1. copy columns from src to buf
- for (int c = 0; c < nc; c++) {
- memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t));
- }
-
- // 2. transpose buf in place
- int rc_max = std::max(nr, nc);
- int rc_min = std::min(nr, nc);
- for (int r = 0; r < rc_max; r++) {
- int end = std::min(r, rc_min);
- for (int c = 0; c < end; c++) {
- scalar_t tmp = bp[r + BLOCK_SZ * c];
- bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
- bp[r * BLOCK_SZ + c] = tmp;
- }
- }
-
- // 3. copy rows from buf to dst
- for (int r = 0; r < nr; r++) {
- memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(scalar_t));
- }
- }
- }
- c10::raw::intrusive_ptr::decref(buf);
-}
-
-void THTensor_(copy)(THTensor *tensor, THTensor *src)
-{
- if (tensor == src) return;
- ptrdiff_t tensorSize = THTensor_(nElement)(tensor);
- ptrdiff_t srcSize = THTensor_(nElement)(src);
- int tensorContig = THTensor_(isContiguous)(tensor);
- int srcContig = THTensor_(isContiguous)(src);
-
- int serial_path = 0;
-#ifdef _OPENMP
- int inOMP = omp_in_parallel();
-#endif
- if (tensorSize == srcSize) {
- if ( tensorContig && srcContig) {
- scalar_t *sp = src->data<scalar_t>();
- scalar_t *rp = tensor->data<scalar_t>();
-#ifndef TH_REAL_IS_HALF
-#ifdef _OPENMP
- #pragma omp parallel if ( (tensorSize > TH_OMP_OVERHEAD_THRESHOLD_COPY) && (!inOMP) )
- {
- size_t num_threads = omp_get_num_threads();
- size_t tid = omp_get_thread_num();
- ptrdiff_t offset = tid * (tensorSize / num_threads);
- ptrdiff_t end = (tid == num_threads - 1) ? tensorSize : offset + tensorSize / num_threads;
- ptrdiff_t len = end - offset;
- scalar_t *tensorData = rp + offset;
- scalar_t *srcData = sp + offset;
- THVector_(copy)(tensorData, srcData, len);
- }
-#else
- THVector_(copy)(rp, sp, srcSize);
-#endif
-
-#else
-
-#ifdef _OPENMP
- if ((srcSize > TH_OMP_OVERHEAD_THRESHOLD_COPY) && (!inOMP)) {
- ptrdiff_t i;
- #pragma omp parallel for private (i)
- for(i=0; i<srcSize; i++){
- rp[i] = sp[i];
- }
- } else {
- memcpy(rp, sp, srcSize * sizeof(scalar_t));
- }
-#else
- memcpy(rp, sp, srcSize * sizeof(scalar_t));
-#endif
-
-#endif
-
-#ifndef TH_REAL_IS_HALF
- } else if (THTensor_(copyTransposeValid)(tensor, src)) {
- THTensor_(copyTranspose)(tensor, src);
-#endif
- } else {
-#ifdef _OPENMP
- if (inOMP) {
- serial_path = 1;
- } else {
- TH_TENSOR_APPLY2_OMP(srcSize, tensorContig, srcContig, scalar_t, tensor, scalar_t, src, *tensor_data = *src_data;, TH_OMP_OVERHEAD_THRESHOLD_COPY)
- }
-#else
- serial_path = 1;
-#endif
- }
- } else {
- serial_path = 1;
- }
-
- if (serial_path) {
- TH_TENSOR_APPLY2(scalar_t, tensor, scalar_t, src, *tensor_data = *src_data;)
- }
-}
-
-#endif
+++ /dev/null
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "generic/THTensorCopy.h"
-#else
-
-/* Support for copy between different Tensor types */
-
-TH_API void THTensor_(copy)(THTensor *tensor, THTensor *src);
-
-#endif
#else
#include <TH/generic/THTensorApply.hpp>
+#include <ATen/native/Copy.h>
void THTensor_(fill)(THTensor *r_, scalar_t value)
{
sSlice = THTensor_(new)();
THTensor_(select)(tSlice, tensor, dim, i);
THTensor_(select)(sSlice, src, dim, index_data[i] - TH_INDEX_BASE);
- THTensor_(copy)(tSlice, sSlice);
+ at::Tensor tSlice_wrap = THTensor_wrap(tSlice);
+ at::Tensor sSlice_wrap = THTensor_wrap(sSlice);
+ at::native::_copy_same_type_(tSlice_wrap, sSlice_wrap);
c10::raw::intrusive_ptr::decref(tSlice);
c10::raw::intrusive_ptr::decref(sSlice);
}
{
THTensor_(select)(tSlice, tensor, dim, index_data[i] - TH_INDEX_BASE);
THTensor_(select)(sSlice, src, dim, i);
- THTensor_(copy)(tSlice, sSlice);
+ at::Tensor tSlice_wrap = THTensor_wrap(tSlice);
+ at::Tensor sSlice_wrap = THTensor_wrap(sSlice);
+ at::native::_copy_same_type_(tSlice_wrap, sSlice_wrap);
}
c10::raw::intrusive_ptr::decref(tSlice);
#define TH_GENERIC_FILE "generic/THTensorLapack.cpp"
#else
+#include <ATen/native/Copy.h>
+
/*
Check if self is transpose of a contiguous matrix
*/
THTensor_(resize2d)(result, src->size(1), nrows);
THTensor_(checkTransposed)(result);
- if (src->size(0) == nrows)
- THTensor_(copy)(result, src);
+ if (src->size(0) == nrows) {
+ at::Tensor result_wrap = THTensor_wrap(result);
+ at::Tensor src_wrap = THTensor_wrap(src);
+ at::native::_copy_same_type_(result_wrap, src_wrap);
+ }
else
{
view = THTensor_(newNarrow)(result, 0, 0, src->size(0));
- THTensor_(copy)(view, src);
+ at::Tensor view_wrap = THTensor_wrap(view);
+ at::Tensor src_wrap = THTensor_wrap(src);
+ at::native::_copy_same_type_(view_wrap, src_wrap);
c10::raw::intrusive_ptr::decref(view);
}
return result;
THTensor_(narrow)(rvf_,NULL,1,0,k);
THTensor_(resizeAs)(rv_, rvf_);
- THTensor_(copy)(rv_, rvf_);
+ at::Tensor rv__wrap = THTensor_wrap(rv_);
+ at::Tensor rvf__wrap = THTensor_wrap(rvf_);
+ at::native::_copy_same_type_(rv__wrap, rvf__wrap);
c10::raw::intrusive_ptr::decref(rvf_);
} else {
THTensor_(zero)(ru_);
if (ra_ != a) {
THTensor_(resizeAs)(ra_, a);
- THTensor_(copy)(ra_, a);
+ at::Tensor ra__wrap = THTensor_wrap(ra_);
+ at::Tensor a_wrap = THTensor_wrap(a);
+ at::native::_copy_same_type_(ra__wrap, a_wrap);
}
int m = a->size(1);
if (rb_ != b) {
THTensor_(resizeAs)(rb_, b);
- THTensor_(copy)(rb_, b);
+ at::Tensor rb__wrap = THTensor_wrap(rb_);
+ at::Tensor b_wrap = THTensor_wrap(b);
+ at::native::_copy_same_type_(rb__wrap, b_wrap);
}
int64_t num_batches = atf->size(0);
#else
#include <TH/generic/THTensorApply.hpp>
+#include <ATen/native/Copy.h>
// HEY YOU!
//
void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value)
{
THTensor_(resizeAs)(r_, t);
- if(value == 1){
- THTensor_(copy)(r_, t);
+ if(value == 1) {
+ at::Tensor r__wrap = THTensor_wrap(r_);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(r__wrap, t_wrap);
}
else if(value == 2){
THTensor_(cmul)(r_, t, t);
if(r_ != t)
{
THTensor_(resizeAs)(r_, t);
- THTensor_(copy)(r_, t);
+ at::Tensor r__wrap = THTensor_wrap(r_);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(r__wrap, t_wrap);
}
int64_t r_Size = THTensor_(nElement)(r_);
int64_t src1Size = THTensor_(nElement)(src1);
if(r_ != t)
{
THTensor_(resizeAs)(r_, t);
- THTensor_(copy)(r_, t);
+ at::Tensor r__wrap = THTensor_wrap(r_);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(r__wrap, t_wrap);
}
int64_t r_Size = THTensor_(nElement)(r_);
int64_t src1Size = THTensor_(nElement)(src1);
if(r_ != t)
{
THTensor_(resizeAs)(r_, t);
- THTensor_(copy)(r_, t);
+ at::Tensor r__wrap = THTensor_wrap(r_);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(r__wrap, t_wrap);
}
auto r_stride = THTensor_strideLegacyNoScalars(r_, 0);
{
THTensor_(resizeAs)(r_, t);
if (beta != 0.0) {
- THTensor_(copy)(r_, t);
+ at::Tensor r__wrap = THTensor_wrap(r_);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(r__wrap, t_wrap);
}
}
if(r_ != t)
{
THTensor_(resizeAs)(r_, t);
- THTensor_(copy)(r_, t);
+ at::Tensor r__wrap = THTensor_wrap(r_);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(r__wrap, t_wrap);
}
if(beta == 0) {
if (t != result) {
THTensor_(resizeAs)(result, t);
if (beta != 0.0) {
- THTensor_(copy)(result, t);
+ at::Tensor result_wrap = THTensor_wrap(result);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(result_wrap, t_wrap);
}
}
#include <TH/generic/THTensorApply.hpp>
#include <TH/THGenerator.hpp>
+#include <ATen/native/Copy.h>
void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2)
{
if (t != result) {
THTensor_(resizeAs)(result, t);
if (beta != 0.0) {
- THTensor_(copy)(result, t);
+ at::Tensor result_wrap = THTensor_wrap(result);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(result_wrap, t_wrap);
}
}
} else {
if (THTensor_(nDimensionLegacyAll)(t) > 1) {
THTensor *t0 = THTensor_(newSelect)(t, dimension, 0);
- THTensor_(copy)(values_, t0);
+ at::Tensor values__wrap = THTensor_wrap(values_);
+ at::Tensor t0_wrap = THTensor_wrap(t0);
+ at::native::_copy_same_type_(values__wrap, t0_wrap);
c10::raw::intrusive_ptr::decref(t0);
} else {
THTensor_(fill)(values_, THTensor_(get1d)(t, 0));
} else {
if (THTensor_(nDimensionLegacyAll)(t) > 1) {
THTensor *t0 = THTensor_(newSelect)(t, dimension, 0);
- THTensor_(copy)(values_, t0);
+ at::Tensor values__wrap = THTensor_wrap(values_);
+ at::Tensor t0_wrap = THTensor_wrap(t0);
+ at::native::_copy_same_type_(values__wrap, t0_wrap);
c10::raw::intrusive_ptr::decref(t0);
} else {
THTensor_(fill)(values_, THTensor_(get1d)(t, 0));
dimension + TH_INDEX_BASE);
THTensor_(resizeAs)(rt_, t);
- THTensor_(copy)(rt_, t);
+ at::Tensor rt__wrap = THTensor_wrap(rt_);
+ at::Tensor t_wrap = THTensor_wrap(t);
+ at::native::_copy_same_type_(rt__wrap, t_wrap);
THLongTensor_resize(ri_, t->sizes(), {});
if(descendingOrder)
int64_t dimSize = inputs[j]->size(dimension);
THTensor *nt = THTensor_(newWithTensor)(result);
THTensor_(narrow)(nt, NULL, dimension, offset, dimSize);
- THTensor_(copy)(nt, inputs[j]);
+ at::Tensor nt__wrap = THTensor_wrap(nt);
+ at::Tensor inputs_wrap = THTensor_wrap(inputs[j]);
+ at::native::_copy_same_type_(nt__wrap, inputs_wrap);
c10::raw::intrusive_ptr::decref(nt);
offset += dimSize;
}
)
}
else
- THTensor_(copy)(rowR, rowS);
+ {
+ at::Tensor rowR_wrap = THTensor_wrap(rowR);
+ at::Tensor rowS_wrap = THTensor_wrap(rowS);
+ at::native::_copy_same_type_(rowR_wrap, rowS_wrap);
+ }
}
c10::raw::intrusive_ptr::decref(rowR);
TH_API void THVector_(muls)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n);
TH_API void THVector_(cdiv)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n);
TH_API void THVector_(divs)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n);
-TH_API void THVector_(copy)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
TH_API void THVector_(neg)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
TH_API void THVector_(normal_fill)(scalar_t *data,
const int64_t size,
THVector_(divs_DISPATCHPTR)(y, x, c, n);
}
-static void (*THVector_(copy_DISPATCHPTR))(scalar_t *, const scalar_t *, const ptrdiff_t) = &THVector_(copy_DEFAULT);
-static FunctionDescription THVector_(copy_DISPATCHTABLE)[] = {
- #if defined(USE_AVX)
- #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
- FUNCTION_IMPL(THVector_(copy_AVX), SIMDExtension_AVX),
- #endif
- #endif
-
- FUNCTION_IMPL(THVector_(copy_DEFAULT), SIMDExtension_DEFAULT)
-};
-void THVector_(copy)(scalar_t *y, const scalar_t *x, const ptrdiff_t n) {
- THVector_(copy_DISPATCHPTR)(y, x, n);
-}
-
static void (*THVector_(normal_fill_DISPATCHPTR))(scalar_t *, const int64_t, THGenerator *, const scalar_t, const scalar_t) = &THVector_(normal_fill_DEFAULT);
static FunctionDescription THVector_(normal_fill_DISPATCHTABLE)[] = {
INIT_DISPATCH_PTR(muls);
INIT_DISPATCH_PTR(cdiv);
INIT_DISPATCH_PTR(divs);
- INIT_DISPATCH_PTR(copy);
INIT_DISPATCH_PTR(normal_fill);
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
#include "AVX.h"
#include <TH/THGeneral.h>
-void THDoubleVector_copy_AVX(double *y, const double *x, const ptrdiff_t n) {
- ptrdiff_t i;
- ptrdiff_t off;
- for (i=0; i<=((n)-8); i+=8) {
- _mm256_storeu_pd(y+i, _mm256_loadu_pd(x+i));
- _mm256_storeu_pd(y+i+4, _mm256_loadu_pd(x+i+4));
- }
- off = (n) - ((n)%8);
- for (i=0; i<((n)%8); i++) {
- y[off+i] = x[off+i];
- }
-}
-
void THDoubleVector_fill_AVX(double *x, const double c, const ptrdiff_t n) {
ptrdiff_t i;
ptrdiff_t off;
}
}
-void THFloatVector_copy_AVX(float *y, const float *x, const ptrdiff_t n) {
- ptrdiff_t i;
- ptrdiff_t off;
- for (i=0; i<=((n)-16); i+=16) {
- _mm256_storeu_ps(y+i, _mm256_loadu_ps(x+i));
- _mm256_storeu_ps(y+i+8, _mm256_loadu_ps(x+i+8));
- }
- off = (n) - ((n)%16);
- for (i=0; i<((n)%16); i++) {
- y[off+i] = x[off+i];
- }
-}
-
void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n) {
ptrdiff_t i;
ptrdiff_t off;
self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t())
self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_())
- # unit test for THTensor_(copyTranspose)
+ # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_big_transpose(self):
t = torch.rand(456, 789)