Remove Context dependency from Tensor class (#14269)
authorDmytro Dzhulgakov <dzhulgakov@fb.com>
Wed, 28 Nov 2018 23:43:22 +0000 (15:43 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 28 Nov 2018 23:45:38 +0000 (15:45 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14269

Removes reference to Context proper and instead adds a bool argument for async copy (the same as `copy_`)

For CopyFrom - I haven't tweaked all callsites yet. Instead I rely on a terrible hack that pointer to context is implicitly converted to bool when passed, haha :) It's not a good code and I propose to fix it in a follow up diff (maybe using clangr tooling).

Reviewed By: ezyang

Differential Revision: D13117981

fbshipit-source-id: 7cb1dc2ba6a4c50ac26614f45ab8318ea96e3138

22 files changed:
aten/src/ATen/core/TensorImpl.h
caffe2/core/tensor.cc
caffe2/core/tensor.h
caffe2/experiments/operators/tt_pad_op.h
caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm
caffe2/operators/box_with_nms_limit_op.cc
caffe2/operators/dataset_ops.cc
caffe2/operators/expand_squeeze_dims_op.h
caffe2/operators/generate_proposals_op.cc
caffe2/operators/last_n_window_collector.cc
caffe2/operators/mean_op.h
caffe2/operators/onnx_while_op.h
caffe2/operators/reservoir_sampling.cc
caffe2/operators/rmac_regions_op.cc
caffe2/operators/sequence_ops.h
caffe2/operators/slice_op.cu
caffe2/operators/slice_op.h
caffe2/operators/stop_gradient.h
caffe2/operators/utility_ops.cu
caffe2/operators/utility_ops.h
caffe2/queue/queue_ops.h
caffe2/video/video_input_op.h

index 59edcaf..a4ebc36 100644 (file)
@@ -917,9 +917,9 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
    * a tensor on CPU and then CopyFrom a CUDA tensor, that will to a
    * CUDA-to-CPU transfer).
    *
-   * If the function is invoked without `context` the copy would be synchronous
+   * 'async' parameter triggers async copy for CUDA tensors
    */
-  void CopyFrom(const TensorImpl& src, at::BaseContext* context = nullptr) {
+  void CopyFrom(const TensorImpl& src, bool async = false) {
     AT_ASSERT(!is_variable());
     AT_ASSERTM(
         src.is_contiguous(),
@@ -978,7 +978,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
             src.device(),
             new_data,
             device(),
-            context != nullptr);
+            async);
       }
     }
   }
@@ -991,8 +991,10 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
    * elements, in which case this tensors' capacity is grown at a factor of
    * growthPct. This ensures that Extend runs on an amortized O(1) time
    * complexity.
+   *
+   * This op is auto-asynchronous if the underlying device (CUDA) supports it.
    */
-  void Extend(int64_t num, float growthPct, at::BaseContext* context) {
+  void Extend(int64_t num, float growthPct) {
     AT_ASSERT(sizes_.size() >= 1u);
     AT_ASSERTM(num >= 0, "`num` must be non-negative for Extend");
     AT_ASSERTM(
@@ -1022,8 +1024,6 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
     auto oldDims = sizes_;
     Resize(newCapacity);
     auto* newData = raw_mutable_data(data_type_);
-    AT_ASSERTM(
-        context != nullptr, "Context must be provided to Extend the tensor");
     if (data_type_.copy()) {
       AT_ASSERTM(
           device_type() == ::at::DeviceType::CPU,
index 8030ffe..79e751c 100644 (file)
@@ -159,7 +159,7 @@ void ReinitializeAndCopyFrom(
     Tensor* t,
     at::TensorOptions options,
     const Tensor& src,
-    BaseContext* context) {
+    bool async) {
   auto device_type = options.device().type();
   CAFFE_ENFORCE(t != nullptr, "Target tensor ptr is null.");
   if (!*t || device_type != t->GetDeviceType()) {
@@ -172,7 +172,7 @@ void ReinitializeAndCopyFrom(
       t->dtype(),
       " to: ",
       src.dtype());
-  t->CopyFrom(src, context);
+  t->CopyFrom(src, async);
 }
 
 namespace {
index 50ee348..4015422 100644 (file)
@@ -97,23 +97,22 @@ class CAFFE2_API Tensor final {
     return impl_.get()->GetDevice();
   }
 
-  void CopyFrom(const Tensor& src, BaseContext* context = nullptr) const {
-    impl_.get()->CopyFrom(*src.impl_.get(), context);
+  void CopyFrom(const Tensor& src, bool async = false) const {
+    impl_.get()->CopyFrom(*src.impl_.get(), async);
   }
 
   /**
    * @brief Extend the outer-most dimension of this tensor
    *        to dimension of `num`.
    */
-  void ExtendTo(int64_t num, float growthPct, BaseContext* context) const {
+  void ExtendTo(int64_t num, float growthPct) const {
     CAFFE_ENFORCE_GE_WITH_CALLER(impl_->dim(), 1);
     CAFFE_ENFORCE_GE_WITH_CALLER(growthPct, 0);
-    CAFFE_ENFORCE(context != nullptr, "Context must be provided.");
-    Extend(num - impl_->size(0), growthPct, context);
+    Extend(num - impl_->size(0), growthPct);
   }
 
-  void Extend(int64_t num, float growthPct, BaseContext* context) const {
-    impl_.get()->Extend(num, growthPct, context);
+  void Extend(int64_t num, float growthPct) const {
+    impl_.get()->Extend(num, growthPct);
   }
 
   /**
@@ -451,7 +450,7 @@ CAFFE2_API void ReinitializeAndCopyFrom(
     Tensor* t,
     at::TensorOptions options,
     const Tensor& src,
-    BaseContext* context = nullptr);
+    bool async = false);
 
 CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(12, Tensor)
 
index 57e0d4e..e25159d 100644 (file)
@@ -52,7 +52,7 @@ class TTPadOp final : public Operator<Context> {
       int64_t padded_dim0 = (X_dim0 / scale_ + 1) * scale_;
       auto dim0_diff = padded_dim0 - X_dim0;
       // set growthPct to the upper bound percentage: (100 * scale_ / X_dim0)
-      X_pad->Extend(dim0_diff, 100 * scale_ / X_dim0, &context_);
+      X_pad->Extend(dim0_diff, 100 * scale_ / X_dim0);
 
       auto* X_pad_data = X_pad->template mutable_data<T>();
       int64_t X_size = X_dim0 * X_dim1;
index 4e8cb62..aa6a547 100644 (file)
@@ -2302,8 +2302,8 @@ class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
       int csz = im_i_boxes.rows();
       int cur_start_idx = out_rois->dim(0);
 
-      out_rois->Extend(csz, 50, &context_);
-      out_rois_probs->Extend(csz, 50, &context_);
+      out_rois->Extend(csz, 50);
+      out_rois_probs->Extend(csz, 50);
 
       // write rois
       Eigen::Map<ERArrXXf> cur_rois(
index 2b83e19..18646b4 100644 (file)
@@ -167,9 +167,9 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
 
     // Write results
     int cur_start_idx = out_scores->size(0);
-    out_scores->Extend(total_keep_count, 50, &context_);
-    out_boxes->Extend(total_keep_count, 50, &context_);
-    out_classes->Extend(total_keep_count, 50, &context_);
+    out_scores->Extend(total_keep_count, 50);
+    out_boxes->Extend(total_keep_count, 50);
+    out_classes->Extend(total_keep_count, 50);
 
     int cur_out_idx = 0;
     for (int j = 1; j < num_classes; j++) {
@@ -202,7 +202,7 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
     }
 
     if (out_keeps) {
-      out_keeps->Extend(total_keep_count, 50, &context_);
+      out_keeps->Extend(total_keep_count, 50);
 
       Eigen::Map<EArrXi> out_keeps_arr(
           out_keeps->template mutable_data<int>() + cur_start_idx,
index 3a074de..4a8efc4 100644 (file)
@@ -776,7 +776,7 @@ class AppendOp final : public Operator<Context> {
       CAFFE_ENFORCE(a.sizes()[i] == b.sizes()[i]);
     }
     auto oldSize = c->numel();
-    c->Extend(b.sizes()[0], kDatasetGrowthPct, &context_);
+    c->Extend(b.sizes()[0], kDatasetGrowthPct);
     auto* dst = (char*)c->raw_mutable_data() + oldSize * b.dtype().itemsize();
     context_.CopyItemsSameDevice(b.dtype(), b.numel(), b.raw_data(), dst);
     return true;
@@ -826,7 +826,7 @@ class AtomicAppendOp final : public Operator<Context> {
         continue;
       }
       auto oldSize = c->numel();
-      c->Extend(b.sizes()[0], kDatasetGrowthPct, &context_);
+      c->Extend(b.sizes()[0], kDatasetGrowthPct);
       auto* dst = (char*)c->raw_mutable_data() + oldSize * b.dtype().itemsize();
       context_.CopyItemsSameDevice(b.dtype(), b.numel(), b.raw_data(), dst);
     }
index f9c8798..89ff9c0 100644 (file)
@@ -26,7 +26,7 @@ class ExpandDimsOp : public Operator<Context> {
   bool RunOnDevice() override {
     auto& input = Input(0);
     auto* output = Output(0);
-    output->CopyFrom(input, &context_);
+    output->CopyFrom(input, true /*async*/);
     if (dims_.empty()) {
       return true;
     }
@@ -70,7 +70,7 @@ class SqueezeOp : public Operator<Context> {
   bool RunOnDevice() override {
     auto& input = Input(0);
     auto* output = Output(0);
-    output->CopyFrom(input, &context_);
+    output->CopyFrom(input, true /*async*/);
 
     CAFFE_ENFORCE_GT(
         input.dim(),
index 0646f27..ade6bfe 100644 (file)
@@ -284,8 +284,8 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
   for (int i = 0; i < num_images; i++) {
     roi_counts += im_boxes[i].rows();
   }
-  out_rois->Extend(roi_counts, 50, &context_);
-  out_rois_probs->Extend(roi_counts, 50, &context_);
+  out_rois->Extend(roi_counts, 50);
+  out_rois_probs->Extend(roi_counts, 50);
   float* out_rois_ptr = out_rois->template mutable_data<float>();
   float* out_rois_probs_ptr = out_rois_probs->template mutable_data<float>();
   for (int i = 0; i < num_images; i++) {
index b98d028..2b0695f 100644 (file)
@@ -71,7 +71,7 @@ class LastNWindowCollectorOp : public Operator<Context> {
     if (num_entries == 0) {
       if (!output_initialized) {
         // Get both shape and meta
-        output->CopyFrom(input, &context_);
+        output->CopyFrom(input, true /*async*/);
       }
       return true;
     }
@@ -83,7 +83,7 @@ class LastNWindowCollectorOp : public Operator<Context> {
 
     // output_num is >= output_batch_size
     if (output_num > output_batch_size) {
-      output->ExtendTo(output_num, 50, &context_);
+      output->ExtendTo(output_num, 50);
     }
 
     auto* output_data =
index 413a0f3..0a5d072 100644 (file)
@@ -23,7 +23,7 @@ class MeanOp final : public Operator<Context> {
     auto* output = Output(0);
 
     output->ResizeLike(input0);
-    output->CopyFrom(input0, &context_);
+    output->CopyFrom(input0, true /*async*/);
 
     if (InputSize() == 1) {
       return true;
@@ -102,7 +102,7 @@ class MeanGradientOp : public Operator<Context> {
     for (int i = 1; i < num_inputs; i++) {
       auto* cur_dX = Output(i);
       cur_dX->ResizeLike(dY);
-      cur_dX->CopyFrom(*dX0, &context_);
+      cur_dX->CopyFrom(*dX0, true /*async*/);
     }
 
     return true;
index 4614b57..eeb45bb 100644 (file)
@@ -171,7 +171,7 @@ class ONNXWhileOp final : public Operator<Context> {
                 scan_outputs_sizes[i],
                 "Size of scan output changed across iterations");
             dims.insert(dims.begin(), itr);
-            scan_output_target->Extend(1, 100, &context_);
+            scan_output_target->Extend(1, 100);
 
             int64_t timestep_size = 1;
             for (const int64_t t : scan_outputs_sizes[i]) {
index 287a77d..285dbba 100644 (file)
@@ -103,9 +103,9 @@ class ReservoirSamplingOp final : public Operator<Context> {
     auto output_num =
         std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
     // output_num is >= output_batch_size
-    output->ExtendTo(output_num, 50, &context_);
+    output->ExtendTo(output_num, 50);
     if (pos_to_object) {
-      pos_to_object->ExtendTo(output_num, 50, &context_);
+      pos_to_object->ExtendTo(output_num, 50);
     }
 
     auto* output_data =
index 458afde..ab04906 100644 (file)
@@ -58,7 +58,7 @@ bool RMACRegionsOp<CPUContext>::RunOnDevice() {
         (l + Hd - 1 > 0) ? ((H - region_size) / (1.0 * (l + Hd - 1))) : 0;
 
     int cur_rows = output->dim32(0);
-    output->Extend((l + Wd) * (l + Hd), 50, &context_);
+    output->Extend((l + Wd) * (l + Hd), 50);
     auto* outputData = output->template mutable_data<float>() + cur_rows * 5;
 
     for (int i = 0; i < l + Wd; ++i) {
@@ -87,7 +87,7 @@ bool RMACRegionsOp<CPUContext>::RunOnDevice() {
 
   // Replicate regions for all items in batch
   int num_rois = output->dim32(0);
-  output->Extend((batch_size - 1) * num_rois, 50, &context_);
+  output->Extend((batch_size - 1) * num_rois, 50);
   auto* outputData = output->template mutable_data<float>();
   for (int b = 1; b < batch_size; ++b) {
     // Copy all rois
index b521aa9..0b41da8 100644 (file)
@@ -120,9 +120,9 @@ class RemovePaddingOp final : public Operator<Context> {
 
   bool RunOnDevice() override {
     if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
-      Output(0)->CopyFrom(Input(0), &context_);
+      Output(0)->CopyFrom(Input(0), true /*async*/);
       if (OutputSize() == 2) {
-        Output(1)->CopyFrom(Input(1), &context_);
+        Output(1)->CopyFrom(Input(1), true /*async*/);
       }
       return true;
     }
@@ -160,9 +160,9 @@ class AddPaddingOp final : public Operator<Context> {
 
   bool RunOnDevice() override {
     if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
-      Output(0)->CopyFrom(Input(0), &context_);
+      Output(0)->CopyFrom(Input(0), true /*async*/);
       if (OutputSize() == 2) {
-        Output(1)->CopyFrom(Input(1), &context_);
+        Output(1)->CopyFrom(Input(1), true /*async*/);
       }
       return true;
     }
index 8ddb204..7d888f2 100644 (file)
@@ -123,9 +123,9 @@ bool SliceImplGpu(
   }
   if (dim == -1) {
     if (!backward) {
-      output->CopyFrom(data, context);
+      output->CopyFrom(data, true /*async*/);
     } else {
-      gdata->CopyFrom(*go, context);
+      gdata->CopyFrom(*go, true /*async*/);
     }
     return true;
   }
index 2e07beb..eb9193f 100644 (file)
@@ -85,9 +85,9 @@ bool SliceImpl(
   }
   if (dim == -1) {
     if (!backward) {
-      output->CopyFrom(data, context);
+      output->CopyFrom(data, true /*async*/);
     } else {
-      gdata->CopyFrom(*go, context);
+      gdata->CopyFrom(*go, true /*async*/);
     }
     return true;
   }
index e05cd11..68bbad6 100644 (file)
@@ -14,7 +14,7 @@ class StopGradientOp : public Operator<Context> {
     const auto& in = Input(0);
     auto* out = Output(0);
     if (out != &in) {
-      out->CopyFrom(in, &context_);
+      out->CopyFrom(in, true /*async*/);
     }
     return true;
   }
index 868e849..0d9bb32 100644 (file)
@@ -130,7 +130,7 @@ bool NanCheckOp<CUDAContext>::RunOnDevice() {
   // This op should act as an identity matrix if we don't find any NaNs/infs.
   // Copy over the data if we are not doing this in-place.
   if (&X != Y) {
-    Y->CopyFrom(X, &context_);
+    Y->CopyFrom(X, true /*async*/);
   }
   return true;
 }
index 09f39d8..0e03bfb 100644 (file)
@@ -196,7 +196,7 @@ class EnsureDenseOp final : public Operator<Context> {
     // allow the output to be copied from the input
     if (&input != output) {
       output->ResizeLike(input);
-      output->CopyFrom(input, &context_);
+      output->CopyFrom(input, true /*async*/);
     }
     return true;
   }
@@ -257,7 +257,7 @@ class SumOp : public Operator<Context> {
     auto& input0 = Input(0);
     auto* output = Output(0);
     if (InputSize() == 1) {
-      output->CopyFrom(input0, &context_);
+      output->CopyFrom(input0, true /*async*/);
       return true;
     }
     output->ResizeLike(input0);
index d5681e5..1050049 100644 (file)
@@ -160,7 +160,7 @@ class SafeDequeueBlobsOp final : public Operator<Context> {
               size,
               " total columns");
 
-          out->Extend(in.sizes()[0], kTensorGrowthPct, &context_);
+          out->Extend(in.sizes()[0], kTensorGrowthPct);
           auto* dst =
               (char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize();
           context_.template CopyItems<Context, Context>(
index b2dffc2..a7855f6 100644 (file)
@@ -808,14 +808,17 @@ bool VideoInputOp<Context>::Prefetch() {
   // prefetch function as well.
   if (!std::is_same<Context, CPUContext>::value) {
     if (get_rgb_) {
-      prefetched_clip_rgb_on_device_.CopyFrom(prefetched_clip_rgb_, &context_);
+      prefetched_clip_rgb_on_device_.CopyFrom(
+          prefetched_clip_rgb_, true /*async*/);
     }
     if (get_optical_flow_) {
-      prefetched_clip_of_on_device_.CopyFrom(prefetched_clip_of_, &context_);
+      prefetched_clip_of_on_device_.CopyFrom(
+          prefetched_clip_of_, true /*async*/);
     }
-    prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
+    prefetched_label_on_device_.CopyFrom(prefetched_label_, true /*async*/);
     if (get_video_id_) {
-      prefetched_video_id_on_device_.CopyFrom(prefetched_video_id_, &context_);
+      prefetched_video_id_on_device_.CopyFrom(
+          prefetched_video_id_, true /*async*/);
     }
   }
   return true;
@@ -828,34 +831,34 @@ bool VideoInputOp<Context>::CopyPrefetched() {
     auto* clip_rgb_output =
         OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
     if (std::is_same<Context, CPUContext>::value) {
-      clip_rgb_output->CopyFrom(prefetched_clip_rgb_, &context_);
+      clip_rgb_output->CopyFrom(prefetched_clip_rgb_, true /*async*/);
     } else {
-      clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, &context_);
+      clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, true /*async*/);
     }
   }
   if (get_optical_flow_) {
     auto* clip_of_output =
         OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
     if (std::is_same<Context, CPUContext>::value) {
-      clip_of_output->CopyFrom(prefetched_clip_of_, &context_);
+      clip_of_output->CopyFrom(prefetched_clip_of_, true /*async*/);
     } else {
-      clip_of_output->CopyFrom(prefetched_clip_of_on_device_, &context_);
+      clip_of_output->CopyFrom(prefetched_clip_of_on_device_, true /*async*/);
     }
   }
   auto* label_output =
       OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
   if (std::is_same<Context, CPUContext>::value) {
-    label_output->CopyFrom(prefetched_label_, &context_);
+    label_output->CopyFrom(prefetched_label_, true /*async*/);
   } else {
-    label_output->CopyFrom(prefetched_label_on_device_, &context_);
+    label_output->CopyFrom(prefetched_label_on_device_, true /*async*/);
   }
   if (get_video_id_) {
     auto* video_id_output =
         OperatorBase::Output<Tensor>(index, Context::GetDeviceType());
     if (std::is_same<Context, CPUContext>::value) {
-      video_id_output->CopyFrom(prefetched_video_id_, &context_);
+      video_id_output->CopyFrom(prefetched_video_id_, true /*async*/);
     } else {
-      video_id_output->CopyFrom(prefetched_video_id_on_device_, &context_);
+      video_id_output->CopyFrom(prefetched_video_id_on_device_, true /*async*/);
     }
   }
   return true;