THCTensor cleanup (#65369)
authorNatalia Gimelshein <ngimel@fb.com>
Tue, 21 Sep 2021 17:26:19 +0000 (10:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 21 Sep 2021 17:28:19 +0000 (10:28 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65369

Reviewed By: bhosmer

Differential Revision: D31071406

Pulled By: ngimel

fbshipit-source-id: bbc3f2781003333641524aeb692b944fd3ad8d7a

aten/src/THC/CMakeLists.txt
aten/src/THC/THCTensor.cpp
aten/src/THC/THCTensor.hpp
aten/src/THC/generic/THCTensor.cpp
aten/src/THC/generic/THCTensor.h
aten/src/THC/generic/THCTensor.hpp [deleted file]

index f030149..11642aa 100644 (file)
@@ -70,7 +70,6 @@ install(FILES
           generic/THCTensor.cpp
           generic/THCTensor.cu
           generic/THCTensor.h
-          generic/THCTensor.hpp
           generic/THCStorageCopy.cpp
           generic/THCStorageCopy.cu
           generic/THCStorageCopy.h
index 9180577..47b3d5d 100644 (file)
 
 #include <ATen/native/cuda/Resize.cuh>
 
-int THCTensor_nDimension(THCState *state, const THCTensor *self) {
-  return THTensor_nDimension(self);
-}
-
-int THCTensor_nDimensionLegacyNoScalars(THCState *state, const THCTensor *self) {
-  return THTensor_nDimensionLegacyNoScalars(self);
-}
-
-int THCTensor_nDimensionLegacyAll(THCState *state, const THCTensor *self) {
-  return THTensor_nDimensionLegacyAll(self);
-}
-
-int64_t THCTensor_size(THCState *state, const THCTensor *self, int dim) {
-  THArgCheck((dim >= 0) && (dim < self->dim()), 2, "out of range");
-  return self->size(dim);
-}
-
-int64_t THCTensor_sizeLegacyNoScalars(THCState *state, const THCTensor *self, int dim) {
-  return THTensor_sizeLegacyNoScalars(self, dim);
-}
-
-
-int64_t THCTensor_stride(THCState *state, const THCTensor *self, int dim) {
-  THArgCheck((dim >= 0) && (dim < self->dim()), 2, "out of range");
-  return self->stride(dim);
-}
-
-int64_t THCTensor_strideLegacyNoScalars(THCState *state, const THCTensor *self, int dim) {
-  return THTensor_strideLegacyNoScalars(self, dim);
-}
-
 THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta) {
   auto scalar_type = at::typeMetaToScalarType(type_meta);
   switch (scalar_type) {
@@ -81,17 +50,6 @@ THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta) {
   }
 }
 
-void THCTensor_resize(THCState *state, THCTensor *self, at::IntArrayRef size, at::IntArrayRef stride) {
-  if(stride.data()) {
-    THArgCheck(stride.size() == size.size(), 3, "invalid stride");
-  }
-
-#ifdef DEBUG
-  THAssert(size.size() <= INT_MAX);
-#endif
-  THCTensor_resizeNd(state, self, size.size(), size.data(), stride.data());
-}
-
 void THCTensor_resizeAs(THCState *state, THCTensor *self, THCTensor *src) {
   int isSame = 0;
   int d;
@@ -123,17 +81,6 @@ void THCTensor_resizeNd(THCState *state, THCTensor *self, int nDimension, const
   at::native::resize_impl_cuda_(self, sizes, strides, /*device_guard=*/false);
 }
 
-void THCTensor_set(THCState *state, THCTensor *self, THCTensor *src)
-{
-  if(self != src)
-    THCTensor_setStorage(state,
-                         self,
-                         THTensor_getStoragePtr(src),
-                         src->storage_offset(),
-                         src->sizes(),
-                         src->strides());
-}
-
 void THCTensor_setStorage(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, at::IntArrayRef size_, at::IntArrayRef stride_)
 {
   c10::raw::intrusive_ptr::incref(storage_);
@@ -141,90 +88,6 @@ void THCTensor_setStorage(THCState *state, THCTensor *self, THCStorage *storage_
                            storageOffset_, size_, stride_);
 }
 
-void THCTensor_squeeze1d(THCState *state, THCTensor *self, THCTensor *src, int dimension)
-{
-  int d;
-
-  if(!src)
-    src = self;
-
-  THArgCheck(dimension < src->dim(), 3, "dimension out of range");
-
-  THCTensor_set(state, self, src);
-
-  if(src->size(dimension) == 1)
-  {
-    at::DimVector newSize(static_cast<size_t>(self->dim() - 1));
-    at::DimVector newStride(static_cast<size_t>(self->dim() - 1));
-    for (d = 0; d < dimension; d++)
-    {
-      newSize[d] = self->size(d);
-      newStride[d] = self->stride(d);
-    }
-
-    for(d = dimension; d < self->dim()-1; d++)
-    {
-      newSize[d] = self->size(d+1);
-      newStride[d] = self->stride(d+1);
-    }
-    self->set_sizes_and_strides(newSize, newStride);
-  }
-}
-
-void THCTensor_unsqueeze1d(THCState *state, THCTensor *self, THCTensor *src, int dimension)
-{
-  int d;
-
-  if(!src)
-    src = self;
-
-  THArgCheck((dimension >= 0) && (dimension <= src->dim()), 3, "dimension out of range");
-
-  THCTensor_set(state, self, src);
-
-  at::DimVector newSize(static_cast<size_t>(/* size */ self->dim()+1));
-  at::DimVector newStride(static_cast<size_t>(/* size */ self->dim()+1));
-
-  for(d = self->dim(); d > dimension; d--)
-  {
-    newSize[d] = self->size(d-1);
-    newStride[d] = self->stride(d-1);
-  }
-  if (dimension < self->dim())
-  {
-    newStride[dimension] = self->size(dimension) * self->stride(dimension);
-  }
-  else
-  {
-    newStride[dimension] = 1;
-  }
-  newSize[dimension] = 1;
-  for(d = dimension - 1; d >= 0; d--)
-  {
-    newSize[d] = self->size(d);
-    newStride[d] = self->stride(d);
-  }
-  self->set_sizes_and_strides(newSize, newStride);
-}
-
-bool THCTensor_allContiguous(THCState *state, THCTensor **inputs, int numInputs) {
-  THAssert(numInputs > 0);
-  for (int i = 0; i < numInputs; ++i) {
-    if (!inputs[i]->is_contiguous()) {
-      return false;
-    }
-  }
-  return true;
-}
-
-ptrdiff_t THCTensor_nElement(THCState *state, const THCTensor *self) {
-  if(THTensor_nDimensionLegacyAll(self) == 0) {
-    return 0;
-  } else {
-    return self->numel();
-  }
-}
-
 // NB: It is INVALID to call this on an UndefinedTensor
 void THCTensor_retain(THCState *state, THCTensor *self) {
   c10::raw::intrusive_ptr::incref(self);
@@ -238,89 +101,3 @@ int THCTensor_getDevice(THCState* state, const THCTensor* tensor) {
   if (!THTensor_getStoragePtr(tensor)) return -1;
   return THCStorage_getDevice(state, THTensor_getStoragePtr(tensor));
 }
-
-bool THCTensor_allSameDevice(THCState* state, THCTensor ** inputs, int numInputs) {
-  THAssert(numInputs > 0);
-  int device = THCTensor_getDevice(state, inputs[0]);
-  for (int i = 1; i < numInputs; ++i) {
-    if (THCTensor_getDevice(state, inputs[i]) != device) {
-      return false;
-    }
-  }
-  return true;
-}
-
-bool THCTensor_canUse32BitIndexMath(THCState* state, const THCTensor* t, ptrdiff_t max_elem) {
-  ptrdiff_t elements = THCTensor_nElement(state, t);
-  if (elements >= max_elem) {
-    return false;
-  }
-  if (t->dim() == 0) {
-    return true;
-  }
-
-  ptrdiff_t offset = 0;
-  ptrdiff_t linearId = elements - 1;
-
-  for (int i = THCTensor_nDimensionLegacyAll(state, t) - 1; i >= 0; --i) {
-    ptrdiff_t curDimIndex =
-      linearId % THCTensor_size(state, t, i);
-    ptrdiff_t curDimOffset = curDimIndex *
-      THCTensor_stride(state, t, i);
-    offset += curDimOffset;
-    linearId /= THCTensor_size(state, t, i);
-  }
-
-  if (offset >= max_elem) {
-    return false;
-  }
-
-  return true;
-}
-
-bool THCTensor_all32BitIndexable(THCState* state, THCTensor** inputs, int numInputs) {
-  for (int i = 0; i < numInputs; ++i) {
-    if (!THCTensor_canUse32BitIndexMath(state, inputs[i])) {
-      return false;
-    }
-  }
-  return true;
-}
-
-/* Due to the resize semantics of ops with `out=` keywords, if       */ \
-/* the output `tensor` has the same shape as the output of the       */ \
-/* reduction operation, then any noncontiguities in the output       */ \
-/* `tensor` should be preserved. This needs to be special cased b/c  */ \
-/* otherwise, when keepdim=False, the implementations of reduction   */ \
-/* ops resize `tensor` to the reduced size with keepdim=True, and    */ \
-/* then later squeeze `tensor` to the correct output size, breaking  */ \
-/* the contiguity guarantees of the resize semantics.                */ \
-void THCTensor_preserveReduceDimSemantics(THCState *state, THCTensor *tensor,
-                                          int in_dims, int64_t dimension, int keepdim) {
-  int out_dims = THCTensor_nDimensionLegacyAll(state, tensor);
-  if (out_dims > 0 && !keepdim && out_dims == in_dims - 1) {
-    THCTensor_unsqueeze1d(state, tensor, tensor, dimension);
-  }
-}
-
-namespace {
-
-struct SizeAndStride {
-  int64_t size;
-  int64_t stride;
-};
-
-/*
- A comparator that will sort SizeAndStride structs by stride,
- in ascending order.
- */
-int compareSizeAndStride(const void* a, const void* b) {
-  const SizeAndStride* aS = (const SizeAndStride*) a;
-  const SizeAndStride* bS = (const SizeAndStride*) b;
-
-  if (aS->stride < bS->stride) return -1;
-  if (aS->stride == bS->stride) return 0;
-  return 1;
-}
-
-}
index 4be966e..a8f34ec 100644 (file)
@@ -8,40 +8,13 @@
 #include <THC/THCStorage.hpp>
 #include <THC/THCGeneral.hpp>
 
-#include <atomic>
 #include <ATen/ATen.h>
 
-// See [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
-TORCH_CUDA_CU_API int THCTensor_nDimension(
-    THCState* state,
-    const THCTensor* self);
-TORCH_CUDA_CU_API int THCTensor_nDimensionLegacyNoScalars(
-    THCState* state,
-    const THCTensor* self);
-TORCH_CUDA_CU_API int THCTensor_nDimensionLegacyAll(
-    THCState* state,
-    const THCTensor* self);
-
-TORCH_CUDA_CU_API int64_t
-THCTensor_size(THCState* state, const THCTensor* self, int dim);
-TORCH_CUDA_CU_API int64_t
-THCTensor_sizeLegacyNoScalars(THCState* state, const THCTensor* self, int dim);
-TORCH_CUDA_CU_API int64_t
-THCTensor_stride(THCState* state, const THCTensor* self, int dim);
-TORCH_CUDA_CU_API int64_t THCTensor_strideLegacyNoScalars(
-    THCState* state,
-    const THCTensor* self,
-    int dim);
 
 TORCH_CUDA_CU_API THCTensor* THCTensor_new(
     THCState* state,
     caffe2::TypeMeta type_meta);
 
-TORCH_CUDA_CU_API void THCTensor_resize(
-    THCState* state,
-    THCTensor* tensor,
-    at::IntArrayRef size,
-    at::IntArrayRef stride);
 TORCH_CUDA_CU_API void THCTensor_resizeNd(
     THCState* state,
     THCTensor* tensor,
@@ -53,10 +26,6 @@ TORCH_CUDA_CU_API void THCTensor_resizeAs(
     THCTensor* tensor,
     THCTensor* src);
 
-TORCH_CUDA_CU_API void THCTensor_set(
-    THCState* state,
-    THCTensor* self,
-    THCTensor* src);
 TORCH_CUDA_CU_API void THCTensor_setStorage(
     THCState* state,
     THCTensor* self,
@@ -65,60 +34,9 @@ TORCH_CUDA_CU_API void THCTensor_setStorage(
     at::IntArrayRef size_,
     at::IntArrayRef stride_);
 
-TORCH_CUDA_CU_API void THCTensor_squeeze1d(
-    THCState* state,
-    THCTensor* self,
-    THCTensor* src,
-    int dimension_);
-TORCH_CUDA_CU_API void THCTensor_unsqueeze1d(
-    THCState* state,
-    THCTensor* self,
-    THCTensor* src,
-    int dimension_);
-
-TORCH_CUDA_CU_API bool THCTensor_allContiguous(
-    THCState* state,
-    THCTensor** inputs,
-    int numInputs);
-TORCH_CUDA_CU_API ptrdiff_t
-THCTensor_nElement(THCState* state, const THCTensor* self);
-
 TORCH_CUDA_CU_API void THCTensor_retain(THCState* state, THCTensor* self);
 TORCH_CUDA_CU_API void THCTensor_free(THCState* state, THCTensor* self);
 
 TORCH_CUDA_CU_API int THCTensor_getDevice(
     THCState* state,
     const THCTensor* tensor);
-TORCH_CUDA_CU_API bool THCTensor_allSameDevice(
-    THCState* state,
-    THCTensor** inputs,
-    int numInputs);
-
-/* Can we use 32 bit math for indexing? */
-TORCH_CUDA_CU_API bool THCTensor_canUse32BitIndexMath(
-    THCState* state,
-    const THCTensor* t,
-    ptrdiff_t max_elem = INT32_MAX);
-/* Are all tensors 32-bit indexable? */
-TORCH_CUDA_CU_API bool THCTensor_all32BitIndexable(
-    THCState* state,
-    THCTensor** inputs,
-    int numInputs);
-TORCH_CUDA_CU_API void THCTensor_preserveReduceDimSemantics(
-    THCState* state,
-    THCTensor* tensor,
-    int in_dims,
-    int64_t dimension,
-    int keepdim);
-
-#include <THC/generic/THCTensor.hpp>
-#include <THC/THCGenerateAllTypes.h>
-
-#include <THC/generic/THCTensor.hpp>
-#include <THC/THCGenerateComplexTypes.h>
-
-#include <THC/generic/THCTensor.hpp>
-#include <THC/THCGenerateBoolType.h>
-
-#include <THC/generic/THCTensor.hpp>
-#include <THC/THCGenerateBFloat16Type.h>
index 531fa13..212a37d 100644 (file)
@@ -11,54 +11,6 @@ THCStorage *THCTensor_(storage)(THCState *state, const THCTensor *self)
   return THTensor_getStoragePtr(self);
 }
 
-ptrdiff_t THCTensor_(storageOffset)(THCState *state, const THCTensor *self)
-{
-  return self->storage_offset();
-}
-
-int THCTensor_(nDimension)(THCState *state, const THCTensor *self)
-{
-  return THCTensor_nDimension(state, self);
-}
-
-int THCTensor_(nDimensionLegacyNoScalars)(THCState *state, const THCTensor *self)
-{
-  return THCTensor_nDimensionLegacyNoScalars(state, self);
-}
-
-int THCTensor_(nDimensionLegacyAll)(THCState *state, const THCTensor *self)
-{
-  return THCTensor_nDimensionLegacyAll(state, self);
-}
-
-int64_t THCTensor_(size)(THCState *state, const THCTensor *self, int dim)
-{
-  return THCTensor_size(state, self, dim);
-}
-
-int64_t THCTensor_(sizeLegacyNoScalars)(THCState *state, const THCTensor *self, int dim)
-{
-  return THTensor_sizeLegacyNoScalars(self, dim);
-}
-
-int64_t THCTensor_(stride)(THCState *state, const THCTensor *self, int dim)
-{
-  return THCTensor_stride(state, self, dim);
-}
-
-int64_t THCTensor_(strideLegacyNoScalars)(THCState *state, const THCTensor *self, int dim)
-{
-  return THTensor_strideLegacyNoScalars(self, dim);
-}
-
-scalar_t *THCTensor_(data)(THCState *state, const THCTensor *self)
-{
-  if(THTensor_getStoragePtr(self))
-    return (THCStorage_(data)(state, THTensor_getStoragePtr(self))+self->storage_offset());
-  else
-    return NULL;
-}
-
 /**** creation methods ****/
 
 /* Empty init */
@@ -72,11 +24,6 @@ THCTensor *THCTensor_(new)(THCState *state)
       .release();
 }
 
-/* Pointer-copy init */
-THCTensor *THCTensor_(newWithTensor)(THCState *state, THCTensor *tensor)
-{
-  return at::native::alias(THTensor_wrap(tensor)).unsafeReleaseTensorImpl();
-}
 
 THCTensor *THCTensor_(newWithStorage1d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset,
                                int64_t size0, int64_t stride0)
@@ -87,393 +34,13 @@ THCTensor *THCTensor_(newWithStorage1d)(THCState *state, THCStorage *storage, pt
                        at::DispatchKey::CUDA,
                        caffe2::TypeMeta::Make<scalar_t>())
                        .release();
-  THCTensor_(setStorage)(state, self, storage, storageOffset, {size0}, {stride0});
-
-  return self;
-}
-
-THCTensor *THCTensor_(newWithSize)(THCState *state, at::IntArrayRef size, at::IntArrayRef stride)
-{
-  TORCH_INTERNAL_ASSERT(false, "this function should not be called and is in the process of being removed");
-}
-
-THCTensor *THCTensor_(newWithSize1d)(THCState *state, int64_t size0)
-{
-  THCStorage *new_storage = THCStorage_(new)(state);
-  THCTensor* self =
-      c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
-          c10::intrusive_ptr<at::StorageImpl>::reclaim(new_storage),
-          at::DispatchKey::CUDA,
-          caffe2::TypeMeta::Make<scalar_t>())
-          .release();
-  THCTensor_(setStorage)(state, self, new_storage, 0, {size0}, {});
+  THCTensor_setStorage(state, self, storage, storageOffset, {size0}, {stride0});
 
   return self;
 }
 
-THCTensor *THCTensor_(newClone)(THCState *state, THCTensor *self)
-{
-  // already available in Aten as at::clone()
-  THCTensor *tensor = THCTensor_(new)(state);
-  at::Tensor tensor_wrap = THTensor_wrap(tensor);
-  at::Tensor self_wrap = THTensor_wrap(self);
-  tensor_wrap.resize_as_(self_wrap);
-  THCTensor_(copy)(state, tensor, self);
-  return tensor;
-}
-
-THCTensor *THCTensor_(newContiguous)(THCState *state, THCTensor *self)
-{
-  if(!THCTensor_(isContiguous)(state, self)) {
-    return THCTensor_(newClone)(state, self);
-  } else {
-    THCTensor_(retain)(state, self);
-    return self;
-  }
-}
-
-THCTensor *THCTensor_(newSelect)(THCState *state, THCTensor *tensor, int dimension_, int64_t sliceIndex_)
-{
-  THCTensor *self = THCTensor_(newWithTensor)(state, tensor);
-  THCTensor_(select)(state, self, NULL, dimension_, sliceIndex_);
-  return self;
-}
-
-THCTensor *THCTensor_(newNarrow)(THCState *state, THCTensor *tensor, int dimension_, int64_t firstIndex_, int64_t size_)
-{
-  THCTensor *self = THCTensor_(newWithTensor)(state, tensor);
-  THCTensor_(narrow)(state, self, NULL, dimension_, firstIndex_, size_);
-  return self;
-}
-
-THCTensor *THCTensor_(newTranspose)(THCState *state, THCTensor *tensor, int dimension1_, int dimension2_)
-{
-  THCTensor *self = THCTensor_(newWithTensor)(state, tensor);
-  THCTensor_(transpose)(state, self, NULL, dimension1_, dimension2_);
-  return self;
-}
-
-// Collapses the first two dimensions of a tensor.
-// Assumes the input tensor is contiguous.
-THCTensor *THCTensor_(newFoldBatchDim)(THCState *state, THCTensor *input) {
-  int in_dims = THCTensor_(nDimensionLegacyAll)(state, input);
-  THArgCheck(in_dims >= 2, 1, "Tensor needs to have at least two dimensions");
-  THArgCheck(THCTensor_(isContiguous)(state, input), 1,
-             "Tensor must be contiguous");
-  std::vector<int64_t> new_size(in_dims - 1);
-  new_size[0] = THCTensor_(size)(state, input, 0) * THCTensor_(size)(state, input, 1);
-  for (int i = 2; i < in_dims; i++) {
-    new_size[i - 1] = THCTensor_(size)(state, input, i);
-  }
-  THCTensor *output = at::native::view(THTensor_wrap(input), new_size).unsafeReleaseTensorImpl();
-  return output;
-}
-
-/* Resize */
-void THCTensor_(resize)(THCState *state, THCTensor *self, at::IntArrayRef size, at::IntArrayRef stride)
-{
-  THCTensor_resize(state, self, size, stride);
-}
-
-void THCTensor_(resizeAs)(THCState *state, THCTensor *self, THCTensor *src)
-{
-  // already available in Aten as at::resize_as_()
-  THCTensor_resizeAs(state, self, src);
-}
-
-void THCTensor_(resize0d)(THCState *state, THCTensor *tensor)
-{
-  THCTensor_resizeNd(state, tensor, 0, {}, nullptr);
-}
-
-void THCTensor_(resize1d)(THCState *state, THCTensor *tensor, int64_t size0)
-{
-  int64_t size[1] = {size0};
-  THCTensor_resizeNd(state, tensor, 1, size, nullptr);
-}
-
-void THCTensor_(resize2d)(THCState *state, THCTensor *tensor, int64_t size0, int64_t size1)
-{
-  int64_t size[2] = {size0, size1};
-  THCTensor_resizeNd(state, tensor, 2, size, nullptr);
-}
-
-void THCTensor_(resize3d)(THCState *state, THCTensor *tensor, int64_t size0, int64_t size1, int64_t size2)
-{
-  int64_t size[3] = {size0, size1, size2};
-  THCTensor_resizeNd(state, tensor, 3, size, nullptr);
-}
-
-void THCTensor_(resize4d)(THCState *state, THCTensor *self, int64_t size0, int64_t size1, int64_t size2, int64_t size3)
-{
-  int64_t size[4] = {size0, size1, size2, size3};
-  THCTensor_resizeNd(state, self, 4, size, nullptr);
-}
-
-void THCTensor_(resize5d)(THCState *state, THCTensor *self, int64_t size0, int64_t size1, int64_t size2, int64_t size3, int64_t size4)
-{
-  int64_t size[5] = {size0, size1, size2, size3, size4};
-  THCTensor_resizeNd(state, self, 5, size, nullptr);
-}
-
-void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src)
-{
-  THCTensor_set(state, self, src);
-}
-
-void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, at::IntArrayRef size_, at::IntArrayRef stride_) {
-  THCTensor_setStorage(state, self, storage_, storageOffset_, size_, stride_);
-}
-
-void THCTensor_(narrow)(THCState *state, THCTensor *self, THCTensor *src, int dimension, int64_t firstIndex, int64_t size)
-{
-  if(!src)
-    src = self;
-
-  THArgCheck( (dimension >= 0) && (dimension < src->dim()), 3, "out of range");
-  THArgCheck( firstIndex >= 0, 4, "out of range");
-  THArgCheck( size >= 0, 5, "out of range");
-  THArgCheck(firstIndex+size <= src->size(dimension), 5, "out of range");
-
-  THCTensor_(set)(state, self, src);
-
-  if (firstIndex > 0) {
-    self->set_storage_offset(self->storage_offset() + firstIndex*self->stride(dimension));
-  }
-
-  self->set_size(dimension, size);
-}
-
-void THCTensor_(select)(THCState *state, THCTensor *self, THCTensor *src, int dimension, int64_t sliceIndex)
-{
-  int d;
-
-  if(!src)
-    src = self;
-
-  THArgCheck(src->dim() > 0, 1, "cannot select on a 0-dim tensor");
-  THArgCheck((dimension >= 0) && (dimension < src->dim()), 3, "out of range");
-  THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size(dimension)), 4, "out of range");
-
-  THCTensor_(set)(state, self, src);
-  THCTensor_(narrow)(state, self, NULL, dimension, sliceIndex, 1);
-
-  std::vector<int64_t> newSize(self->dim()-1);
-  std::vector<int64_t> newStride(self->dim()-1);
-
-  for (d = 0; d < dimension; d++)
-  {
-    newSize[d] = self->size(d);
-    newStride[d] = self->stride(d);
-  }
-
-  for(d = dimension; d < self->dim()-1; d++)
-  {
-    newSize[d] = self->size(d+1);
-    newStride[d] = self->stride(d+1);
-  }
-  self->set_sizes_and_strides(newSize, newStride);
-}
-
-void THCTensor_(transpose)(THCState *state, THCTensor *self, THCTensor *src, int dimension1, int dimension2)
-{
-  int64_t z;
-
-  if(!src)
-    src = self;
-
-  THArgCheck( (dimension1 >= 0) && (dimension1 < THTensor_nDimensionLegacyNoScalars(src)), 1, "out of range");
-  THArgCheck( (dimension2 >= 0) && (dimension2 < THTensor_nDimensionLegacyNoScalars(src)), 2, "out of range");
-
-  THCTensor_(set)(state, self, src);
-
-  if(dimension1 == dimension2)
-    return;
-
-  z = self->stride(dimension1);
-  self->set_stride(dimension1, self->stride(dimension2));
-  self->set_stride(dimension2, z);
-  z = self->size(dimension1);
-  self->set_size(dimension1, self->size(dimension2));
-  self->set_size(dimension2, z);
-}
-
-void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension)
-{
-  THCTensor_squeeze1d(state, self, src, dimension);
-}
-
-void THCTensor_(unsqueeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension)
-{
-  THCTensor_unsqueeze1d(state, self, src, dimension);
-}
-
-int THCTensor_(isContiguous)(THCState *state, const THCTensor *self)
-{
-  return self->is_contiguous();
-}
-
-int THCTensor_(isSameSizeAs)(THCState *state, const THCTensor *self, const THCTensor* src)
-{
-  int d;
-  if (self->dim() != src->dim())
-    return 0;
-  for(d = 0; d < self->dim(); ++d)
-  {
-    if(self->size(d) != src->size(d))
-      return 0;
-  }
-  return 1;
-}
-
-ptrdiff_t THCTensor_(nElement)(THCState *state, const THCTensor *self)
-{
-  return THCTensor_nElement(state, self);
-}
-
-void THCTensor_(retain)(THCState *state, THCTensor *self)
-{
-  THCTensor_retain(state, self);
-}
-
 void THCTensor_(free)(THCState *state, THCTensor *self)
 {
   THCTensor_free(state, self);
 }
-
-void THCTensor_(freeCopyTo)(THCState *state, THCTensor *self, THCTensor *dst)
-{
-  if(self != dst)
-    THCTensor_(copy)(state, dst, self);
-
-  THCTensor_(free)(state, self);
-}
-
-/*******************************************************************************/
-
-void THCTensor_(resizeNd)(THCState *state, THCTensor *self, int nDimension, const int64_t *size, const int64_t *stride)
-{
-  THCTensor_resizeNd(state, self, nDimension, size, stride);
-}
-
-void THCTensor_(set0d)(THCState *state, THCTensor *tensor, scalar_t value)
-{
-  THArgCheck(THTensor_nDimension(tensor) == 0, 1, "tensor must have no dimensions");
-  THCStorage_(set)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset(), value);
-}
-
-
-scalar_t THCTensor_(get0d)(THCState *state, const THCTensor *tensor)
-{
-  THArgCheck(THTensor_nDimension(tensor) == 0, 1, "tensor must have no dimensions dimension");
-  return THCStorage_(get)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset());
-}
-
-void THCTensor_(set1d)(THCState *state, THCTensor *tensor, int64_t x0, scalar_t value)
-{
-  THArgCheck(THTensor_nDimensionLegacyNoScalars(tensor) == 1, 1, "tensor must have one dimension");
-  THArgCheck( (x0 >= 0) && (x0 < THTensor_sizeLegacyNoScalars(tensor, 0)), 2, "out of range");
-  THCStorage_(set)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*THTensor_strideLegacyNoScalars(tensor, 0), value);
-}
-
-scalar_t THCTensor_(get1d)(THCState *state, const THCTensor *tensor, int64_t x0)
-{
-  THArgCheck(THTensor_nDimensionLegacyNoScalars(tensor) == 1, 1, "tensor must have one dimension");
-  THArgCheck( (x0 >= 0) && (x0 < THTensor_sizeLegacyNoScalars(tensor, 0)), 2, "out of range");
-  return THCStorage_(get)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*THTensor_strideLegacyNoScalars(tensor, 0));
-}
-
-void THCTensor_(set2d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, scalar_t value)
-{
-  THArgCheck(tensor->dim() == 2, 1, "tensor must have two dimensions");
-  THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)), 2, "out of range");
-  THCStorage_(set)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1), value);
-}
-
-scalar_t THCTensor_(get2d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1)
-{
-  THArgCheck(tensor->dim() == 2, 1, "tensor must have two dimensions");
-  THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)), 2, "out of range");
-  return THCStorage_(get)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1));
-}
-
-void THCTensor_(set3d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, scalar_t value)
-{
-  THArgCheck(tensor->dim() == 3, 1, "tensor must have three dimensions");
-  THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)), 2, "out of range");
-  THCStorage_(set)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2), value);
-}
-
-scalar_t THCTensor_(get3d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2)
-{
-  THArgCheck(tensor->dim() == 3, 1, "tensor must have three dimensions");
-  THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)), 2, "out of range");
-  return THCStorage_(get)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2));
-}
-
-void THCTensor_(set4d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3, scalar_t value)
-{
-  THArgCheck(tensor->dim() == 4, 1, "tensor must have four dimensions");
-  THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)) && (x3 >= 0) && (x3 < tensor->size(3)), 2, "out of range");
-  THCStorage_(set)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2)+x3*tensor->stride(3), value);
-}
-
-scalar_t THCTensor_(get4d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3)
-{
-  THArgCheck(tensor->dim() == 4, 1, "tensor must have four dimensions");
-  THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)) && (x3 >= 0) && (x3 < tensor->size(3)), 2, "out of range");
-  return THCStorage_(get)(state, THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2)+x3*tensor->stride(3));
-}
-
-int THCTensor_(checkGPU)(THCState *state, unsigned int nTensors, ...)
-{
-  int curDev = -1;
-  THCudaCheck(cudaGetDevice(&curDev));
-  va_list args;
-  va_start(args, nTensors);
-  int valid = 1;
-  for (unsigned int i = 0; i < nTensors; i++) {
-    THCTensor* tensor = va_arg(args, THCTensor*);
-    if (tensor == NULL) {
-      continue;
-    }
-
-    const int tensorDev = THCTensor_(getDevice)(state, tensor);
-
-    // Skips CPU tensors
-    if (tensorDev == -1) { continue; }
-
-    // Checks all tensors are on the same device
-    if (tensorDev != curDev) {
-      valid = 0;
-      break;
-    }
-  }
-
-  va_end(args);
-  return valid;
-}
-
-THCDescBuff THCTensor_(sizeDesc)(THCState *state, const THCTensor *tensor) {
-  const int L = THC_DESC_BUFF_LEN;
-  THCDescBuff buf;
-  char *str = buf.str;
-  int n = 0;
-  n += snprintf(str, L-n, "[");
-  int i;
-  for(i = 0; i < tensor->dim(); i++) {
-    if(n >= L) break;
-    n += snprintf(str+n, L-n, "%" PRId64, tensor->size(i));
-    if(i < tensor->dim()-1) {
-      n += snprintf(str+n, L-n, " x ");
-    }
-  }
-  if(n < L - 2) {
-    snprintf(str+n, L-n, "]");
-  } else {
-    snprintf(str+L-5, 5, "...]");
-  }
-  return buf;
-}
-
 #endif
index f09a97e..b33c04b 100644 (file)
 /**** access methods ****/
 TORCH_CUDA_CU_API THCStorage* THCTensor_(
     storage)(THCState* state, const THCTensor* self);
-TORCH_CUDA_CU_API ptrdiff_t
-    THCTensor_(storageOffset)(THCState* state, const THCTensor* self);
-
-// See [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
-TORCH_CUDA_CU_API int THCTensor_(
-    nDimension)(THCState* state, const THCTensor* self);
-TORCH_CUDA_CU_API int THCTensor_(
-    nDimensionLegacyNoScalars)(THCState* state, const THCTensor* self);
-TORCH_CUDA_CU_API int THCTensor_(
-    nDimensionLegacyAll)(THCState* state, const THCTensor* self);
-
-TORCH_CUDA_CU_API int64_t
-    THCTensor_(size)(THCState* state, const THCTensor* self, int dim);
-TORCH_CUDA_CU_API int64_t THCTensor_(
-    sizeLegacyNoScalars)(THCState* state, const THCTensor* self, int dim);
-TORCH_CUDA_CU_API int64_t
-    THCTensor_(stride)(THCState* state, const THCTensor* self, int dim);
-TORCH_CUDA_CU_API int64_t THCTensor_(
-    strideLegacyNoScalars)(THCState* state, const THCTensor* self, int dim);
-TORCH_CUDA_CU_API scalar_t* THCTensor_(
-    data)(THCState* state, const THCTensor* self);
-
-TORCH_CUDA_CU_API void THCTensor_(
-    setFlag)(THCState* state, THCTensor* self, const char flag);
-TORCH_CUDA_CU_API void THCTensor_(
-    clearFlag)(THCState* state, THCTensor* self, const char flag);
-
 /**** creation methods ****/
 TORCH_CUDA_CU_API THCTensor* THCTensor_(new)(THCState* state);
-TORCH_CUDA_CU_API THCTensor* THCTensor_(
-    newWithTensor)(THCState* state, THCTensor* tensor);
 TORCH_CUDA_CU_API THCTensor* THCTensor_(newWithStorage1d)(
     THCState* state,
     THCStorage* storage_,
@@ -60,175 +31,6 @@ TORCH_CUDA_CU_API THCTensor* THCTensor_(newWithStorage1d)(
     int64_t size0_,
     int64_t stride0_);
 
-/* stride might be NULL */
-TORCH_CUDA_CU_API THCTensor* THCTensor_(
-    newWithSize1d)(THCState* state, int64_t size0_);
-
-TORCH_CUDA_CU_API THCTensor* THCTensor_(
-    newClone)(THCState* state, THCTensor* self);
-TORCH_CUDA_CU_API THCTensor* THCTensor_(
-    newContiguous)(THCState* state, THCTensor* tensor);
-TORCH_CUDA_CU_API THCTensor* THCTensor_(newSelect)(
-    THCState* state,
-    THCTensor* tensor,
-    int dimension_,
-    int64_t sliceIndex_);
-TORCH_CUDA_CU_API THCTensor* THCTensor_(newNarrow)(
-    THCState* state,
-    THCTensor* tensor,
-    int dimension_,
-    int64_t firstIndex_,
-    int64_t size_);
-TORCH_CUDA_CU_API THCTensor* THCTensor_(newTranspose)(
-    THCState* state,
-    THCTensor* tensor,
-    int dimension1_,
-    int dimension2_);
-TORCH_CUDA_CU_API THCTensor* THCTensor_(
-    newFoldBatchDim)(THCState* state, THCTensor* input);
-
-// resize* methods simply resize the storage. So they may not retain the current data at current indices.
-// This is especially likely to happen when the tensor is not contiguous. In general, if you still need the
-// values, unless you are doing some size and stride tricks, do not use resize*.
-TORCH_CUDA_CU_API void THCTensor_(resizeNd)(
-    THCState* state,
-    THCTensor* tensor,
-    int nDimension,
-    const int64_t* size,
-    const int64_t* stride);
-TORCH_CUDA_CU_API void THCTensor_(
-    resizeAs)(THCState* state, THCTensor* tensor, THCTensor* src);
-TORCH_CUDA_CU_API void THCTensor_(resize0d)(THCState* state, THCTensor* tensor);
-TORCH_CUDA_CU_API void THCTensor_(
-    resize1d)(THCState* state, THCTensor* tensor, int64_t size0_);
-TORCH_CUDA_CU_API void THCTensor_(resize2d)(
-    THCState* state,
-    THCTensor* tensor,
-    int64_t size0_,
-    int64_t size1_);
-TORCH_CUDA_CU_API void THCTensor_(resize3d)(
-    THCState* state,
-    THCTensor* tensor,
-    int64_t size0_,
-    int64_t size1_,
-    int64_t size2_);
-TORCH_CUDA_CU_API void THCTensor_(resize4d)(
-    THCState* state,
-    THCTensor* tensor,
-    int64_t size0_,
-    int64_t size1_,
-    int64_t size2_,
-    int64_t size3_);
-TORCH_CUDA_CU_API void THCTensor_(resize5d)(
-    THCState* state,
-    THCTensor* tensor,
-    int64_t size0_,
-    int64_t size1_,
-    int64_t size2_,
-    int64_t size3_,
-    int64_t size4_);
-
-TORCH_CUDA_CU_API void THCTensor_(
-    set)(THCState* state, THCTensor* self, THCTensor* src);
-
-TORCH_CUDA_CU_API void THCTensor_(narrow)(
-    THCState* state,
-    THCTensor* self,
-    THCTensor* src,
-    int dimension_,
-    int64_t firstIndex_,
-    int64_t size_);
-TORCH_CUDA_CU_API void THCTensor_(select)(
-    THCState* state,
-    THCTensor* self,
-    THCTensor* src,
-    int dimension_,
-    int64_t sliceIndex_);
-TORCH_CUDA_CU_API void THCTensor_(transpose)(
-    THCState* state,
-    THCTensor* self,
-    THCTensor* src,
-    int dimension1_,
-    int dimension2_);
-
-TORCH_CUDA_CU_API void THCTensor_(squeeze1d)(
-    THCState* state,
-    THCTensor* self,
-    THCTensor* src,
-    int dimension_);
-TORCH_CUDA_CU_API void THCTensor_(unsqueeze1d)(
-    THCState* state,
-    THCTensor* self,
-    THCTensor* src,
-    int dimension_);
-
-TORCH_CUDA_CU_API int THCTensor_(
-    isContiguous)(THCState* state, const THCTensor* self);
-TORCH_CUDA_CU_API int THCTensor_(
-    isSameSizeAs)(THCState* state, const THCTensor* self, const THCTensor* src);
-TORCH_CUDA_CU_API ptrdiff_t
-    THCTensor_(nElement)(THCState* state, const THCTensor* self);
-
-TORCH_CUDA_CU_API void THCTensor_(retain)(THCState* state, THCTensor* self);
 TORCH_CUDA_CU_API void THCTensor_(free)(THCState* state, THCTensor* self);
-TORCH_CUDA_CU_API void THCTensor_(
-    freeCopyTo)(THCState* state, THCTensor* self, THCTensor* dst);
-
-/* Slow access methods [check everything] */
-TORCH_CUDA_CU_API void THCTensor_(
-    set0d)(THCState* state, THCTensor* tensor, scalar_t value);
-TORCH_CUDA_CU_API void THCTensor_(
-    set1d)(THCState* state, THCTensor* tensor, int64_t x0, scalar_t value);
-TORCH_CUDA_CU_API void THCTensor_(set2d)(
-    THCState* state,
-    THCTensor* tensor,
-    int64_t x0,
-    int64_t x1,
-    scalar_t value);
-TORCH_CUDA_CU_API void THCTensor_(set3d)(
-    THCState* state,
-    THCTensor* tensor,
-    int64_t x0,
-    int64_t x1,
-    int64_t x2,
-    scalar_t value);
-TORCH_CUDA_CU_API void THCTensor_(set4d)(
-    THCState* state,
-    THCTensor* tensor,
-    int64_t x0,
-    int64_t x1,
-    int64_t x2,
-    int64_t x3,
-    scalar_t value);
-
-TORCH_CUDA_CU_API scalar_t
-    THCTensor_(get0d)(THCState* state, const THCTensor* tensor);
-TORCH_CUDA_CU_API scalar_t
-    THCTensor_(get1d)(THCState* state, const THCTensor* tensor, int64_t x0);
-TORCH_CUDA_CU_API scalar_t THCTensor_(
-    get2d)(THCState* state, const THCTensor* tensor, int64_t x0, int64_t x1);
-TORCH_CUDA_CU_API scalar_t THCTensor_(get3d)(
-    THCState* state,
-    const THCTensor* tensor,
-    int64_t x0,
-    int64_t x1,
-    int64_t x2);
-TORCH_CUDA_CU_API scalar_t THCTensor_(get4d)(
-    THCState* state,
-    const THCTensor* tensor,
-    int64_t x0,
-    int64_t x1,
-    int64_t x2,
-    int64_t x3);
-
-/* CUDA-specific functions */
-TORCH_CUDA_CU_API int THCTensor_(
-    getDevice)(THCState* state, const THCTensor* self);
-TORCH_CUDA_CU_API int THCTensor_(
-    checkGPU)(THCState* state, unsigned int nTensors, ...);
-
-/* debug methods */
-TORCH_CUDA_CU_API THCDescBuff
-    THCTensor_(sizeDesc)(THCState* state, const THCTensor* tensor);
 
 #endif
diff --git a/aten/src/THC/generic/THCTensor.hpp b/aten/src/THC/generic/THCTensor.hpp
deleted file mode 100644 (file)
index fb7ae79..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THC/generic/THCTensor.hpp"
-#else
-
-// STOP!!! Thinking of including this header directly?  Please
-// read Note [TH abstraction violation]
-
-// NOTE: functions exist here only to support dispatch via Declarations.cwrap.  You probably don't want to put
-// new functions in here, they should probably be un-genericized.
-
-TORCH_CUDA_CU_API void THCTensor_(setStorage)(
-    THCState* state,
-    THCTensor* self,
-    THCStorage* storage_,
-    ptrdiff_t storageOffset_,
-    at::IntArrayRef size_,
-    at::IntArrayRef stride_);
-
-TORCH_CUDA_CU_API void THCTensor_(resize)(
-    THCState* state,
-    THCTensor* self,
-    at::IntArrayRef size,
-    at::IntArrayRef stride);
-
-#endif