Replace SwitchToDevice(0) with SwitchToDevice() (#15126)
authorEdward Yang <ezyang@fb.com>
Mon, 17 Dec 2018 23:09:40 +0000 (15:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 17 Dec 2018 23:15:00 +0000 (15:15 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15126

I want to make people stop manufacturing StreamId from thin air,
and a first step is to make people use the default stream.

Reviewed By: dzhulgakov

Differential Revision: D13432922

fbshipit-source-id: 9f0d8d70646c50d979bde5ba3c3addeebac48a3d

caffe2/core/blob_serialization.cc
caffe2/core/context_gpu.h
caffe2/core/context_gpu_test.cc
caffe2/core/operator.h
caffe2/operators/prefetch_op.h

index 13eab5b..e421719 100644 (file)
@@ -437,7 +437,7 @@ static std::unique_ptr<BaseContext> ContextFromProto(
 
 Tensor EmptyTensorFromProto(const TensorProto& tensor_proto) {
   auto context = ContextFromProto(tensor_proto);
-  context->SwitchToDevice(0);
+  context->SwitchToDevice();
   if (NumelFromTensorProto(tensor_proto) == 0 &&
       tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
     // TODO: remove when serialization of dtype uninitialized tensor is removed
@@ -455,7 +455,7 @@ Tensor EmptyTensorFromProto(const TensorProto& tensor_proto) {
 void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
   auto tensor_proto = blob_proto.tensor();
   auto context = ContextFromProto(tensor_proto);
-  context->SwitchToDevice(0);
+  context->SwitchToDevice();
   if (NumelFromTensorProto(tensor_proto) == 0 &&
       tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
     // TODO: remove after empty Tensor serialization is forbidden
@@ -486,7 +486,7 @@ void TensorDeserializer::DeserializeToTensor(
   auto uniq_ptr = ContextFromProto(tensor_proto);
   // since CopyFromProtoAsIs accepts BaseContext*
   auto context = uniq_ptr.get();
-  context->SwitchToDevice(0);
+  context->SwitchToDevice();
 
   int64_t chunkBegin = 0;
   auto chunkEnd = tensor->numel();
index f0be5ff..d50c745 100644 (file)
@@ -196,6 +196,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
     CaffeCudaSetDevice(gpu_id_);
   }
 
+  // void SwitchToDevice()
   using BaseContext::SwitchToDevice;
 
   inline void WaitEvent(const Event& ev) override {
index 9960e6a..b59dcdb 100644 (file)
@@ -85,7 +85,7 @@ TEST(CUDAContextTest, TestSameThreadTempObject) {
   if (!HasCudaGPU())
     return;
   CUDAContext context_outer(0); // gpu id
-  context_outer.SwitchToDevice(0); // logical stream id
+  context_outer.SwitchToDevice();
 
   if (NumCudaDevices() >= 2) {
     auto before_stream = context_outer.cuda_stream();
@@ -95,7 +95,7 @@ TEST(CUDAContextTest, TestSameThreadTempObject) {
     context_different_device.SwitchToDevice(10);
 
     // go back
-    context_outer.SwitchToDevice(0); // logical stream id
+    context_outer.SwitchToDevice();
     EXPECT_EQ(context_outer.cuda_stream(), before_stream);
 
     // do nothing - infers the current device and stream
index ec33363..3e56e87 100644 (file)
@@ -567,7 +567,7 @@ class Operator : public OperatorBase {
       : OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
     // In the constructor, we switch to the device so that the child class
     // constructors will run on that device.
-    context_.SwitchToDevice(0);
+    context_.SwitchToDevice();
   }
   explicit Operator(
       const c10::FunctionSchema& fn_schema,
@@ -576,7 +576,7 @@ class Operator : public OperatorBase {
       : OperatorBase(fn_schema, inputs, outputs) {
     // In the constructor, we switch to the device so that the child class
     // constructors will run on that device.
-    context_.SwitchToDevice(0);
+    context_.SwitchToDevice();
   }
   ~Operator() noexcept override {}
 
index 6876bd8..ee1d9f5 100644 (file)
@@ -32,7 +32,7 @@ class PrefetchOperator : public OperatorBase {
         prefetch_success_(true),
         finalize_(false),
         no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) {
-    context_.SwitchToDevice(0);
+    context_.SwitchToDevice();
   }
 
   virtual ~PrefetchOperator() noexcept {
@@ -63,7 +63,7 @@ class PrefetchOperator : public OperatorBase {
 
   bool Run(int /* unused */ /*stream_id*/) override {
     if (no_prefetch_) {
-      context_.SwitchToDevice(0);
+      context_.SwitchToDevice();
       bool result = Prefetch() && CopyPrefetched();
       context_.FinishDeviceComputation();
       return result;
@@ -75,7 +75,7 @@ class PrefetchOperator : public OperatorBase {
       prefetch_thread_.reset(
           new std::thread([this] { this->PrefetchWorker(); }));
     }
-    context_.SwitchToDevice(0);
+    context_.SwitchToDevice();
     std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
     while (!prefetched_)
       consumer_.wait(lock);