From b12f34e8c2ae0a183abc48e65815480bf4c44fbe Mon Sep 17 00:00:00 2001 From: Garrett Cramer Date: Thu, 2 Sep 2021 16:11:10 -0700 Subject: [PATCH] update rpc tensorpipe logic for sparse tensors (#62960) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62960 A bug was filed a few years ago for sending sparse tensor over rpc #30807. This pr updates rpc/tensorpipe logic for CUDA sparse tensors. During the serialization process, the pickler.cpp implementation breaks down the sparse tensor into two tensors and metadata. torch/csrc/distributed/rpc/tensorpipe_agent.cpp needs to be updated because it does not have logic sparse tensors. It pushes a single device for a sparse tensor. This is wrong because after the sparse tensor has been serialized, there will be two tensors. The second tensor will not have a device. This will cause the second tensor to have the wrong target device. tensorpipe_utils.cpp needs to be updated because deserialization happens after the data is received on the target pipe. This takes the two tensors and metadata sent and rebuilds the sparse tensor. There will be two tpDescriptors but only one tensor after deserialization. The logic is updated to verify the sparse tensor is on the correct device using the first tpDescriptor. This pr also updates ivalue.cpp and ivalue.h to support more paths for Sparse COO tensors. I tested these changes by adding sparse tests to rpc_test.py and dist_autograd_test.py. Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D30717285 Pulled By: gcramer23 fbshipit-source-id: daee9a56764550f56b131f9dd8e74e23113d6714 --- aten/src/ATen/core/ivalue.cpp | 41 ++- aten/src/ATen/core/ivalue.h | 9 +- torch/csrc/distributed/rpc/tensorpipe_agent.cpp | 9 +- torch/csrc/distributed/rpc/tensorpipe_utils.cpp | 10 +- .../distributed/rpc/dist_autograd_test.py | 312 +++++++++++---------- .../testing/_internal/distributed/rpc/rpc_test.py | 98 ++++--- 6 files changed, 273 insertions(+), 206 deletions(-) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 1404e01..b81c50f 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -946,36 +946,25 @@ getClassConverter() { } // Needs to be in this .cpp file to access the full definition of PyObjectHolder -std::vector> ivalue::Future::extractStorages( - const at::IValue& value) { +std::vector> ivalue::Future:: + extractStorages(const at::IValue& value) { std::vector> weakStorageImpls; // getSubValues works poorly on Python objects: it only works if they can be // converted to a "regular" IValue type hence, for example, it doesn't support // custom subclasses. Thus, instead, we extract the tensors through pickling. + // Sparse tensors do not have storage. Instead, a sparse tensor + // contains two tensors indices and values, and both contain storage. if (value.isPyObject()) { std::vector tensors = value.toPyObjectHolder()->extractTensors(); - size_t num_storages = 0; - for (const at::Tensor& tensor : tensors) { + weakStorageImpls.reserve(2 * tensors.size()); + for (const auto& tensor : tensors) { if (tensor.is_sparse()) { - // Sparse tensor is indices and values. Both are tensors - // and contain storage. Therefore num_storages needs to be - // incremented by 2. - num_storages += 2; + weakStorageImpls.push_back( + tensor._indices().storage().getWeakStorageImpl()); + weakStorageImpls.push_back( + tensor._values().storage().getWeakStorageImpl()); } else { - // A dense/strided tensor contains 1 storage. - num_storages += 1; - } - } - weakStorageImpls.reserve(num_storages); - for (const at::Tensor& tensor : tensors) { - if (tensor.is_sparse()) { - // Sparse tensor is indices and values. Both are tensors - // and contain storage. - weakStorageImpls.push_back(tensor.indices().storage().getWeakStorageImpl()); - weakStorageImpls.push_back(tensor.values().storage().getWeakStorageImpl()); - } else { - // A dense/strided tensor contains 1 storage weakStorageImpls.push_back(tensor.storage().getWeakStorageImpl()); } } @@ -986,7 +975,15 @@ std::vector> ivalue::Future::extractSt value.getSubValues(sub_values); for (const at::IValue& sub_value : sub_values) { if (sub_value.isTensor()) { - weakStorageImpls.push_back(sub_value.toTensor().storage().getWeakStorageImpl()); + auto& tensor = sub_value.toTensor(); + if (tensor.is_sparse()) { + weakStorageImpls.push_back( + tensor._indices().storage().getWeakStorageImpl()); + weakStorageImpls.push_back( + tensor._values().storage().getWeakStorageImpl()); + } else { + weakStorageImpls.push_back(tensor.storage().getWeakStorageImpl()); + } } } } diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 188a619..6574187 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -872,14 +872,17 @@ struct TORCH_API IValue final { struct HashAliasedIValue { size_t operator()(const IValue& val) const { if (val.isTensor()) { - if (val.toTensor().is_mkldnn()) { + auto& tensor = val.toTensor(); + if (tensor.is_mkldnn() || tensor.is_sparse()) { // MKLDNN tensors dont have storage and dont create views // or aliasing so we can just use Tensor pointer, TODO: find way // to use mkldnn storage - return reinterpret_cast(val.toTensor().unsafeGetTensorImpl()); + // Sparse tensors don't have storage use unsafeGetTensorImpl + // instead of using the storage of indices or values. + return reinterpret_cast(tensor.unsafeGetTensorImpl()); } else { return reinterpret_cast( - val.toTensor().storage().unsafeGetStorageImpl()); + tensor.storage().unsafeGetStorageImpl()); } } // If it is not a Tensor, then two mutable IValues alias each other only diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 8e7ad18..3769db0 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -48,7 +48,7 @@ std::vector getDevicesForTensors( "Request device mapping is not available for destination ", remoteName); std::vector devices; - devices.reserve(tensors.size()); + devices.reserve(2 * tensors.size()); bool hasMappedDevice = false; for (const auto& t : tensors) { if (t.device().is_cpu()) { @@ -67,7 +67,12 @@ std::vector getDevicesForTensors( " for device ", t.device(), " but received a tensor on that device."); - devices.push_back(deviceIter->second); + if (t.is_sparse()) { + devices.push_back(deviceIter->second); + devices.push_back(deviceIter->second); + } else { + devices.push_back(deviceIter->second); + } hasMappedDevice = true; } } diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index ee66f31..aa21fdf 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -311,8 +311,9 @@ c10::intrusive_ptr tensorpipeDeserialize( tensors.emplace_back(std::move(t)); } - for (const auto i : c10::irange(tpDescriptor.tensors.size())) { - auto& tensor = tpDescriptor.tensors[i]; + size_t tpDescriptorIndex = 0; + for (size_t i = 0; i < tensors.size(); i++) { + auto& tensor = tpDescriptor.tensors[tpDescriptorIndex]; if (tensor.targetDevice.has_value() && tensor.targetDevice->type == tensorpipe::kCudaDeviceType) { TORCH_INTERNAL_ASSERT( @@ -326,6 +327,11 @@ c10::intrusive_ptr tensorpipeDeserialize( ", but got it on ", tensors[i].device()); } + if (tensors[i].is_sparse()) { + tpDescriptorIndex += 2; + } else { + tpDescriptorIndex += 1; + } } return c10::make_intrusive( diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index fba5030..2ba25a5 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -78,14 +78,20 @@ def create_tensor(): return torch.ones((3, 3), requires_grad=True) -def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32): +def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32, device=None): i = [[0, 1, 1], [2, 0, 2]] v = [3.2, 4.1, 5.3] - tensor = torch.sparse_coo_tensor(i, v, (3, 3), requires_grad=requires_grad, dtype=dtype) + tensor = torch.sparse_coo_tensor(i, v, (3, 3), requires_grad=requires_grad, dtype=dtype, device=device) if coalesce: tensor = tensor.coalesce() return tensor +def build_sparse_one_gradient(dtype=torch.float32): + i = [[0, 1, 1], [2, 0, 2]] + v = [1, 1, 1] + tensor = torch.sparse_coo_tensor(i, v, (3, 3), dtype=dtype) + return tensor + @torch.jit.script def create_torchscript_tensor() -> torch.Tensor: @@ -104,6 +110,9 @@ def my_rref_add(rref_t1, t2): ret = torch.add(rref_t1.local_value(), t2) return ret +def my_sum(t): + return torch.sparse.sum(t) if t.is_sparse else t.sum() + @torch.jit.script def my_script_add(t1, t2): @@ -159,13 +168,10 @@ def _all_contexts_cleaned_up(timeout_seconds=10): # This function creates a dis atugorad context, run rpc_sync on the given ps, # and then blocks until the ps has verified the grads are correctly accumulated. -def _run_trainer(rref_t1, t2, ps, rank_diff, sparse): +def _run_trainer(rref_t1, t2, ps, rank_diff): with dist_autograd.context() as context_id: ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2)) - if sparse: - loss = torch.sparse.sum(ret) - else: - loss = ret.sum() + loss = my_sum(ret) dist_autograd.backward(context_id, [loss]) # prevent deleting dist autograd context rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) @@ -173,13 +179,10 @@ def _run_trainer(rref_t1, t2, ps, rank_diff, sparse): # This function is the same as _run_trainer, except rpc calls torchscript # function "my_script_ref_add" instead of python funciton "my_rref_add" -def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse): +def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff): with dist_autograd.context() as context_id: ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2)) - if sparse: - loss = torch.sparse.sum(ret) - else: - loss = ret.sum() + loss = my_sum(ret) dist_autograd.backward(context_id, [loss]) # prevent deleting dist autograd context rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) @@ -990,25 +993,19 @@ class DistAutogradTest(CommonDistAutogradTest): def _backward_no_grad_on_tensor(self, t1, t2, sparse): with dist_autograd.context() as context_id: - loss = rpc.rpc_sync( + ret = rpc.rpc_sync( worker_name(self._next_rank()), torch.add, args=(t1, t2)) - if sparse: - loss = torch.sparse.sum(loss) - else: - loss = loss.sum() + loss = my_sum(ret) dist_autograd.backward(context_id, [loss], retain_graph=True) self.assertIsNone(t1.grad) self.assertIsNone(t2.grad) # Now populate .grad with local autograd engine and # verify dist autograd doesn't mess with it. - loss_local = torch.add(t1, t2) - if sparse: - loss_local = torch.sparse.sum(loss_local) - else: - loss_local = loss_local.sum() + ret = torch.add(t1, t2) + loss_local = my_sum(ret) loss_local.backward() self.assertIsNotNone(t1.grad) self.assertIsNotNone(t2.grad) @@ -1043,10 +1040,7 @@ class DistAutogradTest(CommonDistAutogradTest): ret = self._exec_func_with_dst( dst, exec_mode, torch.add, t1, t2 ) - if sparse: - loss = torch.sparse.sum(ret) - else: - loss = ret.sum() + loss = my_sum(ret) ret = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @@ -1099,10 +1093,7 @@ class DistAutogradTest(CommonDistAutogradTest): # tensor lives on the rref owner. def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse): local_ret = torch.add(t1, t2) - if sparse: - local_ret = torch.sparse.sum(local_ret) - else: - local_ret = local_ret.sum() + local_ret = my_sum(local_ret) local_ret.backward() with dist_autograd.context() as context_id: if sparse: @@ -1120,10 +1111,7 @@ class DistAutogradTest(CommonDistAutogradTest): callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2) ) ret = rref.to_here() - if sparse: - ret = torch.sparse.sum(ret) - else: - ret = ret.sum() + ret = my_sum(ret) dist_autograd.backward(context_id, [ret]) # verify grads on caller @@ -1238,10 +1226,7 @@ class DistAutogradTest(CommonDistAutogradTest): t2 = torch.zeros((3, 3), requires_grad=True) local_ret = torch.add(t1, t2) - if sparse: - torch.sparse.sum(local_ret).backward() - else: - local_ret.sum().backward() + my_sum(local_ret).backward() # create rref on self rref_t1 = rpc.remote( @@ -1257,7 +1242,7 @@ class DistAutogradTest(CommonDistAutogradTest): rpc.rpc_async( worker_name((self.rank + rank_diff) % self.world_size), trainer_fn, - args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse), + args=(rref_t1, t2, worker_name(self.rank), rank_diff), ) ) @@ -1309,7 +1294,7 @@ class DistAutogradTest(CommonDistAutogradTest): self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False) - def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse): + def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads): for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: with dist_autograd.context() as context_id: # Multiple RPCs between different nodes. @@ -1317,7 +1302,7 @@ class DistAutogradTest(CommonDistAutogradTest): val = self._exec_func(exec_mode, torch.mul, t3, val) s1 = self._exec_func(exec_mode, torch.stack, (t4, val)) s2 = self._exec_func(exec_mode, torch.stack, (t5, val)) - if sparse: + if s1.is_sparse: val = self._exec_func(exec_mode, torch.mul, s1, s2) val = self._exec_func(exec_mode, torch.mul, val, val) loss = torch.sparse.sum(val) @@ -1339,8 +1324,7 @@ class DistAutogradTest(CommonDistAutogradTest): torch.rand((3, 3), requires_grad=True), torch.rand((3, 3)), torch.rand((3, 3), requires_grad=True), - None, - False + None ) @dist_init @@ -1351,8 +1335,7 @@ class DistAutogradTest(CommonDistAutogradTest): build_sparse_tensor(requires_grad=True), build_sparse_tensor(requires_grad=False), build_sparse_tensor(requires_grad=True), - None, - True + None ) @dist_init @@ -1589,15 +1572,12 @@ class DistAutogradTest(CommonDistAutogradTest): exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2 ) - def _backward_different_dtypes(self, t1, t2, sparse): + def _backward_different_dtypes(self, t1, t2): local_grads = None for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: with dist_autograd.context() as context_id: loss = self._exec_func(exec_mode, torch.add, t1, t2) - if sparse: - loss = torch.sparse.sum(loss) - else: - loss = loss.sum() + loss = my_sum(loss) local_grads = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @@ -1606,29 +1586,24 @@ class DistAutogradTest(CommonDistAutogradTest): def test_backward_different_dtypes(self): self._backward_different_dtypes( torch.rand((3, 3), requires_grad=True, dtype=torch.float32), - torch.rand((3, 3), requires_grad=True, dtype=torch.float64), - False + torch.rand((3, 3), requires_grad=True, dtype=torch.float64) ) @dist_init def test_backward_different_dtypes_sparse(self): self._backward_different_dtypes( build_sparse_tensor(requires_grad=True, dtype=torch.float32), - build_sparse_tensor(requires_grad=True, dtype=torch.float64), - True + build_sparse_tensor(requires_grad=True, dtype=torch.float64) ) # Run the same code locally and with dist autograd and verify gradients # are same. - def _backward_simple_python_udf(self, t1, t2, sparse): + def _backward_simple_python_udf(self, t1, t2): local_grads = None for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: with dist_autograd.context() as context_id: ret = self._exec_func(exec_mode, my_py_add, t1, t2) - if sparse: - loss = torch.sparse.sum(ret) - else: - loss = ret.sum() + loss = my_sum(ret) local_grads = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @@ -1637,21 +1612,19 @@ class DistAutogradTest(CommonDistAutogradTest): def test_backward_simple_python_udf(self): self._backward_simple_python_udf( torch.rand(3, 3, requires_grad=True), - torch.rand(3, 3, requires_grad=True), - False + torch.rand(3, 3, requires_grad=True) ) @dist_init def test_backward_simple_python_udf_sparse(self): self._backward_simple_python_udf( build_sparse_tensor(requires_grad=True), - build_sparse_tensor(requires_grad=True), - True + build_sparse_tensor(requires_grad=True) ) # Run the same code locally and with dist autograd and verify gradients # are same. - def _backward_simple_script_call(self, t1, t2, sparse): + def _backward_simple_script_call(self, t1, t2): local_grads = None for exec_mode in [ ExecMode.LOCAL, @@ -1661,10 +1634,7 @@ class DistAutogradTest(CommonDistAutogradTest): ]: with dist_autograd.context() as context_id: forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2) - if sparse: - loss = torch.sparse.sum(forward_ret) - else: - loss = forward_ret.sum() + loss = my_sum(forward_ret) ret = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @@ -1674,16 +1644,14 @@ class DistAutogradTest(CommonDistAutogradTest): def test_backward_simple_script_call(self): self._backward_simple_script_call( torch.rand(3, 3, requires_grad=True), - torch.rand(3, 3, requires_grad=True), - False + torch.rand(3, 3, requires_grad=True) ) @dist_init def test_backward_simple_script_call_sparse(self): self._backward_simple_script_call( build_sparse_tensor(requires_grad=True), - build_sparse_tensor(requires_grad=True), - True + build_sparse_tensor(requires_grad=True) ) @staticmethod @@ -1796,28 +1764,22 @@ class DistAutogradTest(CommonDistAutogradTest): res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4)) return t1 * t2 * t3 * t4 * res - def _backwards_nested_python_udf(self, t1, t2, sparse): + def _backwards_nested_python_udf(self, t1, t2): t3 = t1 * t2 t4 = t1 + t2 res = t3 + t4 - loss = t1 * t2 * t3 * t4 * res - if sparse: - loss = torch.sparse.sum(loss) - else: - loss = loss.sum() + ret = t1 * t2 * t3 * t4 * res + loss = my_sum(ret) torch.autograd.backward([loss]) # Now run distributed autograd. with dist_autograd.context() as context_id: - loss = rpc.rpc_sync( + ret = rpc.rpc_sync( worker_name(self._next_rank()), DistAutogradTest._nested_python_udf, args=(t1, t2, self._next_rank()), ) - if sparse: - loss = torch.sparse.sum(loss) - else: - loss = loss.sum() + loss = my_sum(ret) dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) self.assertEqual(t1.grad, grads[t1]) @@ -1828,8 +1790,7 @@ class DistAutogradTest(CommonDistAutogradTest): # Run equivalent of _nested_python_udf locally. self._backwards_nested_python_udf( torch.rand(3, 3, requires_grad=True), - torch.rand(3, 3, requires_grad=True), - False + torch.rand(3, 3, requires_grad=True) ) @dist_init @@ -1837,8 +1798,7 @@ class DistAutogradTest(CommonDistAutogradTest): # Run equivalent of _nested_python_udf locally. self._backwards_nested_python_udf( build_sparse_tensor(requires_grad=True), - build_sparse_tensor(requires_grad=True), - True + build_sparse_tensor(requires_grad=True) ) _test_clean_context_backward_context_id = None @@ -1986,17 +1946,14 @@ class DistAutogradTest(CommonDistAutogradTest): else: return t1 * t2 - def _mixed_requires_grad(self, t1, t2, sparse): + def _mixed_requires_grad(self, t1, t2): for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]: with dist_autograd.context() as context_id: ret = self._exec_func( exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2 ) self.assertEqual(t1 * t2, ret) - if sparse: - loss = torch.sparse.sum(ret) - else: - loss = ret.sum() + loss = my_sum(ret) dist_autograd.backward(context_id, [loss]) self.assertTrue(t1.requires_grad) self.assertFalse(t2.requires_grad) @@ -2009,16 +1966,14 @@ class DistAutogradTest(CommonDistAutogradTest): def test_mixed_requires_grad(self): self._mixed_requires_grad( torch.rand(3, 3, requires_grad=True), - torch.rand(3, 3, requires_grad=False), - False + torch.rand(3, 3, requires_grad=False) ) @dist_init def test_mixed_requires_grad_sparse(self): self._mixed_requires_grad( build_sparse_tensor(requires_grad=True), - build_sparse_tensor(requires_grad=False), - True + build_sparse_tensor(requires_grad=False) ) class TestDebugInfoFunc(Function): @@ -2160,17 +2115,14 @@ class DistAutogradTest(CommonDistAutogradTest): def _test_nested_backward_accumulate_grads(t1, t2, dst_rank): return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) - def _nested_backward_accumulate_grads(self, t1, t2, sparse): + def _nested_backward_accumulate_grads(self, t1, t2): with dist_autograd.context() as context_id: ret = rpc.rpc_sync( worker_name(self._next_rank()), DistAutogradTest._test_nested_backward_accumulate_grads, args=(t1, t2, self._next_rank()), ) - if sparse: - loss = torch.sparse.sum(ret) - else: - loss = ret.sum() + loss = my_sum(ret) # Run backward twice. dist_autograd.backward(context_id, [loss], retain_graph=True) dist_autograd.backward(context_id, [loss]) @@ -2179,28 +2131,23 @@ class DistAutogradTest(CommonDistAutogradTest): def test_nested_backward_accumulate_grads(self): self._nested_backward_accumulate_grads( torch.rand(3, 3, requires_grad=True), - torch.rand(3, 3, requires_grad=True), - False + torch.rand(3, 3, requires_grad=True) ) @dist_init def test_nested_backward_accumulate_grads_sparse(self): self._nested_backward_accumulate_grads( build_sparse_tensor(requires_grad=True), - build_sparse_tensor(requires_grad=True), - True + build_sparse_tensor(requires_grad=True) ) - def _multiple_backward(self, t1, t2, sparse): + def _multiple_backward(self, t1, t2): with dist_autograd.context() as context_id: - loss = rpc.rpc_sync( + ret = rpc.rpc_sync( worker_name(self._next_rank()), torch.add, args=(t1, t2)) - if sparse: - loss = torch.sparse.sum(loss) - else: - loss = loss.sum() + loss = my_sum(ret) # Run backward in a loop multiple times. for i in range(1000): dist_autograd.backward(context_id, [loss], retain_graph=True) @@ -2209,16 +2156,14 @@ class DistAutogradTest(CommonDistAutogradTest): def test_multiple_backward(self): self._multiple_backward( torch.rand(3, 3, requires_grad=True), - torch.rand(3, 3, requires_grad=True), - False + torch.rand(3, 3, requires_grad=True) ) @dist_init def test_multiple_backward_sparse(self): self._multiple_backward( build_sparse_tensor(requires_grad=True), - build_sparse_tensor(requires_grad=True), - True + build_sparse_tensor(requires_grad=True) ) @dist_init(clean_shutdown=False) @@ -2524,15 +2469,13 @@ class DistAutogradTest(CommonDistAutogradTest): class CudaDistAutogradTest(CommonDistAutogradTest): - @skip_if_lt_x_gpu(1) - @dist_init - def test_gpu_simple(self): - t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") - t2 = torch.rand(3, 3, requires_grad=True, device="cuda:0") - (t1 + t2).sum().backward() + + def _gpu_simple(self, t1, t2): + my_sum(t1 + t2).backward() with dist_autograd.context() as context_id: t3 = t1 + t2 - dist_autograd.backward(context_id, [t3.sum()]) + loss = my_sum(t3) + dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) self.assertEqual(2, len(grads)) self.assertEqual(t1.grad, grads[t1]) @@ -2540,9 +2483,22 @@ class CudaDistAutogradTest(CommonDistAutogradTest): @skip_if_lt_x_gpu(1) @dist_init - def test_gpu_to_cpu_continuation(self): - t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") - t2 = torch.rand(3, 3, requires_grad=True) + def test_gpu_simple(self): + self._gpu_simple( + torch.rand(3, 3, requires_grad=True, device="cuda:0"), + torch.rand(3, 3, requires_grad=True, device="cuda:0") + ) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_simple_sparse(self): + self._gpu_simple( + build_sparse_tensor(requires_grad=True, device="cuda:0"), + build_sparse_tensor(requires_grad=True, device="cuda:0") + ) + + + def _gpu_to_cpu_continuation(self, t1, t2): # Run a few iterations. for i in range(3): t1.grad = None @@ -2557,16 +2513,29 @@ class CudaDistAutogradTest(CommonDistAutogradTest): t6 = t5.cuda(0) + t4 t7 = self._exec_func(exec_mode, torch.add, t6.cpu(), t5) # Autograd graph consists of CPU -> GPU -> CPU execution. + loss = my_sum(t7) ret = self._verify_backwards( - exec_mode, [t7.sum()], context_id, local_grads, t1, t2 + exec_mode, [loss], context_id, local_grads, t1, t2 ) local_grads = ret if ret else local_grads @skip_if_lt_x_gpu(1) @dist_init - def test_gpu_to_cpu_continuation_gpu_root(self): - t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") - t2 = torch.rand(3, 3, requires_grad=True) + def test_gpu_to_cpu_continuation(self): + self._gpu_to_cpu_continuation( + torch.rand(3, 3, requires_grad=True, device="cuda:0"), + torch.rand(3, 3, requires_grad=True) + ) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_to_cpu_continuation_sparse(self): + self._gpu_to_cpu_continuation( + build_sparse_tensor(requires_grad=True, device="cuda:0"), + build_sparse_tensor(requires_grad=True) + ) + + def _gpu_to_cpu_continuation_gpu_root(self, t1, t2): # Run a few iterations. for i in range(3): t1.grad = None @@ -2580,11 +2549,28 @@ class CudaDistAutogradTest(CommonDistAutogradTest): t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2) t6 = t5.cuda(0) + t4 # Autograd graph consists of CPU -> GPU -> CPU execution. + loss = my_sum(t6) ret = self._verify_backwards( - exec_mode, [t6.sum()], context_id, local_grads, t1, t2 + exec_mode, [loss], context_id, local_grads, t1, t2 ) local_grads = ret if ret else local_grads + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_to_cpu_continuation_gpu_root(self): + self._gpu_to_cpu_continuation_gpu_root( + torch.rand(3, 3, requires_grad=True, device="cuda:0"), + torch.rand(3, 3, requires_grad=True) + ) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_to_cpu_continuation_gpu_root_sparse(self): + self._gpu_to_cpu_continuation_gpu_root( + build_sparse_tensor(requires_grad=True, device="cuda:0"), + build_sparse_tensor(requires_grad=True) + ) + class FaultyAgentDistAutogradTest(RpcAgentTestFixture): # Reusing a simplified helper function from DistAutogradTest to ensure @@ -2646,8 +2632,7 @@ class WrapperModule(nn.Module): class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): - @skip_if_lt_x_gpu(4) - def test_device_maps_backward_pass(self): + def _device_maps_backward_pass(self, t1, t2): options = self.rpc_backend_options dst = worker_name((self.rank + 1) % self.world_size) @@ -2662,19 +2647,36 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): rpc_backend_options=options, ) - t1 = torch.rand(10, device=self.rank, requires_grad=True) - t2 = torch.rand(10, device=self.rank, requires_grad=True) with dist_autograd.context() as context_id: res = rpc.rpc_sync(dst, torch.add, args=(t1, t2)) - dist_autograd.backward(context_id, [res.sum()]) + loss = my_sum(res) + dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) - self.assertEqual(torch.ones(10), grads[t1]) - self.assertEqual(torch.ones(10), grads[t2]) + if t1.is_sparse: + self.assertEqual(build_sparse_one_gradient(), grads[t1]) + self.assertEqual(build_sparse_one_gradient(), grads[t2]) + else: + self.assertEqual(torch.ones(10), grads[t1]) + self.assertEqual(torch.ones(10), grads[t2]) self.assertEqual(t1.device, grads[t1].device) self.assertEqual(t2.device, grads[t2].device) rpc.shutdown() + @skip_if_lt_x_gpu(4) + def test_device_maps_backward_pass(self): + self._device_maps_backward_pass( + torch.rand(10, requires_grad=True, device=self.rank), + torch.ones(10, requires_grad=True, device=self.rank) + ) + + @skip_if_lt_x_gpu(4) + def test_device_maps_backward_pass_sparse(self): + self._device_maps_backward_pass( + build_sparse_tensor(requires_grad=True, device=self.rank), + build_sparse_tensor(requires_grad=True, device=self.rank) + ) + class MyRemoteCompute(torch.nn.Module): def __init__(self): super().__init__() @@ -2691,9 +2693,7 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): def forward(self, input): return self.next_stage.rpc_sync().forward(input) - @skip_if_lt_x_gpu(4) - def test_dist_autograd_sync_streams(self): - + def _dist_autograd_sync_streams(self, sparse): options = self.rpc_backend_options dst = worker_name((self.rank + 1) % self.world_size) @@ -2711,17 +2711,20 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): remote_compute = rpc.remote(dst, TensorPipeCudaDistAutogradTest.MyRemoteCompute) local_compute = TensorPipeCudaDistAutogradTest.MyLocalCompute(remote_compute) for _ in range(10): - input = torch.rand([1000, 10000], device=self.rank, requires_grad=True) + if sparse: + input = build_sparse_tensor(requires_grad=True, device=self.rank) + else: + input = torch.rand([1000, 10000], device=self.rank, requires_grad=True) # Run local autograd result = input * 2.0 r = random.random() - loss = result.sum() * r + loss = my_sum(result) * r loss.backward() # Run distributed autograd with dist_autograd.context() as context_id: result = local_compute(input) - loss = result.sum() * r + loss = my_sum(result) * r dist_autograd.backward(context_id, [loss]) # Compare grads. @@ -2731,7 +2734,14 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): rpc.shutdown() @skip_if_lt_x_gpu(4) - def test_gradients_synchronizations(self): + def test_dist_autograd_sync_streams(self): + self._dist_autograd_sync_streams(False) + + @skip_if_lt_x_gpu(4) + def test_dist_autograd_sync_streams_sparse(self): + self._dist_autograd_sync_streams(True) + + def _gradients_synchronizations(self, x): options = self.rpc_backend_options for peer_rank in range(self.world_size): options.set_device_map(worker_name(peer_rank), {self.rank: peer_rank}) @@ -2755,8 +2765,8 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): WrapperModule, args=(layers[rank - 1], rank) )) + x = x.to(0) - x = torch.randn(5000, 2000).to(0) # local iteration local_model = nn.Sequential(*local_layers) local_model(x).sum().backward() @@ -2778,3 +2788,15 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): self.assertEqual(g1, g2) rpc.shutdown() + + @skip_if_lt_x_gpu(4) + def test_gradients_synchronizations(self): + self._gradients_synchronizations( + torch.randn(5000, 2000) + ) + + @skip_if_lt_x_gpu(4) + def test_gradients_synchronizations_sparse(self): + self._gradients_synchronizations( + torch.randn(5000, 2000).to_sparse() + ) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index e0ef915..23759f1 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -194,6 +194,14 @@ class MyClass: return torch.add(self.a, my_tensor_arg) +def _run_func_in_mode(to, fn, mode, args=None, kwargs=None): + if mode == RPCExecMode.SYNC: + return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs) + elif mode == RPCExecMode.ASYNC: + return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait() + elif mode == RPCExecMode.REMOTE: + return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here() + def _call_method_on_rref(method, rref, *args, **kwargs): return method(rref.local_value(), *args, **kwargs) @@ -736,7 +744,7 @@ class RpcTest(RpcAgentTestFixture): # Test dense tensor for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: - ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + ret = _run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) self.assertEqual(ret, torch.ones(2, 2) + 1) # Test sparse tensor @@ -744,32 +752,32 @@ class RpcTest(RpcAgentTestFixture): x = build_sparse_tensor() y = build_sparse_tensor() expected_tensor = (x + y) - ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) + ret = _run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) self.assertEqual(expected_tensor, ret) for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: x = build_sparse_tensor(coalesce=True) y = build_sparse_tensor(coalesce=True) expected_tensor = (x + y) - ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) + ret = _run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) self.assertEqual(expected_tensor, ret) # Test invalid ranks for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(RuntimeError): - self._run_func_in_mode(self.world_size + 1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + _run_func_in_mode(self.world_size + 1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(RuntimeError): - self._run_func_in_mode(-1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + _run_func_in_mode(-1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(ValueError): - self._run_func_in_mode(dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + _run_func_in_mode(dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(ValueError): - self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + _run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) def _self_py_udf_remote(self, worker_info, x, y, z): rref = rpc.remote(worker_info, my_function, args=(x, y, z)) @@ -4025,17 +4033,9 @@ class RpcTest(RpcAgentTestFixture): def test_future_nested_callback(self): self._test_future_cb(add_use_future_nested_cb) - def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None): - if mode == RPCExecMode.SYNC: - return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs) - elif mode == RPCExecMode.ASYNC: - return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait() - elif mode == RPCExecMode.REMOTE: - return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here() - def _test_async_function_raise(self, mode): with self.assertRaisesRegex(RuntimeError, "Expected error"): - self._run_func_in_mode( + _run_func_in_mode( worker_name((self.rank + 1) % self.world_size), async_raise_func, mode @@ -4059,7 +4059,7 @@ class RpcTest(RpcAgentTestFixture): "torch\\.futures\\.Future object," ) with self.assertRaisesRegex(RuntimeError, errMsg): - self._run_func_in_mode( + _run_func_in_mode( worker_name((self.rank + 1) % self.world_size), async_wrong_type, mode @@ -4090,7 +4090,7 @@ class RpcTest(RpcAgentTestFixture): dst2 = worker_name((self.rank + 2) % self.world_size) args = (dst2, torch.ones(2, 2), 1, 2) - ret = self._run_func_in_mode(dst1, fn, mode, args=args) + ret = _run_func_in_mode(dst1, fn, mode, args=args) self.assertEqual(ret, torch.ones(2, 2) + 3) @dist_init @@ -4183,7 +4183,7 @@ class RpcTest(RpcAgentTestFixture): num = 20 step = 3 args = (dst2, torch.ones(2, 2), num, step) - ret = self._run_func_in_mode(dst1, fn, mode, args=args) + ret = _run_func_in_mode(dst1, fn, mode, args=args) self.assertEqual(ret, torch.ones(2, 2) + num * step) @dist_init @@ -4227,7 +4227,7 @@ class RpcTest(RpcAgentTestFixture): RuntimeError, "Can not pickle torch.futures.Future" ): - self._run_func_in_mode( + _run_func_in_mode( worker_name((self.rank + 1) % self.world_size), return_future, mode @@ -5217,13 +5217,33 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture): rpc_backend_options=options, ) - ret = rpc.rpc_sync( - dst, - TensorPipeAgentCudaRpcTest._gpu_add, - args=(torch.zeros(2).to(0), torch.ones(2).to(0)) - ) - self.assertEqual(ret.device, torch.device(1)) - self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1)) + # Test dense tensor + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = torch.ones(2, 2) + y = torch.ones(2, 2) + expected_tensor = (x + y) + ret = _run_func_in_mode(dst, TensorPipeAgentCudaRpcTest._gpu_add, exec_mode, args=(x.to(0), y.to(0))) + self.assertEqual(ret.device, torch.device(1)) + self.assertEqual(ret, expected_tensor.to(1)) + + # Test sparse tensor uncoalesced + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor() + y = build_sparse_tensor() + expected_tensor = (x + y) + ret = _run_func_in_mode(dst, TensorPipeAgentCudaRpcTest._gpu_add, exec_mode, args=(x.to(0), y.to(0))) + self.assertEqual(ret.device, torch.device(1)) + self.assertEqual(ret, expected_tensor.to(1)) + + # Test sparse tensor coalesced + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor().coalesce() + y = build_sparse_tensor().coalesce() + expected_tensor = (x + y) + ret = _run_func_in_mode(dst, TensorPipeAgentCudaRpcTest._gpu_add, exec_mode, args=(x.to(0), y.to(0))) + self.assertEqual(ret.device, torch.device(1)) + self.assertEqual(ret, expected_tensor.to(1)) + rpc.shutdown() @staticmethod @@ -5722,8 +5742,7 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture): def test_device_maps_missing_config_remote_response(self): self._test_device_maps_missing_config_response(RPCExecMode.REMOTE) - @skip_if_lt_x_gpu(2) - def test_device_maps_remote(self): + def _device_maps_remote(self, x, y, expected): options = self.rpc_backend_options dst = worker_name((self.rank + 1) % self.world_size) options.set_device_map(dst, {1: 0}) @@ -5739,14 +5758,29 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture): rref = rpc.remote( dst, TensorPipeAgentCudaRpcTest._add_to_gpu, - args=(torch.zeros(2), 1) + args=(x, y) ) - self.assertEqual(rref.to_here().device.index, 1) - self.assertEqual(rref.to_here(), torch.ones(2).to(1)) + self.assertEqual(rref.to_here(), expected.to(1)) rpc.shutdown() + @skip_if_lt_x_gpu(2) + def test_device_maps_remote(self): + self._device_maps_remote( + torch.ones(3, 3), + torch.ones(3, 3), + torch.ones(3, 3) + torch.ones(3, 3) + ) + + @skip_if_lt_x_gpu(2) + def test_device_maps_remote_sparse(self): + self._device_maps_remote( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + build_sparse_tensor() + ) + @staticmethod def _slow_add_on_user_stream(x, y): s0 = torch.cuda.current_stream(x.device) -- 2.7.4