Tensor reinitialization codemod - 2/5 (#15947)
authorJerry Zhang <jerryzh@fb.com>
Fri, 11 Jan 2019 22:55:56 +0000 (14:55 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 23:05:01 +0000 (15:05 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15947

Codemod generated with clangr shard mode, 25 files per diff,
To eliminiate partially initialized Tensor, we split the initialization of local Tensor variables into two steps, first declare un uninitialized Tensor, and
call `ReinitializeTensor` to initialize it.
motivation: https://github.com/pytorch/pytorch/pull/12407

Reviewed By: smessmer

Differential Revision: D13586732

fbshipit-source-id: 5295ab27ca0155f96a4fccf9c0ba8a609101ba24

13 files changed:
caffe2/image/image_input_op.h
caffe2/operators/boolean_mask_ops.cu
caffe2/operators/boolean_unmask_ops.cu
caffe2/operators/channel_backprop_stats_op.cu
caffe2/operators/channel_backprop_stats_op.h
caffe2/operators/channel_stats_op.cu
caffe2/operators/channel_stats_op.h
caffe2/operators/conv_op.h
caffe2/operators/conv_op_impl.h
caffe2/operators/conv_transpose_op.h
caffe2/operators/conv_transpose_op_impl.h
caffe2/operators/deform_conv_op.h
caffe2/operators/deform_conv_op_impl.h

index f5db8b3..fef4b70 100644 (file)
@@ -89,8 +89,8 @@ class ImageInputOp final
 
   unique_ptr<db::DBReader> owned_reader_;
   const db::DBReader* reader_;
-  Tensor prefetched_image_{CPU};
-  Tensor prefetched_label_{CPU};
+  Tensor prefetched_image_;
+  Tensor prefetched_label_;
   vector<Tensor> prefetched_additional_outputs_;
   Tensor prefetched_image_on_device_;
   Tensor prefetched_label_on_device_;
@@ -120,8 +120,8 @@ class ImageInputOp final
   int crop_;
   std::vector<float> mean_;
   std::vector<float> std_;
-  Tensor mean_gpu_{Context::GetDeviceType()};
-  Tensor std_gpu_{Context::GetDeviceType()};
+  Tensor mean_gpu_;
+  Tensor std_gpu_;
   bool mirror_;
   bool is_test_;
   bool use_caffe_datum_;
@@ -377,16 +377,24 @@ ImageInputOp<Context>::ImageInputOp(
   for (int i = 0; i < num_decode_threads_; ++i) {
     randgen_per_thread_.emplace_back(meta_randgen());
   }
-  prefetched_image_.Resize(
-      int64_t(batch_size_),
-      int64_t(crop_),
-      int64_t(crop_),
-      int64_t(color_ ? 3 : 1));
+  ReinitializeTensor(
+      &prefetched_image_,
+      {int64_t(batch_size_),
+       int64_t(crop_),
+       int64_t(crop_),
+       int64_t(color_ ? 3 : 1)},
+      at::dtype<uint8_t>().device(CPU));
+  std::vector<int64_t> sizes;
   if (label_type_ != SINGLE_LABEL && label_type_ != SINGLE_LABEL_WEIGHTED) {
-    prefetched_label_.Resize(int64_t(batch_size_), int64_t(num_labels_));
+    sizes = std::vector<int64_t>{int64_t(batch_size_), int64_t(num_labels_)};
   } else {
-    prefetched_label_.Resize(vector<int64_t>(1, batch_size_));
+    sizes = std::vector<int64_t>{batch_size_};
   }
+  // data type for prefetched_label_ is actually not known here..
+  ReinitializeTensor(
+      &prefetched_label_,
+      sizes,
+      at::dtype<int>().device(CPU));
 
   for (int i = 0; i < additional_output_sizes_.size(); ++i) {
     prefetched_additional_outputs_on_device_.emplace_back();
@@ -1256,8 +1264,14 @@ bool ImageInputOp<Context>::CopyPrefetched() {
     // TODO: support color jitter and color lighting in gpu_transform
     if (gpu_transform_) {
       if (!mean_std_copied_) {
-        mean_gpu_.Resize(mean_.size());
-        std_gpu_.Resize(std_.size());
+        ReinitializeTensor(
+            &mean_gpu_,
+            {static_cast<int64_t>(mean_.size())},
+            at::dtype<float>().device(Context::GetDeviceType()));
+        ReinitializeTensor(
+            &std_gpu_,
+            {static_cast<int64_t>(std_.size())},
+            at::dtype<float>().device(Context::GetDeviceType()));
 
         context_.template CopyFromCPU<float>(
             mean_.size(),
index ae5758a..d953e8f 100644 (file)
@@ -39,7 +39,8 @@ class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> {
 
     const auto* maskData = mask.data<bool>();
     const auto outerSize = mask.size(0);
-    indices_.Resize(outerSize);
+    ReinitializeTensor(
+        &indices_, {outerSize}, at::dtype<int64_t>().device(CUDA));
     auto* indicesData = indices_.mutable_data<int64_t>();
 
     size_t numBytes = 0;
@@ -57,7 +58,8 @@ class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> {
     auto numint64_t =
         static_cast<int64_t>((numBytes + sizeof(int64_t) - 1) / sizeof(int64_t));
     // allocate one more int64_t at the end of scratch for storing numOfOutput
-    scratch_.Resize(numint64_t + 1);
+    ReinitializeTensor(
+        &scratch_, {numint64_t + 1}, at::dtype<int64_t>().device(CUDA));
     auto* scratchData = scratch_.mutable_data<int64_t>();
     auto* numOfOutputData = scratchData + numint64_t;
 
@@ -108,8 +110,8 @@ class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> {
   }
 
  private:
-  Tensor indices_{CUDA};
-  Tensor scratch_{CUDA};
+  Tensor indices_;
+  Tensor scratch_;
 };
 
 REGISTER_CUDA_OPERATOR(BooleanMask, BooleanMaskOp<CUDAContext>);
index 2dfc4a1..21d3275 100644 (file)
@@ -60,11 +60,13 @@ class BooleanUnmaskOp<CUDAContext> final : public Operator<CUDAContext> {
     out->Resize(maskSize);
     auto* dest = (char*)out->raw_mutable_data(meta);
 
-    hostMasks_.Resize(numMasks);
+    ReinitializeTensor(&hostMasks_, {numMasks}, at::dtype<bool*>().device(CPU));
     auto* hostMasksData = hostMasks_.mutable_data<bool*>();
-    hostValues_.Resize(numMasks);
+    ReinitializeTensor(
+        &hostValues_, {numMasks}, at::dtype<char*>().device(CPU));
     auto* hostValuesData = hostValues_.mutable_data<char*>();
-    hostValueSizes_.Resize(numMasks);
+    ReinitializeTensor(
+        &hostValueSizes_, {numMasks}, at::dtype<int>().device(CPU));
     auto* hostValueSizesData = hostValueSizes_.mutable_data<int>();
     for (int i = 0; i < numMasks; ++i) {
       auto& mask = Input(i * 2);
@@ -81,7 +83,7 @@ class BooleanUnmaskOp<CUDAContext> final : public Operator<CUDAContext> {
     values_.CopyFrom(hostValues_);
     valueSizes_.CopyFrom(hostValueSizes_);
 
-    indices_.Resize(maskSize);
+    ReinitializeTensor(&indices_, {maskSize}, at::dtype<int>().device(CUDA));
     auto* indicesData = indices_.mutable_data<int>();
 
     ComputeIndicesKernel<<<
@@ -109,14 +111,14 @@ class BooleanUnmaskOp<CUDAContext> final : public Operator<CUDAContext> {
   }
 
  private:
-  Tensor indices_{CUDA};
+  Tensor indices_;
   Tensor masks_{CUDA};
   Tensor values_{CUDA};
   Tensor valueSizes_{CUDA};
 
-  Tensor hostMasks_{CPU};
-  Tensor hostValues_{CPU};
-  Tensor hostValueSizes_{CPU};
+  Tensor hostMasks_;
+  Tensor hostValues_;
+  Tensor hostValueSizes_;
 };
 
 REGISTER_CUDA_OPERATOR(BooleanUnmask, BooleanUnmaskOp<CUDAContext>);
index e5fc470..1dc2a64 100644 (file)
@@ -161,9 +161,6 @@ bool ChannelBackpropStatsOp<CUDAContext>::RunOnDevice() {
   const int W = X.ndim() > 3 ? X.dim32(3) : 1;
   const int D = X.ndim() > 4 ? X.dim32(4) : 1;
 
-  
-  
-
   const auto Xarr = X.data<float>();
   const auto dYarr = dY.data<float>();
   const auto meanArr = mean.data<float>();
@@ -177,8 +174,10 @@ bool ChannelBackpropStatsOp<CUDAContext>::RunOnDevice() {
   const auto numBlocksPerChannel = CAFFE_GET_BLOCKS(valsPerChannel);
   const auto numBlocksTotal = numBlocksPerChannel * N * C;
 
-  dBiasScratch_.Resize(numBlocksTotal);
-  dScaleScratch_.Resize(numBlocksTotal);
+  ReinitializeTensor(
+      &dBiasScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(
+      &dScaleScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA));
 
   ChannelBackpropStatsBlockKernel<CAFFE_CUDA_NUM_THREADS>
       <<<numBlocksTotal, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
index ce0e089..05a13bd 100644 (file)
@@ -23,8 +23,8 @@ class ChannelBackpropStatsOp : public Operator<Context> {
   INPUT_TAGS(INPUT, SAVED_MEAN, SAVED_INV_STDDEV, OUTPUT_GRAD);
   OUTPUT_TAGS(SCALE_GRAD, BIAS_GRAD);
 
-  Tensor dBiasScratch_{Context::GetDeviceType()};
-  Tensor dScaleScratch_{Context::GetDeviceType()};
+  Tensor dBiasScratch_;
+  Tensor dScaleScratch_;
 };
 
 } // namespace caffe2
index 8e473b3..b4e8772 100644 (file)
@@ -154,17 +154,16 @@ bool ChannelStatsOp<CUDAContext>::RunOnDevice() {
   const int W = X.ndim() > 3 ? X.dim32(3) : 1;
   const int D = X.ndim() > 4 ? X.dim32(4) : 1;
 
-  
-  
-
   const auto X_arr = X.data<float>();
   const auto valsPerChannel = H * W * D;
 
   const auto numBlocksPerChannel = CAFFE_GET_BLOCKS(valsPerChannel);
   const auto numBlocksTotal = numBlocksPerChannel * N * C;
 
-  sumScratch_.Resize(numBlocksTotal);
-  sumsqScratch_.Resize(numBlocksTotal);
+  ReinitializeTensor(
+      &sumScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(
+      &sumsqScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA));
 
   auto sum = Output(SUM, {C}, at::dtype<float>());
   auto sumsq = Output(SUMSQ, {C}, at::dtype<float>());
index 0ccb885..532938c 100644 (file)
@@ -23,8 +23,8 @@ class ChannelStatsOp : public Operator<Context> {
   INPUT_TAGS(INPUT);
   OUTPUT_TAGS(SUM, SUMSQ);
 
-  Tensor sumScratch_{Context::GetDeviceType()};
-  Tensor sumsqScratch_{Context::GetDeviceType()};
+  Tensor sumScratch_;
+  Tensor sumsqScratch_;
 };
 
 } // namespace caffe2
index d760642..f4ee1b2 100644 (file)
@@ -85,8 +85,8 @@ class ConvGradientOp final : public ConvPoolOpBase<Context> {
   bool RunOnDeviceWithOrderNHWC() override;
 
  private:
-  Tensor col_buffer_{Context::GetDeviceType()};
-  Tensor bias_multiplier_{Context::GetDeviceType()};
+  Tensor col_buffer_;
+  Tensor bias_multiplier_;
   Tensor img_shape_device_{Context::GetDeviceType()};
   Tensor col_buffer_shape_device_{Context::GetDeviceType()};
   bool no_bias_;
index 14baf15..29d98c0 100644 (file)
@@ -521,9 +521,18 @@ bool ConvGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
   col_buffer_shape.push_back(C / group_ * kernel_dims_size);
   col_buffer_shape.insert(
       col_buffer_shape.end(), output_dims.begin(), output_dims.end());
-  col_buffer_.Resize(col_buffer_shape);
+  vector<int64_t> col_buffer_shape_64;
+  std::copy(
+      col_buffer_shape.cbegin(),
+      col_buffer_shape.cend(),
+      std::back_inserter(col_buffer_shape_64));
+  ReinitializeTensor(
+      &col_buffer_,
+      col_buffer_shape_64,
+      at::dtype<T>().device(Context::GetDeviceType()));
 
   if (kernel_.size() != 2) {
+    // TODO: SetDeviceTensor accept vector<int64_t>
     SetDeviceTensor(img_shape, &img_shape_device_);
     SetDeviceTensor(col_buffer_shape, &col_buffer_shape_device_);
   }
@@ -542,15 +551,16 @@ bool ConvGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
   T* dbias_data = nullptr;
   if (!no_bias_) {
     auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<T>());
-    if (bias_multiplier_.numel() != output_image_size) {
-      // If the helper bias multiplier is not M, reshape and fill it with one.
-      bias_multiplier_.Resize(vector<int64_t>(1, output_image_size));
-      math::Set<T, Context>(
-          output_image_size,
-          static_cast<T>(1),
-          bias_multiplier_.template mutable_data<T>(),
-          &context_);
-    }
+    // Removed the check for whether bias_multiplier_ has correct size or not
+    ReinitializeTensor(
+        &bias_multiplier_,
+        vector<int64_t>(1, output_image_size),
+        at::dtype<T>().device(Context::GetDeviceType()));
+    math::Set<T, Context>(
+        output_image_size,
+        static_cast<T>(1),
+        bias_multiplier_.template mutable_data<T>(),
+        &context_);
     dbias_data = dbias->template mutable_data<T>();
     math::Set<T, Context>(dbias->numel(), 0, dbias_data, &context_);
   }
@@ -726,7 +736,15 @@ bool ConvGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
   vector<int> col_buffer_shape(output_dims.size() + 1);
   std::copy(output_dims.cbegin(), output_dims.cend(), col_buffer_shape.begin());
   col_buffer_shape.back() = C * kernel_dims_size;
-  col_buffer_.Resize(col_buffer_shape);
+  vector<int64_t> col_buffer_shape_64;
+  std::copy(
+      col_buffer_shape.cbegin(),
+      col_buffer_shape.cend(),
+      std::back_inserter(col_buffer_shape_64));
+  ReinitializeTensor(
+      &col_buffer_,
+      col_buffer_shape_64,
+      at::dtype<T>().device(Context::GetDeviceType()));
 
   if (kernel_.size() != 2) {
     SetDeviceTensor(img_shape, &img_shape_device_);
@@ -748,15 +766,16 @@ bool ConvGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
     auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<T>());
     dbias_data = dbias->template mutable_data<T>();
     math::Set<T, Context>(dbias->numel(), 0, dbias_data, &context_);
-    if (bias_multiplier_.numel() != output_image_size) {
-      // If the helper bias multiplier is not M, reshape and fill it with one.
-      bias_multiplier_.Resize(vector<int64_t>(1, output_image_size));
-      math::Set<T, Context>(
-          output_image_size,
-          static_cast<T>(1),
-          bias_multiplier_.template mutable_data<T>(),
-          &context_);
-    }
+    // Removed the check for whether bias_multiplier_ has correct size or not
+    ReinitializeTensor(
+        &bias_multiplier_,
+        vector<int64_t>(1, output_image_size),
+        at::dtype<T>().device(Context::GetDeviceType()));
+    math::Set<T, Context>(
+        output_image_size,
+        static_cast<T>(1),
+        bias_multiplier_.template mutable_data<T>(),
+        &context_);
   }
 
   for (int image_id = 0; image_id < N; ++image_id) {
index b8d875d..b743b30 100644 (file)
@@ -18,8 +18,8 @@ class ConvTransposeOp final : public ConvTransposeUnpoolBase<Context> {
   bool RunOnDeviceWithOrderNHWC() override;
 
  private:
-  Tensor col_buffer_{Context::GetDeviceType()};
-  Tensor bias_multiplier_{Context::GetDeviceType()};
+  Tensor col_buffer_;
+  Tensor bias_multiplier_;
   // Input: X, W, b
   // Output: Y
   INPUT_TAGS(INPUT, FILTER, BIAS);
@@ -41,8 +41,8 @@ class ConvTransposeGradientOp final : public ConvTransposeUnpoolBase<Context> {
   bool RunOnDeviceWithOrderNHWC() override;
 
  private:
-  Tensor col_buffer_{Context::GetDeviceType()};
-  Tensor bias_multiplier_{Context::GetDeviceType()};
+  Tensor col_buffer_;
+  Tensor bias_multiplier_;
   const bool no_bias_;
   // input: X, W, dY
   // output: dW, optionally db and dX
index 993bfc9..a073e7b 100644 (file)
@@ -44,15 +44,16 @@ bool ConvTransposeOp<T, Context>::RunOnDeviceWithOrderNCHW() {
     CAFFE_ENFORCE(
         bias.dim32(0) == C,
         "bias dimension must be equal to output channel number");
-    if (bias_multiplier_.numel() != output_image_size) {
-      bias_multiplier_.Resize(vector<int64_t>(1, output_image_size));
+    ReinitializeTensor(
+        &bias_multiplier_,
+        {1, output_image_size},
+        at::dtype<T>().device(Context::GetDeviceType()));
       T* bm_data = bias_multiplier_.template mutable_data<T>();
       math::Set<T, Context>(
           output_image_size,
           static_cast<T>(1),
           bm_data,
           &context_);
-    }
   }
 
   const T* Xdata = X.template data<T>();
@@ -60,8 +61,7 @@ bool ConvTransposeOp<T, Context>::RunOnDeviceWithOrderNCHW() {
   T* Ydata = Y->template mutable_data<T>();
 
   auto f = [&](Tensor* col_buffer) {
-    col_buffer->Resize(
-        vector<int64_t>{C, this->kernel_h(), this->kernel_w(), H, W});
+    ReinitializeTensor(col_buffer, vector<int64_t>{C, this->kernel_h(), this->kernel_w(), H, W}, at::dtype<T>().device(Context::GetDeviceType()));
     T* col_buffer_data = col_buffer->template mutable_data<T>();
     for (auto image_id = 0; image_id < N; ++image_id) {
       // Weight term
@@ -166,23 +166,27 @@ bool ConvTransposeOp<T, Context>::RunOnDeviceWithOrderNHWC() {
     CAFFE_ENFORCE(
         bias.dim32(0) == C,
         "bias dimension must be equal to output channel number");
-    if (bias_multiplier_.numel() != output_image_size) {
-      bias_multiplier_.Resize(vector<int64_t>(1, output_image_size));
+    // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size
+    ReinitializeTensor(
+        &bias_multiplier_,
+        {1, output_image_size},
+        at::dtype<T>().device(Context::GetDeviceType()));
       T* bm_data = bias_multiplier_.template mutable_data<T>();
       math::Set<T, Context>(
           output_image_size,
           static_cast<T>(1),
           bm_data,
           &context_);
-    }
   }
   const T* Xdata = X.template data<T>();
   const T* filter_data = filter.template data<T>();
   T* Ydata = Y->template mutable_data<T>();
 
   auto f = [&](Tensor* /*col_buffer*/) {
-    col_buffer_.Resize(
-        vector<int64_t>{H, W, this->kernel_h(), this->kernel_w(), C});
+    ReinitializeTensor(
+        &col_buffer_,
+        vector<int64_t>{H, W, this->kernel_h(), this->kernel_w(), C},
+        at::dtype<T>().device(Context::GetDeviceType()));
     T* col_buffer_data = col_buffer_.template mutable_data<T>();
     for (auto image_id = 0; image_id < N; ++image_id) {
       // Weight term
@@ -269,20 +273,24 @@ bool ConvTransposeGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
   const int kernel_dim = C * this->kernel_h() * this->kernel_w();
   const int output_image_size = dY.dim32(2) * dY.dim32(3);
   // The col buffer is stored in CHW order as well
-  col_buffer_.Resize(
-      vector<int64_t>{C, this->kernel_h(), this->kernel_w(), H, W});
+  ReinitializeTensor(
+      &col_buffer_,
+      vector<int64_t>{C, this->kernel_h(), this->kernel_w(), H, W},
+      at::dtype<T>().device(Context::GetDeviceType()));
   if (!no_bias_) {
     auto* dbias = Output(BIAS_OR_INPUT_GRAD);
     dbias->Resize(C);
-    if (bias_multiplier_.numel() != output_image_size) {
-      bias_multiplier_.Resize(1, output_image_size);
-      T* bm_data = bias_multiplier_.template mutable_data<T>();
-      math::Set<T, Context>(
-          output_image_size,
-          static_cast<T>(1),
-          bm_data,
-          &context_);
-    }
+    // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size
+    ReinitializeTensor(
+        &bias_multiplier_,
+        {1, output_image_size},
+        at::dtype<T>().device(Context::GetDeviceType()));
+    T* bm_data = bias_multiplier_.template mutable_data<T>();
+    math::Set<T, Context>(
+        output_image_size,
+        static_cast<T>(1),
+        bm_data,
+        &context_);
   }
   T* col_buffer_data = col_buffer_.template mutable_data<T>();
   const T* Xdata = X.template data<T>();
@@ -422,20 +430,24 @@ bool ConvTransposeGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
   const int kernel_dim = C * this->kernel_h() * this->kernel_w();
   const int output_image_size = dY.dim32(1) * dY.dim32(2);
   // The col buffer is stored in HWC order as well
-  col_buffer_.Resize(
-      vector<int64_t>{H, W, this->kernel_h(), this->kernel_w(), C});
+  ReinitializeTensor(
+      &col_buffer_,
+      vector<int64_t>{H, W, this->kernel_h(), this->kernel_w(), C},
+      at::dtype<T>().device(Context::GetDeviceType()));
   if (!no_bias_) {
     auto* dbias = Output(BIAS_OR_INPUT_GRAD);
     dbias->Resize(C);
-    if (bias_multiplier_.numel() != output_image_size) {
-      bias_multiplier_.Resize(1, output_image_size);
-      T* bm_data = bias_multiplier_.template mutable_data<T>();
-      math::Set<T, Context>(
-          output_image_size,
-          static_cast<T>(1),
-          bm_data,
-          &context_);
-    }
+    // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size
+    ReinitializeTensor(
+        &bias_multiplier_,
+        {1, output_image_size},
+        at::dtype<T>().device(Context::GetDeviceType()));
+    T* bm_data = bias_multiplier_.template mutable_data<T>();
+    math::Set<T, Context>(
+        output_image_size,
+        static_cast<T>(1),
+        bm_data,
+        &context_);
   }
   T* col_buffer_data = col_buffer_.template mutable_data<T>();
   const T* Xdata = X.template data<T>();
index e0f2ade..2a21a0d 100644 (file)
@@ -71,7 +71,7 @@ class DeformConvOp final : public DeformConvOpBase<T, Context> {
 
  private:
   Tensor col_buffer_{Context::GetDeviceType()};
-  Tensor bias_multiplier_{Context::GetDeviceType()};
+  Tensor bias_multiplier_;
   Tensor img_shape_device_{Context::GetDeviceType()};
   Tensor col_buffer_shape_device_{Context::GetDeviceType()};
   // Input: X, o, W, b
@@ -96,8 +96,8 @@ class DeformConvGradientOp final : public DeformConvOpBase<T, Context> {
   bool RunOnDeviceWithOrderNCHW() override;
 
  private:
-  Tensor col_buffer_{Context::GetDeviceType()};
-  Tensor bias_multiplier_{Context::GetDeviceType()};
+  Tensor col_buffer_;
+  Tensor bias_multiplier_;
   Tensor img_shape_device_{Context::GetDeviceType()};
   Tensor col_buffer_shape_device_{Context::GetDeviceType()};
   bool no_bias_;
index 4e333ed..94dea27 100644 (file)
@@ -119,7 +119,10 @@ bool DeformConvOp<T, Context>::RunOnDeviceWithOrderNCHW() {
       // If the helper bias multiplier is not image size, reshape and fill it
       // with
       // one.
-      bias_multiplier_.Resize(vector<int64_t>(1, output_image_size));
+      ReinitializeTensor(
+          &bias_multiplier_,
+          vector<int64_t>(1, output_image_size),
+          at::dtype<T>().device(Context::GetDeviceType()));
       math::Set<T, Context>(
           output_image_size,
           static_cast<T>(1),
@@ -280,7 +283,10 @@ bool DeformConvGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
   col_buffer_shape.push_back(C * kernel_dims_size);
   col_buffer_shape.insert(
       col_buffer_shape.end(), output_dims.begin(), output_dims.end());
-  col_buffer_.Resize(col_buffer_shape);
+  ReinitializeTensor(
+      &col_buffer_,
+      col_buffer_shape,
+      at::dtype<T>().device(Context::GetDeviceType()));
 
   const int col_buffer_offset = col_buffer_.size() / group_;
 
@@ -301,7 +307,10 @@ bool DeformConvGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
     auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<T>());
     if (bias_multiplier_.size() != output_image_size) {
       // If the helper bias multiplier is not M, reshape and fill it with one.
-      bias_multiplier_.Resize(vector<int64_t>(1, output_image_size));
+      ReinitializeTensor(
+          &bias_multiplier_,
+          vector<int64_t>(1, output_image_size),
+          at::dtype<T>().device(Context::GetDeviceType()));
       math::Set<T, Context>(
           output_image_size,
           static_cast<T>(1),