From f1f7c16c9037b817f41d7b166d669aa226dc8220 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 13 Dec 2018 13:33:13 -0800 Subject: [PATCH] Tensor construction codemod(ResizeLike) - 4/7 (#15088) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15088 Codemod generated with clangr shard mode, 25 files per diff, motivation: https://github.com/pytorch/pytorch/pull/12407 Reviewed By: ezyang Differential Revision: D13419682 fbshipit-source-id: 3e59403bc1c0e71e5cb66df932ed0c6a0a72e643 --- caffe2/operators/cast_op.cc | 4 ++-- caffe2/operators/ceil_op.h | 4 ++-- caffe2/operators/channel_shuffle_op.cc | 16 ++++++++-------- caffe2/operators/clip_op.cc | 8 ++++---- caffe2/operators/conv_op_cudnn.cc | 17 ++++++++++------- caffe2/operators/conv_op_impl.h | 18 ++++++++++-------- caffe2/operators/conv_transpose_op_cudnn.cc | 11 +++++++---- caffe2/operators/conv_transpose_op_impl.h | 18 ++++++++++-------- caffe2/operators/cosine_embedding_criterion_op.cc | 7 +++---- caffe2/operators/crf_viterbi_op.cc | 3 +-- caffe2/operators/cross_entropy_op.cc | 14 ++++++-------- caffe2/operators/distance_op.cc | 21 +++++++++------------ caffe2/operators/distance_op.h | 14 ++++++-------- caffe2/operators/elementwise_div_gradient_op.cc | 14 ++++++++------ caffe2/operators/elementwise_linear_op.cc | 12 ++++-------- caffe2/operators/elementwise_logical_ops.h | 8 ++++---- caffe2/operators/elementwise_ops.cc | 6 +++--- caffe2/operators/elementwise_ops.h | 14 ++++++++------ caffe2/operators/elu_op_cudnn.cc | 8 ++++---- caffe2/operators/ensure_clipped_op.h | 4 ++-- caffe2/operators/expand_op.h | 4 ++-- caffe2/operators/find_op.h | 4 ++-- caffe2/operators/floor_op.h | 4 ++-- caffe2/operators/fully_connected_op.h | 11 +++-------- caffe2/operators/group_norm_op.h | 14 ++++++-------- 25 files changed, 126 insertions(+), 132 deletions(-) diff --git a/caffe2/operators/cast_op.cc b/caffe2/operators/cast_op.cc index 82b6355..f479ed4 100644 --- a/caffe2/operators/cast_op.cc +++ b/caffe2/operators/cast_op.cc @@ -20,8 +20,8 @@ template <> template bool CastOp::DoRunWithType() { auto& input = Input(0); - auto* output = Output(0); - output->ResizeLike(input); + + auto* output = Output(0, input.sizes(), at::dtype()); const auto* data = input.template data(); auto* out = output->template mutable_data(); auto N = input.numel(); diff --git a/caffe2/operators/ceil_op.h b/caffe2/operators/ceil_op.h index 6ca1de9..3283fbe 100644 --- a/caffe2/operators/ceil_op.h +++ b/caffe2/operators/ceil_op.h @@ -16,8 +16,8 @@ class CeilOp final : public Operator { bool RunOnDevice() override { auto& X = Input(0); - auto* Y = Output(0); - Y->ResizeLike(X); + + auto* Y = Output(0, X.sizes(), at::dtype()); const float* Xdata = X.template data(); float* Ydata = Y->template mutable_data(); diff --git a/caffe2/operators/channel_shuffle_op.cc b/caffe2/operators/channel_shuffle_op.cc index 0a7ae97..c3e8f88 100644 --- a/caffe2/operators/channel_shuffle_op.cc +++ b/caffe2/operators/channel_shuffle_op.cc @@ -66,8 +66,8 @@ void RunChannelShuffleNHWC( template <> bool ChannelShuffleOp::RunOnDeviceWithOrderNCHW() { const auto& X = Input(0); - auto* Y = Output(0); - Y->ResizeLike(X); + + auto* Y = Output(0, X.sizes(), at::dtype()); const int N = X.dim32(0); const int C = X.dim32(1); const int G = group_; @@ -83,8 +83,8 @@ bool ChannelShuffleOp::RunOnDeviceWithOrderNCHW() { template <> bool ChannelShuffleOp::RunOnDeviceWithOrderNHWC() { const auto& X = Input(0); - auto* Y = Output(0); - Y->ResizeLike(X); + + auto* Y = Output(0, X.sizes(), at::dtype()); const int ndim = X.dim(); const int N = X.dim32(0); const int C = X.dim32(ndim - 1); @@ -101,8 +101,8 @@ bool ChannelShuffleOp::RunOnDeviceWithOrderNHWC() { template <> bool ChannelShuffleGradientOp::RunOnDeviceWithOrderNCHW() { const auto& dY = Input(0); - auto* dX = Output(0); - dX->ResizeLike(dY); + + auto* dX = Output(0, dY.sizes(), at::dtype()); const int N = dY.dim32(0); const int C = dY.dim32(1); const int G = group_; @@ -118,8 +118,8 @@ bool ChannelShuffleGradientOp::RunOnDeviceWithOrderNCHW() { template <> bool ChannelShuffleGradientOp::RunOnDeviceWithOrderNHWC() { const auto& dY = Input(0); - auto* dX = Output(0); - dX->ResizeLike(dY); + + auto* dX = Output(0, dY.sizes(), at::dtype()); const int ndim = dY.dim(); const int N = dY.dim32(0); const int C = dY.dim32(ndim - 1); diff --git a/caffe2/operators/clip_op.cc b/caffe2/operators/clip_op.cc index 80f4451..4b79040 100644 --- a/caffe2/operators/clip_op.cc +++ b/caffe2/operators/clip_op.cc @@ -6,8 +6,8 @@ namespace caffe2 { template <> bool ClipOp::RunOnDevice() { auto& X = Input(0); - auto* Y = Output(0); - Y->ResizeLike(X); + + auto* Y = Output(0, X.sizes(), at::dtype()); EigenVectorMap(Y->template mutable_data(), Y->numel()) = ConstEigenVectorMap(X.data(), X.numel()) .cwiseMax(min_) @@ -19,10 +19,10 @@ template <> bool ClipGradientOp::RunOnDevice() { auto& Y = Input(0); auto& dY = Input(1); - auto* dX = Output(0); + CAFFE_ENFORCE_GE(Y.numel(), 0); CAFFE_ENFORCE_EQ(dY.numel(), Y.numel()); - dX->ResizeLike(Y); + auto* dX = Output(0, Y.sizes(), at::dtype()); const float* Ydata = Y.data(); const float* dYdata = dY.data(); float* dXdata = dX->template mutable_data(); diff --git a/caffe2/operators/conv_op_cudnn.cc b/caffe2/operators/conv_op_cudnn.cc index 1250f1b..9c96536 100644 --- a/caffe2/operators/conv_op_cudnn.cc +++ b/caffe2/operators/conv_op_cudnn.cc @@ -880,7 +880,6 @@ bool CudnnConvGradientOp::DoRunWithType() { auto& X = Input(INPUT); auto& filter = Input(FILTER); auto& dY = Input(OUTPUT_GRAD); - auto* dfilter = Output(FILTER_GRAD); CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5); CAFFE_ENFORCE(filter.dim() >= 3 && filter.dim() <= 5); @@ -945,7 +944,7 @@ bool CudnnConvGradientOp::DoRunWithType() { } else { CAFFE_THROW("Unsupported kernel size:", kernel_.size()); } - dfilter->ResizeLike(filter); + auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); // Set up the cudnn algorithms & workspace if necessary bool input_changed = (X.sizes() != cudnn_input_dims_); @@ -1173,9 +1172,10 @@ bool CudnnConvGradientOp::DoRunWithType() { data_perf_stat; cudnn_wrapper_.with_cudnn_state( cudnn_state_, [&](CuDNNState* state) { - auto* dX = - Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); - dX->ResizeLike(X); + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, + X.sizes(), + at::dtype()); const T_W* filter_data = filter.template data(); const T_DY* dYdata = dY.template data(); T_DX* dXdata = dX->template mutable_data(); @@ -1335,8 +1335,11 @@ bool CudnnConvGradientOp::DoRunWithType() { dfilter->template mutable_data())); if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { // Compute the gradient w.r.t. the input. - auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); - dX->ResizeLike(X); + + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, + X.sizes(), + at::dtype()); CUDNN_ENFORCE(cudnnConvolutionBackwardData( state->cudnn_handle(), cudnnTypeWrapper::kOne(), diff --git a/caffe2/operators/conv_op_impl.h b/caffe2/operators/conv_op_impl.h index 6ea6e7a..f87a4af 100644 --- a/caffe2/operators/conv_op_impl.h +++ b/caffe2/operators/conv_op_impl.h @@ -481,7 +481,7 @@ bool ConvGradientOp::RunOnDeviceWithOrderNCHW() { auto& X = Input(INPUT); auto& filter = Input(FILTER); auto& dY = Input(OUTPUT_GRAD); - auto* dfilter = Output(FILTER_GRAD); + const int N = X.dim32(0), C = X.dim32(1); const vector input_dims = this->GetDims(X); @@ -503,7 +503,7 @@ bool ConvGradientOp::RunOnDeviceWithOrderNCHW() { } CAFFE_ENFORCE_EQ(M % group_, 0); - dfilter->ResizeLike(filter); + auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); // The dimension of each kernel const int kernel_dim = C / group_ * kernel_dims_size; // The offset corresponding to a single input image, and a single output @@ -623,8 +623,9 @@ bool ConvGradientOp::RunOnDeviceWithOrderNCHW() { } if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { // Compute the gradient w.r.t. the input. - auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); - dX->ResizeLike(X); + + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); T* dXdata = dX->template mutable_data(); dYdata = dY.template data(); for (int image_id = 0; image_id < N; ++image_id) { @@ -688,7 +689,7 @@ bool ConvGradientOp::RunOnDeviceWithOrderNHWC() { auto& X = Input(INPUT); auto& filter = Input(FILTER); auto& dY = Input(OUTPUT_GRAD); - auto* dfilter = Output(FILTER_GRAD); + const int N = X.dim32(0), C = X.dim32(X.dim() - 1); const vector input_dims = this->GetDims(X); @@ -710,7 +711,7 @@ bool ConvGradientOp::RunOnDeviceWithOrderNHWC() { } CAFFE_ENFORCE_EQ(M % group_, 0); - dfilter->ResizeLike(filter); + auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); // The dimension of each kernel const int kernel_dim = C / group_ * kernel_dims_size; // The offset corresponding to a single input image, and a single output @@ -830,8 +831,9 @@ bool ConvGradientOp::RunOnDeviceWithOrderNHWC() { if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { // Compute the gradient w.r.t. the input. - auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); - dX->ResizeLike(X); + + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); T* dXdata = dX->template mutable_data(); for (int image_id = 0; image_id < N; ++image_id) { // Compute gradient into col_buffer. diff --git a/caffe2/operators/conv_transpose_op_cudnn.cc b/caffe2/operators/conv_transpose_op_cudnn.cc index a5df470..4ec5a6f 100644 --- a/caffe2/operators/conv_transpose_op_cudnn.cc +++ b/caffe2/operators/conv_transpose_op_cudnn.cc @@ -368,7 +368,7 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { auto& X = Input(INPUT); auto& filter = Input(FILTER); auto& dY = Input(OUTPUT_GRAD); - auto* dfilter = Output(FILTER_GRAD); + CAFFE_ENFORCE_EQ(X.dim(), 4); CAFFE_ENFORCE_EQ(filter.dim(), 4); int C = 0; @@ -413,7 +413,7 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { } // Since we only handle LegacyPadding::NOTSET, we don't need to // compute padding. - dfilter->ResizeLike(filter); + auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); // Set up the cudnn algorithms & workspace if necessary bool input_changed = (X.sizes() != cudnn_input_dims_); @@ -644,8 +644,11 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { // Compute the gradient w.r.t. the input. - auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); - dX->ResizeLike(X); + + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, + X.sizes(), + at::dtype()); CUDNN_ENFORCE(cudnnConvolutionForward( state->cudnn_handle(), cudnnTypeWrapper::kOne(), diff --git a/caffe2/operators/conv_transpose_op_impl.h b/caffe2/operators/conv_transpose_op_impl.h index 86a7318..993bfc9 100644 --- a/caffe2/operators/conv_transpose_op_impl.h +++ b/caffe2/operators/conv_transpose_op_impl.h @@ -250,7 +250,7 @@ bool ConvTransposeGradientOp::RunOnDeviceWithOrderNCHW() { auto& X = Input(INPUT); auto& filter = Input(FILTER); auto& dY = Input(OUTPUT_GRAD); - auto* dfilter = Output(FILTER_GRAD); + const int N = X.dim32(0), M = X.dim32(1), H = X.dim32(2), W = X.dim32(3); // We only handle LegacyPadding::NOTSET case and ignore cases of // LegacyPadding::VALID and LegacyPadding::SAME @@ -264,7 +264,7 @@ bool ConvTransposeGradientOp::RunOnDeviceWithOrderNCHW() { CAFFE_ENFORCE( filter.dim32(3) == this->kernel_w(), "filter width must be equal to kernel width"); - dfilter->ResizeLike(filter); + auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); const int kernel_dim = C * this->kernel_h() * this->kernel_w(); const int output_image_size = dY.dim32(2) * dY.dim32(3); @@ -353,8 +353,9 @@ bool ConvTransposeGradientOp::RunOnDeviceWithOrderNCHW() { // Compute gradients w.r.t. the input // Since we have changed dYdata in the above loop, we will need to reset. dYdata = dY.template data(); - auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); - dX->ResizeLike(X); + + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); T* dXdata = dX->template mutable_data(); for (auto image_id = 0; image_id < N; ++image_id) { // Im2Col. @@ -402,7 +403,7 @@ bool ConvTransposeGradientOp::RunOnDeviceWithOrderNHWC() { auto& X = Input(INPUT); auto& filter = Input(FILTER); auto& dY = Input(OUTPUT_GRAD); - auto* dfilter = Output(FILTER_GRAD); + const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), M = X.dim32(3); // We only handle LegacyPadding::NOTSET case and ignore cases of // LegacyPadding::VALID and LegacyPadding::SAME @@ -416,7 +417,7 @@ bool ConvTransposeGradientOp::RunOnDeviceWithOrderNHWC() { filter.dim32(2) == this->kernel_w(), "filter width must be equal to kernel width"); const int C = filter.dim32(3); - dfilter->ResizeLike(filter); + auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); const int kernel_dim = C * this->kernel_h() * this->kernel_w(); const int output_image_size = dY.dim32(1) * dY.dim32(2); @@ -505,8 +506,9 @@ bool ConvTransposeGradientOp::RunOnDeviceWithOrderNHWC() { // Compute gradients w.r.t. the input // Since we have changed dYdata in the above loop, we will need to reset. dYdata = dY.template data(); - auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD); - dX->ResizeLike(X); + + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); T* dXdata = dX->template mutable_data(); for (auto image_id = 0; image_id < N; ++image_id) { // Im2Col. diff --git a/caffe2/operators/cosine_embedding_criterion_op.cc b/caffe2/operators/cosine_embedding_criterion_op.cc index 3900cf1..a956a67 100644 --- a/caffe2/operators/cosine_embedding_criterion_op.cc +++ b/caffe2/operators/cosine_embedding_criterion_op.cc @@ -10,11 +10,11 @@ template <> bool CosineEmbeddingCriterionOp::RunOnDevice() { auto& S = Input(0); auto& Y = Input(1); - auto* output = Output(0); + CAFFE_ENFORCE( S.numel() == Y.numel(), "The embedding and label should have the same size."); - output->ResizeLike(S); + auto* output = Output(0, S.sizes(), at::dtype()); const float* Sdata = S.data(); const int* Ydata = Y.data(); @@ -31,9 +31,8 @@ bool CosineEmbeddingCriterionGradientOp::RunOnDevice() { auto& S = Input(0); auto& Y = Input(1); auto& dOutput = Input(2); - auto* dS = Output(0); - dS->ResizeLike(S); + auto* dS = Output(0, S.sizes(), at::dtype()); const float* Sdata = S.data(); const int* Ydata = Y.data(); diff --git a/caffe2/operators/crf_viterbi_op.cc b/caffe2/operators/crf_viterbi_op.cc index 279534e..7042470 100644 --- a/caffe2/operators/crf_viterbi_op.cc +++ b/caffe2/operators/crf_viterbi_op.cc @@ -151,7 +151,6 @@ class SwapBestPathOp : public Operator { bool RunOnDevice() override { auto& data = Input(0); auto& newBestIdicies = Input(1); - auto* updatedData = Output(0); CAFFE_ENFORCE( data.dim() == 2 && newBestIdicies.dim() == 1, @@ -161,7 +160,7 @@ class SwapBestPathOp : public Operator { data.size(0) == newBestIdicies.size(0), "predictions and bestPath dimensions not matching"); - updatedData->ResizeLike(data); + auto* updatedData = Output(0, data.sizes(), at::dtype()); float* outData = updatedData->template mutable_data(); context_.CopyItemsSameDevice( data.dtype(), data.numel(), data.template data(), outData); diff --git a/caffe2/operators/cross_entropy_op.cc b/caffe2/operators/cross_entropy_op.cc index aee828b..1b35341 100644 --- a/caffe2/operators/cross_entropy_op.cc +++ b/caffe2/operators/cross_entropy_op.cc @@ -120,8 +120,7 @@ bool SigmoidCrossEntropyWithLogitsGradientOp::RunOnDevice() { const auto outer_size = logits.numel() / inner_size; CAFFE_ENFORCE(g.numel() == outer_size); - auto* out = Output(0); - out->ResizeLike(logits); + auto* out = Output(0, logits.sizes(), at::dtype()); auto* out_ptr = out->template mutable_data(); auto* logits_ptr = logits.data(); @@ -198,8 +197,7 @@ bool WeightedSigmoidCrossEntropyWithLogitsGradientOp:: const auto outer_size = logits.numel() / inner_size; CAFFE_ENFORCE(g.numel() == outer_size); - auto* out = Output(0); - out->ResizeLike(logits); + auto* out = Output(0, logits.sizes(), at::dtype()); auto* out_ptr = out->template mutable_data(); auto* logits_ptr = logits.data(); @@ -225,7 +223,7 @@ bool LabelCrossEntropyGradientOp::RunOnDevice() { auto& X = Input(0); auto& label = Input(1); auto& dY = Input(2); - auto* dX = Output(0); + int N, D; if (X.dim() > 1) { N = X.dim32(0); @@ -239,7 +237,7 @@ bool LabelCrossEntropyGradientOp::RunOnDevice() { CAFFE_ENFORCE_EQ(label.dim32(0), N); CAFFE_ENFORCE_EQ(dY.dim(), 1); CAFFE_ENFORCE_EQ(dY.dim32(0), N); - dX->ResizeLike(X); + auto* dX = Output(0, X.sizes(), at::dtype()); math::Set( dX->numel(), 0.f, dX->template mutable_data(), &context_); const float* Xdata = X.data(); @@ -333,7 +331,7 @@ bool CrossEntropyGradientOp::RunOnDevice() { auto& X = Input(0); auto& label = Input(1); auto& dY = Input(2); - auto* dX = Output(0); + int N, D; if (X.dim() > 1) { N = X.dim32(0); @@ -347,7 +345,7 @@ bool CrossEntropyGradientOp::RunOnDevice() { CAFFE_ENFORCE_EQ(label.dim32(0), N); CAFFE_ENFORCE_EQ(dY.dim(), 1); CAFFE_ENFORCE_EQ(dY.dim32(0), N); - dX->ResizeLike(X); + auto* dX = Output(0, X.sizes(), at::dtype()); math::Set( dX->numel(), 0.f, dX->template mutable_data(), &context_); const float* Xdata = X.data(); diff --git a/caffe2/operators/distance_op.cc b/caffe2/operators/distance_op.cc index b7ec618..3c0d5ac 100644 --- a/caffe2/operators/distance_op.cc +++ b/caffe2/operators/distance_op.cc @@ -66,8 +66,7 @@ bool L1DistanceGradientOp::RunOnDevice() { auto& X = Input(0); auto& Y = Input(1); auto& dDistance = Input(2); - auto* dX = Output(0); - auto* dY = Output(1); + CAFFE_ENFORCE_EQ(X.dim(), Y.dim()); for (int i = 0; i < X.dim(); ++i) { CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i)); @@ -80,8 +79,8 @@ bool L1DistanceGradientOp::RunOnDevice() { } CAFFE_ENFORCE(dDistance.dim() == 1); CAFFE_ENFORCE(dDistance.dim32(0) == N); - dX->ResizeLike(X); - dY->ResizeLike(Y); + auto* dX = Output(0, X.sizes(), at::dtype()); + auto* dY = Output(1, Y.sizes(), at::dtype()); for (int i = 0; i < N; ++i) { auto offset = i * D; @@ -143,8 +142,7 @@ bool CosineSimilarityGradientOp::RunOnDevice() { auto& X = Input(X_IN); auto& Y = Input(Y_IN); auto& dCos = Input(DER_COS_IN); - auto* dX = Output(DER_X_OUT); - auto* dY = Output(DER_Y_OUT); + const int N = X.dim() > 0 ? X.dim32(0) : 1; const int D = X.size_from_dim(1); CAFFE_ENFORCE(X.dim() == Y.dim()); @@ -153,8 +151,8 @@ bool CosineSimilarityGradientOp::RunOnDevice() { } CAFFE_ENFORCE(dCos.dim() == 1); CAFFE_ENFORCE(dCos.dim32(0) == N); - dX->ResizeLike(X); - dY->ResizeLike(Y); + auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype()); + auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype()); const auto* X_data = X.template data(); const auto* Y_data = Y.template data(); @@ -260,8 +258,7 @@ bool DotProductGradientOp::RunOnDevice() { auto& X = Input(X_IN); auto& Y = Input(Y_IN); auto& dDot = Input(DER_DOT_IN); - auto* dX = Output(DER_X_OUT); - auto* dY = Output(DER_Y_OUT); + int N, D; if (X.numel() > 0) { N = X.dim() > 0 ? X.dim32(0) : 1; @@ -276,8 +273,8 @@ bool DotProductGradientOp::RunOnDevice() { } CAFFE_ENFORCE(dDot.dim() == 1); CAFFE_ENFORCE(dDot.dim32(0) == N); - dX->ResizeLike(X); - dY->ResizeLike(Y); + auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype()); + auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype()); const auto* X_data = X.template data(); const auto* Y_data = Y.template data(); diff --git a/caffe2/operators/distance_op.h b/caffe2/operators/distance_op.h index 11bb4b2..c36fff4 100644 --- a/caffe2/operators/distance_op.h +++ b/caffe2/operators/distance_op.h @@ -31,8 +31,7 @@ class SquaredL2DistanceGradientOp final : public Operator { auto& X = Input(0); auto& Y = Input(1); auto& dDistance = Input(2); - auto* dX = Output(0); - auto* dY = Output(1); + int N = X.dim() > 0 ? X.dim32(0) : 1; int D = N > 0 ? X.numel() / N : 0; CAFFE_ENFORCE(X.dim() == Y.dim()); @@ -41,8 +40,8 @@ class SquaredL2DistanceGradientOp final : public Operator { } CAFFE_ENFORCE(dDistance.dim() == 1); CAFFE_ENFORCE(dDistance.dim32(0) == N); - dX->ResizeLike(X); - dY->ResizeLike(Y); + auto* dX = Output(0, X.sizes(), at::dtype()); + auto* dY = Output(1, Y.sizes(), at::dtype()); math::Sub( X.numel(), X.template data(), @@ -190,8 +189,7 @@ class DotProductWithPaddingGradientOp final : public Operator { auto& X = Input(X_IN); auto& Y = Input(Y_IN); auto& dDot = Input(DER_DOT_IN); - auto* dX = Output(DER_X_OUT); - auto* dY = Output(DER_Y_OUT); + int N, D, DX, DY, restD; if (X.numel() > 0) { N = X.dim() > 0 ? X.dim32(0) : 1; @@ -209,8 +207,8 @@ class DotProductWithPaddingGradientOp final : public Operator { CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0)); CAFFE_ENFORCE_EQ(dDot.dim(), 1); CAFFE_ENFORCE_EQ(dDot.dim32(0), N); - dX->ResizeLike(X); - dY->ResizeLike(Y); + auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype()); + auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype()); const auto* X_data = X.template data(); const auto* Y_data = Y.template data(); diff --git a/caffe2/operators/elementwise_div_gradient_op.cc b/caffe2/operators/elementwise_div_gradient_op.cc index f0339ad..e9d9e30 100644 --- a/caffe2/operators/elementwise_div_gradient_op.cc +++ b/caffe2/operators/elementwise_div_gradient_op.cc @@ -173,14 +173,14 @@ class BinaryElementwiseWithArgsGradientOp< template bool DoRunWithType() { - auto* dA = Output(0); - auto* dB = Output(1); const T* dC_data = nullptr; const T* A_data = nullptr; const T* B_data = nullptr; const T* C_data = nullptr; std::vector A_dims; std::vector B_dims; + at::IntList dA_sizes; + at::IntList dB_sizes; if (InputSize() == 3) { const auto& B = Input(0); const auto& C = Input(1); @@ -207,8 +207,8 @@ class BinaryElementwiseWithArgsGradientOp< B_data = B.template data(); C_data = C.template data(); dC_data = dC.template data(); - dA->ResizeLike(C); - dB->ResizeLike(B); + dA_sizes = C.sizes(); + dB_sizes = B.sizes(); } else { const auto& dC = Input(0); const auto& A = Input(1); @@ -237,9 +237,11 @@ class BinaryElementwiseWithArgsGradientOp< A_data = A.template data(); B_data = B.template data(); C_data = C.template data(); - dA->ResizeLike(A); - dB->ResizeLike(B); + dA_sizes = A.sizes(); + dB_sizes = B.sizes(); } + auto* dA = Output(0, dA_sizes, at::dtype()); + auto* dB = Output(1, dB_sizes, at::dtype()); auto* dA_data = dA->template mutable_data(); auto* dB_data = dB->template mutable_data(); return functor_.Backward( diff --git a/caffe2/operators/elementwise_linear_op.cc b/caffe2/operators/elementwise_linear_op.cc index b8ad7b9..92e205e 100644 --- a/caffe2/operators/elementwise_linear_op.cc +++ b/caffe2/operators/elementwise_linear_op.cc @@ -7,7 +7,6 @@ bool ElementwiseLinearOp::RunOnDevice(){ const auto& X = Input(0); const auto& a = Input(1); const auto& b = Input(2); - auto* Y = Output(0); const auto canonical_axis = X.canonical_axis_index(axis_); const int N = X.size_to_dim(canonical_axis); @@ -18,7 +17,7 @@ bool ElementwiseLinearOp::RunOnDevice(){ CAFFE_ENFORCE_EQ(b.dim(), 1, b.dim()); CAFFE_ENFORCE_EQ(b.size(0), D, b.dim()); - Y->ResizeLike(X); + auto* Y = Output(0, X.sizes(), at::dtype()); const float* X_data = X.data(); const float* a_data = a.data(); @@ -48,12 +47,9 @@ bool ElementwiseLinearGradientOp::RunOnDevice(){ CAFFE_ENFORCE_EQ(a.dim(), 1, a.dim()); CAFFE_ENFORCE_EQ(a.size(0), D, a.dim()); - auto* g_X = Output(0); - auto *g_a = Output(1); - auto *g_b = Output(2); - g_X->ResizeLike(X); - g_a->ResizeLike(a); - g_b->ResizeLike(a); + auto* g_X = Output(0, X.sizes(), at::dtype()); + auto* g_a = Output(1, a.sizes(), at::dtype()); + auto* g_b = Output(2, a.sizes(), at::dtype()); const float* g_o_data = g_o.data(); const float* X_data = X.data(); diff --git a/caffe2/operators/elementwise_logical_ops.h b/caffe2/operators/elementwise_logical_ops.h index 4b74327..43d064d 100644 --- a/caffe2/operators/elementwise_logical_ops.h +++ b/caffe2/operators/elementwise_logical_ops.h @@ -32,7 +32,7 @@ class WhereOp final : public Operator { auto& select = Input(0); auto& left = Input(1); auto& right = Input(2); - auto* output = Output(0); + if (enable_broadcast_) { CAFFE_ENFORCE_EQ(select.dim(), 1); CAFFE_ENFORCE_EQ(select.size(0), right.size(0)); @@ -41,7 +41,7 @@ class WhereOp final : public Operator { CAFFE_ENFORCE_EQ(select.sizes(), left.sizes()); CAFFE_ENFORCE_EQ(select.sizes(), right.sizes()); } - output->ResizeLike(left); + auto* output = Output(0, left.sizes(), at::dtype()); const bool* select_data = select.template data(); const T* left_data = left.template data(); @@ -147,8 +147,8 @@ class IsMemberOfOp final : public Operator { template bool DoRunWithType() { auto& input = Input(0); - auto* output = Output(0); - output->ResizeLike(input); + + auto* output = Output(0, input.sizes(), at::dtype()); if (!values_.has_values()) { values_.set(this->template GetRepeatedArgument(VALUE_TAG)); diff --git a/caffe2/operators/elementwise_ops.cc b/caffe2/operators/elementwise_ops.cc index 017d7b6..846d6e9 100644 --- a/caffe2/operators/elementwise_ops.cc +++ b/caffe2/operators/elementwise_ops.cc @@ -99,9 +99,9 @@ template bool SumReduceLikeOp::DoRunWithType() { const auto& A = Input(0); const auto& B = Input(1); - auto* C = Output(0); - CAFFE_ENFORCE(&B != C, "In-place is not allowed."); - C->ResizeLike(B); + + CAFFE_ENFORCE(!IsInputOutputAlias(1, 0), "In-place is not allowed."); + auto* C = Output(0, B.sizes(), at::dtype()); const T* Adata = A.template data(); auto* Cdata = C->template mutable_data(); if (B.numel() == 1) { diff --git a/caffe2/operators/elementwise_ops.h b/caffe2/operators/elementwise_ops.h index 2a42dd2..2cc32df 100644 --- a/caffe2/operators/elementwise_ops.h +++ b/caffe2/operators/elementwise_ops.h @@ -52,8 +52,9 @@ class UnaryElementwiseWithArgsOp final : public Operator { template bool DoRunWithType() { const auto& X = Input(0); - auto* Y = Output(0); - Y->ResizeLike(X); + + auto* Y = Output( + 0, X.sizes(), at::dtype>()); return functor_( X.numel(), X.template data(), @@ -261,8 +262,7 @@ class BinaryElementwiseWithArgsGradientOp final : public Operator { const auto& dC = Input(0); const auto& A = Input(1); const auto& B = Input(2); - auto* dA = Output(0); - auto* dB = Output(1); + vector A_dims; vector B_dims; if (legacy_broadcast_) { @@ -292,8 +292,10 @@ class BinaryElementwiseWithArgsGradientOp final : public Operator { dC.template data>(); const T* A_data = A.template data(); const T* B_data = B.template data(); - dA->ResizeLike(A); - dB->ResizeLike(B); + auto* dA = Output( + 0, A.sizes(), at::dtype>()); + auto* dB = Output( + 1, B.sizes(), at::dtype>()); auto* dA_data = dA->template mutable_data>(); auto* dB_data = diff --git a/caffe2/operators/elu_op_cudnn.cc b/caffe2/operators/elu_op_cudnn.cc index bbfbeb5..b3bc299 100644 --- a/caffe2/operators/elu_op_cudnn.cc +++ b/caffe2/operators/elu_op_cudnn.cc @@ -27,8 +27,8 @@ class CuDNNActivationOp final template bool DoRunWithType() { const auto& X = Input(0); - auto* Y = Output(0); - Y->ResizeLike(X); + + auto* Y = Output(0, X.sizes(), at::dtype()); if (X.numel() == 0) { Y->template mutable_data(); return true; @@ -74,8 +74,8 @@ class CuDNNActivationGradientOp final bool DoRunWithType() { const auto& Y = Input(0); const auto& dY = Input(1); - auto* dX = Output(0); - dX->ResizeLike(Y); + + auto* dX = Output(0, Y.sizes(), at::dtype()); if (Y.numel() == 0) { dX->template mutable_data(); return true; diff --git a/caffe2/operators/ensure_clipped_op.h b/caffe2/operators/ensure_clipped_op.h index 66c5702..a30009a 100644 --- a/caffe2/operators/ensure_clipped_op.h +++ b/caffe2/operators/ensure_clipped_op.h @@ -33,8 +33,8 @@ class EnsureClippedOp final : public Operator { this, Input(INDICES)); } else { auto& X = Input(PARAM); - auto* Y = Output(OUTPUT_PARAM); - Y->ResizeLike(X); + + auto* Y = Output(OUTPUT_PARAM, X.sizes(), at::dtype()); EigenVectorMap(Y->template mutable_data(), Y->numel()) = ConstEigenVectorMap(X.template data(), X.numel()) .cwiseMax(min_) diff --git a/caffe2/operators/expand_op.h b/caffe2/operators/expand_op.h index 7c60456..30860ba 100644 --- a/caffe2/operators/expand_op.h +++ b/caffe2/operators/expand_op.h @@ -82,11 +82,11 @@ class ExpandGradientOp final : public Operator { bool DoRunWithType() { const auto& dY = Input(0); const auto& X = Input(1); - auto* dX = Output(0); + const int ndim = dY.dim(); const std::vector dX_dims(X.sizes().cbegin(), X.sizes().cend()); const std::vector dY_dims(dY.sizes().cbegin(), dY.sizes().cend()); - dX->ResizeLike(X); + auto* dX = Output(0, X.sizes(), at::dtype()); std::vector axes; const int offset = ndim - X.dim(); for (int i = 0; i < ndim; i++) { diff --git a/caffe2/operators/find_op.h b/caffe2/operators/find_op.h index 5aa5e70..54f089d 100644 --- a/caffe2/operators/find_op.h +++ b/caffe2/operators/find_op.h @@ -28,8 +28,8 @@ class FindOp final : public Operator { bool DoRunWithType() { auto& idx = Input(0); auto& needles = Input(1); - auto* res_indices = Output(0); - res_indices->ResizeLike(needles); + + auto* res_indices = Output(0, needles.sizes(), at::dtype()); const T* idx_data = idx.template data(); const T* needles_data = needles.template data(); diff --git a/caffe2/operators/floor_op.h b/caffe2/operators/floor_op.h index fee7304..6af9b41 100644 --- a/caffe2/operators/floor_op.h +++ b/caffe2/operators/floor_op.h @@ -16,8 +16,8 @@ class FloorOp final : public Operator { bool RunOnDevice() override { auto& X = Input(0); - auto* Y = Output(0); - Y->ResizeLike(X); + + auto* Y = Output(0, X.sizes(), at::dtype()); const float* Xdata = X.template data(); float* Ydata = Y->template mutable_data(); diff --git a/caffe2/operators/fully_connected_op.h b/caffe2/operators/fully_connected_op.h index 12133ce..97931ea 100644 --- a/caffe2/operators/fully_connected_op.h +++ b/caffe2/operators/fully_connected_op.h @@ -207,9 +207,7 @@ class FullyConnectedGradientOp : public Operator { CAFFE_ENFORCE(M * K == X.numel(), dimErrorString()); CAFFE_ENFORCE(K * N == W.numel(), dimErrorString()); - auto* dW = Output(0); - - dW->ResizeLike(W); + auto* dW = Output(0, W.sizes(), at::dtype()); auto* db = Output(1, {N}, at::dtype()); if (X.numel() == 0) { @@ -226,9 +224,7 @@ class FullyConnectedGradientOp : public Operator { &context_); if (OutputSize() == 3) { - auto* dX = Output(2); - dX->ResizeLike(X); - dX->template mutable_data(); + Output(2, X.sizes(), at::dtype()); } return true; @@ -278,8 +274,7 @@ class FullyConnectedGradientOp : public Operator { // Compute dX if (OutputSize() == 3) { - auto* dX = Output(2); - dX->ResizeLike(X); + auto* dX = Output(2, X.sizes(), at::dtype()); math::Gemm( CblasNoTrans, TransposeWeight ? CblasNoTrans : CblasTrans, diff --git a/caffe2/operators/group_norm_op.h b/caffe2/operators/group_norm_op.h index 7f12c8a..b2b750f 100644 --- a/caffe2/operators/group_norm_op.h +++ b/caffe2/operators/group_norm_op.h @@ -47,8 +47,8 @@ class GroupNormOp final : public Operator { CAFFE_ENFORCE_EQ(beta.numel(), C); const int G = group_; const int D = C / G; - auto* Y = Output(OUTPUT); - Y->ResizeLike(X); + + auto* Y = Output(OUTPUT, X.sizes(), at::dtype()); T* mu_data = nullptr; T* rsig_data = nullptr; if (OutputSize() == 3) { @@ -218,12 +218,10 @@ class GroupNormGradientOp final : public Operator { CAFFE_ENFORCE_EQ(beta.numel(), C); const int G = group_; const int D = C / G; - auto* dX = Output(INPUT_GRAD); - auto* dgamma = Output(GAMMA_GRAD); - auto* dbeta = Output(BETA_GRAD); - dX->ResizeLike(X); - dgamma->ResizeLike(gamma); - dbeta->ResizeLike(beta); + + auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype()); + auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype()); + auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype()); return RunOnDeviceImpl( N, G, -- 2.7.4