From 07c49916223624d24a522035649648d32924795d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 4 Jan 2019 13:23:21 -0800 Subject: [PATCH] Tensor construction codemod - 2/2 (#15600) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15600 Codemod generated with clangr shard mode, 25 files per diff, motivation: https://github.com/pytorch/pytorch/pull/12407 Reviewed By: dzhulgakov Differential Revision: D13542455 fbshipit-source-id: 8a3b15b0a1f81565f34e309114e1c3e1f7f65a3c --- caffe2/operators/segment_reduction_op_gpu.cu | 38 +++++++++------------- caffe2/operators/sequence_ops.cu | 15 +++------ caffe2/operators/softmax_ops.cu | 11 ++++--- caffe2/operators/sparse_to_dense_op.cu | 4 +-- caffe2/operators/summarize_op.cu | 3 +- caffe2/operators/unique_ops.cu | 6 ++-- caffe2/operators/upsample_op.cu | 12 ++++--- caffe2/operators/utility_ops.h | 11 +++---- caffe2/operators/weighted_sample_op.cu | 14 +++----- modules/detectron/ps_roi_pool_op.cu | 8 ++--- modules/detectron/roi_pool_f_op.cu | 14 +++----- modules/detectron/sample_as_op.cu | 4 +-- modules/detectron/select_smooth_l1_loss_op.cu | 4 +-- modules/detectron/sigmoid_cross_entropy_loss_op.cu | 4 +-- modules/detectron/sigmoid_focal_loss_op.cu | 4 +-- modules/detectron/smooth_l1_loss_op.cu | 4 +-- modules/detectron/softmax_focal_loss_op.cu | 8 ++--- modules/detectron/spatial_narrow_as_op.cu | 9 ++--- 18 files changed, 77 insertions(+), 96 deletions(-) diff --git a/caffe2/operators/segment_reduction_op_gpu.cu b/caffe2/operators/segment_reduction_op_gpu.cu index bf6c69e..0e5e0c5 100644 --- a/caffe2/operators/segment_reduction_op_gpu.cu +++ b/caffe2/operators/segment_reduction_op_gpu.cu @@ -430,7 +430,6 @@ class CUDASparseLengthsSumOp : public Operator { bool DoRunWithType() { auto& dataInput = Input(0); auto& lengthsInput = Input(LENGTHS); - auto* output = Output(0); CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector"); const int64_t dataSize = dataInput.dim(0); @@ -441,7 +440,7 @@ class CUDASparseLengthsSumOp : public Operator { auto shape = dataInput.dims().vec(); shape[0] = outputSize; - output->Resize(shape); + auto* output = Output(0, shape, at::dtype()); T* out_data = output->template mutable_data(); if (len_length <= 0) { @@ -551,7 +550,6 @@ class CUDASparseLengthsMeanOp : public Operator { bool DoRunWithType() { auto& dataInput = Input(0); auto& lengthsInput = Input(LENGTHS); - auto* output = Output(0); CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector"); const int64_t dataSize = dataInput.dim(0); @@ -562,7 +560,7 @@ class CUDASparseLengthsMeanOp : public Operator { auto shape = dataInput.dims().vec(); shape[0] = outputSize; - output->Resize(shape); + auto* output = Output(0, shape, at::dtype()); T* out_data = output->template mutable_data(); if (len_length <= 0) { @@ -673,7 +671,6 @@ class CUDASparseLengthsMaxOp : public Operator { bool DoRunWithType() { auto& dataInput = Input(0); auto& lengthsInput = Input(LENGTHS); - auto* output = Output(0); CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector"); const int64_t dataSize = dataInput.dim(0); @@ -684,7 +681,7 @@ class CUDASparseLengthsMaxOp : public Operator { auto shape = dataInput.dims().vec(); shape[0] = outputSize; - output->Resize(shape); + auto* output = Output(0, shape, at::dtype()); if (len_length <= 0) { // return early to avoid invalid empty kernel @@ -804,7 +801,6 @@ class CUDASparseLengthsWeightedSumOp : public Operator { auto& weightsInput = Input(WEIGHTS); auto& indicesInput = Input(INDICES); auto& lengthsInput = Input(LENGTHS); - auto* output = Output(0); CAFFE_ENFORCE_EQ(1, weightsInput.ndim(), "WEIGHTS must be a vector"); CAFFE_ENFORCE_EQ(1, indicesInput.ndim(), "INDICES must be a vector"); @@ -818,7 +814,7 @@ class CUDASparseLengthsWeightedSumOp : public Operator { auto shape = dataInput.dims().vec(); shape[0] = outputSize; - output->Resize(shape); + auto* output = Output(0, shape, at::dtype()); T* out_data = output->template mutable_data(); if (len_length <= 0) { @@ -940,7 +936,6 @@ class CUDAUnsortedSegmentSumOp : public Operator { bool RunOnDevice() override { auto& data = Input(0); auto& segment_ids = Input(1); - auto* output = Output(0); if (segment_ids.size() == 0 || data.size() == 0) { // Special handling for empty input @@ -948,8 +943,7 @@ class CUDAUnsortedSegmentSumOp : public Operator { if (dims.size() > 0) { dims[0] = 0; } - output->Resize(dims); - output->template mutable_data(); + Output(0, dims, at::dtype()); return true; } @@ -995,7 +989,7 @@ class CUDAUnsortedSegmentSumOp : public Operator { auto dims = data.dims().vec(); dims[0] = K + 1; - output->Resize(dims); + auto* output = Output(0, dims, at::dtype()); // Clear the output as we will be accumulating the values math::Set( @@ -1300,7 +1294,7 @@ class CUDASparseLengthsSumGradientWithIndicesOp : public Operator { auto& segmentGradsInput = Input(0); auto& lengthsInput = Input(1); auto& indicesInput = Input(2); - auto* dataGradsOutput = Output(0); + CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector"); const int len_length = lengthsInput.dim(0); @@ -1310,7 +1304,7 @@ class CUDASparseLengthsSumGradientWithIndicesOp : public Operator { auto shape = segmentGradsInput.dims().vec(); int output_0dim = indicesInput.dim(0); shape[0] = output_0dim; - dataGradsOutput->Resize(shape); + auto* dataGradsOutput = Output(0, shape, at::dtype()); T* out_data = dataGradsOutput->template mutable_data(); if (len_length <= 0) { @@ -1379,7 +1373,7 @@ class CUDASparseLengthsMeanGradientWithIndicesOp auto& segmentGradsInput = Input(0); auto& lengthsInput = Input(1); auto& indicesInput = Input(2); - auto* dataGradsOutput = Output(0); + CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector"); const int len_length = lengthsInput.dim(0); @@ -1389,7 +1383,7 @@ class CUDASparseLengthsMeanGradientWithIndicesOp auto shape = segmentGradsInput.dims().vec(); int output_0dim = indicesInput.dim(0); shape[0] = output_0dim; - dataGradsOutput->Resize(shape); + auto* dataGradsOutput = Output(0, shape, at::dtype()); T* out_data = dataGradsOutput->template mutable_data(); if (len_length <= 0) { @@ -1459,7 +1453,7 @@ class CUDASparseLengthsWeightedSumGradientWithIndicesOp auto& segmentGradsInput = Input(1); auto& lengthsInput = Input(2); auto& indicesInput = Input(3); - auto* dataGradsOutput = Output(0); + CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector"); CAFFE_ENFORCE_EQ(1, weightsInput.ndim(), "WEIGHTS must be a vector"); @@ -1470,7 +1464,7 @@ class CUDASparseLengthsWeightedSumGradientWithIndicesOp auto shape = segmentGradsInput.dims().vec(); int output_0dim = indicesInput.dim(0); shape[0] = output_0dim; - dataGradsOutput->Resize(shape); + auto* dataGradsOutput = Output(0, shape, at::dtype()); T* out_data = dataGradsOutput->template mutable_data(); if (len_length <= 0) { // return early to avoid invalid empty kernel @@ -1597,7 +1591,7 @@ class CUDALengthsMaxWithMainInputAndForwardOutputGradientOp auto& lengthsInput = Input(2); auto& dataInput = Input(3); auto& dataOutput = Input(0); // based on CPU version - auto* dataGradsOutput = Output(0); + CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector"); int len_length = lengthsInput.dim(0); CAFFE_ENFORCE(segmentGradsInput.ndim() > 0); @@ -1616,7 +1610,7 @@ class CUDALengthsMaxWithMainInputAndForwardOutputGradientOp inclusive_scan_length_buffer_.template data(); auto shape = dataInput.dims().vec(); - dataGradsOutput->Resize(shape); + auto* dataGradsOutput = Output(0, shape, at::dtype()); const T* in_data = segmentGradsInput.template data(); T* out_data = dataGradsOutput->template mutable_data(); @@ -1692,7 +1686,7 @@ class CUDASparseLengthsIndicesInGradientWeightedSumWithMainInputGradientOp auto& lengthsInput = Input(2); auto& dataInput = Input(3); auto& indicesInput = Input(4); - auto* dataGradsOutput = Output(0); + auto* weightGradsOutput = Output(1); CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector"); CAFFE_ENFORCE_EQ(1, weightsInput.ndim(), "WEIGHTS must be a vector"); @@ -1704,7 +1698,7 @@ class CUDASparseLengthsIndicesInGradientWeightedSumWithMainInputGradientOp auto shape = segmentGradsInput.dims().vec(); int output_0dim = indicesInput.dim(0); shape[0] = output_0dim; - dataGradsOutput->Resize(shape); + auto* dataGradsOutput = Output(0, shape, at::dtype()); weightGradsOutput->ResizeLike(indicesInput); T* out_data_grads = dataGradsOutput->template mutable_data(); T* out_weight_grads = weightGradsOutput->template mutable_data(); diff --git a/caffe2/operators/sequence_ops.cu b/caffe2/operators/sequence_ops.cu index 6e9abcd..7435eac 100644 --- a/caffe2/operators/sequence_ops.cu +++ b/caffe2/operators/sequence_ops.cu @@ -202,8 +202,7 @@ bool AddPaddingOp::MakePadding( int32_t* lengths_out_ptr = nullptr; if (OutputSize() > 1) { - auto* lengths_out = Output(1); - lengths_out->Resize(lengths_size); + auto* lengths_out = Output(1, {lengths_size}, at::dtype()); lengths_out_ptr = lengths_out->template mutable_data(); } @@ -248,12 +247,9 @@ bool RemovePaddingOp::DoRunWithType() { lengths_size = lengths.size(); } - auto* out = Output(0); - { - auto out_dims = in.dims().vec(); - out_dims[0] -= (startPaddingWidth_ + endPaddingWidth_) * lengths_size; - out->Resize(std::move(out_dims)); - } + auto out_dims = in.dims().vec(); + out_dims[0] -= (startPaddingWidth_ + endPaddingWidth_) * lengths_size; + auto* out = Output(0, out_dims, at::dtype()); const auto* in_ptr = in.template data(); auto* out_ptr = out->template mutable_data(); @@ -272,8 +268,7 @@ bool RemovePaddingOp::DoRunWithType() { int32_t* lengths_out_ptr = nullptr; if (OutputSize() > 1) { - auto* lengths_out = Output(1); - lengths_out->Resize(lengths_size); + auto* lengths_out = Output(1, {lengths_size}, at::dtype()); lengths_out_ptr = lengths_out->template mutable_data(); } diff --git a/caffe2/operators/softmax_ops.cu b/caffe2/operators/softmax_ops.cu index 1945b59..81955f9 100644 --- a/caffe2/operators/softmax_ops.cu +++ b/caffe2/operators/softmax_ops.cu @@ -285,9 +285,8 @@ bool SoftmaxWithLossOp::RunOnDevice() { auto& X = Input(0); // Logits auto& T = Input(1); // Labels / targets auto* P = Output(0); // Probabilities from softmax - auto* avg_loss = Output(1); // Average loss - const float* weights = (InputSize() > 2 ? Input(2).data() : NULL); + const float* weights = (InputSize() > 2 ? Input(2).data() : NULL); const auto canonical_axis = X.canonical_axis_index(axis_); int N, D; N = X.size_to_dim(canonical_axis); // batch size @@ -308,7 +307,8 @@ bool SoftmaxWithLossOp::RunOnDevice() { } } - avg_loss->Resize(vector()); + auto* avg_loss = + Output(1, vector(), at::dtype()); // Average loss if (losses_.size() != N) { losses_.Resize(N); } @@ -392,7 +392,7 @@ bool SpatialSoftmaxWithLossOp::RunOnDevice() { auto& X = Input(0); // Logits auto& T = Input(1); // Labels / targets auto* P = Output(0); // Probabilities from softmax - auto* avg_loss = Output(1); // Average loss + const float* weights = (InputSize() > 2 ? Input(2).data() : NULL); int N, D; N = X.dim32(0); @@ -423,7 +423,8 @@ bool SpatialSoftmaxWithLossOp::RunOnDevice() { context_.cuda_stream()>>>(N, D, W, H, Xdata, Pdata); // Cross entropy - avg_loss->Resize(vector()); + auto* avg_loss = + Output(1, vector(), at::dtype()); // Average loss float* avg_loss_data = avg_loss->template mutable_data(); math::Set(1, 0.0f, avg_loss_data, &context_); diff --git a/caffe2/operators/sparse_to_dense_op.cu b/caffe2/operators/sparse_to_dense_op.cu index 648c688..dc85ac3 100644 --- a/caffe2/operators/sparse_to_dense_op.cu +++ b/caffe2/operators/sparse_to_dense_op.cu @@ -47,8 +47,8 @@ namespace caffe2 { auto shape = sparse_values.dims().vec(); shape[0] = output_first_dim; - auto* output = Output(0); - output->Resize(shape); + + auto* output = Output(0, shape, at::dtype()); TData* output_data = output->template mutable_data(); math::Set(output->size(), TData(0), output_data, &context_); diff --git a/caffe2/operators/summarize_op.cu b/caffe2/operators/summarize_op.cu index 13c1a1b..f51cb05 100644 --- a/caffe2/operators/summarize_op.cu +++ b/caffe2/operators/summarize_op.cu @@ -96,8 +96,7 @@ bool SummarizeOp::RunOnDevice() { << standard_deviation << std::endl; } if (OutputSize()) { - auto* Y = Output(0); - Y->Resize(4); + auto* Y = Output(0, {4}, at::dtype()); float output_buffer[NUM_STATS] = {result.min, result.max, result.mean, standard_deviation}; context_.CopyFromCPU( diff --git a/caffe2/operators/unique_ops.cu b/caffe2/operators/unique_ops.cu index 90252bf..125870d 100644 --- a/caffe2/operators/unique_ops.cu +++ b/caffe2/operators/unique_ops.cu @@ -54,7 +54,6 @@ bool UniqueOp::DoRunWithType() { // use dim32 to enforce that it's fine to have remapping of type int int N = inputTensor.dim32(0); CAFFE_ENFORCE_EQ(inputTensor.ndim(), 1, "Input should be a vector"); - auto* uniqueTensor = Output(UNIQUE); int* remapping = nullptr; if (REMAPPING < OutputSize()) { @@ -65,8 +64,7 @@ bool UniqueOp::DoRunWithType() { if (N <= 0) { // if the input is empty, we have nothing to do, not even launch kernel. - uniqueTensor->Resize(0); - T* unique = uniqueTensor->template mutable_data(); + /* auto* uniqueTensor = */ Output(UNIQUE, {0}, at::dtype()); return true; } @@ -112,7 +110,7 @@ bool UniqueOp::DoRunWithType() { order2.begin()); int K = new_last.first - buffer; - uniqueTensor->Resize(K); + auto* uniqueTensor = Output(UNIQUE, {K}, at::dtype()); T* unique = uniqueTensor->template mutable_data(); context_.CopyItemsSameDevice(thrust_unique_buffer_.meta(), K, buffer, unique); diff --git a/caffe2/operators/upsample_op.cu b/caffe2/operators/upsample_op.cu index 29afab7..8402840 100644 --- a/caffe2/operators/upsample_op.cu +++ b/caffe2/operators/upsample_op.cu @@ -174,7 +174,6 @@ __global__ void UpsampleBilinearGradientKernel( template <> bool UpsampleBilinearOp::RunOnDevice() { const auto& X = Input(0); - auto* Y = Output(0); const auto inputDims = X.dims(); CAFFE_ENFORCE_EQ(4, inputDims.size()); @@ -191,7 +190,10 @@ bool UpsampleBilinearOp::RunOnDevice() { } int output_width = input_width * width_scale_; int output_height = input_height * height_scale_; - Y->Resize(batch_size, num_channels, output_height, output_width); + auto* Y = Output( + 0, + {batch_size, num_channels, output_height, output_width}, + at::dtype()); const auto size = Y->size(); UpsampleBilinearKernel<<< @@ -217,7 +219,6 @@ template <> bool UpsampleBilinearGradientOp::RunOnDevice() { const auto& dY = Input(0); const auto& X = Input(1); - auto* dX = Output(0); const auto inputDims = dY.dims(); CAFFE_ENFORCE_EQ(4, inputDims.size()); @@ -236,7 +237,10 @@ bool UpsampleBilinearGradientOp::RunOnDevice() { height_scale_ = scales_data[0]; width_scale_ = scales_data[1]; } - dX->Resize(batch_size, num_channels, output_height, output_width); + auto* dX = Output( + 0, + {batch_size, num_channels, output_height, output_width}, + at::dtype()); math::Set( dX->size(), 0.0f, dX->mutable_data(), &context_); diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index d80f4ec..0bc30e1 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -955,9 +955,8 @@ class SizeOp : public Operator { bool RunOnDevice() override { auto& input = Input(0); - auto* output = Output(0); - output->Resize(vector()); + auto* output = Output(0, vector(), at::dtype()); auto* output_data = output->template mutable_data(); auto size = input.numel(); @@ -1269,15 +1268,13 @@ class RangeOp : public Operator { } else { length = static_cast(ceil(diff / step)); } - auto* output = Output(0); + // Match numpy's behavior here. if (length <= 0) { - output->Resize(0); - // Called for the side effect of setting the data. - output->template mutable_data(); + Output(0, {0}, at::dtype()); return true; } else { - output->Resize(length); + auto* output = Output(0, {length}, at::dtype()); return DoRunOnDevice(start, step, output); } } diff --git a/caffe2/operators/weighted_sample_op.cu b/caffe2/operators/weighted_sample_op.cu index ba44868..83b7842 100644 --- a/caffe2/operators/weighted_sample_op.cu +++ b/caffe2/operators/weighted_sample_op.cu @@ -48,12 +48,12 @@ bool WeightedSampleOp::RunOnDevice() { "The number of tensors of the input and the output must be the same."); auto& in_weights = Input(0); - auto* out_idx = Output(0); + int batch_size = in_weights.dim(0); int weights_dim = in_weights.dim(1); if (batch_size > 0 && weights_dim > 0) { - out_idx->Resize(batch_size, 1); + auto* out_idx = Output(0, {batch_size, 1}, at::dtype()); unif_samples_.Resize(batch_size); const float* in_weights_data = in_weights.data(); @@ -69,8 +69,7 @@ bool WeightedSampleOp::RunOnDevice() { "The sampling weights tensor and the sampling values tensor must have the same dimensions."); in_val_data = in_val.data(); - auto* out_val = Output(1); - out_val->Resize(batch_size, 1); + auto* out_val = Output(1, {batch_size, 1}, at::dtype()); out_val_data = out_val->template mutable_data(); } @@ -91,12 +90,9 @@ bool WeightedSampleOp::RunOnDevice() { out_idx_data, out_val_data); } else { - out_idx->Resize(0); - out_idx->template mutable_data(); + /* out_idx = */ Output(0, {0}, at::dtype()); if (OutputSize() == 2) { - auto* out_val = Output(1); - out_val->Resize(0); - out_val->template mutable_data(); + /* out_val = */ Output(1, {0}, at::dtype()); } } diff --git a/modules/detectron/ps_roi_pool_op.cu b/modules/detectron/ps_roi_pool_op.cu index 2e713a7..dcd099b 100644 --- a/modules/detectron/ps_roi_pool_op.cu +++ b/modules/detectron/ps_roi_pool_op.cu @@ -243,11 +243,11 @@ template<> bool PSRoIPoolOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs - auto* Y = Output(0); // PSRoI pooled data - auto* A = Output(1); // mapping_channel + // PSRoI pooled data + // mapping_channel - Y->Resize(R.dim32(0), output_dim_, pooled_height_, pooled_width_); - A->Resize(Y->dims()); + auto* Y = Output(0, {R.dim32(0), output_dim_, pooled_height_, pooled_width_}, at::dtype()); + auto* A = Output(1, Y->dims(), at::dtype()); int output_size = Y->size(); PSRoIPoolForward<< bool RoIPoolFOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs - auto* Y = Output(0); // RoI pooled data - auto* A = Output(1); // argmaxes if (R.size() == 0) { // Handle empty rois - Y->Resize(0, X.dim32(1), pooled_height_, pooled_width_); - A->Resize(0, X.dim32(1), pooled_height_, pooled_width_); - // The following mutable_data calls are needed to allocate the tensors - Y->mutable_data(); - A->mutable_data(); + std::vector sizes = {0, X.dim32(1), pooled_height_, pooled_width_}; + /* auto* Y = */ Output(0, sizes, at::dtype()); + /* auto* A = */ Output(1, sizes, at::dtype()); return true; } - Y->Resize(R.dim32(0), X.dim32(1), pooled_height_, pooled_width_); - A->Resize(Y->dims()); + auto* Y = Output(0, {R.dim32(0), X.dim32(1), pooled_height_, pooled_width_}, at::dtype()); // RoI pooled data + auto* A = Output(1, Y->sizes(), at::dtype()); // argmaxes int output_size = Y->size(); RoIPoolFForward<< bool SampleAsOp::RunOnDevice() { auto& X = Input(0); // Input data to be sliced auto& L = Input(1); // Target data that provide the identity - auto* Y = Output(0); // Sliced data (Y.dim32(0) = num of (L > 0)) + // Sliced data (Y.dim32(0) = num of (L > 0)) CAFFE_ENFORCE( X.dim32(0) == L.dim32(0), @@ -60,7 +60,7 @@ bool SampleAsOp::RunOnDevice() { // resize Y vector out_shape(X.dims().vec()); out_shape[0] = count; - Y->Resize(out_shape); + auto* Y = Output(0, out_shape, at::dtype()); const int len = X.size() / X.dim32(0); diff --git a/modules/detectron/select_smooth_l1_loss_op.cu b/modules/detectron/select_smooth_l1_loss_op.cu index 259f892..5e2526f 100644 --- a/modules/detectron/select_smooth_l1_loss_op.cu +++ b/modules/detectron/select_smooth_l1_loss_op.cu @@ -97,9 +97,9 @@ bool SelectSmoothL1LossOp::RunOnDevice() { auto& L = Input(2); // total number of fg boxes across all FPN levels: scalar auto& S = Input(3); - auto* avg_loss = Output(0); - avg_loss->Resize(vector()); + + auto* avg_loss = Output(0, vector(), at::dtype()); if (Y.size() == 0){ math::Set( 1, static_cast(0), avg_loss->mutable_data(), &context_); diff --git a/modules/detectron/sigmoid_cross_entropy_loss_op.cu b/modules/detectron/sigmoid_cross_entropy_loss_op.cu index a8b6390..2f9e6bf 100644 --- a/modules/detectron/sigmoid_cross_entropy_loss_op.cu +++ b/modules/detectron/sigmoid_cross_entropy_loss_op.cu @@ -69,7 +69,7 @@ template <> bool SigmoidCrossEntropyLossOp::RunOnDevice() { auto& X = Input(0); auto& T = Input(1); - auto* avg_loss = Output(0); + CAFFE_ENFORCE( X.size() == T.size(), @@ -79,7 +79,7 @@ bool SigmoidCrossEntropyLossOp::RunOnDevice() { " vs. ", T.size(), ")"); - avg_loss->Resize(vector()); + auto* avg_loss = Output(0, vector(), at::dtype()); counts_.ResizeLike(X); losses_.ResizeLike(X); normalizer_.Resize(vector()); diff --git a/modules/detectron/sigmoid_focal_loss_op.cu b/modules/detectron/sigmoid_focal_loss_op.cu index 2630cf3..ff0028b 100644 --- a/modules/detectron/sigmoid_focal_loss_op.cu +++ b/modules/detectron/sigmoid_focal_loss_op.cu @@ -118,14 +118,14 @@ bool SigmoidFocalLossOp::RunOnDevice() { // Number of positive examples: scalar auto& wp = Input(2); // output avg Sigmoid focal loss as mentioned in RetinaNet paper - auto* avg_loss = Output(0); + int N = X.dim32(0); int D = X.dim32(1); int H = X.dim32(2); int W = X.dim32(3); - avg_loss->Resize(vector()); + auto* avg_loss = Output(0, vector(), at::dtype()); losses_.ResizeLike(X); float* avg_loss_data = avg_loss->mutable_data(); diff --git a/modules/detectron/smooth_l1_loss_op.cu b/modules/detectron/smooth_l1_loss_op.cu index 30aadc5..eee88cf 100644 --- a/modules/detectron/smooth_l1_loss_op.cu +++ b/modules/detectron/smooth_l1_loss_op.cu @@ -66,7 +66,7 @@ bool SmoothL1LossOp::RunOnDevice() { auto& Y = Input(1); auto& alpha_in = Input(2); auto& alpha_out = Input(3); - auto* avg_loss = Output(0); + int N = Y.dim32(0); // Require the same number of elements along axis 0 (batch size), but @@ -78,7 +78,7 @@ bool SmoothL1LossOp::RunOnDevice() { CAFFE_ENFORCE_EQ(Y_hat.size(), alpha_in.size()); CAFFE_ENFORCE_EQ(Y_hat.size(), alpha_out.size()); - avg_loss->Resize(vector()); + auto* avg_loss = Output(0, vector(), at::dtype()); buff_.ResizeLike(Y); // Difference diff --git a/modules/detectron/softmax_focal_loss_op.cu b/modules/detectron/softmax_focal_loss_op.cu index 72b24ae..9b758b4 100644 --- a/modules/detectron/softmax_focal_loss_op.cu +++ b/modules/detectron/softmax_focal_loss_op.cu @@ -147,8 +147,8 @@ bool SoftmaxFocalLossOp::RunOnDevice() { auto& X = Input(0); // Logits auto& T = Input(1); // Labels auto& wp = Input(2); // num of foregound - auto* avg_loss = Output(0); // average loss as output - auto* P = Output(1); // softmax probability, going to be re-used in gradient + // average loss as output + // softmax probability, going to be re-used in gradient int N = X.dim32(0); int D = X.dim32(1); @@ -157,8 +157,8 @@ bool SoftmaxFocalLossOp::RunOnDevice() { int A = D / num_classes_; losses_.Resize(N * A * H * W); - P->Resize(N * D * H * W); - avg_loss->Resize(vector()); + auto* P = Output(1, {N * D * H * W}, at::dtype()); + auto* avg_loss = Output(0, vector(), at::dtype()); math::Set( avg_loss->size(), 0.f, avg_loss->mutable_data(), &context_); math::Set( diff --git a/modules/detectron/spatial_narrow_as_op.cu b/modules/detectron/spatial_narrow_as_op.cu index 1ee1cbc..9c4df31 100644 --- a/modules/detectron/spatial_narrow_as_op.cu +++ b/modules/detectron/spatial_narrow_as_op.cu @@ -76,17 +76,17 @@ bool SpatialNarrowAsOp::DoRunWithType() { // Narrows input 0 (A) spatially to match input 1 (B) auto& A = Input(0); auto& B = Input(1); - auto* C = Output(0); + CAFFE_ENFORCE_EQ(A.dim32(0), B.dim32(0), "Input dim 0 must be equal."); + std::vector sizes; if (A.ndim() == B.ndim()) { CAFFE_ENFORCE_EQ(A.dim32(1), B.dim32(1), "Input dim 1 must be equal."); CAFFE_ENFORCE_GE( A.dim32(2), B.dim32(2), "Input 0 height must be >= input 1 height."); CAFFE_ENFORCE_GE( A.dim32(3), B.dim32(3), "Input 0 width must be >= input 1 width."); - - C->ResizeLike(B); + sizes = B.sizes().vec(); } else { // For (N, H, W) case CAFFE_ENFORCE_EQ(A.ndim() - 1, B.ndim(), "Dimension mismatch."); @@ -94,8 +94,9 @@ bool SpatialNarrowAsOp::DoRunWithType() { A.dim32(2), B.dim32(1), "Input 0 height must be >= input 1 height."); CAFFE_ENFORCE_GE( A.dim32(3), B.dim32(2), "Input 0 width must be >= input 1 width."); - C->Resize(A.dim32(0), A.dim32(1), B.dim32(1), B.dim32(2)); + sizes = {A.dim32(0), A.dim32(1), B.dim32(1), B.dim32(2)}; } + auto* C = Output(0, sizes, at::dtype()); int out_width = C->dim32(3); int out_height = C->dim32(2); int in_width = A.dim32(3); -- 2.7.4