From 7ebdbf82dccea370edda161936cc533c012e690a Mon Sep 17 00:00:00 2001 From: Garrett Cramer Date: Sun, 29 Aug 2021 11:33:48 -0700 Subject: [PATCH] add support for sending cpu sparse tensors over rpc (#62794) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62794 This pr updates jit serialization to support pickling Sparse COO tensors. This pr updates message.cpp to support Sparse COO tensors. A bug was filed a few years ago https://github.com/pytorch/pytorch/issues/30807. I tested the fix by adding sparse tensor tests to rpc_test.py and dist_autograd_test.py. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse agolynski SciPioneer H-Huang mrzzd cbalioglu gcramer23 gmagogsfm Test Plan: Imported from OSS Reviewed By: soulitzer Differential Revision: D30608848 Pulled By: gcramer23 fbshipit-source-id: 629ba8e4a3d8365875a709c9b87447c7a71204fb --- torch/csrc/distributed/rpc/message.cpp | 11 +- torch/csrc/jit/serialization/pickler.cpp | 44 ++ torch/csrc/jit/serialization/pickler.h | 1 + torch/csrc/jit/serialization/unpickler.cpp | 35 ++ torch/csrc/jit/serialization/unpickler.h | 1 + .../distributed/rpc/dist_autograd_test.py | 653 ++++++++++++++++---- .../testing/_internal/distributed/rpc/rpc_test.py | 663 +++++++++++++++++---- 7 files changed, 1172 insertions(+), 236 deletions(-) diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp index 0277114..7265ed4 100644 --- a/torch/csrc/distributed/rpc/message.cpp +++ b/torch/csrc/distributed/rpc/message.cpp @@ -68,10 +68,17 @@ void Message::setId(int64_t id) { std::vector> Message::getStorages() const { + // Sparse tensors do not have storage. Instead, a sparse tensor + // contains two tensors indices and values, and both contain storage. std::vector> storages; - storages.reserve(tensors_.size()); + storages.reserve(2 * tensors_.size()); for (const auto& tensor : tensors_) { - storages.emplace_back(tensor.storage().getWeakStorageImpl()); + if (tensor.is_sparse()) { + storages.emplace_back(tensor._indices().storage().getWeakStorageImpl()); + storages.emplace_back(tensor._values().storage().getWeakStorageImpl()); + } else { + storages.emplace_back(tensor.storage().getWeakStorageImpl()); + } } return storages; } diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 4a4e866..f465eaf 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -353,6 +353,44 @@ void Pickler::pushTensor(const IValue& ivalue) { } } +void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) { + pushGlobal("torch._utils", "_rebuild_sparse_tensor"); + push(PickleOpCode::MARK); + // layout + auto layout = static_cast(tensor.layout()); + pushInt(layout); + switch (layout) { + case static_cast(c10::Layout::Sparse): + // size + push(PickleOpCode::MARK); + for (auto size : tensor.sizes()) { + pushInt(size); + } + push(PickleOpCode::TUPLE); + // requires grad + pushIValue(tensor.requires_grad()); + // indices + pushTensor(tensor._indices()); + // values + pushTensor(tensor._values()); + break; + default: + TORCH_CHECK( + false, + "Unsupported sparse tensor layout type in serialization ", + static_cast(layout)); + break; + } + // backward_hooks + pushGlobal("collections", "OrderedDict"); + push(PickleOpCode::EMPTY_TUPLE); + // Construct the collections.OrderedDict for the backward_hooks + push(PickleOpCode::REDUCE); + push(PickleOpCode::TUPLE); + // Call torch._utils._rebuild_sparse_coo_tensor + push(PickleOpCode::REDUCE); +} + void Pickler::pushLiteralTensor(const IValue& ivalue) { // In contrast to tensor references, literal tensors are included in the // pickle program binary blob. They are written to the file after the STOP @@ -362,6 +400,12 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { // The format here is the same one used by `torch.save()`. The code for the // format can be found in `torch/serialization.py`. auto& tensor = ivalue.toTensor(); + + if (tensor.is_sparse() || tensor.is_sparse_csr()) { + pushLiteralSparseTensor(tensor); + return; + } + bool quantized = tensor.is_quantized(); // The arguments to this function are: // storage, storage_offset, size, stride, requires_grad, backward_hooks diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index ac54ac4..3dc6bef 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -172,6 +172,7 @@ class TORCH_API Pickler { void pushTensor(const IValue& ivalue); void pushTensorReference(const IValue& ivalue); void pushLiteralTensor(const IValue& ivalue); + void pushLiteralSparseTensor(const at::Tensor& tensor); void pushTuple(const IValue& ivalue); void pushString(const std::string& string); void pushDevice(const IValue& ivalue); diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 581b949..f944387 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -550,6 +550,9 @@ void Unpickler::readGlobal( // Unpickle a tensor bool quantized = class_name == "_rebuild_qtensor"; rebuildTensor(quantized); + } else if ( + module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { + rebuildSparseTensor(); } else if (module_name == "builtins" && class_name == "complex") { globals_.emplace_back([this] { auto elems = pop(stack_).toTuple()->elements(); @@ -647,6 +650,38 @@ void Unpickler::readGlobal( stack_.emplace_back(int64_t(globals_.size() - 1)); } +void Unpickler::rebuildSparseTensor() { + globals_.emplace_back([this] { + auto tup = pop(stack_).toTuple(); + const auto& elements = tup->elements(); + size_t idx = 0; + auto layout = elements.at(idx++).toInt(); + at::Tensor result; + switch (layout) { + case static_cast(c10::Layout::Sparse): { + std::vector size = tupleToIntList(elements.at(idx++)); + bool requires_grad = elements.at(idx++).toBool(); + auto& indices_tensor = elements.at(idx++).toTensor(); + auto& values_tensor = elements.at(idx++).toTensor(); + auto options = values_tensor.options() + .layout(c10::Layout::Sparse) + .requires_grad(requires_grad); + result = at::_sparse_coo_tensor_unsafe( + indices_tensor, values_tensor, size, options); + result = autograd::make_variable(result, options.requires_grad()); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported sparse tensor layout type in serialization ", + static_cast(layout)); + break; + } + stack_.emplace_back(std::move(result)); + }); +} + void Unpickler::rebuildTensor(bool quantized) { globals_.emplace_back([this, quantized] { auto tup = pop(stack_).toTuple(); diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index f404dee..586ff9c 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -108,6 +108,7 @@ class TORCH_API Unpickler { const std::string& module_name, const std::string& class_name); void rebuildTensor(bool quantized); + void rebuildSparseTensor(); #ifdef USE_DISTRIBUTED void rebuildRRef(); #endif diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index 017a61b..fba5030 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -64,13 +64,29 @@ def _torch_ones(sizes, requires_grad=False): # rref tensor equals to the given grad. def _compare_owner_value(context_id, rref, grad): grads = dist_autograd.get_gradients(context_id) - return torch.equal(grads[rref.local_value()], grad) + x = grads[rref.local_value()] + if x.is_sparse: + assert grad.is_sparse + x = x.to_dense() + grad = grad.to_dense() + else: + assert not grad.is_sparse + return torch.equal(x, grad) def create_tensor(): return torch.ones((3, 3), requires_grad=True) +def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32): + i = [[0, 1, 1], [2, 0, 2]] + v = [3.2, 4.1, 5.3] + tensor = torch.sparse_coo_tensor(i, v, (3, 3), requires_grad=requires_grad, dtype=dtype) + if coalesce: + tensor = tensor.coalesce() + return tensor + + @torch.jit.script def create_torchscript_tensor() -> torch.Tensor: return torch.ones((3, 3)).requires_grad_() @@ -143,20 +159,28 @@ def _all_contexts_cleaned_up(timeout_seconds=10): # This function creates a dis atugorad context, run rpc_sync on the given ps, # and then blocks until the ps has verified the grads are correctly accumulated. -def _run_trainer(rref_t1, t2, ps, rank_diff): +def _run_trainer(rref_t1, t2, ps, rank_diff, sparse): with dist_autograd.context() as context_id: ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2)) - dist_autograd.backward(context_id, [ret.sum()]) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + dist_autograd.backward(context_id, [loss]) # prevent deleting dist autograd context rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) rpc.rpc_sync(ps, _check_rpc_done, args=(0,)) # This function is the same as _run_trainer, except rpc calls torchscript # function "my_script_ref_add" instead of python funciton "my_rref_add" -def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff): +def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse): with dist_autograd.context() as context_id: ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2)) - dist_autograd.backward(context_id, [ret.sum()]) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + dist_autograd.backward(context_id, [loss]) # prevent deleting dist autograd context rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) rpc.rpc_sync(ps, _check_rpc_done, args=(0,)) @@ -379,14 +403,18 @@ class DistAutogradTest(CommonDistAutogradTest): "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() ) - def _test_graph(self, fn, exec_mode): + def _test_graph(self, fn, exec_mode, sparse): dst_rank = (self.rank + 1) % self.world_size initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: - t1 = torch.ones(3, 3, requires_grad=True) - t2 = torch.zeros(3, 3, requires_grad=True) + if sparse: + t1 = build_sparse_tensor() + t2 = build_sparse_tensor() + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2)) elif ExecMode.REMOTE == exec_mode: @@ -436,29 +464,49 @@ class DistAutogradTest(CommonDistAutogradTest): @dist_init def test_graph_for_builtin_call(self): - self._test_graph(torch.add, ExecMode.RPC_SYNC) + self._test_graph(torch.add, ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_builtin_call_sparse(self): + self._test_graph(torch.add, ExecMode.RPC_SYNC, True) @dist_init def test_graph_for_python_call(self): - self._test_graph(my_py_add, ExecMode.RPC_SYNC) + self._test_graph(my_py_add, ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_python_call_sparse(self): + self._test_graph(my_py_add, ExecMode.RPC_SYNC, True) @dist_init def test_graph_for_builtin_remote_call(self): - self._test_graph(torch.add, ExecMode.REMOTE) + self._test_graph(torch.add, ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_builtin_remote_call_sparse(self): + self._test_graph(torch.add, ExecMode.REMOTE, True) @dist_init def test_graph_for_python_remote_call(self): - self._test_graph(my_py_add, ExecMode.REMOTE) + self._test_graph(my_py_add, ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_python_remote_call_sparse(self): + self._test_graph(my_py_add, ExecMode.REMOTE, True) # 3-layer nested calls - def _test_graph_for_py_nested_call(self, exec_mode): + def _test_graph_for_py_nested_call(self, exec_mode, sparse): dst_rank = (self.rank + 1) % self.world_size initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: - t1 = torch.ones(3, 3, requires_grad=True) - t2 = torch.zeros(3, 3, requires_grad=True) + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) nest_dst_rank = (dst_rank + 1) % self.world_size if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( @@ -531,21 +579,33 @@ class DistAutogradTest(CommonDistAutogradTest): @dist_init def test_graph_for_py_nested_call(self): - self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC) + self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_py_nested_call_sparse(self): + self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True) @dist_init def test_graph_for_py_nested_remote_call(self): - self._test_graph_for_py_nested_call(ExecMode.REMOTE) + self._test_graph_for_py_nested_call(ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_py_nested_remote_call_sparse(self): + self._test_graph_for_py_nested_call(ExecMode.REMOTE, True) # Rank0->Rank1->Rank0 - def _test_graph_for_py_nested_call_itself(self, exec_mode): + def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse): dst_rank = (self.rank + 1) % self.world_size initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: - t1 = torch.ones(3, 3, requires_grad=True) - t2 = torch.zeros(3, 3, requires_grad=True) + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( worker_name(dst_rank), @@ -610,18 +670,30 @@ class DistAutogradTest(CommonDistAutogradTest): @dist_init def test_graph_for_py_nested_call_itself(self): - self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC) + self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_py_nested_call_itself_sparse(self): + self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True) @dist_init def test_graph_for_py_nested_remote_call_itself(self): - self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE) + self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_py_nested_remote_call_itself_sparse(self): + self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True) - def _test_no_graph_with_tensors_not_require_grad(self, exec_mode): + def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse): initialize_pg(self.file_init_method, self.rank, self.world_size) dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: - t1 = torch.ones(3, 3, requires_grad=False) - t2 = torch.zeros(3, 3, requires_grad=False) + if sparse: + t1 = build_sparse_tensor(requires_grad=False) + t2 = build_sparse_tensor(requires_grad=False) + else: + t1 = torch.ones(3, 3, requires_grad=False) + t2 = torch.zeros(3, 3, requires_grad=False) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( worker_name(dst_rank), torch.add, args=(t1, t2) @@ -656,11 +728,19 @@ class DistAutogradTest(CommonDistAutogradTest): @dist_init def test_no_graph_with_tensors_not_require_grad(self): - self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC) + self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False) + + @dist_init + def test_no_graph_with_tensors_not_require_grad_sparse(self): + self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True) @dist_init def test_no_graph_with_tensors_not_require_grad_remote(self): - self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE) + self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False) + + @dist_init + def test_no_graph_with_tensors_not_require_grad_remote_sparse(self): + self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True) def _test_grad_only_on_return_value(self, exec_mode): initialize_pg(self.file_init_method, self.rank, self.world_size) @@ -699,13 +779,16 @@ class DistAutogradTest(CommonDistAutogradTest): def test_grad_only_on_return_value_remote(self): self._test_grad_only_on_return_value(ExecMode.REMOTE) - def _test_rpc_complex_args(self, exec_mode): + def _test_rpc_complex_args(self, exec_mode, sparse): with dist_autograd.context() as context_id: num_tensors = 10 tensors = [] for i in range(num_tensors): - tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) - + if sparse: + tensor = build_sparse_tensor(requires_grad=(i % 2 == 0)) + else: + tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0)) + tensors.append(tensor) dst_rank = self._next_rank() if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( @@ -739,11 +822,19 @@ class DistAutogradTest(CommonDistAutogradTest): @dist_init def test_rpc_complex_args(self): - self._test_rpc_complex_args(ExecMode.RPC_SYNC) + self._test_rpc_complex_args(ExecMode.RPC_SYNC, False) + + @dist_init + def test_rpc_complex_args_sparse(self): + self._test_rpc_complex_args(ExecMode.RPC_SYNC, True) @dist_init def test_remote_complex_args(self): - self._test_rpc_complex_args(ExecMode.REMOTE) + self._test_rpc_complex_args(ExecMode.REMOTE, False) + + @dist_init + def test_remote_complex_args_sparse(self): + self._test_rpc_complex_args(ExecMode.REMOTE, True) def context_cleanup_test_helper(self, rpc_args, func, nested=False): initialize_pg(self.file_init_method, self.rank, self.world_size) @@ -789,11 +880,22 @@ class DistAutogradTest(CommonDistAutogradTest): self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) @dist_init + def test_context_cleanup_tensor_with_grad_sparse(self): + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) + + @dist_init def test_context_cleanup_tensor_no_grad(self): t1 = torch.ones(3, 3, requires_grad=False) self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add) @dist_init + def test_context_cleanup_tensor_no_grad_sparse(self): + t1 = build_sparse_tensor(requires_grad=False) + self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add) + + @dist_init def test_context_cleanup_no_tensors(self): self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add) @@ -808,6 +910,16 @@ class DistAutogradTest(CommonDistAutogradTest): ) @dist_init + def test_context_cleanup_nested_rpc_sparse(self): + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + dst_rank = (self.rank + 1) % self.world_size + args = (t1, t2, dst_rank, self.world_size, 0) + self.context_cleanup_test_helper( + rpc_args=args, func=my_py_nested_call, nested=True + ) + + @dist_init def test_worker_ids_recorded(self): dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} with dist_autograd.context() as context_id: @@ -876,23 +988,27 @@ class DistAutogradTest(CommonDistAutogradTest): worker_name(self._next_rank()), torch.matmul, args=(t1, t2) ) - @dist_init - def test_backward_no_grad_on_tensor(self): - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + def _backward_no_grad_on_tensor(self, t1, t2, sparse): with dist_autograd.context() as context_id: loss = rpc.rpc_sync( worker_name(self._next_rank()), torch.add, - args=(t1, t2)).sum() - + args=(t1, t2)) + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() dist_autograd.backward(context_id, [loss], retain_graph=True) self.assertIsNone(t1.grad) self.assertIsNone(t2.grad) # Now populate .grad with local autograd engine and # verify dist autograd doesn't mess with it. - loss_local = torch.add(t1, t2).sum() + loss_local = torch.add(t1, t2) + if sparse: + loss_local = torch.sparse.sum(loss_local) + else: + loss_local = loss_local.sum() loss_local.backward() self.assertIsNotNone(t1.grad) self.assertIsNotNone(t2.grad) @@ -903,18 +1019,34 @@ class DistAutogradTest(CommonDistAutogradTest): self.assertEqual(t1_grad_before, t1.grad) self.assertEqual(t2_grad_before, t2.grad) - def _test_backward_simple(self, dst): - # Run the same code locally and with dist autograd and verify gradients - # are same. - local_grads = None - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + @dist_init + def test_backward_no_grad_on_tensor(self): + self._backward_no_grad_on_tensor( + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + False + ) + + @dist_init + def test_backward_no_grad_on_tensor_sparse(self): + self._backward_no_grad_on_tensor( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True + ) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple(self, dst, t1, t2, local_grads, sparse): for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: with dist_autograd.context() as context_id: ret = self._exec_func_with_dst( dst, exec_mode, torch.add, t1, t2 ) - loss = ret.sum() + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() ret = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @@ -922,29 +1054,65 @@ class DistAutogradTest(CommonDistAutogradTest): @dist_init def test_backward_simple(self): - self._test_backward_simple(self._next_rank()) + self._backward_simple( + self._next_rank(), + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_simple_sparse(self): + self._backward_simple( + self._next_rank(), + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) @dist_init def test_backward_simple_self(self): - self._test_backward_simple(self.rank) + self._backward_simple( + self.rank, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_simple_self_sparse(self): + self._backward_simple( + self.rank, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) # The current rank first creates a tensor on the rref_owner, and then passes # the rref with another tensor to the callee to run either my_rref_add or # my_nested_rref_add, depending on whether the callee is the rref owner. # The grad of tensor lives on the current rank, and the grad of the rref # tensor lives on the rref owner. - def _test_backward_rref(self, callee, rref_owner): - local_grads = None - t1 = torch.ones((3, 3), requires_grad=True) - t2 = torch.zeros((3, 3), requires_grad=True) - + def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse): local_ret = torch.add(t1, t2) - local_ret.sum().backward() + if sparse: + local_ret = torch.sparse.sum(local_ret) + else: + local_ret = local_ret.sum() + local_ret.backward() with dist_autograd.context() as context_id: - rref_t1 = rpc.remote( - rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True} - ) - + if sparse: + rref_t1 = rpc.remote( + rref_owner, build_sparse_tensor, args=(False, True,) + ) + else: + rref_t1 = rpc.remote( + rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True} + ) if callee == rref_owner: rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2)) else: @@ -952,7 +1120,11 @@ class DistAutogradTest(CommonDistAutogradTest): callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2) ) ret = rref.to_here() - dist_autograd.backward(context_id, [ret.sum()]) + if sparse: + ret = torch.sparse.sum(ret) + else: + ret = ret.sum() + dist_autograd.backward(context_id, [ret]) # verify grads on caller grads = dist_autograd.get_gradients(context_id) @@ -972,20 +1144,81 @@ class DistAutogradTest(CommonDistAutogradTest): def test_backward_rref(self): callee = worker_name(self._next_rank()) rref_owner = callee - self._test_backward_rref(callee, rref_owner) + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_rref_sparse(self): + callee = worker_name(self._next_rank()) + rref_owner = callee + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) @dist_init def test_backward_rref_multi(self): if self.rank > 0: callee = "worker0" rref_owner = callee - self._test_backward_rref(callee, rref_owner) + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_rref_multi_sparse(self): + if self.rank > 0: + callee = "worker0" + rref_owner = callee + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) @dist_init def test_backward_rref_nested(self): callee = worker_name((self.rank + 1) % self.world_size) rref_owner = worker_name((self.rank + 2) % self.world_size) - self._test_backward_rref(callee, rref_owner) + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_rref_nested_sparse(self): + callee = worker_name((self.rank + 1) % self.world_size) + rref_owner = worker_name((self.rank + 2) % self.world_size) + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) # In this test, every rank will serve as a parameter server (ps) and a # driver, and then kicks off trainers on the other three ranks. So, we have: @@ -996,13 +1229,19 @@ class DistAutogradTest(CommonDistAutogradTest): # # These four test ps-trainer groups run on completely separate autograd # graphs, but they share the same set of underlying RpcAgents. - def _test_trainer_ps(self, create_ref_fn, trainer_fn): - local_grads = None - t1 = torch.ones((3, 3), requires_grad=True) - t2 = torch.zeros((3, 3), requires_grad=True) + def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse): + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones((3, 3), requires_grad=True) + t2 = torch.zeros((3, 3), requires_grad=True) local_ret = torch.add(t1, t2) - local_ret.sum().backward() + if sparse: + torch.sparse.sum(local_ret).backward() + else: + local_ret.sum().backward() # create rref on self rref_t1 = rpc.remote( @@ -1018,7 +1257,7 @@ class DistAutogradTest(CommonDistAutogradTest): rpc.rpc_async( worker_name((self.rank + rank_diff) % self.world_size), trainer_fn, - args=(rref_t1, t2, worker_name(self.rank), rank_diff), + args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse), ) ) @@ -1045,7 +1284,19 @@ class DistAutogradTest(CommonDistAutogradTest): @dist_init def test_trainer_ps(self): - self._test_trainer_ps(create_tensor, _run_trainer) + self._test_trainer_ps( + create_tensor, + _run_trainer, + False + ) + + @dist_init + def test_trainer_ps_sparse(self): + self._test_trainer_ps( + build_sparse_tensor, + _run_trainer, + True + ) @dist_init def test_trainer_ps_torchscript_functions(self): @@ -1056,17 +1307,9 @@ class DistAutogradTest(CommonDistAutogradTest): import torch.distributed.rpc.api as api api._ignore_rref_leak = True - self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript) - - @dist_init - def test_backward_multiple_round_trips(self): - local_grads = None - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3)) - t3 = torch.rand((3, 3), requires_grad=True) - t4 = torch.rand((3, 3)) - t5 = torch.rand((3, 3), requires_grad=True) + self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False) + def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse): for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: with dist_autograd.context() as context_id: # Multiple RPCs between different nodes. @@ -1074,9 +1317,14 @@ class DistAutogradTest(CommonDistAutogradTest): val = self._exec_func(exec_mode, torch.mul, t3, val) s1 = self._exec_func(exec_mode, torch.stack, (t4, val)) s2 = self._exec_func(exec_mode, torch.stack, (t5, val)) - val = self._exec_func(exec_mode, torch.bmm, s1, s2) - val = self._exec_func(exec_mode, torch.matmul, val, val) - loss = val.sum() + if sparse: + val = self._exec_func(exec_mode, torch.mul, s1, s2) + val = self._exec_func(exec_mode, torch.mul, val, val) + loss = torch.sparse.sum(val) + else: + val = self._exec_func(exec_mode, torch.bmm, s1, s2) + val = self._exec_func(exec_mode, torch.matmul, val, val) + loss = val.sum() ret = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5 @@ -1084,6 +1332,30 @@ class DistAutogradTest(CommonDistAutogradTest): local_grads = ret if ret else local_grads @dist_init + def test_backward_multiple_round_trips(self): + self._backward_multiple_round_trips( + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3)), + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3)), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_multiple_round_trips_sparse(self): + self._backward_multiple_round_trips( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False), + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False), + build_sparse_tensor(requires_grad=True), + None, + True + ) + + @dist_init def test_backward_different_tensor_dims(self): local_grads = None t1 = torch.rand((4, 6), requires_grad=True) @@ -1317,41 +1589,70 @@ class DistAutogradTest(CommonDistAutogradTest): exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2 ) - @dist_init - def test_backward_different_dtypes(self): + def _backward_different_dtypes(self, t1, t2, sparse): local_grads = None - t1 = torch.rand((3, 3), requires_grad=True, dtype=torch.float32) - t2 = torch.rand((3, 3), requires_grad=True, dtype=torch.float64) for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: with dist_autograd.context() as context_id: - loss = self._exec_func(exec_mode, torch.add, t1, t2).sum() - + loss = self._exec_func(exec_mode, torch.add, t1, t2) + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() local_grads = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @dist_init - def test_backward_simple_python_udf(self): - # Run the same code locally and with dist autograd and verify gradients - # are same. + def test_backward_different_dtypes(self): + self._backward_different_dtypes( + torch.rand((3, 3), requires_grad=True, dtype=torch.float32), + torch.rand((3, 3), requires_grad=True, dtype=torch.float64), + False + ) + + @dist_init + def test_backward_different_dtypes_sparse(self): + self._backward_different_dtypes( + build_sparse_tensor(requires_grad=True, dtype=torch.float32), + build_sparse_tensor(requires_grad=True, dtype=torch.float64), + True + ) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple_python_udf(self, t1, t2, sparse): local_grads = None - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: with dist_autograd.context() as context_id: ret = self._exec_func(exec_mode, my_py_add, t1, t2) - loss = ret.sum() + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() local_grads = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @dist_init - def test_backward_simple_script_call(self): - # Run the same code locally and with dist autograd and verify gradients - # are same. + def test_backward_simple_python_udf(self): + self._backward_simple_python_udf( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False + ) + + @dist_init + def test_backward_simple_python_udf_sparse(self): + self._backward_simple_python_udf( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True + ) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple_script_call(self, t1, t2, sparse): local_grads = None - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) for exec_mode in [ ExecMode.LOCAL, ExecMode.RPC_SYNC, @@ -1360,12 +1661,31 @@ class DistAutogradTest(CommonDistAutogradTest): ]: with dist_autograd.context() as context_id: forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2) - loss = forward_ret.sum() + if sparse: + loss = torch.sparse.sum(forward_ret) + else: + loss = forward_ret.sum() ret = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) local_grads = ret if ret else local_grads + @dist_init + def test_backward_simple_script_call(self): + self._backward_simple_script_call( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False + ) + + @dist_init + def test_backward_simple_script_call_sparse(self): + self._backward_simple_script_call( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True + ) + @staticmethod def _complex_python_udf(t1, t2): t3 = torch.nn.functional.linear(t1, t2) @@ -1474,17 +1794,17 @@ class DistAutogradTest(CommonDistAutogradTest): t3 = t1 * t2 t4 = t1 + t2 res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4)) - return torch.linalg.multi_dot([t1, t2, t3, t4, res]) + return t1 * t2 * t3 * t4 * res - @dist_init - def test_backwards_nested_python_udf(self): - # Run equivalent of _nested_python_udf locally. - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + def _backwards_nested_python_udf(self, t1, t2, sparse): t3 = t1 * t2 t4 = t1 + t2 res = t3 + t4 - loss = torch.linalg.multi_dot([t1, t2, t3, t4, res]).sum() + loss = t1 * t2 * t3 * t4 * res + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() torch.autograd.backward([loss]) # Now run distributed autograd. @@ -1494,12 +1814,33 @@ class DistAutogradTest(CommonDistAutogradTest): DistAutogradTest._nested_python_udf, args=(t1, t2, self._next_rank()), ) - dist_autograd.backward(context_id, [loss.sum()]) - + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() + dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) self.assertEqual(t1.grad, grads[t1]) self.assertEqual(t2.grad, grads[t2]) + @dist_init + def test_backwards_nested_python_udf(self): + # Run equivalent of _nested_python_udf locally. + self._backwards_nested_python_udf( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False + ) + + @dist_init + def test_backwards_nested_python_udf_sparse(self): + # Run equivalent of _nested_python_udf locally. + self._backwards_nested_python_udf( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True + ) + _test_clean_context_backward_context_id = None class MyBackwardFunc(Function): @@ -1594,8 +1935,7 @@ class DistAutogradTest(CommonDistAutogradTest): def _get_grad(cls, embedding_rref, context_id): embedding = embedding_rref.local_value() grad_map = dist_autograd.get_gradients(context_id) - # Can't send sparse tensors over RPC: https://github.com/pytorch/pytorch/issues/30807 - return grad_map[embedding.weight].to_dense() + return grad_map[embedding.weight] @dist_init def test_embedding_bag_with_no_grad_tensors(self): @@ -1637,26 +1977,27 @@ class DistAutogradTest(CommonDistAutogradTest): args=(remote_embedding, context_id), ) - self.assertEqual(local_grad.to_dense(), remote_grad) + self.assertEqual(local_grad, remote_grad) @classmethod - def _mixed_requires_grad(cls, t1, t2): + def _mixed_requires_grad_operaton(cls, t1, t2): if t2.requires_grad: return t1 - t2 else: return t1 * t2 - @dist_init - def test_mixed_requires_grad(self): + def _mixed_requires_grad(self, t1, t2, sparse): for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]: - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=False) with dist_autograd.context() as context_id: ret = self._exec_func( - exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2 + exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2 ) self.assertEqual(t1 * t2, ret) - dist_autograd.backward(context_id, [ret.sum()]) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + dist_autograd.backward(context_id, [loss]) self.assertTrue(t1.requires_grad) self.assertFalse(t2.requires_grad) grads = dist_autograd.get_gradients(context_id) @@ -1664,6 +2005,22 @@ class DistAutogradTest(CommonDistAutogradTest): self.assertNotIn(t2, grads) self.assertEqual(t2, grads[t1]) + @dist_init + def test_mixed_requires_grad(self): + self._mixed_requires_grad( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=False), + False + ) + + @dist_init + def test_mixed_requires_grad_sparse(self): + self._mixed_requires_grad( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False), + True + ) + class TestDebugInfoFunc(Function): @staticmethod def forward(ctx, input): @@ -1801,37 +2158,69 @@ class DistAutogradTest(CommonDistAutogradTest): @staticmethod def _test_nested_backward_accumulate_grads(t1, t2, dst_rank): - return rpc.rpc_sync(worker_name(dst_rank), torch.matmul, args=(t1, t2)) + return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) - @dist_init - def test_nested_backward_accumulate_grads(self): - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + def _nested_backward_accumulate_grads(self, t1, t2, sparse): with dist_autograd.context() as context_id: - loss = rpc.rpc_sync( + ret = rpc.rpc_sync( worker_name(self._next_rank()), DistAutogradTest._test_nested_backward_accumulate_grads, args=(t1, t2, self._next_rank()), - ).sum() - + ) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() # Run backward twice. dist_autograd.backward(context_id, [loss], retain_graph=True) dist_autograd.backward(context_id, [loss]) @dist_init - def test_multiple_backward(self): - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + def test_nested_backward_accumulate_grads(self): + self._nested_backward_accumulate_grads( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False + ) + + @dist_init + def test_nested_backward_accumulate_grads_sparse(self): + self._nested_backward_accumulate_grads( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True + ) + + def _multiple_backward(self, t1, t2, sparse): with dist_autograd.context() as context_id: loss = rpc.rpc_sync( worker_name(self._next_rank()), torch.add, - args=(t1, t2)).sum() - + args=(t1, t2)) + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() # Run backward in a loop multiple times. for i in range(1000): dist_autograd.backward(context_id, [loss], retain_graph=True) + @dist_init + def test_multiple_backward(self): + self._multiple_backward( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False + ) + + @dist_init + def test_multiple_backward_sparse(self): + self._multiple_backward( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True + ) + @dist_init(clean_shutdown=False) def test_multiple_backward_with_errors(self): initialize_pg(self.file_init_method, self.rank, self.world_size) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 1a44ef6..e0ef915 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -209,10 +209,13 @@ def add_rref_to_value(rref, value): def run_nested_pickle(pickle_cls_instance, tensor): return pickle_cls_instance.t + tensor -def build_sparse_tensor(): +def build_sparse_tensor(coalesce=False): i = [[0, 1, 1], [2, 0, 2]] v = [3, 4, 5] - return torch.sparse_coo_tensor(i, v, (2, 3)) + tensor = torch.sparse_coo_tensor(i, v, (2, 3)) + if coalesce: + tensor = tensor.coalesce() + return tensor def build_complex_tensors(): a = torch.ones(3, 3) @@ -238,6 +241,12 @@ def my_function(a, b, c): def my_tensor_function(a, b): return a + b +def my_container_sum(a): + result = a[0] + for tensor in a[1:]: + result += tensor + return result + def my_sleep_func(seconds=1): time.sleep(seconds) @@ -275,6 +284,14 @@ def nested_rpc(dst): return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) +def nested_rpc_sparse(dst): + return rpc.rpc_sync( + dst, + torch.add, + args=(build_sparse_tensor(), build_sparse_tensor()) + ) + + def multi_layer_nested_async_rpc(dst, world_size, ttl): # this method returns immediately without blocking the callee, but will # generate additional requests. @@ -296,10 +313,29 @@ def nested_rref(dst): ) +def nested_rref_sparse(dst): + return ( + rpc.remote( + dst, + torch.add, + args=(build_sparse_tensor(), build_sparse_tensor()) + ), + rpc.remote( + dst, + torch.add, + args=(build_sparse_tensor(), build_sparse_tensor()) + ), + ) + + def nested_remote(dst): rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3)) return rref.to_here() +def nested_remote_sparse(dst): + rref = rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())) + return rref.to_here() + def rref_forward_chain(dst, world_size, rref, ttl): if ttl > 0: @@ -328,6 +364,12 @@ def heavy_rpc(tensor): return 0 +def heavy_rpc_sparse(tensor): + for i in range(1, 100): + tensor *= i + tensor = tensor / (i + 1) + return 0 + @torch.jit.script def heavy_rpc_torchscript(tensor): for i in range(1, 100): @@ -600,6 +642,57 @@ class FooBackendOptions(rpc.RpcBackendOptions): load_tests = load_tests +class MyEmbeddingBagModel(torch.nn.Module): + def __init__(self, sparse): + super().__init__() + self.eb = torch.nn.EmbeddingBag( + 10, + 10, + sparse=sparse + ) + + def forward(self, x): + return self.eb(x) + + +class MyParameterServer: + def __init__(self, trainers): + self.lock = Lock() + self.trainers = trainers + self.iteration = 0 + self.updates = 0 + self.futures = [] + self.total = None + self.gradient = None + + @staticmethod + def get_gradient(rref): + return rref.local_value().gradient + + @staticmethod + @rpc.functions.async_execution + def average(rref, riteration, tensor): + self = rref.local_value() + fut = torch.futures.Future() + with self.lock: + if riteration > self.iteration: + self.iteration = riteration + self.updates = 0 + self.futures.clear() + self.futures.append(fut) + if self.total is None: + self.total = tensor + else: + self.total += tensor + self.updates += 1 + if self.trainers == self.updates: + self.gradient = self.total / float(self.trainers) + for fut in self.futures: + result = self.total / float(self.trainers) + fut.set_result(result) + return fut + + class RpcTest(RpcAgentTestFixture): @dist_init def test_worker_id(self): @@ -641,10 +734,26 @@ class RpcTest(RpcAgentTestFixture): def test_send_to_rank(self): dst_rank = (self.rank + 1) % self.world_size + # Test dense tensor for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) self.assertEqual(ret, torch.ones(2, 2) + 1) + # Test sparse tensor + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor() + y = build_sparse_tensor() + expected_tensor = (x + y) + ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) + self.assertEqual(expected_tensor, ret) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor(coalesce=True) + y = build_sparse_tensor(coalesce=True) + expected_tensor = (x + y) + ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) + self.assertEqual(expected_tensor, ret) + # Test invalid ranks for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(RuntimeError): @@ -662,41 +771,120 @@ class RpcTest(RpcAgentTestFixture): with self.assertRaises(ValueError): self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + def _self_py_udf_remote(self, worker_info, x, y, z): + rref = rpc.remote(worker_info, my_function, args=(x, y, z)) + self.assertEqual(rref.to_here(), x + y + z) + @dist_init def test_self_py_udf_remote(self): - self_worker_info = rpc.get_worker_info() - rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3)) - self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1 + 3) + self._self_py_udf_remote( + rpc.get_worker_info(), + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_py_udf_remote_sparse(self): + self._self_py_udf_remote( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) + - def _test_self_remote_rref_as_rpc_arg(self, dst): + def _self_remote_rref_as_rpc_arg(self, dst, x, y, z): self_worker_info = rpc.get_worker_info() - rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3)) - fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2))) - ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1)) - self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1) - self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2)) + rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) + fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x)) + ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y)) + self.assertEqual(ret, x + y + z + x + y) + self.assertEqual(fut.wait(), x + y + z + x) @dist_init def test_self_remote_rref_as_rpc_arg(self): dst = worker_name((self.rank + 1) % self.world_size) - self._test_self_remote_rref_as_rpc_arg(dst) + self._self_remote_rref_as_rpc_arg( + dst, + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_remote_rref_as_rpc_arg_sparse(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._self_remote_rref_as_rpc_arg( + dst, + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) @dist_init def test_self_remote_rref_as_self_rpc_arg(self): - self._test_self_remote_rref_as_rpc_arg(rpc.get_worker_info()) + self._self_remote_rref_as_rpc_arg( + rpc.get_worker_info(), + torch.ones(2, 2), + 1, + 3 + ) - def _test_self_remote_rref_as_remote_arg(self, dst): + @dist_init + def test_self_remote_rref_as_self_rpc_arg_sparse(self): + self._self_remote_rref_as_rpc_arg( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) + + def _self_remote_rref_as_remote_arg(self, dst, x, y, z): self_worker_info = rpc.get_worker_info() - rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3)) - ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, torch.ones(2, 2))) + rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) + ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x)) self.assertEqual( - ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + ret_rref.to_here(), x + y + z + x ) @dist_init def test_self_remote_rref_as_remote_arg(self): dst = worker_name((self.rank + 1) % self.world_size) - self._test_self_remote_rref_as_remote_arg(dst) + self._self_remote_rref_as_remote_arg( + dst, + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_remote_rref_as_remote_arg_sparse(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._self_remote_rref_as_remote_arg( + dst, + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) + + @dist_init + def test_self_remote_rref_as_self_remote_arg(self): + self._self_remote_rref_as_remote_arg( + rpc.get_worker_info(), + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_remote_rref_as_self_remote_arg_sparse(self): + self._self_remote_rref_as_remote_arg( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) @dist_init def test_rref_proxy_non_exist(self): @@ -816,10 +1004,6 @@ class RpcTest(RpcAgentTestFixture): def test_rref_proxy_class_self(self): self._test_rref_proxy_class(rpc.get_worker_info()) - @dist_init - def test_self_remote_rref_as_self_remote_arg(self): - self._test_self_remote_rref_as_remote_arg(rpc.get_worker_info()) - @mock.patch.object(torch.distributed.autograd, "_init") @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent") @dist_init(setup_rpc=False) @@ -911,7 +1095,7 @@ class RpcTest(RpcAgentTestFixture): ) rpc.shutdown() - def test_world_size_one(self): + def _world_size_one(self, a, b): if self.rank == 0: rpc.init_rpc( name="me", @@ -921,32 +1105,51 @@ class RpcTest(RpcAgentTestFixture): rpc_backend_options=self.rpc_backend_options, ) - expect = torch.ones(2, 2) * 2 - result = rpc.rpc_sync( - "me", - my_tensor_function, - args=(torch.ones(2, 2), torch.ones(2, 2)) - ) - self.assertEqual(expect, result) - - expect = torch.ones(3, 3) * 2 - result = rpc.rpc_async( - "me", - my_tensor_function, - args=(torch.ones(3, 3), torch.ones(3, 3)) - ).wait() - self.assertEqual(expect, result) + def _rpc_sync(x, y): + expect = x * 2 + result = rpc.rpc_sync( + "me", + my_tensor_function, + args=(x, y) + ) + self.assertEqual(expect, result) + + def _rpc_async(x, y): + expect = x * 2 + result = rpc.rpc_async( + "me", + my_tensor_function, + args=(x, y) + ).wait() + self.assertEqual(expect, result) + + def _remote(x, y): + expect = x * 2 + result = rpc.remote( + "me", + my_tensor_function, + args=(x, y) + ).to_here() + self.assertEqual(expect, result) - expect = torch.ones(4, 4) * 2 - result = rpc.remote( - "me", - my_tensor_function, - args=(torch.ones(4, 4), torch.ones(4, 4)) - ).to_here() - self.assertEqual(expect, result) + _rpc_sync(a, b) + _rpc_async(a, b) + _remote(a, b) rpc.shutdown() + def test_world_size_one(self): + self._world_size_one( + torch.ones(2, 2), + torch.ones(2, 2) + ) + + def test_world_size_one_sparse(self): + self._world_size_one( + build_sparse_tensor(), + build_sparse_tensor() + ) + @dist_init(setup_rpc=False) def test_invalid_names(self): from torch.distributed.rpc import WorkerInfo @@ -1027,17 +1230,30 @@ class RpcTest(RpcAgentTestFixture): ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,)) self.assertEqual(ret, x.nonzero()) - @dist_init - def test_multi_rpc(self): + def _multi_rpc(self, sparse): dst_rank = (self.rank + 1) % self.world_size for i in range(20): n = i + self.rank + 1 + if sparse: + x = build_sparse_tensor() * n + y = build_sparse_tensor() * n + else: + x = torch.ones(2, 2) + y = torch.ones(2, 2) ret = rpc.rpc_sync( worker_name(dst_rank), torch.add, - args=(torch.ones(n, n), torch.ones(n, n)), + args=(x, y), ) - self.assertEqual(ret, torch.ones(n, n) * 2) + self.assertEqual(ret, x * 2) + + @dist_init + def test_multi_rpc(self): + self._multi_rpc(False) + + @dist_init + def test_multi_rpc_sparse(self): + self._multi_rpc(True) @dist_init def test_future_wait_twice(self): @@ -1053,7 +1269,7 @@ class RpcTest(RpcAgentTestFixture): with self.assertRaisesRegex(ValueError, "Expected error"): fut.wait() - def _run_uneven_workload(self, num_repeat=30): + def _run_uneven_workload(self, f, x, num_repeat=30): # worker0 drives and waits for worker1 and worker2 # throughout the test. if self.rank == 0: @@ -1063,7 +1279,7 @@ class RpcTest(RpcAgentTestFixture): dst = "worker1" futs = [] for _ in range(num_repeat): - fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) + fut = rpc.rpc_async(dst, f, args=(x,)) futs.append(fut) for fut in torch.futures.collect_all(futs).wait(): @@ -1075,13 +1291,13 @@ class RpcTest(RpcAgentTestFixture): dst = "worker2" futs = [] for _ in range(num_repeat): - fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) + fut = rpc.rpc_async(dst, f, args=(x,)) futs.append(fut) for val in torch.futures.wait_all(futs): self.assertEqual(val, 0) - def test_wait_all_workers(self): + def _wait_all_workers(self, f, x): initialize_pg(self.file_init_method, self.rank, self.world_size) rpc.init_rpc( name="worker%d" % self.rank, @@ -1091,7 +1307,7 @@ class RpcTest(RpcAgentTestFixture): rpc_backend_options=self.rpc_backend_options, ) - self._run_uneven_workload() + self._run_uneven_workload(f, x) # worker0 calls this at the end after waiting for RPC responses. # worker1/2 calls this immediately and has some works after it. @@ -1103,7 +1319,13 @@ class RpcTest(RpcAgentTestFixture): dist.barrier() rpc.shutdown(graceful=False) - def test_wait_all_workers_twice(self): + def test_wait_all_workers_dense(self): + self._wait_all_workers(heavy_rpc, torch.ones(100, 100)) + + def test_wait_all_workers_sparse(self): + self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor()) + + def _wait_all_workers_twice(self, f, x): initialize_pg(self.file_init_method, self.rank, self.world_size) rpc.init_rpc( name="worker%d" % self.rank, @@ -1113,7 +1335,7 @@ class RpcTest(RpcAgentTestFixture): rpc_backend_options=self.rpc_backend_options, ) - self._run_uneven_workload() + self._run_uneven_workload(f, x) # worker0 calls this at the end after waiting for RPC responses. # worker1/2 calls this immediately and has some works after it. @@ -1126,6 +1348,12 @@ class RpcTest(RpcAgentTestFixture): dist.barrier() rpc.shutdown(graceful=False) + def test_wait_all_workers_twice_dense(self): + self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100)) + + def test_wait_all_workers_twice_sparse(self): + self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor()) + @dist_init def test_all_gather(self): info = rpc.get_worker_info() @@ -1211,7 +1439,7 @@ class RpcTest(RpcAgentTestFixture): @dist_init def test_graceful_shutdown_with_uneven_workload(self): """Test graceful termination.""" - self._run_uneven_workload() + self._run_uneven_workload(heavy_rpc, torch.ones(100, 100)) @dist_init(setup_rpc=False) def test_shutdown_followed_by_rpc(self): @@ -2082,6 +2310,16 @@ class RpcTest(RpcAgentTestFixture): self.assertEqual(ret, my_complex_tensor_function(a, b, c)) @dist_init + def test_py_sparse_tensors_in_container(self): + n = self.rank + 1 + dst_rank = n % self.world_size + a = [build_sparse_tensor(), build_sparse_tensor()] + ret = rpc.rpc_sync( + worker_name(dst_rank), my_container_sum, args=(a,) + ) + self.assertEqual(ret, my_container_sum(a)) + + @dist_init def test_py_nested_pickle(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -2137,16 +2375,23 @@ class RpcTest(RpcAgentTestFixture): else: self.assertTrue(False, "expected raise_func_escape to raise ValueError.") - @dist_init - def test_nested_rpc(self): + def _nested_rpc(self, f, expected): n = self.rank + 1 dst_rank = n % self.world_size ret = rpc.rpc_sync( worker_name(dst_rank), - nested_rpc, + f, args=(worker_name(self.rank),), ) - self.assertEqual(ret, torch.ones(2, 2) + 1) + self.assertEqual(ret, expected) + + @dist_init + def test_nested_rpc(self): + self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1) + + @dist_init + def test_nested_rpc_sparse(self): + self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2) def _stress_test_rpc(self, f, repeat=1000, args=()): n = self.rank + 1 @@ -2175,30 +2420,64 @@ class RpcTest(RpcAgentTestFixture): self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),)) @dist_init + def test_stress_heavy_rpc_sparse(self): + self._stress_test_rpc(heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),)) + + @dist_init def test_stress_heavy_rpc_torchscript(self): self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),)) - @dist_init - def test_builtin_remote_ret(self): + def _builtin_remote_ret(self, x, y, expected): n = self.rank + 1 dst_rank = n % self.world_size rref = rpc.remote( worker_name(dst_rank), torch.add, - args=(torch.ones(n, n), torch.ones(n, n)), + args=(x, y), ) - self.assertEqual(rref.to_here(), torch.ones(n, n) * 2) + self.assertEqual(rref.to_here(), expected) @dist_init - def test_builtin_remote_self(self): + def test_builtin_remote_ret(self): + self._builtin_remote_ret( + torch.ones(2, 2), + torch.ones(2, 2), + torch.ones(2, 2) * 2 + ) + + @dist_init + def test_builtin_remote_ret_sparse(self): + self._builtin_remote_ret( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 2 + ) + + def _builtin_remote_self(self, x, y, expected): rref = rpc.remote( worker_name(self.rank), torch.add, - args=(torch.ones(2, 2), torch.ones(2, 2)), + args=(x, y), + ) + self.assertEqual(rref.local_value(), expected) + + @dist_init + def test_builtin_remote_self(self): + self._builtin_remote_self( + torch.ones(2, 2), + torch.ones(2, 2), + torch.ones(2, 2) * 2 + ) + + @dist_init + def test_builtin_remote_self_sparse(self): + self._builtin_remote_self( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 2 ) - self.assertEqual(rref.local_value(), torch.ones(2, 2) * 2) - def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}): + def _test_multi_remote_call(self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}): m = 10 n = self.rank + 1 dst_rank = n % self.world_size @@ -2210,21 +2489,35 @@ class RpcTest(RpcAgentTestFixture): rpc.remote( worker_name(dst_rank), fn, - args=args_fn(n), - kwargs=kwargs_fn(n), + args=args_fn(n, sparse), + kwargs=kwargs_fn(n, sparse), ) ) - expected.append(fn(*args_fn(n), **kwargs_fn(n))) + expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse))) for i in range(m): self.assertEqual(rrefs[i].to_here(), expected[i]) + @staticmethod + def _multi_args_fn(n, sparse=False): + if sparse: + return (build_sparse_tensor(), build_sparse_tensor()) + else: + return (torch.ones(n, n), torch.ones(n, n)) + @dist_init def test_multi_builtin_remote_ret(self): - def args_fn(n): - return (torch.ones(n, n), torch.ones(n, n)) + self._test_multi_remote_call( + torch.add, False, + args_fn=RpcTest._multi_args_fn + ) - self._test_multi_remote_call(torch.add, args_fn=args_fn) + @dist_init + def test_multi_builtin_remote_ret_sparse(self): + self._test_multi_remote_call( + torch.add, True, + args_fn=RpcTest._multi_args_fn + ) @dist_init def test_py_udf_remote(self): @@ -2237,82 +2530,177 @@ class RpcTest(RpcAgentTestFixture): ) self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2)) - @dist_init - def test_multi_py_udf_remote(self): - def kwargs_fn(n): + @staticmethod + def _multi_kwargs_fn(n, sparse=False): + if sparse: + return { + "a": build_sparse_tensor(), + "b": build_sparse_tensor(), + "c": build_sparse_tensor() + } + else: return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)} - self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn) + @dist_init + def test_multi_py_udf_remote(self): + self._test_multi_remote_call( + my_function, + False, + kwargs_fn=RpcTest._multi_kwargs_fn + ) @dist_init - def test_py_rref_args(self): + def test_multi_py_udf_remote_sparse(self): + self._test_multi_remote_call( + my_function, + True, + kwargs_fn=RpcTest._multi_kwargs_fn + ) + + def _py_rref_args(self, a, b, x, y, expected): n = self.rank + 1 dst_rank = n % self.world_size rref_a = rpc.remote( - worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2) + worker_name(dst_rank), torch.add, args=(a, b) ) rref_b = rpc.remote( - worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1) + worker_name(dst_rank), torch.add, args=(x, y) ) rref_c = rpc.remote( worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) ) - self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) + self.assertEqual(rref_c.to_here(), expected) @dist_init - def test_py_rref_args_user_share(self): + def test_py_rref_args(self): + self._py_rref_args( + torch.ones(2, 2), + 1, + torch.ones(2, 2), + 2, + torch.ones(2, 2) * 2 + 3) + + @dist_init + def test_py_rref_args_sparse(self): + self._py_rref_args( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 4 + ) + + def _py_rref_args_user_share(self, a, b, c, x, y, z, expected): n = self.rank + 1 owner_rank = n % self.world_size user_rank = (n + 1) % self.world_size rref_a = rpc.remote( - worker_name(owner_rank), my_function, args=(torch.ones(n, n), 2, 0) + worker_name(owner_rank), my_function, args=(a, b, c) ) rref_b = rpc.remote( - worker_name(owner_rank), my_function, args=(torch.ones(n, n), 1, 0) + worker_name(owner_rank), my_function, args=(x, y, z) ) rref_c = rpc.remote( worker_name(user_rank), my_rref_function, args=(rref_a, rref_b) ) - self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) + self.assertEqual(rref_c.to_here(), expected) @dist_init - def test_py_rpc_rref_args(self): + def test_py_rref_args_user_share(self): + self._py_rref_args_user_share( + torch.ones(2, 2), + 1, + 2, + torch.ones(2, 2), + 3, + 4, + torch.ones(2, 2) * 2 + 10 + ) + + @dist_init + def test_py_rref_args_user_share_sparse(self): + self._py_rref_args_user_share( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 6 + ) + + def _py_rpc_rref_args(self, a, b, c, x, y, z, expected): n = self.rank + 1 dst_rank = n % self.world_size rref_a = rpc.remote( - worker_name(dst_rank), my_function, args=(torch.ones(n, n), 2, 0) + worker_name(dst_rank), my_function, args=(a, b, c) ) rref_b = rpc.remote( - worker_name(dst_rank), my_function, args=(torch.ones(n, n), 1, 0) + worker_name(dst_rank), my_function, args=(x, y, z) ) c = rpc.rpc_sync( worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) ) + self.assertEqual(c, expected) - self.assertEqual(c, torch.ones(n, n) + 4) + @dist_init + def test_py_rpc_rref_args(self): + self._py_rpc_rref_args( + torch.ones(2, 2), + 1, + 2, + torch.ones(2, 2), + 3, + 4, + torch.ones(2, 2) * 2 + 10 + ) @dist_init - def test_nested_remote(self): + def test_py_rpc_rref_args_sparse(self): + self._py_rpc_rref_args( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 6 + ) + + def _nested_remote(self, f, expected): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref = rpc.remote( worker_name(dst_rank1), - nested_remote, + f, args=(worker_name(dst_rank2),), ) - self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3) + self.assertEqual(rref.to_here(), expected) @dist_init - def test_nested_rref(self): + def test_nested_remote(self): + self._nested_remote( + nested_remote, + torch.ones(2, 2) + 3 + ) + + @dist_init + def test_nested_remote_sparse(self): + self._nested_remote( + nested_remote_sparse, + build_sparse_tensor() + build_sparse_tensor() + ) + + def _nested_rref(self, f, expected1, expected2): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref_of_rrefs = rpc.remote( worker_name(dst_rank1), - nested_rref, + f, args=(worker_name(dst_rank2),), ) @@ -2322,11 +2710,26 @@ class RpcTest(RpcAgentTestFixture): rrefs = rref_of_rrefs.to_here() self.assertEqual(len(rrefs), 2) - self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) - self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2) + self.assertEqual(rrefs[0].to_here(), expected1) + self.assertEqual(rrefs[1].to_here(), expected2) @dist_init - def test_nested_rref_stress(self): + def test_nested_rref(self): + self._nested_rref( + nested_rref, + torch.ones(2, 2) + 1, + torch.ones(2, 2) + 2 + ) + + @dist_init + def test_nested_rref_sparse(self): + self._nested_rref( + nested_rref_sparse, + build_sparse_tensor() * 2, + build_sparse_tensor() * 2 + ) + + def _nested_rref_stress(self, f, expected1, expected2): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size @@ -2335,7 +2738,7 @@ class RpcTest(RpcAgentTestFixture): all_rrefs.append( rpc.remote( worker_name(dst_rank1), - nested_rref, + f, args=(worker_name(dst_rank2),), ) ) @@ -2344,8 +2747,24 @@ class RpcTest(RpcAgentTestFixture): rref_of_rrefs = all_rrefs[i] rrefs = rref_of_rrefs.to_here() self.assertEqual(len(rrefs), 2) - self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) - self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2) + self.assertEqual(rrefs[0].to_here(), expected1) + self.assertEqual(rrefs[1].to_here(), expected2) + + @dist_init + def test_nested_rref_stress(self): + self._nested_rref_stress( + nested_rref, + torch.ones(2, 2) + 1, + torch.ones(2, 2) + 2 + ) + + @dist_init + def test_nested_rref_stress_sparse(self): + self._nested_rref_stress( + nested_rref_sparse, + build_sparse_tensor() * 2, + build_sparse_tensor() * 2 + ) @dist_init def test_multi_layer_nested_async_rpc(self): @@ -4110,6 +4529,46 @@ class RpcTest(RpcAgentTestFixture): dist.barrier() + def _trainer_func(self, rref, sparse): + m = MyEmbeddingBagModel(sparse=sparse) + loss_fn = nn.MSELoss() + for i in range(10): + outputs = m(torch.rand(10, 10).long()) + loss_fn(outputs, torch.rand(10, 10)).backward() + gradient = list(m.parameters())[0].grad + fut = rref.rpc_async().average(rref, i, gradient) + gradient = fut.wait() + if gradient.is_sparse: + gradient = gradient.to_dense().double() + ps_gradient = rref.rpc_sync().get_gradient(rref) + if ps_gradient.is_sparse: + ps_gradient = ps_gradient.to_dense().double() + self.assertTrue(torch.equal(gradient, ps_gradient)) + + def _my_parameter_server(self, sparse): + ps_rref = RRef(MyParameterServer(self.world_size - 1)) + futures = [] + for index in range(1, self.world_size): + futures.append( + rpc.rpc_async( + worker_name((self.rank + index) % self.world_size), + self._trainer_func, + args=( + ps_rref, + sparse + ), + ) + ) + torch.futures.wait_all(futures) + + @dist_init + def test_my_parameter_server(self): + self._my_parameter_server(False) + + @dist_init + def test_my_parameter_server_sparse(self): + self._my_parameter_server(True) + class CudaRpcTest(RpcAgentTestFixture): -- 2.7.4