From 754bf595ca3ae2cbbc1be1032862119ee64773ac Mon Sep 17 00:00:00 2001 From: Huitong Qiu Date: Mon, 18 Mar 2019 08:49:44 -0700 Subject: [PATCH] Replace resize_dim with set_sizes_and_strides in THTensor_(squeeze) (#18059) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18059 Replace resize_dim() with set_sizes_and_strides() in `THTensor_(squeeze)` in aten/src/TH/generic/THTensor.cpp and `THCTensor_(squeeze)` in aten/src/THC/generic/THCTensor.cpp Reviewed By: ezyang Differential Revision: D14471066 fbshipit-source-id: 1c8c412ff09246c4df6843736e3bf0279bfadea8 --- aten/src/TH/generic/THTensor.cpp | 17 ++++++----------- aten/src/THC/generic/THCTensor.cpp | 17 ++++++----------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp index 4a5ae3b..bfd5c9f 100644 --- a/aten/src/TH/generic/THTensor.cpp +++ b/aten/src/TH/generic/THTensor.cpp @@ -430,28 +430,23 @@ void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension, int64_t siz /* we have to handle the case where the result is a number */ void THTensor_(squeeze)(THTensor *self, THTensor *src) { - int ndim = 0; - int d; - if(!src) src = self; THTensor_(set)(self, src); - for(d = 0; d < src->dim(); d++) + std::vector newSize; + std::vector newStride; + for(int d = 0; d < src->dim(); ++d) { if(src->size(d) != 1) { - if(d != ndim) - { - self->set_size(ndim, src->size(d)); - self->set_stride(ndim, src->stride(d)); - } - ndim++; + newSize.push_back(src->size(d)); + newStride.push_back(src->stride(d)); } } - self->resize_dim(ndim); + self->set_sizes_and_strides(newSize, newStride); } void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension) diff --git a/aten/src/THC/generic/THCTensor.cpp b/aten/src/THC/generic/THCTensor.cpp index 05be40c..6e00629 100644 --- a/aten/src/THC/generic/THCTensor.cpp +++ b/aten/src/THC/generic/THCTensor.cpp @@ -447,28 +447,23 @@ void THCTensor_(unfold)(THCState *state, THCTensor *self, THCTensor *src, int di /* we have to handle the case where the result is a number */ void THCTensor_(squeeze)(THCState *state, THCTensor *self, THCTensor *src) { - int ndim = 0; - int d; - if(!src) src = self; THCTensor_(set)(state, self, src); - for(d = 0; d < src->dim(); d++) + std::vector newSize; + std::vector newStride; + for(int d = 0; d < src->dim(); ++d) { if(src->size(d) != 1) { - if(d != ndim) - { - self->set_size(ndim, src->size(d)); - self->set_stride(ndim, src->stride(d)); - } - ndim++; + newSize.push_back(src->size(d)); + newStride.push_back(src->stride(d)); } } - self->resize_dim(ndim); + self->set_sizes_and_strides(newSize, newStride); } void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension) -- 2.7.4