From ae91156e5dcda265f797606bc2e1fa0e46b29d1d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 4 Jan 2019 15:48:21 -0800 Subject: [PATCH] Tensor method rename dims()->sizes() - 1/2 Summary: Codemod generated with clangr shard mode, 25 files per diff, Reviewed By: BIT-silence Differential Revision: D13581782 fbshipit-source-id: b16b4198e100617769d84aa599bf141117cfbe5b --- caffe2/operators/boolean_mask_ops.cc | 2 +- caffe2/operators/boolean_mask_ops.cu | 6 +++--- caffe2/operators/cross_entropy_op.cu | 29 ++++++++++++++--------------- caffe2/operators/deform_conv_op_impl.h | 8 ++++---- caffe2/operators/distance_op.cu | 12 ++++++------ caffe2/operators/dropout_op.cu | 9 +++------ caffe2/operators/gather_op.cuh | 2 +- caffe2/operators/integral_image_op.cu | 4 ++-- caffe2/operators/lengths_tile_op.cu | 2 +- caffe2/operators/pack_segments.cu | 4 ++-- caffe2/operators/pool_op.cu | 8 ++++---- caffe2/operators/pool_op_cudnn.cu | 8 ++++---- caffe2/operators/resize_op.cu | 4 ++-- caffe2/operators/reverse_packed_segs_op.cu | 12 +++++------- caffe2/operators/roi_pool_op.cu | 4 ++-- caffe2/operators/sequence_ops.cu | 6 +++--- caffe2/operators/slice_op.cu | 30 +++++++++++++++--------------- caffe2/operators/sparse_to_dense_op.cu | 2 +- caffe2/operators/top_k.cu | 2 +- caffe2/operators/upsample_op.cu | 4 ++-- caffe2/operators/weighted_sample_op.cu | 4 ++-- modules/detectron/ps_roi_pool_op.cu | 2 +- 22 files changed, 79 insertions(+), 85 deletions(-) diff --git a/caffe2/operators/boolean_mask_ops.cc b/caffe2/operators/boolean_mask_ops.cc index f23e655..839fd0e 100644 --- a/caffe2/operators/boolean_mask_ops.cc +++ b/caffe2/operators/boolean_mask_ops.cc @@ -52,7 +52,7 @@ bool BooleanMaskOp::RunOnDevice() { auto* dataOut = Output(0); CAFFE_ENFORCE(data.dim() >= 1); CAFFE_ENFORCE_EQ(mask.dim(), 1); - CAFFE_ENFORCE(data.sizes()[0] == mask.sizes()[0]); + CAFFE_ENFORCE(data.size(0) == mask.size(0)); const auto* maskPtr = mask.template data(); int numOutputs = 0; diff --git a/caffe2/operators/boolean_mask_ops.cu b/caffe2/operators/boolean_mask_ops.cu index bfed461..3e8c6ce 100644 --- a/caffe2/operators/boolean_mask_ops.cu +++ b/caffe2/operators/boolean_mask_ops.cu @@ -35,10 +35,10 @@ class BooleanMaskOp final : public Operator { CAFFE_ENFORCE(src.ndim() >= 1); CAFFE_ENFORCE_EQ(mask.ndim(), 1); - CAFFE_ENFORCE(src.dims()[0] == mask.dims()[0]); + CAFFE_ENFORCE(src.size(0) == mask.size(0)); const auto* maskData = mask.data(); - const auto outerSize = mask.dims()[0]; + const auto outerSize = mask.size(0); indices_.Resize(outerSize); auto* indicesData = indices_.mutable_data(); @@ -76,7 +76,7 @@ class BooleanMaskOp final : public Operator { context_.CopyToCPU(1, numOfOutputData, &numOfOutput); indices_.Resize(numOfOutput); - std::vector dims = src.dims().vec(); + std::vector dims = src.sizes().vec(); dims[0] = numOfOutput; dest->Resize(dims); auto* destData = (uint8_t*)dest->raw_mutable_data(src.meta()); diff --git a/caffe2/operators/cross_entropy_op.cu b/caffe2/operators/cross_entropy_op.cu index 2fde322..ab2a702 100644 --- a/caffe2/operators/cross_entropy_op.cu +++ b/caffe2/operators/cross_entropy_op.cu @@ -113,8 +113,7 @@ __global__ void MakeTwoClassGradientKernel( template <> bool MakeTwoClassOp::RunOnDevice() { auto& X = Input(0); - - auto shape = X.dims().vec(); + auto shape = X.sizes().vec(); shape.push_back(2); CAFFE_ENFORCE_LT(X.size(), std::numeric_limits::max() / 2); auto* Y = Output(0, shape, at::dtype()); @@ -131,8 +130,7 @@ bool MakeTwoClassOp::RunOnDevice() { template <> bool MakeTwoClassGradientOp::RunOnDevice() { auto& dY = Input(0); - - auto shape = dY.dims().vec(); + auto shape = dY.sizes().vec(); CAFFE_ENFORCE_GE(shape.size(), 1); CAFFE_ENFORCE_EQ(shape.back(), 2); shape.pop_back(); @@ -244,13 +242,14 @@ template <> bool SigmoidCrossEntropyWithLogitsOp::RunOnDevice() { auto& logits = Input(0); auto& targets = Input(1); - CAFFE_ENFORCE(logits.dims() == targets.dims()); - const auto inner_size = logits.ndim() > 0 ? logits.dims().back() : 1; + CAFFE_ENFORCE(logits.sizes() == targets.sizes()); + const auto inner_size = logits.ndim() > 0 ? logits.sizes().back() : 1; const auto outer_size = logits.size() / inner_size; std::vector dims; if (logits.dim() != 0) { - dims = std::vector(logits.dims().begin(), logits.dims().end() - 1); + dims = + std::vector(logits.sizes().begin(), logits.sizes().end() - 1); } auto* out = Output(0, dims, at::dtype()); auto* out_ptr = out->template mutable_data(); @@ -284,8 +283,8 @@ bool SigmoidCrossEntropyWithLogitsGradientOp:: auto& g = Input(0); auto& logits = Input(1); auto& targets = Input(2); - CAFFE_ENFORCE(logits.dims() == targets.dims()); - const auto inner_size = logits.ndim() > 0 ? logits.dims().back() : 1; + CAFFE_ENFORCE(logits.sizes() == targets.sizes()); + const auto inner_size = logits.ndim() > 0 ? logits.sizes().back() : 1; const auto outer_size = logits.size() / inner_size; CAFFE_ENFORCE(g.size() == outer_size); @@ -363,9 +362,9 @@ bool WeightedSigmoidCrossEntropyWithLogitsOp:: auto& logits = Input(0); auto& targets = Input(1); auto& weights = Input(2); - CAFFE_ENFORCE(logits.dims() == targets.dims()); - CAFFE_ENFORCE(weights.dims() == targets.dims()); - const auto inner_size = logits.ndim() > 0 ? logits.dims().back() : 1; + CAFFE_ENFORCE(logits.sizes() == targets.sizes()); + CAFFE_ENFORCE(weights.sizes() == targets.sizes()); + const auto inner_size = logits.ndim() > 0 ? logits.sizes().back() : 1; const auto outer_size = logits.size() / inner_size; std::vector dims; @@ -396,9 +395,9 @@ bool WeightedSigmoidCrossEntropyWithLogitsGradientOp:: auto& logits = Input(1); auto& targets = Input(2); auto& weights = Input(3); - CAFFE_ENFORCE(logits.dims() == targets.dims()); - CAFFE_ENFORCE(weights.dims() == targets.dims()); - const auto inner_size = logits.ndim() > 0 ? logits.dims().back() : 1; + CAFFE_ENFORCE(logits.sizes() == targets.sizes()); + CAFFE_ENFORCE(weights.sizes() == targets.sizes()); + const auto inner_size = logits.ndim() > 0 ? logits.sizes().back() : 1; const auto outer_size = logits.size() / inner_size; CAFFE_ENFORCE(g.size() == outer_size); diff --git a/caffe2/operators/deform_conv_op_impl.h b/caffe2/operators/deform_conv_op_impl.h index cad6f9c..5b3517e 100644 --- a/caffe2/operators/deform_conv_op_impl.h +++ b/caffe2/operators/deform_conv_op_impl.h @@ -275,7 +275,7 @@ bool DeformConvGradientOp::RunOnDeviceWithOrderNCHW() { // The col buffer is stored in CHW order as well - kernel_dim, and the // height and width. vector img_shape; - img_shape.assign(X.dims().begin() + 1, X.dims().end()); + img_shape.assign(X.sizes().begin() + 1, X.sizes().end()); vector col_buffer_shape; col_buffer_shape.push_back(C * kernel_dims_size); col_buffer_shape.insert( @@ -341,20 +341,20 @@ bool DeformConvGradientOp::RunOnDeviceWithOrderNCHW() { col_buffer_data, Xdata, offset_data, - X.dims(), + X.sizes(), col_buffer_shape, doffset_data); // Gradient with respect to input data if (dXdata) { DeformableCol2im( - col_buffer_data, offset_data, X.dims(), col_buffer_shape, dXdata); + col_buffer_data, offset_data, X.sizes(), col_buffer_shape, dXdata); dXdata += input_offset * group_; } // Gradient with respect to filter DeformableIm2col( - Xdata, offset_data, X.dims(), col_buffer_shape, col_buffer_data); + Xdata, offset_data, X.sizes(), col_buffer_shape, col_buffer_data); for (int group_id = 0; group_id < group_; ++group_id) { math::Gemm( diff --git a/caffe2/operators/distance_op.cu b/caffe2/operators/distance_op.cu index bfafb15..2dc51cc 100644 --- a/caffe2/operators/distance_op.cu +++ b/caffe2/operators/distance_op.cu @@ -43,9 +43,9 @@ bool SquaredL2DistanceOp::RunOnDevice() { X.dim32(i), Y.dim32(i), "Mismatch in dimensions", - X.dims(), + X.sizes(), " / ", - Y.dims()); + Y.sizes()); } int N = X.ndim() > 0 ? X.dim32(0) : 1; int D = X.size() / N; @@ -89,9 +89,9 @@ bool SquaredL2DistanceGradientOp::RunOnDevice() { X.dim32(i), Y.dim32(i), "Mismatch on dimensions: ", - X.dims(), + X.sizes(), " / ", - Y.dims()); + Y.sizes()); } CAFFE_ENFORCE_EQ(dDistance.ndim(), 1); CAFFE_ENFORCE_EQ(dDistance.dim32(0), N); @@ -221,9 +221,9 @@ bool L1DistanceGradientOp::RunOnDevice() { X.dim32(i), Y.dim32(i), "Mismatch on dimensions: ", - X.dims(), + X.sizes(), " / ", - Y.dims()); + Y.sizes()); } CAFFE_ENFORCE_EQ(dDistance.ndim(), 1); CAFFE_ENFORCE_EQ(dDistance.dim32(0), N); diff --git a/caffe2/operators/dropout_op.cu b/caffe2/operators/dropout_op.cu index 09ef3de..9d7d54d 100644 --- a/caffe2/operators/dropout_op.cu +++ b/caffe2/operators/dropout_op.cu @@ -21,8 +21,7 @@ __global__ void DropoutKernel( template <> bool DropoutOp::RunOnDevice() { auto& X = Input(0); - - auto* Y = Output(0, X.dims(), at::dtype()); + auto* Y = Output(0, X.sizes(), at::dtype()); if (is_test_) { if (Y != &X) { context_.CopySameDevice( @@ -34,8 +33,7 @@ bool DropoutOp::RunOnDevice() { // boolean numbers, we will generate into dY and write the result to // mask. float* Ydata = Y->template mutable_data(); - - auto* mask = Output(1, X.dims(), at::dtype()); + auto* mask = Output(1, X.sizes(), at::dtype()); CAFFE_ENFORCE(X.data() != Ydata, "In-place GPU dropout is broken"); CURAND_ENFORCE( curandGenerateUniform(context_.curand_generator(), Ydata, X.size())); @@ -69,8 +67,7 @@ __global__ void DropoutGradientKernel( template <> bool DropoutGradientOp::RunOnDevice() { auto& dY = Input(0); - - auto* dX = Output(0, dY.dims(), at::dtype()); + auto* dX = Output(0, dY.sizes(), at::dtype()); if (is_test_) { if (dX != &dY) { context_.CopySameDevice( diff --git a/caffe2/operators/gather_op.cuh b/caffe2/operators/gather_op.cuh index ed7c67c..16a85e3 100644 --- a/caffe2/operators/gather_op.cuh +++ b/caffe2/operators/gather_op.cuh @@ -62,7 +62,7 @@ static bool gather_impl_cuda( // New shape: // [data dims before axis] + [indices dims] + [data dims after axis] vector shape = - calc_output_shape_vector(data.dims(), indices.dims(), axis); + calc_output_shape_vector(data.sizes(), indices.sizes(), axis); Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType)); float* out = static_cast(output->raw_mutable_data(dataType)); diff --git a/caffe2/operators/integral_image_op.cu b/caffe2/operators/integral_image_op.cu index 4155a1f..1e2b710 100644 --- a/caffe2/operators/integral_image_op.cu +++ b/caffe2/operators/integral_image_op.cu @@ -124,7 +124,7 @@ bool IntegralImageOp::RunOnDevice() { // Input is (N, C, H, W) // Output is (N, C, H + 1, W + 1) - vector out_shape(X.dims().vec()); + vector out_shape(X.sizes().vec()); out_shape[2] += 1; // H + 1 output size out_shape[3] += 1; // W + 1 output size auto* Y = Output(0, out_shape, at::dtype()); @@ -172,7 +172,7 @@ bool IntegralImageGradientOp::RunOnDevice() { // Row pass reduces shape of dY from (N, C, H + 1, W + 1) // to (N, C, H + 1, W) // Col pass reduces shape to (N, C, H, W) - vector row_pass_shape(dY.dims().vec()); + vector row_pass_shape(dY.sizes().vec()); row_pass_shape[3] -= 1; row_pass_buffer_.Resize(row_pass_shape); const int chans = row_pass_buffer_.dim32(1); diff --git a/caffe2/operators/lengths_tile_op.cu b/caffe2/operators/lengths_tile_op.cu index 04fea15..5ea88ee 100644 --- a/caffe2/operators/lengths_tile_op.cu +++ b/caffe2/operators/lengths_tile_op.cu @@ -37,7 +37,7 @@ bool LengthsTileOp::RunOnDevice() { math::Sum( lengths_size, lengths_data, &total_length, &cpuContext); - auto shape = data.dims().vec(); + auto shape = data.sizes().vec(); shape[0] = total_length; auto* output = Output(0, shape, at::dtype()); diff --git a/caffe2/operators/pack_segments.cu b/caffe2/operators/pack_segments.cu index 2a2a40d..0ea9dc8 100644 --- a/caffe2/operators/pack_segments.cu +++ b/caffe2/operators/pack_segments.cu @@ -219,7 +219,7 @@ bool PackSegmentsOp::DoRunWithType2() { } // create output tensor - auto shape = data.dims().vec(); // Shape of out is batch_size x max_len x ... + auto shape = data.sizes().vec(); // Shape of out is batch_size x max_len x ... shape[0] = max_length; shape.insert(shape.begin(), lengths.size()); out->Resize(shape); @@ -310,7 +310,7 @@ bool UnpackSegmentsOp::DoRunWithType2() { context_); // create output tensor - auto shape = data.dims().vec(); + auto shape = data.sizes().vec(); CAFFE_ENFORCE_EQ( shape[0], lengths.dim(0), "LENGTH should match DATA in dimension 0"); shape.erase(shape.begin()); diff --git a/caffe2/operators/pool_op.cu b/caffe2/operators/pool_op.cu index 8a4360b..55c59cb 100644 --- a/caffe2/operators/pool_op.cu +++ b/caffe2/operators/pool_op.cu @@ -730,7 +730,7 @@ bool PoolGradientOp:: CAFFE_ENFORCE_EQ(dY.dim32(1), X.dim32(1)); auto* dX = Output(0); dX->ResizeLike(X); - vector dims(X.dims().begin() + 2, X.dims().end()); + vector dims(X.sizes().begin() + 2, X.sizes().end()); ConvPoolOpBase::ComputePads(dims); switch (kernel_.size()) { case 1: @@ -814,7 +814,7 @@ bool PoolGradientOp:: CAFFE_ENFORCE_EQ(X.dim32(X.ndim() - 1), dY.dim32(dY.ndim() - 1)); auto* dX = Output(0); dX->ResizeLike(X); - vector dims(X.dims().begin() + 1, X.dims().end() - 1); + vector dims(X.sizes().begin() + 1, X.sizes().end() - 1); ConvPoolOpBase::ComputePads(dims); switch (kernel_.size()) { case 1: @@ -1565,7 +1565,7 @@ bool PoolGradientOp::RunOnDeviceWithOrderNCHW() { CAFFE_ENFORCE_EQ(dY.ndim(), X.ndim()); auto* dX = Output(0); dX->ResizeLike(X); - vector dims(X.dims().begin() + 2, X.dims().end()); + vector dims(X.sizes().begin() + 2, X.sizes().end()); ConvPoolOpBase::ComputePads(dims); switch (kernel_.size()) { case 1: @@ -1654,7 +1654,7 @@ bool PoolGradientOp::RunOnDeviceWithOrderNHWC() { CAFFE_ENFORCE_EQ(dY.ndim(), X.ndim()); auto* dX = Output(0); dX->ResizeLike(X); - vector dims(X.dims().begin() + 1, X.dims().end() - 1); + vector dims(X.sizes().begin() + 1, X.sizes().end() - 1); ConvPoolOpBase::ComputePads(dims); switch (kernel_.size()) { case 1: diff --git a/caffe2/operators/pool_op_cudnn.cu b/caffe2/operators/pool_op_cudnn.cu index 292438e..0a6ede3 100644 --- a/caffe2/operators/pool_op_cudnn.cu +++ b/caffe2/operators/pool_op_cudnn.cu @@ -217,10 +217,10 @@ class CuDNNPoolOp : public ConvPoolOpBase { } } - if (cudnn_input_dims_ != X.dims()) { + if (cudnn_input_dims_ != X.sizes()) { // Dimensions changed; we will need to re-initialize things. VLOG(1) << "Changing the cudnn descriptor configurations."; - cudnn_input_dims_ = X.dims().vec(); + cudnn_input_dims_ = X.sizes().vec(); setTensorDescriptor(X.ndim(), order_, N, C, H, W, D, bottom_desc_); setTensorDescriptor( Y->ndim(), order_, N, C, H_out, W_out, D_out, top_desc_); @@ -420,10 +420,10 @@ class CuDNNPoolGradientOp : public ConvPoolOpBase { CAFFE_THROW("Unsupported kernel size :", kernel_.size()); } - if (cudnn_input_dims_ != X.dims()) { + if (cudnn_input_dims_ != X.sizes()) { // Dimensions changed; we will need to re-initialize things. VLOG(1) << "Changing the cudnn descriptor configurations."; - cudnn_input_dims_ = X.dims().vec(); + cudnn_input_dims_ = X.sizes().vec(); setTensorDescriptor(X.ndim(), order_, N, C, H, W, D, bottom_desc_); setTensorDescriptor( Y.ndim(), order_, N, C, H_out, W_out, D_out, top_desc_); diff --git a/caffe2/operators/resize_op.cu b/caffe2/operators/resize_op.cu index 271a51b..d9fbd88 100644 --- a/caffe2/operators/resize_op.cu +++ b/caffe2/operators/resize_op.cu @@ -75,7 +75,7 @@ bool ResizeNearestOp::RunOnDevice() { const auto& X = Input(0); - const auto inputDims = X.dims(); + const auto inputDims = X.sizes(); CAFFE_ENFORCE_EQ(4, inputDims.size()); const int batch_size = X.dim32(0), num_channels = X.dim32(1), input_height = X.dim32(2), input_width = X.dim32(3); @@ -118,7 +118,7 @@ bool ResizeNearestGradientOp::RunOnDevice() { const auto& X = Input(1); - const auto inputDims = dY.dims(); + const auto inputDims = dY.sizes(); CAFFE_ENFORCE_EQ(4, inputDims.size()); const int batch_size = dY.dim32(0), num_channels = dY.dim32(1), input_height = dY.dim32(2), input_width = dY.dim32(3); diff --git a/caffe2/operators/reverse_packed_segs_op.cu b/caffe2/operators/reverse_packed_segs_op.cu index 6f700aa..7d09d3c 100644 --- a/caffe2/operators/reverse_packed_segs_op.cu +++ b/caffe2/operators/reverse_packed_segs_op.cu @@ -58,15 +58,13 @@ void ReversePackedSegsOp::DoRunWithLengthType() { "segments, embeddings>"); CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D"); - - const auto shape = data.dims(); - auto* output = Output(0, shape, at::dtype()); + auto* output = Output(0, data.sizes(), at::dtype()); - const auto max_length = data.dims()[0]; - const auto batch_size = data.dims()[1]; - const auto block_size = data.dims()[2]; + const auto max_length = data.size(0); + const auto batch_size = data.size(1); + const auto block_size = data.size(2); CAFFE_ENFORCE( - lengths.dims()[0] == batch_size, + lengths.sizes()[0] == batch_size, "lenths size should be" " equal to batch size"); diff --git a/caffe2/operators/roi_pool_op.cu b/caffe2/operators/roi_pool_op.cu index db18b3f..65b1702 100644 --- a/caffe2/operators/roi_pool_op.cu +++ b/caffe2/operators/roi_pool_op.cu @@ -135,7 +135,7 @@ bool RoIPoolOp::RunOnDevice() { // mutable_data calls are needed to allocate the tensors Y->template mutable_data(); if (!is_test_) { - A->Resize(Y->dims()); + A->Resize(Y->sizes()); A->template mutable_data(); } return true; @@ -143,7 +143,7 @@ bool RoIPoolOp::RunOnDevice() { Y->Resize(R.dim32(0), X.dim32(1), pooled_height_, pooled_width_); if (!is_test_) { - A->Resize(Y->dims()); + A->Resize(Y->sizes()); } int output_size = Y->size(); int* argmax_data = is_test_ ? nullptr : A->template mutable_data(); diff --git a/caffe2/operators/sequence_ops.cu b/caffe2/operators/sequence_ops.cu index 7435eac..0a91573 100644 --- a/caffe2/operators/sequence_ops.cu +++ b/caffe2/operators/sequence_ops.cu @@ -234,9 +234,9 @@ template bool RemovePaddingOp::DoRunWithType() { const auto& in = Input(0); CAFFE_ENFORCE_GE(in.ndim(), 1); - const int32_t outer_size = in.dims()[0]; + const int32_t outer_size = in.sizes()[0]; const auto block_size = std::accumulate( - in.dims().begin() + 1, in.dims().end(), 1, std::multiplies()); + in.sizes().begin() + 1, in.sizes().end(), 1, std::multiplies()); // if no lengths is provided, assume it is a single full-span entry const int32_t* lengths_ptr = nullptr; @@ -247,7 +247,7 @@ bool RemovePaddingOp::DoRunWithType() { lengths_size = lengths.size(); } - auto out_dims = in.dims().vec(); + auto out_dims = in.sizes().vec(); out_dims[0] -= (startPaddingWidth_ + endPaddingWidth_) * lengths_size; auto* out = Output(0, out_dims, at::dtype()); const auto* in_ptr = in.template data(); diff --git a/caffe2/operators/slice_op.cu b/caffe2/operators/slice_op.cu index 7d888f2..2c9ba89 100644 --- a/caffe2/operators/slice_op.cu +++ b/caffe2/operators/slice_op.cu @@ -73,23 +73,23 @@ bool SliceImplGpu( for (int i = 0; i < data.ndim(); ++i) { if (i >= starts.size()) { starts_idx[i] = 0; - ends_idx[i] = data.dims()[i]; + ends_idx[i] = data.size(i); continue; } - if (data.dims()[i] > 0) { + if (data.size(i) > 0) { auto start = starts_data[i]; auto end = ends_data[i]; if (start < 0) { - start = data.dims()[i] + 1 + start; + start = data.sizes()[i] + 1 + start; } if (end < 0) { - end = data.dims()[i] + 1 + end; + end = data.sizes()[i] + 1 + end; } - if (start > data.dims()[i]) { - start = data.dims()[i]; + if (start > data.sizes()[i]) { + start = data.sizes()[i]; } - if (end > data.dims()[i]) { - end = data.dims()[i]; + if (end > data.sizes()[i]) { + end = data.sizes()[i]; } CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(end, 0); @@ -115,7 +115,7 @@ bool SliceImplGpu( // for now only supports slicing in 1 dimension int dim = -1; for (int i = 0; i < data.ndim(); ++i) { - if (starts_idx[i] > 0 || ends_idx[i] < data.dims()[i]) { + if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { CAFFE_ENFORCE_EQ( dim, -1, "Currently only possible to slice in 1 dimension."); dim = i; @@ -130,13 +130,13 @@ bool SliceImplGpu( return true; } int unit = std::accumulate( - data.dims().begin() + dim + 1, - data.dims().end(), + data.sizes().begin() + dim + 1, + data.sizes().end(), 1, std::multiplies()); int num_blocks = std::accumulate( - data.dims().begin(), - data.dims().begin() + dim, + data.sizes().begin(), + data.sizes().begin() + dim, 1, std::multiplies()); if (!backward) { @@ -154,7 +154,7 @@ bool SliceImplGpu( size_t src_nbytes = data.nbytes(); size_t dst_nbytes = output->nbytes(); - size_t src_block_size = unit * data.dims()[dim]; + size_t src_block_size = unit * data.sizes()[dim]; size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_offset = unit * starts_idx[dim]; @@ -187,7 +187,7 @@ bool SliceImplGpu( size_t dst_nbytes = gdata->nbytes(); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); - size_t dst_block_size = unit * data.dims()[dim]; + size_t dst_block_size = unit * data.sizes()[dim]; size_t dst_offset = unit * starts_idx[dim]; if (num_blocks == 0 || dst_block_size == 0) { diff --git a/caffe2/operators/sparse_to_dense_op.cu b/caffe2/operators/sparse_to_dense_op.cu index dc85ac3..d10899f 100644 --- a/caffe2/operators/sparse_to_dense_op.cu +++ b/caffe2/operators/sparse_to_dense_op.cu @@ -45,7 +45,7 @@ namespace caffe2 { const int output_first_dim = GetOutputFirstDim(sparse_indices_vec, sparse_indices_len); - auto shape = sparse_values.dims().vec(); + auto shape = sparse_values.sizes().vec(); shape[0] = output_first_dim; auto* output = Output(0, shape, at::dtype()); diff --git a/caffe2/operators/top_k.cu b/caffe2/operators/top_k.cu index 1b6bad7..703b679 100644 --- a/caffe2/operators/top_k.cu +++ b/caffe2/operators/top_k.cu @@ -187,7 +187,7 @@ bool TopKCudaOp::RunOnDevice() { auto* indices = Output(1); auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr; - at::IntList input_dims = input.dims(); + at::IntList input_dims = input.sizes(); if (axis_ == -1) { axis_ = input_dims.size() - 1; } diff --git a/caffe2/operators/upsample_op.cu b/caffe2/operators/upsample_op.cu index 8402840..d1d4f7c 100644 --- a/caffe2/operators/upsample_op.cu +++ b/caffe2/operators/upsample_op.cu @@ -175,7 +175,7 @@ template <> bool UpsampleBilinearOp::RunOnDevice() { const auto& X = Input(0); - const auto inputDims = X.dims(); + const auto inputDims = X.sizes(); CAFFE_ENFORCE_EQ(4, inputDims.size()); const int batch_size = X.dim32(0), num_channels = X.dim32(1), input_height = X.dim32(2), input_width = X.dim32(3); @@ -220,7 +220,7 @@ bool UpsampleBilinearGradientOp::RunOnDevice() { const auto& dY = Input(0); const auto& X = Input(1); - const auto inputDims = dY.dims(); + const auto inputDims = dY.sizes(); CAFFE_ENFORCE_EQ(4, inputDims.size()); const int batch_size = dY.dim32(0); const int num_channels = dY.dim32(1); diff --git a/caffe2/operators/weighted_sample_op.cu b/caffe2/operators/weighted_sample_op.cu index 83b7842..2e98692 100644 --- a/caffe2/operators/weighted_sample_op.cu +++ b/caffe2/operators/weighted_sample_op.cu @@ -64,8 +64,8 @@ bool WeightedSampleOp::RunOnDevice() { if (OutputSize() == 2) { auto& in_val = Input(1); CAFFE_ENFORCE_EQ( - in_weights.dims(), - in_val.dims(), + in_weights.sizes(), + in_val.sizes(), "The sampling weights tensor and the sampling values tensor must have the same dimensions."); in_val_data = in_val.data(); diff --git a/modules/detectron/ps_roi_pool_op.cu b/modules/detectron/ps_roi_pool_op.cu index dcd099b..186d0a1 100644 --- a/modules/detectron/ps_roi_pool_op.cu +++ b/modules/detectron/ps_roi_pool_op.cu @@ -247,7 +247,7 @@ bool PSRoIPoolOp::RunOnDevice() { // mapping_channel auto* Y = Output(0, {R.dim32(0), output_dim_, pooled_height_, pooled_width_}, at::dtype()); - auto* A = Output(1, Y->dims(), at::dtype()); + auto* A = Output(1, Y->sizes(), at::dtype()); int output_size = Y->size(); PSRoIPoolForward<<