Replace std::unordered_map<c10::Device, c10::Device> with DeviceMap (#64393)
authorPavel Belevich <pbelevich@fb.com>
Thu, 2 Sep 2021 07:57:39 +0000 (00:57 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 08:36:19 +0000 (01:36 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64393

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse agolynski SciPioneer H-Huang mrzzd cbalioglu gcramer23

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D30708384

Pulled By: pbelevich

fbshipit-source-id: 1c565727e4f09cd9e560874dd90aa403470b4a97

torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp
torch/csrc/distributed/autograd/functions/recvrpc_backward.h
torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h
torch/csrc/distributed/autograd/utils.cpp
torch/csrc/distributed/autograd/utils.h
torch/csrc/distributed/rpc/request_callback_no_python.cpp
torch/csrc/distributed/rpc/rpc_agent.h
torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp
torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h
torch/csrc/distributed/rpc/utils.cpp

index 0d82c07..a492d98 100644 (file)
@@ -15,7 +15,7 @@ RecvRpcBackward::RecvRpcBackward(
     const AutogradMetadata& autogradMetadata,
     ContextPtr autogradContext,
     rpc::worker_id_t fromWorkerId,
-    std::unordered_map<c10::Device, c10::Device> deviceMap)
+    rpc::DeviceMap deviceMap)
     : autogradMetadata_(autogradMetadata),
       // NOLINTNEXTLINE(performance-move-const-arg)
       autogradContext_(std::move(autogradContext)),
index 46bdb29..6e6678b 100644 (file)
@@ -23,7 +23,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
       const AutogradMetadata& autogradMetadata,
       std::shared_ptr<DistAutogradContext> autogradContext,
       rpc::worker_id_t fromWorkerId,
-      std::unordered_map<c10::Device, c10::Device> deviceMap);
+      rpc::DeviceMap deviceMap);
 
   torch::autograd::variable_list apply(
       torch::autograd::variable_list&& grads) override;
@@ -41,7 +41,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
   rpc::worker_id_t fromWorkerId_;
 
   // Device mapping for tensors sent over RPC.
-  const std::unordered_map<c10::Device, c10::Device> deviceMap_;
+  const rpc::DeviceMap deviceMap_;
 };
 
 } // namespace autograd
index 4d84e99..b8d28f7 100644 (file)
@@ -19,7 +19,7 @@ RpcWithAutograd::RpcWithAutograd(
     MessageType messageType,
     const AutogradMetadata& autogradMetadata,
     c10::intrusive_ptr<rpc::Message> wrappedMessage,
-    std::unordered_map<c10::Device, c10::Device> deviceMap)
+    rpc::DeviceMap deviceMap)
     : fromWorkerId_(fromWorkerId),
       messageType_(messageType),
       autogradMetadata_(autogradMetadata),
@@ -39,7 +39,7 @@ RpcWithAutograd::RpcWithAutograd(
     std::unique_ptr<RpcCommandBase> wrappedRpc,
     MessageType wrappedMessageType,
     std::vector<torch::Tensor> tensors,
-    std::unordered_map<c10::Device, c10::Device> deviceMap)
+    rpc::DeviceMap deviceMap)
     : fromWorkerId_(fromWorkerId),
       messageType_(messageType),
       autogradMetadata_(autogradMetadata),
@@ -112,7 +112,7 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
   auto c10DeviceMap = tupleElements[4].to<c10::Dict<std::string, std::string>>();
 
   // Convert to regular map.
-  std::unordered_map<c10::Device, c10::Device> deviceMap;
+  rpc::DeviceMap deviceMap;
   for (const auto& mapEntry : c10DeviceMap) {
     deviceMap.insert({mapEntry.key(), mapEntry.value()});
   }
@@ -169,7 +169,7 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
   return fromWorkerId_;
 }
 
-const std::unordered_map<c10::Device, c10::Device>& RpcWithAutograd::
+const rpc::DeviceMap& RpcWithAutograd::
     deviceMap() {
   return deviceMap_;
 }
index 1884cc9..6d0b611 100644 (file)
@@ -19,7 +19,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
       rpc::MessageType messageType,
       const AutogradMetadata& autogradMetadata,
       c10::intrusive_ptr<rpc::Message> wrappedMessage,
-      std::unordered_map<c10::Device, c10::Device> deviceMap = {});
+      rpc::DeviceMap deviceMap = {});
 
   // Used when receiving an RPC over the wire.
   RpcWithAutograd(
@@ -29,7 +29,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
       std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
       rpc::MessageType wrappedMessageType,
       std::vector<torch::Tensor> tensors,
-      std::unordered_map<c10::Device, c10::Device> deviceMap = {});
+      rpc::DeviceMap deviceMap = {});
 
   c10::intrusive_ptr<rpc::Message> toMessageImpl() && override;
 
@@ -55,7 +55,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
   rpc::worker_id_t fromWorkerId() const;
 
   // Retrieve the device map.
-  const std::unordered_map<c10::Device, c10::Device>& deviceMap();
+  const rpc::DeviceMap& deviceMap();
 
  private:
   // WorkerId from which this RPC originated. This is necessary for knowing
@@ -90,7 +90,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
   std::vector<torch::Tensor> tensors_;
 
   // Device mapping for tensors that are sent across an RPC to another node.
-  std::unordered_map<c10::Device, c10::Device> deviceMap_;
+  rpc::DeviceMap deviceMap_;
 };
 
 } // namespace autograd
index 4e29bfc..9db4076 100644 (file)
@@ -53,7 +53,7 @@ ContextPtr addRecvRpcBackward(
     const AutogradMetadata& autogradMetadata,
     std::vector<torch::Tensor>& tensors,
     rpc::worker_id_t fromWorkerId,
-    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
+    const rpc::DeviceMap& deviceMap) {
   // Initialize autograd context if necessary.
   auto& autogradContainer = DistAutogradContainer::getInstance();
   auto autogradContext =
@@ -105,7 +105,7 @@ c10::intrusive_ptr<Message> getMessageWithAutograd(
     c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
     MessageType msgType,
     bool forceGradRecording,
-    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
+    const rpc::DeviceMap& deviceMap) {
   auto& autogradContainer = DistAutogradContainer::getInstance();
 
   // If there is no valid context and no tensor requires grads, send original
index fae675d..94883ce 100644 (file)
@@ -31,7 +31,7 @@ TORCH_API ContextPtr addRecvRpcBackward(
     const AutogradMetadata& autogradMetadata,
     std::vector<torch::Tensor>& tensors,
     rpc::worker_id_t fromWorkerId,
-    const std::unordered_map<c10::Device, c10::Device>& deviceMap);
+    const rpc::DeviceMap& deviceMap);
 
 // This method is a wrapper utility used internally to wrap autograd info
 // and attach autograd function for each type of rpc call if it has valid
@@ -44,7 +44,7 @@ TORCH_API c10::intrusive_ptr<rpc::Message> getMessageWithAutograd(
     c10::intrusive_ptr<rpc::Message> wrappedRpcMsg,
     rpc::MessageType msgType,
     bool forceGradRecording = false,
-    const std::unordered_map<c10::Device, c10::Device>& deviceMap =
+    const rpc::DeviceMap& deviceMap =
         {});
 
 // Send message after autograd checking
index 5eada8d..9e16061 100644 (file)
@@ -290,7 +290,7 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
 
   // Need to reverse the device map for the backward pass of distributed
   // autograd.
-  std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
+  DeviceMap reverseDeviceMap;
   for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
     reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
   }
index a83e77b..7cd228e 100644 (file)
@@ -164,7 +164,7 @@ class TORCH_API RpcAgent {
       const WorkerInfo& to,
       c10::intrusive_ptr<Message> message,
       const float rpcTimeoutSeconds = kUnsetRpcTimeout,
-      const std::unordered_map<c10::Device, c10::Device>& deviceMap = {}) = 0;
+      const DeviceMap& deviceMap = {}) = 0;
 
   // Retries sending the message up to maxRetries times until an ACK is
   // receieved. The duration between consecutive sends is increased over
index 72d4d5d..a2e0525 100644 (file)
@@ -67,7 +67,7 @@ c10::intrusive_ptr<JitFuture> FaultyTensorPipeAgent::send(
     const WorkerInfo& to,
     c10::intrusive_ptr<Message> message,
     const float rpcTimeoutSeconds,
-    const std::unordered_map<c10::Device, c10::Device>& /* unused */) {
+    const DeviceMap& /* unused */) {
   // We only fail control messages that have been specified by the test case.
   // For all other messages, we just send them without any failures.
   if (!shouldFailMessage(message->type())) {
index 5d60597..e69a76c 100644 (file)
@@ -53,7 +53,7 @@ class TORCH_API FaultyTensorPipeAgent : public TensorPipeAgent {
       const WorkerInfo& to,
       c10::intrusive_ptr<Message> message,
       const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
-      const std::unordered_map<c10::Device, c10::Device>& deviceMap = {})
+      const DeviceMap& deviceMap = {})
       override;
 
   // Add delay to writes
index 615abbf..820ec31 100644 (file)
@@ -177,7 +177,7 @@ std::unique_ptr<RpcCommandBase> deserializeResponse(
 
       // Need to reverse the device map for the backward pass of distributed
       // autograd.
-      std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
+      DeviceMap reverseDeviceMap;
       for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
         reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
       }