add support for sending cpu sparse tensors over rpc (#62794)
authorGarrett Cramer <gcramer@fb.com>
Sun, 29 Aug 2021 18:33:48 +0000 (11:33 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sun, 29 Aug 2021 18:35:00 +0000 (11:35 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62794

This pr updates jit serialization to support pickling Sparse COO tensors.
This pr updates message.cpp to support Sparse COO tensors.
A bug was filed a few years ago https://github.com/pytorch/pytorch/issues/30807.

I tested the fix by adding sparse tensor tests to rpc_test.py and dist_autograd_test.py.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse agolynski SciPioneer H-Huang mrzzd cbalioglu gcramer23 gmagogsfm

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D30608848

Pulled By: gcramer23

fbshipit-source-id: 629ba8e4a3d8365875a709c9b87447c7a71204fb

torch/csrc/distributed/rpc/message.cpp
torch/csrc/jit/serialization/pickler.cpp
torch/csrc/jit/serialization/pickler.h
torch/csrc/jit/serialization/unpickler.cpp
torch/csrc/jit/serialization/unpickler.h
torch/testing/_internal/distributed/rpc/dist_autograd_test.py
torch/testing/_internal/distributed/rpc/rpc_test.py

index 0277114..7265ed4 100644 (file)
@@ -68,10 +68,17 @@ void Message::setId(int64_t id) {
 
 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> Message::getStorages()
     const {
+  // Sparse tensors do not have storage. Instead, a sparse tensor
+  // contains two tensors indices and values, and both contain storage.
   std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages;
-  storages.reserve(tensors_.size());
+  storages.reserve(2 * tensors_.size());
   for (const auto& tensor : tensors_) {
-    storages.emplace_back(tensor.storage().getWeakStorageImpl());
+    if (tensor.is_sparse()) {
+      storages.emplace_back(tensor._indices().storage().getWeakStorageImpl());
+      storages.emplace_back(tensor._values().storage().getWeakStorageImpl());
+    } else {
+      storages.emplace_back(tensor.storage().getWeakStorageImpl());
+    }
   }
   return storages;
 }
index 4a4e866..f465eaf 100644 (file)
@@ -353,6 +353,44 @@ void Pickler::pushTensor(const IValue& ivalue) {
   }
 }
 
+void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) {
+  pushGlobal("torch._utils", "_rebuild_sparse_tensor");
+  push<PickleOpCode>(PickleOpCode::MARK);
+  // layout
+  auto layout = static_cast<int>(tensor.layout());
+  pushInt(layout);
+  switch (layout) {
+    case static_cast<int>(c10::Layout::Sparse):
+      // size
+      push<PickleOpCode>(PickleOpCode::MARK);
+      for (auto size : tensor.sizes()) {
+        pushInt(size);
+      }
+      push<PickleOpCode>(PickleOpCode::TUPLE);
+      // requires grad
+      pushIValue(tensor.requires_grad());
+      // indices
+      pushTensor(tensor._indices());
+      // values
+      pushTensor(tensor._values());
+      break;
+    default:
+      TORCH_CHECK(
+          false,
+          "Unsupported sparse tensor layout type in serialization ",
+          static_cast<c10::Layout>(layout));
+      break;
+  }
+  // backward_hooks
+  pushGlobal("collections", "OrderedDict");
+  push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
+  // Construct the collections.OrderedDict for the backward_hooks
+  push<PickleOpCode>(PickleOpCode::REDUCE);
+  push<PickleOpCode>(PickleOpCode::TUPLE);
+  // Call torch._utils._rebuild_sparse_coo_tensor
+  push<PickleOpCode>(PickleOpCode::REDUCE);
+}
+
 void Pickler::pushLiteralTensor(const IValue& ivalue) {
   // In contrast to tensor references, literal tensors are included in the
   // pickle program binary blob. They are written to the file after the STOP
@@ -362,6 +400,12 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) {
   // The format here is the same one used by `torch.save()`. The code for the
   // format can be found in `torch/serialization.py`.
   auto& tensor = ivalue.toTensor();
+
+  if (tensor.is_sparse() || tensor.is_sparse_csr()) {
+    pushLiteralSparseTensor(tensor);
+    return;
+  }
+
   bool quantized = tensor.is_quantized();
   // The arguments to this function are:
   //    storage, storage_offset, size, stride, requires_grad, backward_hooks
index ac54ac4..3dc6bef 100644 (file)
@@ -172,6 +172,7 @@ class TORCH_API Pickler {
   void pushTensor(const IValue& ivalue);
   void pushTensorReference(const IValue& ivalue);
   void pushLiteralTensor(const IValue& ivalue);
+  void pushLiteralSparseTensor(const at::Tensor& tensor);
   void pushTuple(const IValue& ivalue);
   void pushString(const std::string& string);
   void pushDevice(const IValue& ivalue);
index 581b949..f944387 100644 (file)
@@ -550,6 +550,9 @@ void Unpickler::readGlobal(
     // Unpickle a tensor
     bool quantized = class_name == "_rebuild_qtensor";
     rebuildTensor(quantized);
+  } else if (
+      module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
+    rebuildSparseTensor();
   } else if (module_name == "builtins" && class_name == "complex") {
     globals_.emplace_back([this] {
       auto elems = pop(stack_).toTuple()->elements();
@@ -647,6 +650,38 @@ void Unpickler::readGlobal(
   stack_.emplace_back(int64_t(globals_.size() - 1));
 }
 
+void Unpickler::rebuildSparseTensor() {
+  globals_.emplace_back([this] {
+    auto tup = pop(stack_).toTuple();
+    const auto& elements = tup->elements();
+    size_t idx = 0;
+    auto layout = elements.at(idx++).toInt();
+    at::Tensor result;
+    switch (layout) {
+      case static_cast<int>(c10::Layout::Sparse): {
+        std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
+        bool requires_grad = elements.at(idx++).toBool();
+        auto& indices_tensor = elements.at(idx++).toTensor();
+        auto& values_tensor = elements.at(idx++).toTensor();
+        auto options = values_tensor.options()
+                           .layout(c10::Layout::Sparse)
+                           .requires_grad(requires_grad);
+        result = at::_sparse_coo_tensor_unsafe(
+            indices_tensor, values_tensor, size, options);
+        result = autograd::make_variable(result, options.requires_grad());
+        break;
+      }
+      default:
+        TORCH_CHECK(
+            false,
+            "Unsupported sparse tensor layout type in serialization ",
+            static_cast<c10::Layout>(layout));
+        break;
+    }
+    stack_.emplace_back(std::move(result));
+  });
+}
+
 void Unpickler::rebuildTensor(bool quantized) {
   globals_.emplace_back([this, quantized] {
     auto tup = pop(stack_).toTuple();
index f404dee..586ff9c 100644 (file)
@@ -108,6 +108,7 @@ class TORCH_API Unpickler {
       const std::string& module_name,
       const std::string& class_name);
   void rebuildTensor(bool quantized);
+  void rebuildSparseTensor();
 #ifdef USE_DISTRIBUTED
   void rebuildRRef();
 #endif
index 017a61b..fba5030 100644 (file)
@@ -64,13 +64,29 @@ def _torch_ones(sizes, requires_grad=False):
 # rref tensor equals to the given grad.
 def _compare_owner_value(context_id, rref, grad):
     grads = dist_autograd.get_gradients(context_id)
-    return torch.equal(grads[rref.local_value()], grad)
+    x = grads[rref.local_value()]
+    if x.is_sparse:
+        assert grad.is_sparse
+        x = x.to_dense()
+        grad = grad.to_dense()
+    else:
+        assert not grad.is_sparse
+    return torch.equal(x, grad)
 
 
 def create_tensor():
     return torch.ones((3, 3), requires_grad=True)
 
 
+def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32):
+    i = [[0, 1, 1], [2, 0, 2]]
+    v = [3.2, 4.1, 5.3]
+    tensor = torch.sparse_coo_tensor(i, v, (3, 3), requires_grad=requires_grad, dtype=dtype)
+    if coalesce:
+        tensor = tensor.coalesce()
+    return tensor
+
+
 @torch.jit.script
 def create_torchscript_tensor() -> torch.Tensor:
     return torch.ones((3, 3)).requires_grad_()
@@ -143,20 +159,28 @@ def _all_contexts_cleaned_up(timeout_seconds=10):
 
 # This function creates a dis atugorad context, run rpc_sync on the given ps,
 # and then blocks until the ps has verified the grads are correctly accumulated.
-def _run_trainer(rref_t1, t2, ps, rank_diff):
+def _run_trainer(rref_t1, t2, ps, rank_diff, sparse):
     with dist_autograd.context() as context_id:
         ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
-        dist_autograd.backward(context_id, [ret.sum()])
+        if sparse:
+            loss = torch.sparse.sum(ret)
+        else:
+            loss = ret.sum()
+        dist_autograd.backward(context_id, [loss])
         # prevent deleting dist autograd context
         rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
         rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
 
 # This function is the same as _run_trainer, except rpc calls torchscript
 # function "my_script_ref_add" instead of python funciton "my_rref_add"
-def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff):
+def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse):
     with dist_autograd.context() as context_id:
         ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
-        dist_autograd.backward(context_id, [ret.sum()])
+        if sparse:
+            loss = torch.sparse.sum(ret)
+        else:
+            loss = ret.sum()
+        dist_autograd.backward(context_id, [loss])
         # prevent deleting dist autograd context
         rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
         rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
@@ -379,14 +403,18 @@ class DistAutogradTest(CommonDistAutogradTest):
             "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
         )
 
-    def _test_graph(self, fn, exec_mode):
+    def _test_graph(self, fn, exec_mode, sparse):
         dst_rank = (self.rank + 1) % self.world_size
 
         initialize_pg(self.file_init_method, self.rank, self.world_size)
 
         with dist_autograd.context() as context_id:
-            t1 = torch.ones(3, 3, requires_grad=True)
-            t2 = torch.zeros(3, 3, requires_grad=True)
+            if sparse:
+                t1 = build_sparse_tensor()
+                t2 = build_sparse_tensor()
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
             if ExecMode.RPC_SYNC == exec_mode:
                 ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2))
             elif ExecMode.REMOTE == exec_mode:
@@ -436,29 +464,49 @@ class DistAutogradTest(CommonDistAutogradTest):
 
     @dist_init
     def test_graph_for_builtin_call(self):
-        self._test_graph(torch.add, ExecMode.RPC_SYNC)
+        self._test_graph(torch.add, ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_builtin_call_sparse(self):
+        self._test_graph(torch.add, ExecMode.RPC_SYNC, True)
 
     @dist_init
     def test_graph_for_python_call(self):
-        self._test_graph(my_py_add, ExecMode.RPC_SYNC)
+        self._test_graph(my_py_add, ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_python_call_sparse(self):
+        self._test_graph(my_py_add, ExecMode.RPC_SYNC, True)
 
     @dist_init
     def test_graph_for_builtin_remote_call(self):
-        self._test_graph(torch.add, ExecMode.REMOTE)
+        self._test_graph(torch.add, ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_builtin_remote_call_sparse(self):
+        self._test_graph(torch.add, ExecMode.REMOTE, True)
 
     @dist_init
     def test_graph_for_python_remote_call(self):
-        self._test_graph(my_py_add, ExecMode.REMOTE)
+        self._test_graph(my_py_add, ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_python_remote_call_sparse(self):
+        self._test_graph(my_py_add, ExecMode.REMOTE, True)
 
     # 3-layer nested calls
-    def _test_graph_for_py_nested_call(self, exec_mode):
+    def _test_graph_for_py_nested_call(self, exec_mode, sparse):
         dst_rank = (self.rank + 1) % self.world_size
 
         initialize_pg(self.file_init_method, self.rank, self.world_size)
 
         with dist_autograd.context() as context_id:
-            t1 = torch.ones(3, 3, requires_grad=True)
-            t2 = torch.zeros(3, 3, requires_grad=True)
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=True)
+                t2 = build_sparse_tensor(requires_grad=True)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
             nest_dst_rank = (dst_rank + 1) % self.world_size
             if ExecMode.RPC_SYNC == exec_mode:
                 ret = rpc.rpc_sync(
@@ -531,21 +579,33 @@ class DistAutogradTest(CommonDistAutogradTest):
 
     @dist_init
     def test_graph_for_py_nested_call(self):
-        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC)
+        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_py_nested_call_sparse(self):
+        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True)
 
     @dist_init
     def test_graph_for_py_nested_remote_call(self):
-        self._test_graph_for_py_nested_call(ExecMode.REMOTE)
+        self._test_graph_for_py_nested_call(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_sparse(self):
+        self._test_graph_for_py_nested_call(ExecMode.REMOTE, True)
 
     # Rank0->Rank1->Rank0
-    def _test_graph_for_py_nested_call_itself(self, exec_mode):
+    def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse):
         dst_rank = (self.rank + 1) % self.world_size
 
         initialize_pg(self.file_init_method, self.rank, self.world_size)
 
         with dist_autograd.context() as context_id:
-            t1 = torch.ones(3, 3, requires_grad=True)
-            t2 = torch.zeros(3, 3, requires_grad=True)
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=True)
+                t2 = build_sparse_tensor(requires_grad=True)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
             if ExecMode.RPC_SYNC == exec_mode:
                 ret = rpc.rpc_sync(
                     worker_name(dst_rank),
@@ -610,18 +670,30 @@ class DistAutogradTest(CommonDistAutogradTest):
 
     @dist_init
     def test_graph_for_py_nested_call_itself(self):
-        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC)
+        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_py_nested_call_itself_sparse(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True)
 
     @dist_init
     def test_graph_for_py_nested_remote_call_itself(self):
-        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE)
+        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_itself_sparse(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True)
 
-    def _test_no_graph_with_tensors_not_require_grad(self, exec_mode):
+    def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse):
         initialize_pg(self.file_init_method, self.rank, self.world_size)
         dst_rank = (self.rank + 1) % self.world_size
         with dist_autograd.context() as context_id:
-            t1 = torch.ones(3, 3, requires_grad=False)
-            t2 = torch.zeros(3, 3, requires_grad=False)
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=False)
+                t2 = build_sparse_tensor(requires_grad=False)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=False)
+                t2 = torch.zeros(3, 3, requires_grad=False)
             if ExecMode.RPC_SYNC == exec_mode:
                 ret = rpc.rpc_sync(
                     worker_name(dst_rank), torch.add, args=(t1, t2)
@@ -656,11 +728,19 @@ class DistAutogradTest(CommonDistAutogradTest):
 
     @dist_init
     def test_no_graph_with_tensors_not_require_grad(self):
-        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC)
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_sparse(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True)
 
     @dist_init
     def test_no_graph_with_tensors_not_require_grad_remote(self):
-        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE)
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_remote_sparse(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True)
 
     def _test_grad_only_on_return_value(self, exec_mode):
         initialize_pg(self.file_init_method, self.rank, self.world_size)
@@ -699,13 +779,16 @@ class DistAutogradTest(CommonDistAutogradTest):
     def test_grad_only_on_return_value_remote(self):
         self._test_grad_only_on_return_value(ExecMode.REMOTE)
 
-    def _test_rpc_complex_args(self, exec_mode):
+    def _test_rpc_complex_args(self, exec_mode, sparse):
         with dist_autograd.context() as context_id:
             num_tensors = 10
             tensors = []
             for i in range(num_tensors):
-                tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
-
+                if sparse:
+                    tensor = build_sparse_tensor(requires_grad=(i % 2 == 0))
+                else:
+                    tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0))
+                tensors.append(tensor)
             dst_rank = self._next_rank()
             if ExecMode.RPC_SYNC == exec_mode:
                 ret = rpc.rpc_sync(
@@ -739,11 +822,19 @@ class DistAutogradTest(CommonDistAutogradTest):
 
     @dist_init
     def test_rpc_complex_args(self):
-        self._test_rpc_complex_args(ExecMode.RPC_SYNC)
+        self._test_rpc_complex_args(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_rpc_complex_args_sparse(self):
+        self._test_rpc_complex_args(ExecMode.RPC_SYNC, True)
 
     @dist_init
     def test_remote_complex_args(self):
-        self._test_rpc_complex_args(ExecMode.REMOTE)
+        self._test_rpc_complex_args(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_remote_complex_args_sparse(self):
+        self._test_rpc_complex_args(ExecMode.REMOTE, True)
 
     def context_cleanup_test_helper(self, rpc_args, func, nested=False):
         initialize_pg(self.file_init_method, self.rank, self.world_size)
@@ -789,11 +880,22 @@ class DistAutogradTest(CommonDistAutogradTest):
         self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
 
     @dist_init
+    def test_context_cleanup_tensor_with_grad_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=True)
+        t2 = build_sparse_tensor(requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
     def test_context_cleanup_tensor_no_grad(self):
         t1 = torch.ones(3, 3, requires_grad=False)
         self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
 
     @dist_init
+    def test_context_cleanup_tensor_no_grad_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=False)
+        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
+
+    @dist_init
     def test_context_cleanup_no_tensors(self):
         self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add)
 
@@ -808,6 +910,16 @@ class DistAutogradTest(CommonDistAutogradTest):
         )
 
     @dist_init
+    def test_context_cleanup_nested_rpc_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=True)
+        t2 = build_sparse_tensor(requires_grad=True)
+        dst_rank = (self.rank + 1) % self.world_size
+        args = (t1, t2, dst_rank, self.world_size, 0)
+        self.context_cleanup_test_helper(
+            rpc_args=args, func=my_py_nested_call, nested=True
+        )
+
+    @dist_init
     def test_worker_ids_recorded(self):
         dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
         with dist_autograd.context() as context_id:
@@ -876,23 +988,27 @@ class DistAutogradTest(CommonDistAutogradTest):
                     worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
                 )
 
-    @dist_init
-    def test_backward_no_grad_on_tensor(self):
-        t1 = torch.rand((3, 3), requires_grad=True)
-        t2 = torch.rand((3, 3), requires_grad=True)
+    def _backward_no_grad_on_tensor(self, t1, t2, sparse):
         with dist_autograd.context() as context_id:
             loss = rpc.rpc_sync(
                 worker_name(self._next_rank()),
                 torch.add,
-                args=(t1, t2)).sum()
-
+                args=(t1, t2))
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
             dist_autograd.backward(context_id, [loss], retain_graph=True)
             self.assertIsNone(t1.grad)
             self.assertIsNone(t2.grad)
 
             # Now populate .grad with local autograd engine and
             # verify dist autograd doesn't mess with it.
-            loss_local = torch.add(t1, t2).sum()
+            loss_local = torch.add(t1, t2)
+            if sparse:
+                loss_local = torch.sparse.sum(loss_local)
+            else:
+                loss_local = loss_local.sum()
             loss_local.backward()
             self.assertIsNotNone(t1.grad)
             self.assertIsNotNone(t2.grad)
@@ -903,18 +1019,34 @@ class DistAutogradTest(CommonDistAutogradTest):
             self.assertEqual(t1_grad_before, t1.grad)
             self.assertEqual(t2_grad_before, t2.grad)
 
-    def _test_backward_simple(self, dst):
-        # Run the same code locally and with dist autograd and verify gradients
-        # are same.
-        local_grads = None
-        t1 = torch.rand((3, 3), requires_grad=True)
-        t2 = torch.rand((3, 3), requires_grad=True)
+    @dist_init
+    def test_backward_no_grad_on_tensor(self):
+        self._backward_no_grad_on_tensor(
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            False
+        )
+
+    @dist_init
+    def test_backward_no_grad_on_tensor_sparse(self):
+        self._backward_no_grad_on_tensor(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True
+        )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple(self, dst, t1, t2, local_grads, sparse):
         for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
             with dist_autograd.context() as context_id:
                 ret = self._exec_func_with_dst(
                     dst, exec_mode, torch.add, t1, t2
                 )
-                loss = ret.sum()
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
                 ret = self._verify_backwards(
                     exec_mode, [loss], context_id, local_grads, t1, t2
                 )
@@ -922,29 +1054,65 @@ class DistAutogradTest(CommonDistAutogradTest):
 
     @dist_init
     def test_backward_simple(self):
-        self._test_backward_simple(self._next_rank())
+        self._backward_simple(
+            self._next_rank(),
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False
+        )
+
+    @dist_init
+    def test_backward_simple_sparse(self):
+        self._backward_simple(
+            self._next_rank(),
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True
+        )
 
     @dist_init
     def test_backward_simple_self(self):
-        self._test_backward_simple(self.rank)
+        self._backward_simple(
+            self.rank,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False
+        )
+
+    @dist_init
+    def test_backward_simple_self_sparse(self):
+        self._backward_simple(
+            self.rank,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True
+        )
 
     # The current rank first creates a tensor on the rref_owner, and then passes
     # the rref with another tensor to the callee to run either my_rref_add or
     # my_nested_rref_add, depending on whether the callee is the rref owner.
     # The grad of tensor lives on the current rank, and the grad of the rref
     # tensor lives on the rref owner.
-    def _test_backward_rref(self, callee, rref_owner):
-        local_grads = None
-        t1 = torch.ones((3, 3), requires_grad=True)
-        t2 = torch.zeros((3, 3), requires_grad=True)
-
+    def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse):
         local_ret = torch.add(t1, t2)
-        local_ret.sum().backward()
+        if sparse:
+            local_ret = torch.sparse.sum(local_ret)
+        else:
+            local_ret = local_ret.sum()
+        local_ret.backward()
         with dist_autograd.context() as context_id:
-            rref_t1 = rpc.remote(
-                rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True}
-            )
-
+            if sparse:
+                rref_t1 = rpc.remote(
+                    rref_owner, build_sparse_tensor, args=(False, True,)
+                )
+            else:
+                rref_t1 = rpc.remote(
+                    rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True}
+                )
             if callee == rref_owner:
                 rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
             else:
@@ -952,7 +1120,11 @@ class DistAutogradTest(CommonDistAutogradTest):
                     callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
                 )
             ret = rref.to_here()
-            dist_autograd.backward(context_id, [ret.sum()])
+            if sparse:
+                ret = torch.sparse.sum(ret)
+            else:
+                ret = ret.sum()
+            dist_autograd.backward(context_id, [ret])
 
             # verify grads on caller
             grads = dist_autograd.get_gradients(context_id)
@@ -972,20 +1144,81 @@ class DistAutogradTest(CommonDistAutogradTest):
     def test_backward_rref(self):
         callee = worker_name(self._next_rank())
         rref_owner = callee
-        self._test_backward_rref(callee, rref_owner)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False
+        )
+
+    @dist_init
+    def test_backward_rref_sparse(self):
+        callee = worker_name(self._next_rank())
+        rref_owner = callee
+        self._backward_rref(
+            callee,
+            rref_owner,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True
+        )
 
     @dist_init
     def test_backward_rref_multi(self):
         if self.rank > 0:
             callee = "worker0"
             rref_owner = callee
-            self._test_backward_rref(callee, rref_owner)
+            self._backward_rref(
+                callee,
+                rref_owner,
+                torch.rand((3, 3), requires_grad=True),
+                torch.rand((3, 3), requires_grad=True),
+                None,
+                False
+            )
+
+    @dist_init
+    def test_backward_rref_multi_sparse(self):
+        if self.rank > 0:
+            callee = "worker0"
+            rref_owner = callee
+            self._backward_rref(
+                callee,
+                rref_owner,
+                build_sparse_tensor(requires_grad=True),
+                build_sparse_tensor(requires_grad=True),
+                None,
+                True
+            )
 
     @dist_init
     def test_backward_rref_nested(self):
         callee = worker_name((self.rank + 1) % self.world_size)
         rref_owner = worker_name((self.rank + 2) % self.world_size)
-        self._test_backward_rref(callee, rref_owner)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False
+        )
+
+    @dist_init
+    def test_backward_rref_nested_sparse(self):
+        callee = worker_name((self.rank + 1) % self.world_size)
+        rref_owner = worker_name((self.rank + 2) % self.world_size)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True
+        )
 
     # In this test, every rank will serve as a parameter server (ps) and a
     # driver, and then kicks off trainers on the other three ranks. So, we have:
@@ -996,13 +1229,19 @@ class DistAutogradTest(CommonDistAutogradTest):
     #
     # These four test ps-trainer groups run on completely separate autograd
     # graphs, but they share the same set of underlying RpcAgents.
-    def _test_trainer_ps(self, create_ref_fn, trainer_fn):
-        local_grads = None
-        t1 = torch.ones((3, 3), requires_grad=True)
-        t2 = torch.zeros((3, 3), requires_grad=True)
+    def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse):
+        if sparse:
+            t1 = build_sparse_tensor(requires_grad=True)
+            t2 = build_sparse_tensor(requires_grad=True)
+        else:
+            t1 = torch.ones((3, 3), requires_grad=True)
+            t2 = torch.zeros((3, 3), requires_grad=True)
 
         local_ret = torch.add(t1, t2)
-        local_ret.sum().backward()
+        if sparse:
+            torch.sparse.sum(local_ret).backward()
+        else:
+            local_ret.sum().backward()
 
         # create rref on self
         rref_t1 = rpc.remote(
@@ -1018,7 +1257,7 @@ class DistAutogradTest(CommonDistAutogradTest):
                 rpc.rpc_async(
                     worker_name((self.rank + rank_diff) % self.world_size),
                     trainer_fn,
-                    args=(rref_t1, t2, worker_name(self.rank), rank_diff),
+                    args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse),
                 )
             )
 
@@ -1045,7 +1284,19 @@ class DistAutogradTest(CommonDistAutogradTest):
 
     @dist_init
     def test_trainer_ps(self):
-        self._test_trainer_ps(create_tensor, _run_trainer)
+        self._test_trainer_ps(
+            create_tensor,
+            _run_trainer,
+            False
+        )
+
+    @dist_init
+    def test_trainer_ps_sparse(self):
+        self._test_trainer_ps(
+            build_sparse_tensor,
+            _run_trainer,
+            True
+        )
 
     @dist_init
     def test_trainer_ps_torchscript_functions(self):
@@ -1056,17 +1307,9 @@ class DistAutogradTest(CommonDistAutogradTest):
         import torch.distributed.rpc.api as api
         api._ignore_rref_leak = True
 
-        self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript)
-
-    @dist_init
-    def test_backward_multiple_round_trips(self):
-        local_grads = None
-        t1 = torch.rand((3, 3), requires_grad=True)
-        t2 = torch.rand((3, 3))
-        t3 = torch.rand((3, 3), requires_grad=True)
-        t4 = torch.rand((3, 3))
-        t5 = torch.rand((3, 3), requires_grad=True)
+        self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False)
 
+    def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse):
         for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
             with dist_autograd.context() as context_id:
                 # Multiple RPCs between different nodes.
@@ -1074,9 +1317,14 @@ class DistAutogradTest(CommonDistAutogradTest):
                 val = self._exec_func(exec_mode, torch.mul, t3, val)
                 s1 = self._exec_func(exec_mode, torch.stack, (t4, val))
                 s2 = self._exec_func(exec_mode, torch.stack, (t5, val))
-                val = self._exec_func(exec_mode, torch.bmm, s1, s2)
-                val = self._exec_func(exec_mode, torch.matmul, val, val)
-                loss = val.sum()
+                if sparse:
+                    val = self._exec_func(exec_mode, torch.mul, s1, s2)
+                    val = self._exec_func(exec_mode, torch.mul, val, val)
+                    loss = torch.sparse.sum(val)
+                else:
+                    val = self._exec_func(exec_mode, torch.bmm, s1, s2)
+                    val = self._exec_func(exec_mode, torch.matmul, val, val)
+                    loss = val.sum()
 
                 ret = self._verify_backwards(
                     exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5
@@ -1084,6 +1332,30 @@ class DistAutogradTest(CommonDistAutogradTest):
                 local_grads = ret if ret else local_grads
 
     @dist_init
+    def test_backward_multiple_round_trips(self):
+        self._backward_multiple_round_trips(
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3)),
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3)),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False
+        )
+
+    @dist_init
+    def test_backward_multiple_round_trips_sparse(self):
+        self._backward_multiple_round_trips(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True
+        )
+
+    @dist_init
     def test_backward_different_tensor_dims(self):
         local_grads = None
         t1 = torch.rand((4, 6), requires_grad=True)
@@ -1317,41 +1589,70 @@ class DistAutogradTest(CommonDistAutogradTest):
                     exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2
                 )
 
-    @dist_init
-    def test_backward_different_dtypes(self):
+    def _backward_different_dtypes(self, t1, t2, sparse):
         local_grads = None
-        t1 = torch.rand((3, 3), requires_grad=True, dtype=torch.float32)
-        t2 = torch.rand((3, 3), requires_grad=True, dtype=torch.float64)
         for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
             with dist_autograd.context() as context_id:
-                loss = self._exec_func(exec_mode, torch.add, t1, t2).sum()
-
+                loss = self._exec_func(exec_mode, torch.add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(loss)
+                else:
+                    loss = loss.sum()
                 local_grads = self._verify_backwards(
                     exec_mode, [loss], context_id, local_grads, t1, t2
                 )
 
     @dist_init
-    def test_backward_simple_python_udf(self):
-        # Run the same code locally and with dist autograd and verify gradients
-        # are same.
+    def test_backward_different_dtypes(self):
+        self._backward_different_dtypes(
+            torch.rand((3, 3), requires_grad=True, dtype=torch.float32),
+            torch.rand((3, 3), requires_grad=True, dtype=torch.float64),
+            False
+        )
+
+    @dist_init
+    def test_backward_different_dtypes_sparse(self):
+        self._backward_different_dtypes(
+            build_sparse_tensor(requires_grad=True, dtype=torch.float32),
+            build_sparse_tensor(requires_grad=True, dtype=torch.float64),
+            True
+        )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple_python_udf(self, t1, t2, sparse):
         local_grads = None
-        t1 = torch.rand((3, 3), requires_grad=True)
-        t2 = torch.rand((3, 3), requires_grad=True)
         for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
             with dist_autograd.context() as context_id:
                 ret = self._exec_func(exec_mode, my_py_add, t1, t2)
-                loss = ret.sum()
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
                 local_grads = self._verify_backwards(
                     exec_mode, [loss], context_id, local_grads, t1, t2
                 )
 
     @dist_init
-    def test_backward_simple_script_call(self):
-        # Run the same code locally and with dist autograd and verify gradients
-        # are same.
+    def test_backward_simple_python_udf(self):
+        self._backward_simple_python_udf(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False
+        )
+
+    @dist_init
+    def test_backward_simple_python_udf_sparse(self):
+        self._backward_simple_python_udf(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True
+        )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple_script_call(self, t1, t2, sparse):
         local_grads = None
-        t1 = torch.rand((3, 3), requires_grad=True)
-        t2 = torch.rand((3, 3), requires_grad=True)
         for exec_mode in [
             ExecMode.LOCAL,
             ExecMode.RPC_SYNC,
@@ -1360,12 +1661,31 @@ class DistAutogradTest(CommonDistAutogradTest):
         ]:
             with dist_autograd.context() as context_id:
                 forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2)
-                loss = forward_ret.sum()
+                if sparse:
+                    loss = torch.sparse.sum(forward_ret)
+                else:
+                    loss = forward_ret.sum()
                 ret = self._verify_backwards(
                     exec_mode, [loss], context_id, local_grads, t1, t2
                 )
                 local_grads = ret if ret else local_grads
 
+    @dist_init
+    def test_backward_simple_script_call(self):
+        self._backward_simple_script_call(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False
+        )
+
+    @dist_init
+    def test_backward_simple_script_call_sparse(self):
+        self._backward_simple_script_call(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True
+        )
+
     @staticmethod
     def _complex_python_udf(t1, t2):
         t3 = torch.nn.functional.linear(t1, t2)
@@ -1474,17 +1794,17 @@ class DistAutogradTest(CommonDistAutogradTest):
         t3 = t1 * t2
         t4 = t1 + t2
         res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4))
-        return torch.linalg.multi_dot([t1, t2, t3, t4, res])
+        return t1 * t2 * t3 * t4 * res
 
-    @dist_init
-    def test_backwards_nested_python_udf(self):
-        # Run equivalent of _nested_python_udf locally.
-        t1 = torch.rand((3, 3), requires_grad=True)
-        t2 = torch.rand((3, 3), requires_grad=True)
+    def _backwards_nested_python_udf(self, t1, t2, sparse):
         t3 = t1 * t2
         t4 = t1 + t2
         res = t3 + t4
-        loss = torch.linalg.multi_dot([t1, t2, t3, t4, res]).sum()
+        loss = t1 * t2 * t3 * t4 * res
+        if sparse:
+            loss = torch.sparse.sum(loss)
+        else:
+            loss = loss.sum()
         torch.autograd.backward([loss])
 
         # Now run distributed autograd.
@@ -1494,12 +1814,33 @@ class DistAutogradTest(CommonDistAutogradTest):
                 DistAutogradTest._nested_python_udf,
                 args=(t1, t2, self._next_rank()),
             )
-            dist_autograd.backward(context_id, [loss.sum()])
-
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            dist_autograd.backward(context_id, [loss])
             grads = dist_autograd.get_gradients(context_id)
             self.assertEqual(t1.grad, grads[t1])
             self.assertEqual(t2.grad, grads[t2])
 
+    @dist_init
+    def test_backwards_nested_python_udf(self):
+        # Run equivalent of _nested_python_udf locally.
+        self._backwards_nested_python_udf(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False
+        )
+
+    @dist_init
+    def test_backwards_nested_python_udf_sparse(self):
+        # Run equivalent of _nested_python_udf locally.
+        self._backwards_nested_python_udf(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True
+        )
+
     _test_clean_context_backward_context_id = None
 
     class MyBackwardFunc(Function):
@@ -1594,8 +1935,7 @@ class DistAutogradTest(CommonDistAutogradTest):
     def _get_grad(cls, embedding_rref, context_id):
         embedding = embedding_rref.local_value()
         grad_map = dist_autograd.get_gradients(context_id)
-        # Can't send sparse tensors over RPC: https://github.com/pytorch/pytorch/issues/30807
-        return grad_map[embedding.weight].to_dense()
+        return grad_map[embedding.weight]
 
     @dist_init
     def test_embedding_bag_with_no_grad_tensors(self):
@@ -1637,26 +1977,27 @@ class DistAutogradTest(CommonDistAutogradTest):
                 args=(remote_embedding, context_id),
             )
 
-            self.assertEqual(local_grad.to_dense(), remote_grad)
+            self.assertEqual(local_grad, remote_grad)
 
     @classmethod
-    def _mixed_requires_grad(cls, t1, t2):
+    def _mixed_requires_grad_operaton(cls, t1, t2):
         if t2.requires_grad:
             return t1 - t2
         else:
             return t1 * t2
 
-    @dist_init
-    def test_mixed_requires_grad(self):
+    def _mixed_requires_grad(self, t1, t2, sparse):
         for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]:
-            t1 = torch.rand((3, 3), requires_grad=True)
-            t2 = torch.rand((3, 3), requires_grad=False)
             with dist_autograd.context() as context_id:
                 ret = self._exec_func(
-                    exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2
+                    exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2
                 )
                 self.assertEqual(t1 * t2, ret)
-                dist_autograd.backward(context_id, [ret.sum()])
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                dist_autograd.backward(context_id, [loss])
                 self.assertTrue(t1.requires_grad)
                 self.assertFalse(t2.requires_grad)
                 grads = dist_autograd.get_gradients(context_id)
@@ -1664,6 +2005,22 @@ class DistAutogradTest(CommonDistAutogradTest):
                 self.assertNotIn(t2, grads)
                 self.assertEqual(t2, grads[t1])
 
+    @dist_init
+    def test_mixed_requires_grad(self):
+        self._mixed_requires_grad(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=False),
+            False
+        )
+
+    @dist_init
+    def test_mixed_requires_grad_sparse(self):
+        self._mixed_requires_grad(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            True
+        )
+
     class TestDebugInfoFunc(Function):
         @staticmethod
         def forward(ctx, input):
@@ -1801,37 +2158,69 @@ class DistAutogradTest(CommonDistAutogradTest):
 
     @staticmethod
     def _test_nested_backward_accumulate_grads(t1, t2, dst_rank):
-        return rpc.rpc_sync(worker_name(dst_rank), torch.matmul, args=(t1, t2))
+        return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
 
-    @dist_init
-    def test_nested_backward_accumulate_grads(self):
-        t1 = torch.rand((3, 3), requires_grad=True)
-        t2 = torch.rand((3, 3), requires_grad=True)
+    def _nested_backward_accumulate_grads(self, t1, t2, sparse):
         with dist_autograd.context() as context_id:
-            loss = rpc.rpc_sync(
+            ret = rpc.rpc_sync(
                 worker_name(self._next_rank()),
                 DistAutogradTest._test_nested_backward_accumulate_grads,
                 args=(t1, t2, self._next_rank()),
-            ).sum()
-
+            )
+            if sparse:
+                loss = torch.sparse.sum(ret)
+            else:
+                loss = ret.sum()
             # Run backward twice.
             dist_autograd.backward(context_id, [loss], retain_graph=True)
             dist_autograd.backward(context_id, [loss])
 
     @dist_init
-    def test_multiple_backward(self):
-        t1 = torch.rand((3, 3), requires_grad=True)
-        t2 = torch.rand((3, 3), requires_grad=True)
+    def test_nested_backward_accumulate_grads(self):
+        self._nested_backward_accumulate_grads(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False
+        )
+
+    @dist_init
+    def test_nested_backward_accumulate_grads_sparse(self):
+        self._nested_backward_accumulate_grads(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True
+        )
+
+    def _multiple_backward(self, t1, t2, sparse):
         with dist_autograd.context() as context_id:
             loss = rpc.rpc_sync(
                 worker_name(self._next_rank()),
                 torch.add,
-                args=(t1, t2)).sum()
-
+                args=(t1, t2))
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
             # Run backward in a loop multiple times.
             for i in range(1000):
                 dist_autograd.backward(context_id, [loss], retain_graph=True)
 
+    @dist_init
+    def test_multiple_backward(self):
+        self._multiple_backward(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False
+        )
+
+    @dist_init
+    def test_multiple_backward_sparse(self):
+        self._multiple_backward(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True
+        )
+
     @dist_init(clean_shutdown=False)
     def test_multiple_backward_with_errors(self):
         initialize_pg(self.file_init_method, self.rank, self.world_size)
index 1a44ef6..e0ef915 100644 (file)
@@ -209,10 +209,13 @@ def add_rref_to_value(rref, value):
 def run_nested_pickle(pickle_cls_instance, tensor):
     return pickle_cls_instance.t + tensor
 
-def build_sparse_tensor():
+def build_sparse_tensor(coalesce=False):
     i = [[0, 1, 1], [2, 0, 2]]
     v = [3, 4, 5]
-    return torch.sparse_coo_tensor(i, v, (2, 3))
+    tensor = torch.sparse_coo_tensor(i, v, (2, 3))
+    if coalesce:
+        tensor = tensor.coalesce()
+    return tensor
 
 def build_complex_tensors():
     a = torch.ones(3, 3)
@@ -238,6 +241,12 @@ def my_function(a, b, c):
 def my_tensor_function(a, b):
     return a + b
 
+def my_container_sum(a):
+    result = a[0]
+    for tensor in a[1:]:
+        result += tensor
+    return result
+
 
 def my_sleep_func(seconds=1):
     time.sleep(seconds)
@@ -275,6 +284,14 @@ def nested_rpc(dst):
     return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
 
 
+def nested_rpc_sparse(dst):
+    return rpc.rpc_sync(
+        dst,
+        torch.add,
+        args=(build_sparse_tensor(), build_sparse_tensor())
+    )
+
+
 def multi_layer_nested_async_rpc(dst, world_size, ttl):
     # this method returns immediately without blocking the callee, but will
     # generate additional requests.
@@ -296,10 +313,29 @@ def nested_rref(dst):
     )
 
 
+def nested_rref_sparse(dst):
+    return (
+        rpc.remote(
+            dst,
+            torch.add,
+            args=(build_sparse_tensor(), build_sparse_tensor())
+        ),
+        rpc.remote(
+            dst,
+            torch.add,
+            args=(build_sparse_tensor(), build_sparse_tensor())
+        ),
+    )
+
+
 def nested_remote(dst):
     rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
     return rref.to_here()
 
+def nested_remote_sparse(dst):
+    rref = rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor()))
+    return rref.to_here()
+
 
 def rref_forward_chain(dst, world_size, rref, ttl):
     if ttl > 0:
@@ -328,6 +364,12 @@ def heavy_rpc(tensor):
     return 0
 
 
+def heavy_rpc_sparse(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor = tensor / (i + 1)
+    return 0
+
 @torch.jit.script
 def heavy_rpc_torchscript(tensor):
     for i in range(1, 100):
@@ -600,6 +642,57 @@ class FooBackendOptions(rpc.RpcBackendOptions):
 load_tests = load_tests
 
 
+class MyEmbeddingBagModel(torch.nn.Module):
+    def __init__(self, sparse):
+        super().__init__()
+        self.eb = torch.nn.EmbeddingBag(
+            10,
+            10,
+            sparse=sparse
+        )
+
+    def forward(self, x):
+        return self.eb(x)
+
+
+class MyParameterServer:
+    def __init__(self, trainers):
+        self.lock = Lock()
+        self.trainers = trainers
+        self.iteration = 0
+        self.updates = 0
+        self.futures = []
+        self.total = None
+        self.gradient = None
+
+    @staticmethod
+    def get_gradient(rref):
+        return rref.local_value().gradient
+
+    @staticmethod
+    @rpc.functions.async_execution
+    def average(rref, riteration, tensor):
+        self = rref.local_value()
+        fut = torch.futures.Future()
+        with self.lock:
+            if riteration > self.iteration:
+                self.iteration = riteration
+                self.updates = 0
+                self.futures.clear()
+            self.futures.append(fut)
+            if self.total is None:
+                self.total = tensor
+            else:
+                self.total += tensor
+            self.updates += 1
+            if self.trainers == self.updates:
+                self.gradient = self.total / float(self.trainers)
+                for fut in self.futures:
+                    result = self.total / float(self.trainers)
+                    fut.set_result(result)
+        return fut
+
+
 class RpcTest(RpcAgentTestFixture):
     @dist_init
     def test_worker_id(self):
@@ -641,10 +734,26 @@ class RpcTest(RpcAgentTestFixture):
     def test_send_to_rank(self):
         dst_rank = (self.rank + 1) % self.world_size
 
+        # Test dense tensor
         for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
             ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
             self.assertEqual(ret, torch.ones(2, 2) + 1)
 
+        # Test sparse tensor
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            x = build_sparse_tensor()
+            y = build_sparse_tensor()
+            expected_tensor = (x + y)
+            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+            self.assertEqual(expected_tensor, ret)
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            x = build_sparse_tensor(coalesce=True)
+            y = build_sparse_tensor(coalesce=True)
+            expected_tensor = (x + y)
+            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+            self.assertEqual(expected_tensor, ret)
+
         # Test invalid ranks
         for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
             with self.assertRaises(RuntimeError):
@@ -662,41 +771,120 @@ class RpcTest(RpcAgentTestFixture):
             with self.assertRaises(ValueError):
                 self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
 
+    def _self_py_udf_remote(self, worker_info, x, y, z):
+        rref = rpc.remote(worker_info, my_function, args=(x, y, z))
+        self.assertEqual(rref.to_here(), x + y + z)
+
     @dist_init
     def test_self_py_udf_remote(self):
-        self_worker_info = rpc.get_worker_info()
-        rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
-        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1 + 3)
+        self._self_py_udf_remote(
+            rpc.get_worker_info(),
+            torch.ones(2, 2),
+            1,
+            3
+        )
+
+    @dist_init
+    def test_self_py_udf_remote_sparse(self):
+        self._self_py_udf_remote(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor()
+        )
+
 
-    def _test_self_remote_rref_as_rpc_arg(self, dst):
+    def _self_remote_rref_as_rpc_arg(self, dst, x, y, z):
         self_worker_info = rpc.get_worker_info()
-        rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
-        fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
-        ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1))
-        self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1)
-        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
+        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+        fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x))
+        ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y))
+        self.assertEqual(ret, x + y + z + x + y)
+        self.assertEqual(fut.wait(), x + y + z + x)
 
     @dist_init
     def test_self_remote_rref_as_rpc_arg(self):
         dst = worker_name((self.rank + 1) % self.world_size)
-        self._test_self_remote_rref_as_rpc_arg(dst)
+        self._self_remote_rref_as_rpc_arg(
+            dst,
+            torch.ones(2, 2),
+            1,
+            3
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_rpc_arg_sparse(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_rpc_arg(
+            dst,
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor()
+        )
 
     @dist_init
     def test_self_remote_rref_as_self_rpc_arg(self):
-        self._test_self_remote_rref_as_rpc_arg(rpc.get_worker_info())
+        self._self_remote_rref_as_rpc_arg(
+            rpc.get_worker_info(),
+            torch.ones(2, 2),
+            1,
+            3
+        )
 
-    def _test_self_remote_rref_as_remote_arg(self, dst):
+    @dist_init
+    def test_self_remote_rref_as_self_rpc_arg_sparse(self):
+        self._self_remote_rref_as_rpc_arg(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor()
+        )
+
+    def _self_remote_rref_as_remote_arg(self, dst, x, y, z):
         self_worker_info = rpc.get_worker_info()
-        rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
-        ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
+        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+        ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x))
         self.assertEqual(
-            ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2)
+            ret_rref.to_here(), x + y + z + x
         )
 
     @dist_init
     def test_self_remote_rref_as_remote_arg(self):
         dst = worker_name((self.rank + 1) % self.world_size)
-        self._test_self_remote_rref_as_remote_arg(dst)
+        self._self_remote_rref_as_remote_arg(
+            dst,
+            torch.ones(2, 2),
+            1,
+            3
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_remote_arg_sparse(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_remote_arg(
+            dst,
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_self_remote_arg(self):
+        self._self_remote_rref_as_remote_arg(
+            rpc.get_worker_info(),
+            torch.ones(2, 2),
+            1,
+            3
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_self_remote_arg_sparse(self):
+        self._self_remote_rref_as_remote_arg(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor()
+        )
 
     @dist_init
     def test_rref_proxy_non_exist(self):
@@ -816,10 +1004,6 @@ class RpcTest(RpcAgentTestFixture):
     def test_rref_proxy_class_self(self):
         self._test_rref_proxy_class(rpc.get_worker_info())
 
-    @dist_init
-    def test_self_remote_rref_as_self_remote_arg(self):
-        self._test_self_remote_rref_as_remote_arg(rpc.get_worker_info())
-
     @mock.patch.object(torch.distributed.autograd, "_init")
     @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent")
     @dist_init(setup_rpc=False)
@@ -911,7 +1095,7 @@ class RpcTest(RpcAgentTestFixture):
             )
         rpc.shutdown()
 
-    def test_world_size_one(self):
+    def _world_size_one(self, a, b):
         if self.rank == 0:
             rpc.init_rpc(
                 name="me",
@@ -921,32 +1105,51 @@ class RpcTest(RpcAgentTestFixture):
                 rpc_backend_options=self.rpc_backend_options,
             )
 
-            expect = torch.ones(2, 2) * 2
-            result = rpc.rpc_sync(
-                "me",
-                my_tensor_function,
-                args=(torch.ones(2, 2), torch.ones(2, 2))
-            )
-            self.assertEqual(expect, result)
-
-            expect = torch.ones(3, 3) * 2
-            result = rpc.rpc_async(
-                "me",
-                my_tensor_function,
-                args=(torch.ones(3, 3), torch.ones(3, 3))
-            ).wait()
-            self.assertEqual(expect, result)
+            def _rpc_sync(x, y):
+                expect = x * 2
+                result = rpc.rpc_sync(
+                    "me",
+                    my_tensor_function,
+                    args=(x, y)
+                )
+                self.assertEqual(expect, result)
+
+            def _rpc_async(x, y):
+                expect = x * 2
+                result = rpc.rpc_async(
+                    "me",
+                    my_tensor_function,
+                    args=(x, y)
+                ).wait()
+                self.assertEqual(expect, result)
+
+            def _remote(x, y):
+                expect = x * 2
+                result = rpc.remote(
+                    "me",
+                    my_tensor_function,
+                    args=(x, y)
+                ).to_here()
+                self.assertEqual(expect, result)
 
-            expect = torch.ones(4, 4) * 2
-            result = rpc.remote(
-                "me",
-                my_tensor_function,
-                args=(torch.ones(4, 4), torch.ones(4, 4))
-            ).to_here()
-            self.assertEqual(expect, result)
+            _rpc_sync(a, b)
+            _rpc_async(a, b)
+            _remote(a, b)
 
             rpc.shutdown()
 
+    def test_world_size_one(self):
+        self._world_size_one(
+            torch.ones(2, 2),
+            torch.ones(2, 2)
+        )
+
+    def test_world_size_one_sparse(self):
+        self._world_size_one(
+            build_sparse_tensor(),
+            build_sparse_tensor()
+        )
+
     @dist_init(setup_rpc=False)
     def test_invalid_names(self):
         from torch.distributed.rpc import WorkerInfo
@@ -1027,17 +1230,30 @@ class RpcTest(RpcAgentTestFixture):
         ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,))
         self.assertEqual(ret, x.nonzero())
 
-    @dist_init
-    def test_multi_rpc(self):
+    def _multi_rpc(self, sparse):
         dst_rank = (self.rank + 1) % self.world_size
         for i in range(20):
             n = i + self.rank + 1
+            if sparse:
+                x = build_sparse_tensor() * n
+                y = build_sparse_tensor() * n
+            else:
+                x = torch.ones(2, 2)
+                y = torch.ones(2, 2)
             ret = rpc.rpc_sync(
                 worker_name(dst_rank),
                 torch.add,
-                args=(torch.ones(n, n), torch.ones(n, n)),
+                args=(x, y),
             )
-            self.assertEqual(ret, torch.ones(n, n) * 2)
+            self.assertEqual(ret, x * 2)
+
+    @dist_init
+    def test_multi_rpc(self):
+        self._multi_rpc(False)
+
+    @dist_init
+    def test_multi_rpc_sparse(self):
+        self._multi_rpc(True)
 
     @dist_init
     def test_future_wait_twice(self):
@@ -1053,7 +1269,7 @@ class RpcTest(RpcAgentTestFixture):
             with self.assertRaisesRegex(ValueError, "Expected error"):
                 fut.wait()
 
-    def _run_uneven_workload(self, num_repeat=30):
+    def _run_uneven_workload(self, f, x, num_repeat=30):
         # worker0 drives and waits for worker1 and worker2
         # throughout the test.
         if self.rank == 0:
@@ -1063,7 +1279,7 @@ class RpcTest(RpcAgentTestFixture):
             dst = "worker1"
             futs = []
             for _ in range(num_repeat):
-                fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
+                fut = rpc.rpc_async(dst, f, args=(x,))
                 futs.append(fut)
 
             for fut in torch.futures.collect_all(futs).wait():
@@ -1075,13 +1291,13 @@ class RpcTest(RpcAgentTestFixture):
             dst = "worker2"
             futs = []
             for _ in range(num_repeat):
-                fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
+                fut = rpc.rpc_async(dst, f, args=(x,))
                 futs.append(fut)
 
             for val in torch.futures.wait_all(futs):
                 self.assertEqual(val, 0)
 
-    def test_wait_all_workers(self):
+    def _wait_all_workers(self, f, x):
         initialize_pg(self.file_init_method, self.rank, self.world_size)
         rpc.init_rpc(
             name="worker%d" % self.rank,
@@ -1091,7 +1307,7 @@ class RpcTest(RpcAgentTestFixture):
             rpc_backend_options=self.rpc_backend_options,
         )
 
-        self._run_uneven_workload()
+        self._run_uneven_workload(f, x)
 
         # worker0 calls this at the end after waiting for RPC responses.
         # worker1/2 calls this immediately and has some works after it.
@@ -1103,7 +1319,13 @@ class RpcTest(RpcAgentTestFixture):
         dist.barrier()
         rpc.shutdown(graceful=False)
 
-    def test_wait_all_workers_twice(self):
+    def test_wait_all_workers_dense(self):
+        self._wait_all_workers(heavy_rpc, torch.ones(100, 100))
+
+    def test_wait_all_workers_sparse(self):
+        self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor())
+
+    def _wait_all_workers_twice(self, f, x):
         initialize_pg(self.file_init_method, self.rank, self.world_size)
         rpc.init_rpc(
             name="worker%d" % self.rank,
@@ -1113,7 +1335,7 @@ class RpcTest(RpcAgentTestFixture):
             rpc_backend_options=self.rpc_backend_options,
         )
 
-        self._run_uneven_workload()
+        self._run_uneven_workload(f, x)
 
         # worker0 calls this at the end after waiting for RPC responses.
         # worker1/2 calls this immediately and has some works after it.
@@ -1126,6 +1348,12 @@ class RpcTest(RpcAgentTestFixture):
         dist.barrier()
         rpc.shutdown(graceful=False)
 
+    def test_wait_all_workers_twice_dense(self):
+        self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100))
+
+    def test_wait_all_workers_twice_sparse(self):
+        self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor())
+
     @dist_init
     def test_all_gather(self):
         info = rpc.get_worker_info()
@@ -1211,7 +1439,7 @@ class RpcTest(RpcAgentTestFixture):
     @dist_init
     def test_graceful_shutdown_with_uneven_workload(self):
         """Test graceful termination."""
-        self._run_uneven_workload()
+        self._run_uneven_workload(heavy_rpc, torch.ones(100, 100))
 
     @dist_init(setup_rpc=False)
     def test_shutdown_followed_by_rpc(self):
@@ -2082,6 +2310,16 @@ class RpcTest(RpcAgentTestFixture):
         self.assertEqual(ret, my_complex_tensor_function(a, b, c))
 
     @dist_init
+    def test_py_sparse_tensors_in_container(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        a = [build_sparse_tensor(), build_sparse_tensor()]
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), my_container_sum, args=(a,)
+        )
+        self.assertEqual(ret, my_container_sum(a))
+
+    @dist_init
     def test_py_nested_pickle(self):
         n = self.rank + 1
         dst_rank = n % self.world_size
@@ -2137,16 +2375,23 @@ class RpcTest(RpcAgentTestFixture):
         else:
             self.assertTrue(False, "expected raise_func_escape to raise ValueError.")
 
-    @dist_init
-    def test_nested_rpc(self):
+    def _nested_rpc(self, f, expected):
         n = self.rank + 1
         dst_rank = n % self.world_size
         ret = rpc.rpc_sync(
             worker_name(dst_rank),
-            nested_rpc,
+            f,
             args=(worker_name(self.rank),),
         )
-        self.assertEqual(ret, torch.ones(2, 2) + 1)
+        self.assertEqual(ret, expected)
+
+    @dist_init
+    def test_nested_rpc(self):
+        self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_nested_rpc_sparse(self):
+        self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2)
 
     def _stress_test_rpc(self, f, repeat=1000, args=()):
         n = self.rank + 1
@@ -2175,30 +2420,64 @@ class RpcTest(RpcAgentTestFixture):
         self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
 
     @dist_init
+    def test_stress_heavy_rpc_sparse(self):
+        self._stress_test_rpc(heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),))
+
+    @dist_init
     def test_stress_heavy_rpc_torchscript(self):
         self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),))
 
-    @dist_init
-    def test_builtin_remote_ret(self):
+    def _builtin_remote_ret(self, x, y, expected):
         n = self.rank + 1
         dst_rank = n % self.world_size
         rref = rpc.remote(
             worker_name(dst_rank),
             torch.add,
-            args=(torch.ones(n, n), torch.ones(n, n)),
+            args=(x, y),
         )
-        self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
+        self.assertEqual(rref.to_here(), expected)
 
     @dist_init
-    def test_builtin_remote_self(self):
+    def test_builtin_remote_ret(self):
+        self._builtin_remote_ret(
+            torch.ones(2, 2),
+            torch.ones(2, 2),
+            torch.ones(2, 2) * 2
+        )
+
+    @dist_init
+    def test_builtin_remote_ret_sparse(self):
+        self._builtin_remote_ret(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 2
+        )
+
+    def _builtin_remote_self(self, x, y, expected):
         rref = rpc.remote(
             worker_name(self.rank),
             torch.add,
-            args=(torch.ones(2, 2), torch.ones(2, 2)),
+            args=(x, y),
+        )
+        self.assertEqual(rref.local_value(), expected)
+
+    @dist_init
+    def test_builtin_remote_self(self):
+        self._builtin_remote_self(
+            torch.ones(2, 2),
+            torch.ones(2, 2),
+            torch.ones(2, 2) * 2
+        )
+
+    @dist_init
+    def test_builtin_remote_self_sparse(self):
+        self._builtin_remote_self(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 2
         )
-        self.assertEqual(rref.local_value(), torch.ones(2, 2) * 2)
 
-    def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}):
+    def _test_multi_remote_call(self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}):
         m = 10
         n = self.rank + 1
         dst_rank = n % self.world_size
@@ -2210,21 +2489,35 @@ class RpcTest(RpcAgentTestFixture):
                 rpc.remote(
                     worker_name(dst_rank),
                     fn,
-                    args=args_fn(n),
-                    kwargs=kwargs_fn(n),
+                    args=args_fn(n, sparse),
+                    kwargs=kwargs_fn(n, sparse),
                 )
             )
-            expected.append(fn(*args_fn(n), **kwargs_fn(n)))
+            expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse)))
 
         for i in range(m):
             self.assertEqual(rrefs[i].to_here(), expected[i])
 
+    @staticmethod
+    def _multi_args_fn(n, sparse=False):
+        if sparse:
+            return (build_sparse_tensor(), build_sparse_tensor())
+        else:
+            return (torch.ones(n, n), torch.ones(n, n))
+
     @dist_init
     def test_multi_builtin_remote_ret(self):
-        def args_fn(n):
-            return (torch.ones(n, n), torch.ones(n, n))
+        self._test_multi_remote_call(
+            torch.add, False,
+            args_fn=RpcTest._multi_args_fn
+        )
 
-        self._test_multi_remote_call(torch.add, args_fn=args_fn)
+    @dist_init
+    def test_multi_builtin_remote_ret_sparse(self):
+        self._test_multi_remote_call(
+            torch.add, True,
+            args_fn=RpcTest._multi_args_fn
+        )
 
     @dist_init
     def test_py_udf_remote(self):
@@ -2237,82 +2530,177 @@ class RpcTest(RpcAgentTestFixture):
         )
         self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
 
-    @dist_init
-    def test_multi_py_udf_remote(self):
-        def kwargs_fn(n):
+    @staticmethod
+    def _multi_kwargs_fn(n, sparse=False):
+        if sparse:
+            return {
+                "a": build_sparse_tensor(),
+                "b": build_sparse_tensor(),
+                "c": build_sparse_tensor()
+            }
+        else:
             return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
 
-        self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn)
+    @dist_init
+    def test_multi_py_udf_remote(self):
+        self._test_multi_remote_call(
+            my_function,
+            False,
+            kwargs_fn=RpcTest._multi_kwargs_fn
+        )
 
     @dist_init
-    def test_py_rref_args(self):
+    def test_multi_py_udf_remote_sparse(self):
+        self._test_multi_remote_call(
+            my_function,
+            True,
+            kwargs_fn=RpcTest._multi_kwargs_fn
+        )
+
+    def _py_rref_args(self, a, b, x, y, expected):
         n = self.rank + 1
         dst_rank = n % self.world_size
         rref_a = rpc.remote(
-            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2)
+            worker_name(dst_rank), torch.add, args=(a, b)
         )
         rref_b = rpc.remote(
-            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1)
+            worker_name(dst_rank), torch.add, args=(x, y)
         )
         rref_c = rpc.remote(
             worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
         )
-        self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
+        self.assertEqual(rref_c.to_here(), expected)
 
     @dist_init
-    def test_py_rref_args_user_share(self):
+    def test_py_rref_args(self):
+        self._py_rref_args(
+            torch.ones(2, 2),
+            1,
+            torch.ones(2, 2),
+            2,
+            torch.ones(2, 2) * 2 + 3)
+
+    @dist_init
+    def test_py_rref_args_sparse(self):
+        self._py_rref_args(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 4
+        )
+
+    def _py_rref_args_user_share(self, a, b, c, x, y, z, expected):
         n = self.rank + 1
         owner_rank = n % self.world_size
         user_rank = (n + 1) % self.world_size
         rref_a = rpc.remote(
-            worker_name(owner_rank), my_function, args=(torch.ones(n, n), 2, 0)
+            worker_name(owner_rank), my_function, args=(a, b, c)
         )
         rref_b = rpc.remote(
-            worker_name(owner_rank), my_function, args=(torch.ones(n, n), 1, 0)
+            worker_name(owner_rank), my_function, args=(x, y, z)
         )
         rref_c = rpc.remote(
             worker_name(user_rank), my_rref_function, args=(rref_a, rref_b)
         )
-        self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
+        self.assertEqual(rref_c.to_here(), expected)
 
     @dist_init
-    def test_py_rpc_rref_args(self):
+    def test_py_rref_args_user_share(self):
+        self._py_rref_args_user_share(
+            torch.ones(2, 2),
+            1,
+            2,
+            torch.ones(2, 2),
+            3,
+            4,
+            torch.ones(2, 2) * 2 + 10
+        )
+
+    @dist_init
+    def test_py_rref_args_user_share_sparse(self):
+        self._py_rref_args_user_share(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 6
+        )
+
+    def _py_rpc_rref_args(self, a, b, c, x, y, z, expected):
         n = self.rank + 1
         dst_rank = n % self.world_size
         rref_a = rpc.remote(
-            worker_name(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)
+            worker_name(dst_rank), my_function, args=(a, b, c)
         )
         rref_b = rpc.remote(
-            worker_name(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)
+            worker_name(dst_rank), my_function, args=(x, y, z)
         )
 
         c = rpc.rpc_sync(
             worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
         )
+        self.assertEqual(c, expected)
 
-        self.assertEqual(c, torch.ones(n, n) + 4)
+    @dist_init
+    def test_py_rpc_rref_args(self):
+        self._py_rpc_rref_args(
+            torch.ones(2, 2),
+            1,
+            2,
+            torch.ones(2, 2),
+            3,
+            4,
+            torch.ones(2, 2) * 2 + 10
+        )
 
     @dist_init
-    def test_nested_remote(self):
+    def test_py_rpc_rref_args_sparse(self):
+        self._py_rpc_rref_args(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 6
+        )
+
+    def _nested_remote(self, f, expected):
         n = self.rank + 1
         dst_rank1 = n % self.world_size
         dst_rank2 = (n + 1) % self.world_size
 
         rref = rpc.remote(
             worker_name(dst_rank1),
-            nested_remote,
+            f,
             args=(worker_name(dst_rank2),),
         )
-        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
+        self.assertEqual(rref.to_here(), expected)
 
     @dist_init
-    def test_nested_rref(self):
+    def test_nested_remote(self):
+        self._nested_remote(
+            nested_remote,
+            torch.ones(2, 2) + 3
+        )
+
+    @dist_init
+    def test_nested_remote_sparse(self):
+        self._nested_remote(
+            nested_remote_sparse,
+            build_sparse_tensor() + build_sparse_tensor()
+        )
+
+    def _nested_rref(self, f, expected1, expected2):
         n = self.rank + 1
         dst_rank1 = n % self.world_size
         dst_rank2 = (n + 1) % self.world_size
         rref_of_rrefs = rpc.remote(
             worker_name(dst_rank1),
-            nested_rref,
+            f,
             args=(worker_name(dst_rank2),),
         )
 
@@ -2322,11 +2710,26 @@ class RpcTest(RpcAgentTestFixture):
         rrefs = rref_of_rrefs.to_here()
 
         self.assertEqual(len(rrefs), 2)
-        self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
-        self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
+        self.assertEqual(rrefs[0].to_here(), expected1)
+        self.assertEqual(rrefs[1].to_here(), expected2)
 
     @dist_init
-    def test_nested_rref_stress(self):
+    def test_nested_rref(self):
+        self._nested_rref(
+            nested_rref,
+            torch.ones(2, 2) + 1,
+            torch.ones(2, 2) + 2
+        )
+
+    @dist_init
+    def test_nested_rref_sparse(self):
+        self._nested_rref(
+            nested_rref_sparse,
+            build_sparse_tensor() * 2,
+            build_sparse_tensor() * 2
+        )
+
+    def _nested_rref_stress(self, f, expected1, expected2):
         n = self.rank + 1
         dst_rank1 = n % self.world_size
         dst_rank2 = (n + 1) % self.world_size
@@ -2335,7 +2738,7 @@ class RpcTest(RpcAgentTestFixture):
             all_rrefs.append(
                 rpc.remote(
                     worker_name(dst_rank1),
-                    nested_rref,
+                    f,
                     args=(worker_name(dst_rank2),),
                 )
             )
@@ -2344,8 +2747,24 @@ class RpcTest(RpcAgentTestFixture):
             rref_of_rrefs = all_rrefs[i]
             rrefs = rref_of_rrefs.to_here()
             self.assertEqual(len(rrefs), 2)
-            self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
-            self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
+            self.assertEqual(rrefs[0].to_here(), expected1)
+            self.assertEqual(rrefs[1].to_here(), expected2)
+
+    @dist_init
+    def test_nested_rref_stress(self):
+        self._nested_rref_stress(
+            nested_rref,
+            torch.ones(2, 2) + 1,
+            torch.ones(2, 2) + 2
+        )
+
+    @dist_init
+    def test_nested_rref_stress_sparse(self):
+        self._nested_rref_stress(
+            nested_rref_sparse,
+            build_sparse_tensor() * 2,
+            build_sparse_tensor() * 2
+        )
 
     @dist_init
     def test_multi_layer_nested_async_rpc(self):
@@ -4110,6 +4529,46 @@ class RpcTest(RpcAgentTestFixture):
 
         dist.barrier()
 
+    def _trainer_func(self, rref, sparse):
+        m = MyEmbeddingBagModel(sparse=sparse)
+        loss_fn = nn.MSELoss()
+        for i in range(10):
+            outputs = m(torch.rand(10, 10).long())
+            loss_fn(outputs, torch.rand(10, 10)).backward()
+            gradient = list(m.parameters())[0].grad
+            fut = rref.rpc_async().average(rref, i, gradient)
+            gradient = fut.wait()
+            if gradient.is_sparse:
+                gradient = gradient.to_dense().double()
+            ps_gradient = rref.rpc_sync().get_gradient(rref)
+            if ps_gradient.is_sparse:
+                ps_gradient = ps_gradient.to_dense().double()
+            self.assertTrue(torch.equal(gradient, ps_gradient))
+
+    def _my_parameter_server(self, sparse):
+        ps_rref = RRef(MyParameterServer(self.world_size - 1))
+        futures = []
+        for index in range(1, self.world_size):
+            futures.append(
+                rpc.rpc_async(
+                    worker_name((self.rank + index) % self.world_size),
+                    self._trainer_func,
+                    args=(
+                        ps_rref,
+                        sparse
+                    ),
+                )
+            )
+        torch.futures.wait_all(futures)
+
+    @dist_init
+    def test_my_parameter_server(self):
+        self._my_parameter_server(False)
+
+    @dist_init
+    def test_my_parameter_server_sparse(self):
+        self._my_parameter_server(True)
+
 
 class CudaRpcTest(RpcAgentTestFixture):