From 246f5c412eea453d42245ff496316934c01bda73 Mon Sep 17 00:00:00 2001 From: Junjie Bai Date: Mon, 1 Apr 2019 14:30:09 -0700 Subject: [PATCH] Revert "Tensor construction codemod(raw_mutable_data) (#16373)" (#18680) Summary: This reverts commit d73c830e236f5b980e5c91914b818d150b60278c. We have observed significant perf drop when training ResNext101 with multiple amd GPUs: Before: https://ci.pytorch.org/jenkins/job/caffe2-builds/job/py2-clang7-rocmdeb-ubuntu16.04-bench/1636/console 2 GPUs ResNext training got 150\~160 imgs/sec 4 GPUs ResNext training got 270\~280 imgs/sec After: https://ci.pytorch.org/jenkins/job/caffe2-builds/job/py2-clang7-rocmdeb-ubuntu16.04-bench/1637/console Both 2 and 4 GPUs ResNext training drop to 110\~120 imgs/sec Similar perf drop are seen on ResNet50 training jobs as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18680 Differential Revision: D14702941 Pulled By: bddppq fbshipit-source-id: 828141805afc23f25c08d4a2eb6d4b99f817c128 --- caffe2/operators/boolean_mask_ops.cc | 3 +- caffe2/operators/boolean_mask_ops.cu | 5 ++- caffe2/operators/boolean_unmask_ops.cc | 13 ++++--- caffe2/operators/boolean_unmask_ops.cu | 9 +++-- caffe2/operators/concat_split_op.h | 10 +++-- caffe2/operators/conditional_op.cc | 3 +- caffe2/operators/copy_op.cu | 8 ++-- caffe2/operators/copy_op.h | 7 ++-- caffe2/operators/crf_viterbi_op.cc | 43 +++++++++------------- caffe2/operators/dataset_ops.cc | 34 +++++++---------- caffe2/operators/ensure_cpu_output_op.h | 5 +-- caffe2/operators/flatten_op.h | 6 +-- caffe2/operators/gather_ranges_to_dense_op.h | 3 +- caffe2/operators/lengths_tile_op.cc | 3 +- caffe2/operators/pack_segments.cc | 3 +- caffe2/operators/pack_segments.cu | 16 +++++--- caffe2/operators/partition_ops.h | 16 +++++--- caffe2/operators/prepend_dim_op.h | 10 +++-- caffe2/operators/remove_data_blocks_op.h | 3 +- caffe2/operators/reservoir_sampling.cc | 1 - caffe2/operators/reshape_op.h | 5 +-- caffe2/operators/sequence_ops.cc | 7 ++-- caffe2/operators/text_file_reader.cc | 4 +- caffe2/operators/tile_op.h | 6 ++- caffe2/operators/utility_ops.h | 16 +++++--- .../server/fully_connected_dnnlowp_op.cc | 3 +- caffe2/queue/rebatching_queue.cc | 3 +- 27 files changed, 129 insertions(+), 116 deletions(-) diff --git a/caffe2/operators/boolean_mask_ops.cc b/caffe2/operators/boolean_mask_ops.cc index 11b3015..dae6427 100644 --- a/caffe2/operators/boolean_mask_ops.cc +++ b/caffe2/operators/boolean_mask_ops.cc @@ -50,6 +50,7 @@ template <> bool BooleanMaskOp::RunOnDevice() { auto& data = Input(0); auto& mask = Input(1); + auto* dataOut = Output(0); CAFFE_ENFORCE(data.dim() >= 1); CAFFE_ENFORCE_EQ(mask.dim(), 1); CAFFE_ENFORCE(data.size(0) == mask.size(0)); @@ -65,7 +66,7 @@ bool BooleanMaskOp::RunOnDevice() { std::vector outShape; outShape.push_back(numOutputs); outShape.insert(outShape.end(), data.sizes().begin() + 1, data.sizes().end()); - auto* dataOut = Output(0, outShape, at::dtype(data.dtype())); + dataOut->Resize(outShape); auto* outPtr = (char*)dataOut->raw_mutable_data(data.dtype()); int64_t* out_vec = nullptr; diff --git a/caffe2/operators/boolean_mask_ops.cu b/caffe2/operators/boolean_mask_ops.cu index d809dff..4cbf368 100644 --- a/caffe2/operators/boolean_mask_ops.cu +++ b/caffe2/operators/boolean_mask_ops.cu @@ -31,6 +31,7 @@ class BooleanMaskOp final : public Operator { bool RunOnDevice() override { const auto& src = Input(0); const auto& mask = Input(1); + auto* dest = Output(0); CAFFE_ENFORCE(src.dim() >= 1); CAFFE_ENFORCE_EQ(mask.dim(), 1); @@ -79,8 +80,8 @@ class BooleanMaskOp final : public Operator { indices_.Resize(numOfOutput); std::vector dims = src.sizes().vec(); dims[0] = numOfOutput; - auto* dest = Output(0, dims, at::dtype(src.dtype())); - auto* destData = (uint8_t*)dest->raw_mutable_data(src.dtype()); + dest->Resize(dims); + auto* destData = (uint8_t*)dest->raw_mutable_data(src.meta()); const auto* srcData = (uint8_t*)src.raw_data(); if (OutputSize() == 2) { diff --git a/caffe2/operators/boolean_unmask_ops.cc b/caffe2/operators/boolean_unmask_ops.cc index 75b33fd..55e2449 100644 --- a/caffe2/operators/boolean_unmask_ops.cc +++ b/caffe2/operators/boolean_unmask_ops.cc @@ -8,10 +8,11 @@ template <> bool BooleanUnmaskOp::RunOnDevice() { int maskSize = Input(0).numel(); int numMasks = InputSize() / 2; - auto& valueDtype = Input(1).dtype(); + auto& valueMeta = Input(1).dtype(); - auto* valuesOut = Output(0, maskSize, at::dtype(valueDtype)); - auto* valuesOutPtr = (char*)valuesOut->raw_mutable_data(valueDtype); + auto* valuesOut = Output(0); + valuesOut->Resize(maskSize); + auto* valuesOutPtr = (char*)valuesOut->raw_mutable_data(valueMeta); std::vector nextValueIndices(numMasks, 0); for (int maskOffset = 0; maskOffset < maskSize; ++maskOffset) { @@ -29,9 +30,9 @@ bool BooleanUnmaskOp::RunOnDevice() { if (maskPtr[maskOffset]) { auto& valueIndex = nextValueIndices[maskIndex]; CAFFE_ENFORCE_LT(valueIndex, values.numel()); - auto* src = valuesPtr + (valueIndex++) * valueDtype.itemsize(); - auto* dst = valuesOutPtr + maskOffset * valueDtype.itemsize(); - std::copy(src, src + valueDtype.itemsize(), dst); + auto* src = valuesPtr + (valueIndex++) * valueMeta.itemsize(); + auto* dst = valuesOutPtr + maskOffset * valueMeta.itemsize(); + std::copy(src, src + valueMeta.itemsize(), dst); maskFound = true; break; } diff --git a/caffe2/operators/boolean_unmask_ops.cu b/caffe2/operators/boolean_unmask_ops.cu index 27de5ae..ce90dab 100644 --- a/caffe2/operators/boolean_unmask_ops.cu +++ b/caffe2/operators/boolean_unmask_ops.cu @@ -54,10 +54,11 @@ class BooleanUnmaskOp final : public Operator { bool RunOnDevice() override { int maskSize = Input(0).numel(); int numMasks = InputSize() / 2; - const auto& dtype = Input(1).dtype(); + const auto& meta = Input(1).meta(); - auto* out = Output(0, maskSize, at::dtype(dtype)); - auto* dest = (char*)out->raw_mutable_data(dtype); + auto* out = Output(0); + out->Resize(maskSize); + auto* dest = (char*)out->raw_mutable_data(meta); ReinitializeTensor(&hostMasks_, {numMasks}, at::dtype().device(CPU)); auto* hostMasksData = hostMasks_.mutable_data(); @@ -100,7 +101,7 @@ class BooleanUnmaskOp final : public Operator { context_.cuda_stream()>>>( numMasks, maskSize, - dtype.itemsize(), + meta.itemsize(), indicesData, values_.data(), valueSizesData, diff --git a/caffe2/operators/concat_split_op.h b/caffe2/operators/concat_split_op.h index 4382e89..47ed663 100644 --- a/caffe2/operators/concat_split_op.h +++ b/caffe2/operators/concat_split_op.h @@ -177,11 +177,12 @@ bool SplitOp::RunOnDevice() { } size_t input_offset = 0; for (int i = 0; i < OutputSize(); ++i) { + auto* output = Output(i); auto axis_dim = add_axis_ ? 1 : axis_data[i]; if (!add_axis_) { output_dims[canonical_axis] = axis_data[i]; } - auto* output = Output(i, output_dims, at::dtype(input.dtype())); + output->Resize(output_dims); math::CopyMatrix( input.itemsize(), before, @@ -222,11 +223,12 @@ bool SplitByLengthsOp::RunOnDevice() { int after = input.size_from_dim(canonical_axis + 1); size_t input_offset = 0; for (int i = 0; i < OutputSize(); ++i) { + auto* output = Output(i); const auto* axis_offset = axis_data + length_length / OutputSize() * i; auto axis_dim = std::accumulate( axis_offset, axis_offset + length_length / OutputSize(), 0); output_dims[canonical_axis] = axis_dim; - auto* output = Output(i, output_dims, at::dtype(input.dtype())); + output->Resize(output_dims); math::CopyMatrix( input.itemsize(), before, @@ -244,6 +246,8 @@ bool SplitByLengthsOp::RunOnDevice() { template bool ConcatOp::RunOnDevice() { + auto* output = Output(0); + // We can override default options(Context::GetDeviceType()) // by explictly passing in device type we want Tensor* split = Output( @@ -310,7 +314,7 @@ bool ConcatOp::RunOnDevice() { } else { output_dims[canonical_axis] = output_channels; } - auto* output = Output(0, output_dims, at::dtype(input_zero.dtype())); + output->Resize(output_dims); size_t output_offset = 0; for (int i = 0; i < InputSize(); ++i) { auto& input = Input(i); diff --git a/caffe2/operators/conditional_op.cc b/caffe2/operators/conditional_op.cc index 780efe3..6097912 100644 --- a/caffe2/operators/conditional_op.cc +++ b/caffe2/operators/conditional_op.cc @@ -23,8 +23,9 @@ bool ConditionalOp::RunOnDevice() { CAFFE_ENFORCE(innerSize * dataF.dtype().itemsize() == innerSizeBytes); // initialize output shape + auto* dataOut = Output(0); const auto* condPtr = condition.template data(); - auto* dataOut = Output(0, dataT.sizes(), at::dtype(dataT.dtype())); + dataOut->ResizeLike(dataT); auto* outPtr = (char*)dataOut->raw_mutable_data(dataT.dtype()); // perform conditional op along first dimension diff --git a/caffe2/operators/copy_op.cu b/caffe2/operators/copy_op.cu index bc1d734..2091524 100644 --- a/caffe2/operators/copy_op.cu +++ b/caffe2/operators/copy_op.cu @@ -13,14 +13,14 @@ class CopyOnDeviceLikeOp bool RunOnDevice() override { auto& input = Input(0); - auto* output = OperatorBase::OutputTensor( - 0, input.sizes(), at::dtype(input.dtype()).device(CUDA)); + auto* output = OperatorBase::Output(0, CUDA); CUDAContext context(GetGPUIDForPointer(Input(1).raw_data())); + output->ResizeLike(input); context.template CopyItems( - input.dtype(), + input.meta(), input.numel(), input.raw_data(), - output->raw_mutable_data(input.dtype())); + output->raw_mutable_data(input.meta())); return true; } }; diff --git a/caffe2/operators/copy_op.h b/caffe2/operators/copy_op.h index 86e6c45..8ccbcb9 100644 --- a/caffe2/operators/copy_op.h +++ b/caffe2/operators/copy_op.h @@ -14,10 +14,9 @@ class CopyOp : public Operator { bool RunOnDevice() override { auto& input = this->template Input(0, SrcContext::GetDeviceType()); - auto* output = this->OutputTensor( - 0, - input.sizes(), - at::dtype(input.dtype()).device(DstContext::GetDeviceType())); + auto* output = + this->template Output(0, DstContext::GetDeviceType()); + output->ResizeLike(input); this->context_.template CopyItems( input.dtype(), input.numel(), diff --git a/caffe2/operators/crf_viterbi_op.cc b/caffe2/operators/crf_viterbi_op.cc index c99c632..07630a6 100644 --- a/caffe2/operators/crf_viterbi_op.cc +++ b/caffe2/operators/crf_viterbi_op.cc @@ -94,16 +94,13 @@ class ViterbiPathOp : public Operator { auto block_size = predictions.numel() / predictions.size(0); auto block_bytesize = predictions.size_from_dim(1) * predictions.dtype().itemsize(); - Tensor backpointers = - caffe2::empty(predictions.sizes(), at::dtype().device(CPU)); - - Tensor trellis = caffe2::empty( - std::vector{block_size}, - at::dtype(predictions.dtype()).device(CPU)); - Tensor dpMat = - caffe2::empty(transitions.sizes(), at::dtype().device(CPU)); - Tensor dpMax = caffe2::empty( - std::vector{block_size}, at::dtype().device(CPU)); + Tensor backpointers(CPU); + backpointers.ResizeLike(predictions); + + Tensor trellis(std::vector{block_size}, CPU); + Tensor dpMat(CPU); + dpMat.ResizeLike(transitions); + Tensor dpMax(std::vector{block_size}, CPU); GatherRow(predictions, 0, block_size, block_bytesize, &trellis); for (auto i = 1; i < seqLen; i++) { AddColToMat(transitions, trellis, &dpMat); @@ -123,10 +120,8 @@ class ViterbiPathOp : public Operator { &context_); } - Tensor tMax = - caffe2::empty(std::vector{1}, at::dtype().device(CPU)); - Tensor tArgMax = caffe2::empty( - std::vector{1}, at::dtype().device(CPU)); + Tensor tMax(std::vector{1}, CPU); + Tensor tArgMax(std::vector{1}, CPU); ColwiseMaxAndArg( trellis.template data(), 1, @@ -136,9 +131,7 @@ class ViterbiPathOp : public Operator { std::vector viterbiVec; viterbiVec.push_back(tArgMax.template data()[0]); - Tensor bpEntry = caffe2::empty( - std::vector{block_size}, - at::dtype(backpointers.dtype()).device(CPU)); + Tensor bpEntry(std::vector{block_size}, CPU); block_bytesize = backpointers.size_from_dim(1) * backpointers.dtype().itemsize(); for (auto i = seqLen - 1; i > 0; i--) { @@ -159,14 +152,14 @@ class SwapBestPathOp : public Operator { : Operator(std::forward(args)...) {} bool RunOnDevice() override { auto& data = Input(0); - auto& newBestIndicies = Input(1); + auto& newBestIdicies = Input(1); CAFFE_ENFORCE( - data.dim() == 2 && newBestIndicies.dim() == 1, + data.dim() == 2 && newBestIdicies.dim() == 1, "predictions should be a 2D matrix and bestPath should be 1D vector"); CAFFE_ENFORCE( - data.size(0) == newBestIndicies.size(0), + data.size(0) == newBestIdicies.size(0), "predictions and bestPath dimensions not matching"); auto* updatedData = Output(0, data.sizes(), at::dtype()); @@ -174,10 +167,10 @@ class SwapBestPathOp : public Operator { context_.CopyItemsSameDevice( data.dtype(), data.numel(), data.template data(), outData); - Tensor bestScores = - caffe2::empty(newBestIndicies.sizes(), at::dtype().device(CPU)); - Tensor oldBestIndices = caffe2::empty( - newBestIndicies.sizes(), at::dtype().device(CPU)); + Tensor bestScores(CPU); + bestScores.ResizeLike(newBestIdicies); + Tensor oldBestIndices(CPU); + oldBestIndices.ResizeLike(newBestIdicies); ColwiseMaxAndArg( data.template data(), @@ -189,7 +182,7 @@ class SwapBestPathOp : public Operator { auto block_size = data.numel() / data.size(0); const int32_t* oldBestIdx = oldBestIndices.template data(); - const int32_t* newIdx = newBestIndicies.template data(); + const int32_t* newIdx = newBestIdicies.template data(); for (auto i = 0; i < data.dim32(0); i++) { std::swap( diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc index 73004d0..78fe003 100644 --- a/caffe2/operators/dataset_ops.cc +++ b/caffe2/operators/dataset_ops.cc @@ -319,11 +319,7 @@ class PackRecordsOp : public Operator { Output(0)->Resize(walker.size()); // Output(0)->raw_mutable_data(TypeMeta::Make())); - auto* dst = Output( - 0, - {static_cast(walker.size())}, - at::dtype()) - ->template mutable_data(); + auto* dst = Output(0)->template mutable_data(); for (int batchId = 0; batchId < walker.size(); ++batchId) { dst[batchId] = std::make_shared>(); @@ -399,8 +395,8 @@ class UnPackRecordsOp : public Operator { // Resize to the final output size std::vector destinations(numTensors); for (int i = 0; i < numTensors; ++i) { - auto* output = Output(i, {outputDims[i]}, at::dtype(*metas[i])); - destinations[i] = output->raw_mutable_data(*metas[i]); + Output(i)->Resize(outputDims[i]); + destinations[i] = Output(i)->raw_mutable_data(*metas[i]); } for (int i = 0; i < numRows; ++i) { @@ -521,9 +517,10 @@ class ReadNextBatchOp : public Operator { auto innerSize = in.size_from_dim(1); outDim = in.sizes().vec(); outDim[0] = size; + auto* out = Output(i); + out->Resize(outDim); void* src = (char*)in.raw_data() + offset * innerSize * in.dtype().itemsize(); - auto* out = Output(i, {outDim}, at::dtype(in.dtype())); void* dst = out->raw_mutable_data(in.dtype()); // create the tensor if (out->numel() == 0) { continue; @@ -728,7 +725,8 @@ class ReadRandomBatchOp : public Operator { idx++; } idx = idxbegin; // reSet - auto* out = Output(i, {outDim}, at::dtype(in.dtype())); + auto* out = Output(i); + out->Resize(outDim); if (out->numel() == 0) { continue; } @@ -775,13 +773,13 @@ class AppendOp final : public Operator { bool RunOnDevice() override { auto& a = Input(0); auto& b = Input(1); - auto* c = Output(0, a.sizes(), at::dtype(a.dtype())); + auto* c = Output(0); CAFFE_ENFORCE(b.dim() >= 1); if (a.numel() == 0 && a.size(0) == 0) { c->CopyFrom(b); return true; } - CAFFE_ENFORCE(IsInputOutputAlias(0, 0), "First argument must be in-place."); + CAFFE_ENFORCE(&a == c, "First argument must be in-place."); CAFFE_ENFORCE(c->dim() == b.dim()); CAFFE_ENFORCE(b.dim() == c->dim()); CAFFE_ENFORCE(a.dtype() == b.dtype()); @@ -815,14 +813,13 @@ class AtomicAppendOp final : public Operator { for (int i = 0; i < numFields; ++i) { auto& a = Input(1 + i); auto& b = Input(1 + i + numFields); - auto* c = Output(i, a.sizes(), at::dtype(a.dtype())); + auto* c = Output(i); CAFFE_ENFORCE(b.dim() >= 1); if (a.numel() == 0) { continue; } CAFFE_ENFORCE( - IsInputOutputAlias(1 + i, i), - "Appended-to arguments must be in-place."); + (void*)&a == (void*)c, "Appended-to arguments must be in-place."); CAFFE_ENFORCE(c->dim() == b.dim()); CAFFE_ENFORCE(b.dim() == c->dim()); CAFFE_ENFORCE(a.dtype() == b.dtype()); @@ -835,8 +832,7 @@ class AtomicAppendOp final : public Operator { for (int i = 0; i < numFields; ++i) { auto& a = Input(1 + i); auto& b = Input(1 + i + numFields); - // Can we create Tensor with numel() == 0? - auto* c = Output(i, a.sizes(), at::dtype(a.dtype())); + auto* c = Output(i); if (a.numel() == 0 && a.size(0) == 0) { c->CopyFrom(b); continue; @@ -896,6 +892,7 @@ class ConcatTensorVectorOp final : public Operator { const TensorVectorPtr& tensorVector = OperatorBase::Input(TENSOR_VECTOR); + auto* tensor = Output(TENSOR); CAFFE_ENFORCE(!tensorVector->empty()); vector outputDims(tensorVector->at(0).sizes().vec()); @@ -909,8 +906,7 @@ class ConcatTensorVectorOp final : public Operator { outputDims[0] += tensorVector->at(i).sizes()[0]; } - auto* tensor = - Output(TENSOR, outputDims, at::dtype(tensorVector->at(0).dtype())); + tensor->Resize(outputDims); int64_t offset = 0; auto* dst = (char*)tensor->raw_mutable_data(tensorVector->at(0).dtype()); @@ -1025,8 +1021,6 @@ class TrimDatasetOp : public Operator { // trim each column to the offset for (int col = 0; col < walker.fields().size(); ++col) { auto newOuterSize = walker.fields().at(col).offset(); - // TODO: Remove call to Output(col) since it - // returns partially initialized Tensor Output(col)->ShrinkTo(newOuterSize); } return true; diff --git a/caffe2/operators/ensure_cpu_output_op.h b/caffe2/operators/ensure_cpu_output_op.h index 6d06e1f..04255ed 100644 --- a/caffe2/operators/ensure_cpu_output_op.h +++ b/caffe2/operators/ensure_cpu_output_op.h @@ -33,10 +33,9 @@ class EnsureCPUOutputOp : public Operator { template bool CopyWithContext() { // Output is always on CPU + auto* output = this->template Output(0, CPU); auto& input = this->template Input(0, InputContext::GetDeviceType()); - // TODO: is it possible to use OutputTensorCopyFrom? - auto* output = this->OutputTensor( - 0, input.sizes(), at::dtype(input.dtype()).device(CPU)); + output->ResizeLike(input); context_.CopyItemsToCPU( input.dtype(), input.numel(), diff --git a/caffe2/operators/flatten_op.h b/caffe2/operators/flatten_op.h index f840347..401e6fb 100644 --- a/caffe2/operators/flatten_op.h +++ b/caffe2/operators/flatten_op.h @@ -17,12 +17,10 @@ class FlattenOp : public Operator { bool RunOnDevice() override { auto& input = Input(0); + auto* output = Output(0); CAFFE_ENFORCE_GE( input.dim(), axis_, "The rank of the tensor must be >= axis."); - auto* output = Output( - 0, - {input.size_to_dim(axis_), input.size_from_dim(axis_)}, - at::dtype(input.dtype())); + output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_)); context_.CopyItemsSameDevice( input.dtype(), input.numel(), diff --git a/caffe2/operators/gather_ranges_to_dense_op.h b/caffe2/operators/gather_ranges_to_dense_op.h index 60947c7..ee8c09a 100644 --- a/caffe2/operators/gather_ranges_to_dense_op.h +++ b/caffe2/operators/gather_ranges_to_dense_op.h @@ -66,8 +66,9 @@ class GatherRangesToDenseOp final : public Operator { vector outputDims{batchSize, 0}; vector outputRawData; for (int i = 0; i < OutputSize(); ++i) { + auto* output = Output(i); outputDims[1] = lengths_[i]; - auto* output = Output(i, outputDims, at::dtype(data.dtype())); + output->Resize(outputDims); char* ptr = static_cast(output->raw_mutable_data(data.dtype())); memset(ptr, 0, output->nbytes()); outputRawData.push_back(ptr); diff --git a/caffe2/operators/lengths_tile_op.cc b/caffe2/operators/lengths_tile_op.cc index 10178e7..5e3f167 100644 --- a/caffe2/operators/lengths_tile_op.cc +++ b/caffe2/operators/lengths_tile_op.cc @@ -6,6 +6,7 @@ template <> bool LengthsTileOp::RunOnDevice() { auto& data = Input(DATA); auto& lengths = Input(LENGTHS); + auto* output = Output(0); CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS must be 1-D"); CAFFE_ENFORCE_GE(data.dim(), 1, "DATA should be at least 1-D"); @@ -25,7 +26,7 @@ bool LengthsTileOp::RunOnDevice() { auto shape = data.sizes().vec(); shape[0] = total_length; - auto* output = Output(0, shape, at::dtype(data.dtype())); + output->Resize(shape); auto block_bytesize = data.size_from_dim(1) * data.dtype().itemsize(); auto src = static_cast(data.raw_data()); diff --git a/caffe2/operators/pack_segments.cc b/caffe2/operators/pack_segments.cc index ac52c1b..8e82ed0 100644 --- a/caffe2/operators/pack_segments.cc +++ b/caffe2/operators/pack_segments.cc @@ -116,6 +116,7 @@ template bool UnpackSegmentsOp::DoRunWithType2() { const auto& data = Input(DATA); const auto& lengths = Input(LENGTHS); + auto* output = Output(0); CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D"); CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTH should be 1-D"); @@ -134,7 +135,7 @@ bool UnpackSegmentsOp::DoRunWithType2() { shape[0], lengths.size(0), "LENGTH should match DATA in dimension 0"); shape.erase(shape.begin()); shape[0] = total_l; - auto* output = Output(0, shape, at::dtype(data.dtype())); + output->Resize(shape); // create output tensor auto* out = static_cast(output->raw_mutable_data(data.dtype())); if (!(data.size(0) && data.size(1))) { diff --git a/caffe2/operators/pack_segments.cu b/caffe2/operators/pack_segments.cu index 3e6a2b6..213bb60 100644 --- a/caffe2/operators/pack_segments.cu +++ b/caffe2/operators/pack_segments.cu @@ -179,6 +179,11 @@ bool PackSegmentsOp::DoRunWithType2() { int64_t num_seq = lengths.dim(0); const Data_T* data_ptr = data.data(); const T* lengths_ptr = lengths.data(); + auto* out = Output(0); + Tensor* presence_mask = nullptr; + if (return_presence_mask_) { + presence_mask = Output(1); + } CAFFE_ENFORCE_GE(data.dim(), 1, "DATA should be at least 1-D"); CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTH should be 1-D"); @@ -209,7 +214,7 @@ bool PackSegmentsOp::DoRunWithType2() { bool* presence_mask_data = nullptr; if (return_presence_mask_) { std::vector presence_shape{lengths.numel(), max_length}; - auto* presence_mask = Output(1, presence_shape, at::dtype()); + presence_mask->Resize(presence_shape); presence_mask_data = presence_mask->template mutable_data(); } @@ -217,8 +222,8 @@ bool PackSegmentsOp::DoRunWithType2() { auto shape = data.sizes().vec(); // Shape of out is batch_size x max_len x ... shape[0] = max_length; shape.insert(shape.begin(), lengths.numel()); - auto* out = Output(0, shape, at::dtype(data.dtype())); - Data_T* out_ptr = static_cast(out->raw_mutable_data(data.dtype())); + out->Resize(shape); + Data_T* out_ptr = static_cast(out->raw_mutable_data(data.meta())); // Return empty out (with the proper shape) if first dim is 0. if (!data.dim(0)) { @@ -260,6 +265,7 @@ bool UnpackSegmentsOp::DoRunWithType2() { int64_t num_seq = lengths.dim(0); const Data_T* data_ptr = data.data(); const T* lengths_ptr = lengths.data(); + auto* out = Output(0); CAFFE_ENFORCE_GE(data.dim(), 1, "DATA should be at least 1-D"); CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTH should be 1-D"); @@ -309,8 +315,8 @@ bool UnpackSegmentsOp::DoRunWithType2() { shape[0], lengths.dim(0), "LENGTH should match DATA in dimension 0"); shape.erase(shape.begin()); shape[0] = num_cell; - auto* out = Output(0, shape, at::dtype(data.dtype())); - Data_T* out_ptr = static_cast(out->raw_mutable_data(data.dtype())); + out->Resize(shape); + Data_T* out_ptr = static_cast(out->raw_mutable_data(data.meta())); // Return empty out (with the proper shape) if any of the dimensions is 0. if (data.dim(0) == 0 || data.dim(1) == 0) { diff --git a/caffe2/operators/partition_ops.h b/caffe2/operators/partition_ops.h index f5f7dc9..fa8a27c 100644 --- a/caffe2/operators/partition_ops.h +++ b/caffe2/operators/partition_ops.h @@ -60,7 +60,8 @@ class GatherByKeyOp : public Operator { } CAFFE_ENFORCE_EQ(keysTensor.numel(), totalSize); - auto* outTensor = Output(0, outShape, at::dtype(meta)); + auto* outTensor = Output(0); + outTensor->Resize(outShape); auto* outData = static_cast(outTensor->raw_mutable_data(meta)); const auto blockSize = outTensor->size_from_dim(1); @@ -163,8 +164,9 @@ class PartitionOpBase : public Operator { input.sizes().begin() + main_input.dim() - 1, input.sizes().end()); for (int j = 0; j < partitions; ++j) { int out_idx = i + j * inputSize; + auto output = Output(out_idx); shape[0] = counts_[j]; - auto output = Output(out_idx, shape, at::dtype(input.dtype())); + output->Resize(shape); out_datas_[out_idx] = output->raw_mutable_data(input.dtype()); } } @@ -254,12 +256,13 @@ class LengthsPartitionOp : public PartitionOpBase { // Specialization when partitions == 1 which just becomes a copy. for (int i = 0; i < InputSize(); ++i) { auto& input = Input(i); - auto* output = Output(i, input.sizes(), at::dtype(input.dtype())); + auto& output = *Output(i); + output.ResizeLike(input); context_.CopyItemsSameDevice( input.dtype(), input.numel(), input.raw_data(), - output->raw_mutable_data(input.dtype())); + output.raw_mutable_data(input.dtype())); } return true; } @@ -277,8 +280,9 @@ class LengthsPartitionOp : public PartitionOpBase { const int32_t* lengths_data = length_input.template data(); out_length_.resize(partitions); for (int i = 0; i < partitions; ++i) { - auto* output = Output(i * InputSize(), elements, at::dtype()); - out_length_[i] = output->template mutable_data(); + auto& output = *Output(i * InputSize()); + output.Resize(elements); + out_length_[i] = output.template mutable_data(); } int total_length = 0; diff --git a/caffe2/operators/prepend_dim_op.h b/caffe2/operators/prepend_dim_op.h index 5cc8d64..2df840d 100644 --- a/caffe2/operators/prepend_dim_op.h +++ b/caffe2/operators/prepend_dim_op.h @@ -23,6 +23,7 @@ class PrependDimOp : public Operator { bool RunOnDevice() override { auto& input = Input(0); + auto* output = Output(0); CAFFE_ENFORCE(input.dim() > 0, "Input must be at least 1D."); CAFFE_ENFORCE( @@ -36,9 +37,9 @@ class PrependDimOp : public Operator { for (int i = 1; i < input.sizes().size(); ++i) { actual_new_shape[i + 1] = input.size(i); } - auto* output = Output(0, actual_new_shape, at::dtype(input.dtype())); + output->Resize(actual_new_shape); - if (!IsInputOutputAlias(0, 0)) { + if (output != &input) { // If we are not doing in-place computation, a copy is needed. context_.CopyItemsSameDevice( input.dtype(), @@ -63,6 +64,7 @@ class MergeDimOp : public Operator { bool RunOnDevice() override { auto& input = Input(0); + auto* output = Output(0); CAFFE_ENFORCE(input.dim() > 1, "Input must be at least 2D."); @@ -71,9 +73,9 @@ class MergeDimOp : public Operator { for (int i = 1; i < input.sizes().size() - 1; ++i) { actual_new_shape[i] = input.size(i + 1); } - auto* output = Output(0, actual_new_shape, at::dtype(input.dtype())); + output->Resize(actual_new_shape); - if (!IsInputOutputAlias(0, 0)) { + if (output != &input) { // If we are not doing in-place computation, a copy is needed. context_.CopyItemsSameDevice( input.dtype(), diff --git a/caffe2/operators/remove_data_blocks_op.h b/caffe2/operators/remove_data_blocks_op.h index 9303904..5f409bf 100644 --- a/caffe2/operators/remove_data_blocks_op.h +++ b/caffe2/operators/remove_data_blocks_op.h @@ -52,9 +52,10 @@ class RemoveDataBlocksOp final : public Operator { ind_vec.erase(std::unique(ind_vec.begin(), ind_vec.end()), ind_vec.end()); indices_size = ind_vec.size(); + auto* output = Output(0); auto shape = data.sizes().vec(); shape[0] -= indices_size; - auto* output = Output(0, shape, at::dtype(data.dtype())); + output->Resize(shape); char* out_ptr = (char*)output->raw_mutable_data(data.dtype()); ind_vec.insert(ind_vec.begin(), -1); diff --git a/caffe2/operators/reservoir_sampling.cc b/caffe2/operators/reservoir_sampling.cc index 206ec90..0e125dd 100644 --- a/caffe2/operators/reservoir_sampling.cc +++ b/caffe2/operators/reservoir_sampling.cc @@ -23,7 +23,6 @@ class ReservoirSamplingOp final : public Operator { auto& mutex = OperatorBase::Input>(MUTEX); std::lock_guard guard(*mutex); - // TODO: separate diff for this auto* output = Output(RESERVOIR); const auto& input = Input(DATA); diff --git a/caffe2/operators/reshape_op.h b/caffe2/operators/reshape_op.h index 0ad70e7..c7f4b16 100644 --- a/caffe2/operators/reshape_op.h +++ b/caffe2/operators/reshape_op.h @@ -30,8 +30,7 @@ class ReshapeOp : public Operator { template bool DoRunWithType() { - DoRunWithTypeImpl( - Input(0), Output(0, Input(0).sizes(), Input(0).dtype())); + DoRunWithTypeImpl(Input(0), Output(0)); return true; } @@ -124,7 +123,7 @@ class ReshapeOp : public Operator { } output->Resize(actual_new_shape); - if (!IsInputOutputAlias(0, 0)) { + if (output != &input) { // If we are not doing in-place computation, a copy is needed. context_.CopyItemsSameDevice( input.dtype(), diff --git a/caffe2/operators/sequence_ops.cc b/caffe2/operators/sequence_ops.cc index b1e2a8a..dfb01ad 100644 --- a/caffe2/operators/sequence_ops.cc +++ b/caffe2/operators/sequence_ops.cc @@ -192,15 +192,16 @@ bool PadEmptySamplesOp::RunOnDevice() { features.size(0) == sumLen, "FEATURE and LENGTH should be consistent"); const auto block_size = features.size_from_dim(1); + auto* out_features = Output(1 + k); auto outDim = features.sizes().vec(); outDim.at(0) += needPadding; - auto* out_features = Output(1 + k, outDim, at::dtype(features.dtype())); + out_features->Resize(outDim); auto dst = static_cast(out_features->raw_mutable_data(features.dtype())); auto src_base = static_cast(features.raw_data()); // copy data and add padding index as zero - Tensor zero = - caffe2::empty({block_size}, at::dtype(features.dtype()).device(CPU)); + Tensor zero{CPU}; + zero.Resize(block_size); auto zeroPtr = static_cast(zero.raw_mutable_data(features.dtype())); memset(zeroPtr, 0, zero.nbytes()); int start_dest = 0; diff --git a/caffe2/operators/text_file_reader.cc b/caffe2/operators/text_file_reader.cc index cbe3d81..238ab6a 100644 --- a/caffe2/operators/text_file_reader.cc +++ b/caffe2/operators/text_file_reader.cc @@ -110,8 +110,8 @@ class TextFileReaderReadOp : public Operator { // it. std::vector datas(numFields); for (int i = 0; i < numFields; ++i) { - auto* output = Output(i, batchSize_, at::dtype(instance->fieldMetas[i])); - datas[i] = (char*)output->raw_mutable_data(instance->fieldMetas[i]); + Output(i)->Resize(batchSize_); + datas[i] = (char*)Output(i)->raw_mutable_data(instance->fieldMetas[i]); } int rowsRead = 0; diff --git a/caffe2/operators/tile_op.h b/caffe2/operators/tile_op.h index 72cd56d..ad0b924 100644 --- a/caffe2/operators/tile_op.h +++ b/caffe2/operators/tile_op.h @@ -74,12 +74,13 @@ class TileOp final : public Operator { } const auto& X = Input(0); + auto* Y = Output(0); const int axis = X.canonical_axis_index(axis_); // reshape output to be input tiled along the axis std::vector Y_dims = X.sizes().vec(); Y_dims[axis] *= tiles_; - auto* Y = Output(0, Y_dims, at::dtype()); + Y->Resize(Y_dims); // size up to (and not including) axis const int outer_size = X.size_to_dim(axis); @@ -178,13 +179,14 @@ class TileGradientOp final : public Operator { } const auto& dY = Input(0); + auto* dX = Output(0); const int axis = dY.canonical_axis_index(axis_); // reshape output to be input "untiled" along the axis std::vector X_dims = dY.sizes().vec(); CAFFE_ENFORCE_EQ(X_dims[axis] % tiles_, 0); X_dims[axis] /= tiles_; - auto* dX = Output(0, X_dims, at::dtype()); + dX->Resize(X_dims); // size up to (and not including) axis const int outer_size = dX->size_to_dim(axis); diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index af79f44..2b38d1b 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -235,9 +235,10 @@ class FlattenToVecOp : public Operator { bool RunOnDevice() override { auto& input = Input(0); + auto* output = Output(0); CAFFE_ENFORCE_GE( input.dim(), 1, "The rank of the tensor must be >= 1."); - auto* output = Output(0, {input.numel()}, at::dtype(input.dtype())); + output->Resize(input.numel()); context_.CopyItemsSameDevice( input.dtype(), @@ -258,8 +259,9 @@ class ResizeLikeOp : public Operator { bool RunOnDevice() override { auto& input0 = Input(0); auto& input1 = Input(1); + auto* output = Output(0); CAFFE_ENFORCE_EQ(input0.numel(), input1.numel()); - auto* output = Output(0, input1.sizes(), at::dtype(input0.dtype())); + output->ResizeLike(Input(1)); context_.CopyItemsSameDevice( input0.dtype(), input0.numel(), @@ -1048,6 +1050,8 @@ class GatherRangesOp : public Operator { bool DoRunWithType() { auto& data = Input(DATA); auto& ranges = Input(RANGES); + auto* outputData = Output(0); + auto* outputLengths = Output(1); auto batchSize = ranges.size(0); CAFFE_ENFORCE(data.dim() == 1, "Data has to be 1-D"); @@ -1059,7 +1063,7 @@ class GatherRangesOp : public Operator { auto* rawData = static_cast(data.raw_data()); auto* rangesData = ranges.template data(); - auto* outputLengths = Output(1, {batchSize}, at::dtype()); + outputLengths->Resize(batchSize); auto* outputLengthsPtr = outputLengths->template mutable_data(); size_t start = 0; size_t blockSize = ranges.size_from_dim(1); @@ -1070,8 +1074,7 @@ class GatherRangesOp : public Operator { } size_t outputSize = accumulate(rangesData, 0, ranges.numel()); - auto* outputData = - Output(0, {static_cast(outputSize)}, at::dtype(data.dtype())); + outputData->Resize(outputSize); auto outputRawData = static_cast(outputData->raw_mutable_data(data.dtype())); @@ -1127,6 +1130,7 @@ class LengthsGatherOp : public Operator { auto& items = Input(ITEMS); auto& lengths = Input(LENGTHS); auto& indices = Input(INDICES); + auto* output = Output(0); CAFFE_ENFORCE_GE(items.dim(), 1, "ITEMS should be at least 1-D"); CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS should be 1-D"); @@ -1143,7 +1147,7 @@ class LengthsGatherOp : public Operator { } auto shape = items.sizes().vec(); shape[0] = total_length; - auto* output = Output(0, {shape}, at::dtype(items.dtype())); + output->Resize(shape); offsets_.clear(); int64_t running_offset = 0; diff --git a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc index f697429..5eee0a9 100644 --- a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc +++ b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc @@ -83,7 +83,8 @@ bool FullyConnectedDNNLowPOp::RunOnDevice() { } auto* Y_ref = fp32_op->Output(0); - auto* Y = OutputTensorCPU_(0, Y_ref->sizes(), at::dtype(Y_ref->dtype())); + auto* Y = OutputTensorCPU_(0); + Y->ResizeLike(*Y_ref); fp32_op->context_.CopyItemsSameDevice( Y_ref->dtype(), Y_ref->size(), diff --git a/caffe2/queue/rebatching_queue.cc b/caffe2/queue/rebatching_queue.cc index e4015b9..4cd54e0 100644 --- a/caffe2/queue/rebatching_queue.cc +++ b/caffe2/queue/rebatching_queue.cc @@ -84,8 +84,7 @@ std::vector> split( CAFFE_ENFORCE_EQ(input.sizes().at(0), outputSize); for (int i = 0; i < outputSize; ++i) { - outputs[i].push_back( - caffe2::empty(outputDims, at::dtype(input.dtype()).device(CPU))); + outputs[i].push_back(Tensor(outputDims, CPU)); context.CopyItemsToCPU( input.dtype(), innerSize, -- 2.7.4