Replace resize_dim with set_sizes_and_strides in THTensor_(squeeze) (#18059)
authorHuitong Qiu <huitong@fb.com>
Mon, 18 Mar 2019 15:49:44 +0000 (08:49 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 18 Mar 2019 15:52:58 +0000 (08:52 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18059

Replace resize_dim() with set_sizes_and_strides() in `THTensor_(squeeze)` in aten/src/TH/generic/THTensor.cpp and `THCTensor_(squeeze)` in aten/src/THC/generic/THCTensor.cpp

Reviewed By: ezyang

Differential Revision: D14471066

fbshipit-source-id: 1c8c412ff09246c4df6843736e3bf0279bfadea8

aten/src/TH/generic/THTensor.cpp
aten/src/THC/generic/THCTensor.cpp

index 4a5ae3b..bfd5c9f 100644 (file)
@@ -430,28 +430,23 @@ void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension, int64_t siz
 /* we have to handle the case where the result is a number */
 void THTensor_(squeeze)(THTensor *self, THTensor *src)
 {
-  int ndim = 0;
-  int d;
-
   if(!src)
     src = self;
 
   THTensor_(set)(self, src);
 
-  for(d = 0; d < src->dim(); d++)
+  std::vector<int64_t> newSize;
+  std::vector<int64_t> newStride;
+  for(int d = 0; d < src->dim(); ++d)
   {
     if(src->size(d) != 1)
     {
-      if(d != ndim)
-      {
-        self->set_size(ndim, src->size(d));
-        self->set_stride(ndim, src->stride(d));
-      }
-      ndim++;
+      newSize.push_back(src->size(d));
+      newStride.push_back(src->stride(d));
     }
   }
 
-  self->resize_dim(ndim);
+  self->set_sizes_and_strides(newSize, newStride);
 }
 
 void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension)
index 05be40c..6e00629 100644 (file)
@@ -447,28 +447,23 @@ void THCTensor_(unfold)(THCState *state, THCTensor *self, THCTensor *src, int di
 /* we have to handle the case where the result is a number */
 void THCTensor_(squeeze)(THCState *state, THCTensor *self, THCTensor *src)
 {
-  int ndim = 0;
-  int d;
-
   if(!src)
     src = self;
 
   THCTensor_(set)(state, self, src);
 
-  for(d = 0; d < src->dim(); d++)
+  std::vector<int64_t> newSize;
+  std::vector<int64_t> newStride;
+  for(int d = 0; d < src->dim(); ++d)
   {
     if(src->size(d) != 1)
     {
-      if(d != ndim)
-      {
-        self->set_size(ndim, src->size(d));
-        self->set_stride(ndim, src->stride(d));
-      }
-      ndim++;
+      newSize.push_back(src->size(d));
+      newStride.push_back(src->stride(d));
     }
   }
 
-  self->resize_dim(ndim);
+  self->set_sizes_and_strides(newSize, newStride);
 }
 
 void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension)