# rref tensor equals to the given grad.
def _compare_owner_value(context_id, rref, grad):
grads = dist_autograd.get_gradients(context_id)
- return torch.equal(grads[rref.local_value()], grad)
+ x = grads[rref.local_value()]
+ if x.is_sparse:
+ assert grad.is_sparse
+ x = x.to_dense()
+ grad = grad.to_dense()
+ else:
+ assert not grad.is_sparse
+ return torch.equal(x, grad)
def create_tensor():
return torch.ones((3, 3), requires_grad=True)
+def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32):
+ i = [[0, 1, 1], [2, 0, 2]]
+ v = [3.2, 4.1, 5.3]
+ tensor = torch.sparse_coo_tensor(i, v, (3, 3), requires_grad=requires_grad, dtype=dtype)
+ if coalesce:
+ tensor = tensor.coalesce()
+ return tensor
+
+
@torch.jit.script
def create_torchscript_tensor() -> torch.Tensor:
return torch.ones((3, 3)).requires_grad_()
# This function creates a dis atugorad context, run rpc_sync on the given ps,
# and then blocks until the ps has verified the grads are correctly accumulated.
-def _run_trainer(rref_t1, t2, ps, rank_diff):
+def _run_trainer(rref_t1, t2, ps, rank_diff, sparse):
with dist_autograd.context() as context_id:
ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
- dist_autograd.backward(context_id, [ret.sum()])
+ if sparse:
+ loss = torch.sparse.sum(ret)
+ else:
+ loss = ret.sum()
+ dist_autograd.backward(context_id, [loss])
# prevent deleting dist autograd context
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
# This function is the same as _run_trainer, except rpc calls torchscript
# function "my_script_ref_add" instead of python funciton "my_rref_add"
-def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff):
+def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse):
with dist_autograd.context() as context_id:
ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
- dist_autograd.backward(context_id, [ret.sum()])
+ if sparse:
+ loss = torch.sparse.sum(ret)
+ else:
+ loss = ret.sum()
+ dist_autograd.backward(context_id, [loss])
# prevent deleting dist autograd context
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
"torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
)
- def _test_graph(self, fn, exec_mode):
+ def _test_graph(self, fn, exec_mode, sparse):
dst_rank = (self.rank + 1) % self.world_size
initialize_pg(self.file_init_method, self.rank, self.world_size)
with dist_autograd.context() as context_id:
- t1 = torch.ones(3, 3, requires_grad=True)
- t2 = torch.zeros(3, 3, requires_grad=True)
+ if sparse:
+ t1 = build_sparse_tensor()
+ t2 = build_sparse_tensor()
+ else:
+ t1 = torch.ones(3, 3, requires_grad=True)
+ t2 = torch.zeros(3, 3, requires_grad=True)
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2))
elif ExecMode.REMOTE == exec_mode:
@dist_init
def test_graph_for_builtin_call(self):
- self._test_graph(torch.add, ExecMode.RPC_SYNC)
+ self._test_graph(torch.add, ExecMode.RPC_SYNC, False)
+
+ @dist_init
+ def test_graph_for_builtin_call_sparse(self):
+ self._test_graph(torch.add, ExecMode.RPC_SYNC, True)
@dist_init
def test_graph_for_python_call(self):
- self._test_graph(my_py_add, ExecMode.RPC_SYNC)
+ self._test_graph(my_py_add, ExecMode.RPC_SYNC, False)
+
+ @dist_init
+ def test_graph_for_python_call_sparse(self):
+ self._test_graph(my_py_add, ExecMode.RPC_SYNC, True)
@dist_init
def test_graph_for_builtin_remote_call(self):
- self._test_graph(torch.add, ExecMode.REMOTE)
+ self._test_graph(torch.add, ExecMode.REMOTE, False)
+
+ @dist_init
+ def test_graph_for_builtin_remote_call_sparse(self):
+ self._test_graph(torch.add, ExecMode.REMOTE, True)
@dist_init
def test_graph_for_python_remote_call(self):
- self._test_graph(my_py_add, ExecMode.REMOTE)
+ self._test_graph(my_py_add, ExecMode.REMOTE, False)
+
+ @dist_init
+ def test_graph_for_python_remote_call_sparse(self):
+ self._test_graph(my_py_add, ExecMode.REMOTE, True)
# 3-layer nested calls
- def _test_graph_for_py_nested_call(self, exec_mode):
+ def _test_graph_for_py_nested_call(self, exec_mode, sparse):
dst_rank = (self.rank + 1) % self.world_size
initialize_pg(self.file_init_method, self.rank, self.world_size)
with dist_autograd.context() as context_id:
- t1 = torch.ones(3, 3, requires_grad=True)
- t2 = torch.zeros(3, 3, requires_grad=True)
+ if sparse:
+ t1 = build_sparse_tensor(requires_grad=True)
+ t2 = build_sparse_tensor(requires_grad=True)
+ else:
+ t1 = torch.ones(3, 3, requires_grad=True)
+ t2 = torch.zeros(3, 3, requires_grad=True)
nest_dst_rank = (dst_rank + 1) % self.world_size
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(
@dist_init
def test_graph_for_py_nested_call(self):
- self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC)
+ self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False)
+
+ @dist_init
+ def test_graph_for_py_nested_call_sparse(self):
+ self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True)
@dist_init
def test_graph_for_py_nested_remote_call(self):
- self._test_graph_for_py_nested_call(ExecMode.REMOTE)
+ self._test_graph_for_py_nested_call(ExecMode.REMOTE, False)
+
+ @dist_init
+ def test_graph_for_py_nested_remote_call_sparse(self):
+ self._test_graph_for_py_nested_call(ExecMode.REMOTE, True)
# Rank0->Rank1->Rank0
- def _test_graph_for_py_nested_call_itself(self, exec_mode):
+ def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse):
dst_rank = (self.rank + 1) % self.world_size
initialize_pg(self.file_init_method, self.rank, self.world_size)
with dist_autograd.context() as context_id:
- t1 = torch.ones(3, 3, requires_grad=True)
- t2 = torch.zeros(3, 3, requires_grad=True)
+ if sparse:
+ t1 = build_sparse_tensor(requires_grad=True)
+ t2 = build_sparse_tensor(requires_grad=True)
+ else:
+ t1 = torch.ones(3, 3, requires_grad=True)
+ t2 = torch.zeros(3, 3, requires_grad=True)
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(
worker_name(dst_rank),
@dist_init
def test_graph_for_py_nested_call_itself(self):
- self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC)
+ self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False)
+
+ @dist_init
+ def test_graph_for_py_nested_call_itself_sparse(self):
+ self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True)
@dist_init
def test_graph_for_py_nested_remote_call_itself(self):
- self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE)
+ self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False)
+
+ @dist_init
+ def test_graph_for_py_nested_remote_call_itself_sparse(self):
+ self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True)
- def _test_no_graph_with_tensors_not_require_grad(self, exec_mode):
+ def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse):
initialize_pg(self.file_init_method, self.rank, self.world_size)
dst_rank = (self.rank + 1) % self.world_size
with dist_autograd.context() as context_id:
- t1 = torch.ones(3, 3, requires_grad=False)
- t2 = torch.zeros(3, 3, requires_grad=False)
+ if sparse:
+ t1 = build_sparse_tensor(requires_grad=False)
+ t2 = build_sparse_tensor(requires_grad=False)
+ else:
+ t1 = torch.ones(3, 3, requires_grad=False)
+ t2 = torch.zeros(3, 3, requires_grad=False)
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(
worker_name(dst_rank), torch.add, args=(t1, t2)
@dist_init
def test_no_graph_with_tensors_not_require_grad(self):
- self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC)
+ self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False)
+
+ @dist_init
+ def test_no_graph_with_tensors_not_require_grad_sparse(self):
+ self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True)
@dist_init
def test_no_graph_with_tensors_not_require_grad_remote(self):
- self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE)
+ self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False)
+
+ @dist_init
+ def test_no_graph_with_tensors_not_require_grad_remote_sparse(self):
+ self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True)
def _test_grad_only_on_return_value(self, exec_mode):
initialize_pg(self.file_init_method, self.rank, self.world_size)
def test_grad_only_on_return_value_remote(self):
self._test_grad_only_on_return_value(ExecMode.REMOTE)
- def _test_rpc_complex_args(self, exec_mode):
+ def _test_rpc_complex_args(self, exec_mode, sparse):
with dist_autograd.context() as context_id:
num_tensors = 10
tensors = []
for i in range(num_tensors):
- tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
-
+ if sparse:
+ tensor = build_sparse_tensor(requires_grad=(i % 2 == 0))
+ else:
+ tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0))
+ tensors.append(tensor)
dst_rank = self._next_rank()
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(
@dist_init
def test_rpc_complex_args(self):
- self._test_rpc_complex_args(ExecMode.RPC_SYNC)
+ self._test_rpc_complex_args(ExecMode.RPC_SYNC, False)
+
+ @dist_init
+ def test_rpc_complex_args_sparse(self):
+ self._test_rpc_complex_args(ExecMode.RPC_SYNC, True)
@dist_init
def test_remote_complex_args(self):
- self._test_rpc_complex_args(ExecMode.REMOTE)
+ self._test_rpc_complex_args(ExecMode.REMOTE, False)
+
+ @dist_init
+ def test_remote_complex_args_sparse(self):
+ self._test_rpc_complex_args(ExecMode.REMOTE, True)
def context_cleanup_test_helper(self, rpc_args, func, nested=False):
initialize_pg(self.file_init_method, self.rank, self.world_size)
self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
@dist_init
+ def test_context_cleanup_tensor_with_grad_sparse(self):
+ t1 = build_sparse_tensor(requires_grad=True)
+ t2 = build_sparse_tensor(requires_grad=True)
+ self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+ @dist_init
def test_context_cleanup_tensor_no_grad(self):
t1 = torch.ones(3, 3, requires_grad=False)
self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
@dist_init
+ def test_context_cleanup_tensor_no_grad_sparse(self):
+ t1 = build_sparse_tensor(requires_grad=False)
+ self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
+
+ @dist_init
def test_context_cleanup_no_tensors(self):
self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add)
)
@dist_init
+ def test_context_cleanup_nested_rpc_sparse(self):
+ t1 = build_sparse_tensor(requires_grad=True)
+ t2 = build_sparse_tensor(requires_grad=True)
+ dst_rank = (self.rank + 1) % self.world_size
+ args = (t1, t2, dst_rank, self.world_size, 0)
+ self.context_cleanup_test_helper(
+ rpc_args=args, func=my_py_nested_call, nested=True
+ )
+
+ @dist_init
def test_worker_ids_recorded(self):
dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
with dist_autograd.context() as context_id:
worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
)
- @dist_init
- def test_backward_no_grad_on_tensor(self):
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=True)
+ def _backward_no_grad_on_tensor(self, t1, t2, sparse):
with dist_autograd.context() as context_id:
loss = rpc.rpc_sync(
worker_name(self._next_rank()),
torch.add,
- args=(t1, t2)).sum()
-
+ args=(t1, t2))
+ if sparse:
+ loss = torch.sparse.sum(loss)
+ else:
+ loss = loss.sum()
dist_autograd.backward(context_id, [loss], retain_graph=True)
self.assertIsNone(t1.grad)
self.assertIsNone(t2.grad)
# Now populate .grad with local autograd engine and
# verify dist autograd doesn't mess with it.
- loss_local = torch.add(t1, t2).sum()
+ loss_local = torch.add(t1, t2)
+ if sparse:
+ loss_local = torch.sparse.sum(loss_local)
+ else:
+ loss_local = loss_local.sum()
loss_local.backward()
self.assertIsNotNone(t1.grad)
self.assertIsNotNone(t2.grad)
self.assertEqual(t1_grad_before, t1.grad)
self.assertEqual(t2_grad_before, t2.grad)
- def _test_backward_simple(self, dst):
- # Run the same code locally and with dist autograd and verify gradients
- # are same.
- local_grads = None
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=True)
+ @dist_init
+ def test_backward_no_grad_on_tensor(self):
+ self._backward_no_grad_on_tensor(
+ torch.rand((3, 3), requires_grad=True),
+ torch.rand((3, 3), requires_grad=True),
+ False
+ )
+
+ @dist_init
+ def test_backward_no_grad_on_tensor_sparse(self):
+ self._backward_no_grad_on_tensor(
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ True
+ )
+
+ # Run the same code locally and with dist autograd and verify gradients
+ # are same.
+ def _backward_simple(self, dst, t1, t2, local_grads, sparse):
for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
with dist_autograd.context() as context_id:
ret = self._exec_func_with_dst(
dst, exec_mode, torch.add, t1, t2
)
- loss = ret.sum()
+ if sparse:
+ loss = torch.sparse.sum(ret)
+ else:
+ loss = ret.sum()
ret = self._verify_backwards(
exec_mode, [loss], context_id, local_grads, t1, t2
)
@dist_init
def test_backward_simple(self):
- self._test_backward_simple(self._next_rank())
+ self._backward_simple(
+ self._next_rank(),
+ torch.rand((3, 3), requires_grad=True),
+ torch.rand((3, 3), requires_grad=True),
+ None,
+ False
+ )
+
+ @dist_init
+ def test_backward_simple_sparse(self):
+ self._backward_simple(
+ self._next_rank(),
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ None,
+ True
+ )
@dist_init
def test_backward_simple_self(self):
- self._test_backward_simple(self.rank)
+ self._backward_simple(
+ self.rank,
+ torch.rand((3, 3), requires_grad=True),
+ torch.rand((3, 3), requires_grad=True),
+ None,
+ False
+ )
+
+ @dist_init
+ def test_backward_simple_self_sparse(self):
+ self._backward_simple(
+ self.rank,
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ None,
+ True
+ )
# The current rank first creates a tensor on the rref_owner, and then passes
# the rref with another tensor to the callee to run either my_rref_add or
# my_nested_rref_add, depending on whether the callee is the rref owner.
# The grad of tensor lives on the current rank, and the grad of the rref
# tensor lives on the rref owner.
- def _test_backward_rref(self, callee, rref_owner):
- local_grads = None
- t1 = torch.ones((3, 3), requires_grad=True)
- t2 = torch.zeros((3, 3), requires_grad=True)
-
+ def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse):
local_ret = torch.add(t1, t2)
- local_ret.sum().backward()
+ if sparse:
+ local_ret = torch.sparse.sum(local_ret)
+ else:
+ local_ret = local_ret.sum()
+ local_ret.backward()
with dist_autograd.context() as context_id:
- rref_t1 = rpc.remote(
- rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True}
- )
-
+ if sparse:
+ rref_t1 = rpc.remote(
+ rref_owner, build_sparse_tensor, args=(False, True,)
+ )
+ else:
+ rref_t1 = rpc.remote(
+ rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True}
+ )
if callee == rref_owner:
rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
else:
callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
)
ret = rref.to_here()
- dist_autograd.backward(context_id, [ret.sum()])
+ if sparse:
+ ret = torch.sparse.sum(ret)
+ else:
+ ret = ret.sum()
+ dist_autograd.backward(context_id, [ret])
# verify grads on caller
grads = dist_autograd.get_gradients(context_id)
def test_backward_rref(self):
callee = worker_name(self._next_rank())
rref_owner = callee
- self._test_backward_rref(callee, rref_owner)
+ self._backward_rref(
+ callee,
+ rref_owner,
+ torch.rand((3, 3), requires_grad=True),
+ torch.rand((3, 3), requires_grad=True),
+ None,
+ False
+ )
+
+ @dist_init
+ def test_backward_rref_sparse(self):
+ callee = worker_name(self._next_rank())
+ rref_owner = callee
+ self._backward_rref(
+ callee,
+ rref_owner,
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ None,
+ True
+ )
@dist_init
def test_backward_rref_multi(self):
if self.rank > 0:
callee = "worker0"
rref_owner = callee
- self._test_backward_rref(callee, rref_owner)
+ self._backward_rref(
+ callee,
+ rref_owner,
+ torch.rand((3, 3), requires_grad=True),
+ torch.rand((3, 3), requires_grad=True),
+ None,
+ False
+ )
+
+ @dist_init
+ def test_backward_rref_multi_sparse(self):
+ if self.rank > 0:
+ callee = "worker0"
+ rref_owner = callee
+ self._backward_rref(
+ callee,
+ rref_owner,
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ None,
+ True
+ )
@dist_init
def test_backward_rref_nested(self):
callee = worker_name((self.rank + 1) % self.world_size)
rref_owner = worker_name((self.rank + 2) % self.world_size)
- self._test_backward_rref(callee, rref_owner)
+ self._backward_rref(
+ callee,
+ rref_owner,
+ torch.rand((3, 3), requires_grad=True),
+ torch.rand((3, 3), requires_grad=True),
+ None,
+ False
+ )
+
+ @dist_init
+ def test_backward_rref_nested_sparse(self):
+ callee = worker_name((self.rank + 1) % self.world_size)
+ rref_owner = worker_name((self.rank + 2) % self.world_size)
+ self._backward_rref(
+ callee,
+ rref_owner,
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ None,
+ True
+ )
# In this test, every rank will serve as a parameter server (ps) and a
# driver, and then kicks off trainers on the other three ranks. So, we have:
#
# These four test ps-trainer groups run on completely separate autograd
# graphs, but they share the same set of underlying RpcAgents.
- def _test_trainer_ps(self, create_ref_fn, trainer_fn):
- local_grads = None
- t1 = torch.ones((3, 3), requires_grad=True)
- t2 = torch.zeros((3, 3), requires_grad=True)
+ def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse):
+ if sparse:
+ t1 = build_sparse_tensor(requires_grad=True)
+ t2 = build_sparse_tensor(requires_grad=True)
+ else:
+ t1 = torch.ones((3, 3), requires_grad=True)
+ t2 = torch.zeros((3, 3), requires_grad=True)
local_ret = torch.add(t1, t2)
- local_ret.sum().backward()
+ if sparse:
+ torch.sparse.sum(local_ret).backward()
+ else:
+ local_ret.sum().backward()
# create rref on self
rref_t1 = rpc.remote(
rpc.rpc_async(
worker_name((self.rank + rank_diff) % self.world_size),
trainer_fn,
- args=(rref_t1, t2, worker_name(self.rank), rank_diff),
+ args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse),
)
)
@dist_init
def test_trainer_ps(self):
- self._test_trainer_ps(create_tensor, _run_trainer)
+ self._test_trainer_ps(
+ create_tensor,
+ _run_trainer,
+ False
+ )
+
+ @dist_init
+ def test_trainer_ps_sparse(self):
+ self._test_trainer_ps(
+ build_sparse_tensor,
+ _run_trainer,
+ True
+ )
@dist_init
def test_trainer_ps_torchscript_functions(self):
import torch.distributed.rpc.api as api
api._ignore_rref_leak = True
- self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript)
-
- @dist_init
- def test_backward_multiple_round_trips(self):
- local_grads = None
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3))
- t3 = torch.rand((3, 3), requires_grad=True)
- t4 = torch.rand((3, 3))
- t5 = torch.rand((3, 3), requires_grad=True)
+ self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False)
+ def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse):
for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
with dist_autograd.context() as context_id:
# Multiple RPCs between different nodes.
val = self._exec_func(exec_mode, torch.mul, t3, val)
s1 = self._exec_func(exec_mode, torch.stack, (t4, val))
s2 = self._exec_func(exec_mode, torch.stack, (t5, val))
- val = self._exec_func(exec_mode, torch.bmm, s1, s2)
- val = self._exec_func(exec_mode, torch.matmul, val, val)
- loss = val.sum()
+ if sparse:
+ val = self._exec_func(exec_mode, torch.mul, s1, s2)
+ val = self._exec_func(exec_mode, torch.mul, val, val)
+ loss = torch.sparse.sum(val)
+ else:
+ val = self._exec_func(exec_mode, torch.bmm, s1, s2)
+ val = self._exec_func(exec_mode, torch.matmul, val, val)
+ loss = val.sum()
ret = self._verify_backwards(
exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5
local_grads = ret if ret else local_grads
@dist_init
+ def test_backward_multiple_round_trips(self):
+ self._backward_multiple_round_trips(
+ torch.rand((3, 3), requires_grad=True),
+ torch.rand((3, 3)),
+ torch.rand((3, 3), requires_grad=True),
+ torch.rand((3, 3)),
+ torch.rand((3, 3), requires_grad=True),
+ None,
+ False
+ )
+
+ @dist_init
+ def test_backward_multiple_round_trips_sparse(self):
+ self._backward_multiple_round_trips(
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=False),
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=False),
+ build_sparse_tensor(requires_grad=True),
+ None,
+ True
+ )
+
+ @dist_init
def test_backward_different_tensor_dims(self):
local_grads = None
t1 = torch.rand((4, 6), requires_grad=True)
exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2
)
- @dist_init
- def test_backward_different_dtypes(self):
+ def _backward_different_dtypes(self, t1, t2, sparse):
local_grads = None
- t1 = torch.rand((3, 3), requires_grad=True, dtype=torch.float32)
- t2 = torch.rand((3, 3), requires_grad=True, dtype=torch.float64)
for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
with dist_autograd.context() as context_id:
- loss = self._exec_func(exec_mode, torch.add, t1, t2).sum()
-
+ loss = self._exec_func(exec_mode, torch.add, t1, t2)
+ if sparse:
+ loss = torch.sparse.sum(loss)
+ else:
+ loss = loss.sum()
local_grads = self._verify_backwards(
exec_mode, [loss], context_id, local_grads, t1, t2
)
@dist_init
- def test_backward_simple_python_udf(self):
- # Run the same code locally and with dist autograd and verify gradients
- # are same.
+ def test_backward_different_dtypes(self):
+ self._backward_different_dtypes(
+ torch.rand((3, 3), requires_grad=True, dtype=torch.float32),
+ torch.rand((3, 3), requires_grad=True, dtype=torch.float64),
+ False
+ )
+
+ @dist_init
+ def test_backward_different_dtypes_sparse(self):
+ self._backward_different_dtypes(
+ build_sparse_tensor(requires_grad=True, dtype=torch.float32),
+ build_sparse_tensor(requires_grad=True, dtype=torch.float64),
+ True
+ )
+
+ # Run the same code locally and with dist autograd and verify gradients
+ # are same.
+ def _backward_simple_python_udf(self, t1, t2, sparse):
local_grads = None
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=True)
for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
with dist_autograd.context() as context_id:
ret = self._exec_func(exec_mode, my_py_add, t1, t2)
- loss = ret.sum()
+ if sparse:
+ loss = torch.sparse.sum(ret)
+ else:
+ loss = ret.sum()
local_grads = self._verify_backwards(
exec_mode, [loss], context_id, local_grads, t1, t2
)
@dist_init
- def test_backward_simple_script_call(self):
- # Run the same code locally and with dist autograd and verify gradients
- # are same.
+ def test_backward_simple_python_udf(self):
+ self._backward_simple_python_udf(
+ torch.rand(3, 3, requires_grad=True),
+ torch.rand(3, 3, requires_grad=True),
+ False
+ )
+
+ @dist_init
+ def test_backward_simple_python_udf_sparse(self):
+ self._backward_simple_python_udf(
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ True
+ )
+
+ # Run the same code locally and with dist autograd and verify gradients
+ # are same.
+ def _backward_simple_script_call(self, t1, t2, sparse):
local_grads = None
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=True)
for exec_mode in [
ExecMode.LOCAL,
ExecMode.RPC_SYNC,
]:
with dist_autograd.context() as context_id:
forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2)
- loss = forward_ret.sum()
+ if sparse:
+ loss = torch.sparse.sum(forward_ret)
+ else:
+ loss = forward_ret.sum()
ret = self._verify_backwards(
exec_mode, [loss], context_id, local_grads, t1, t2
)
local_grads = ret if ret else local_grads
+ @dist_init
+ def test_backward_simple_script_call(self):
+ self._backward_simple_script_call(
+ torch.rand(3, 3, requires_grad=True),
+ torch.rand(3, 3, requires_grad=True),
+ False
+ )
+
+ @dist_init
+ def test_backward_simple_script_call_sparse(self):
+ self._backward_simple_script_call(
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ True
+ )
+
@staticmethod
def _complex_python_udf(t1, t2):
t3 = torch.nn.functional.linear(t1, t2)
t3 = t1 * t2
t4 = t1 + t2
res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4))
- return torch.linalg.multi_dot([t1, t2, t3, t4, res])
+ return t1 * t2 * t3 * t4 * res
- @dist_init
- def test_backwards_nested_python_udf(self):
- # Run equivalent of _nested_python_udf locally.
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=True)
+ def _backwards_nested_python_udf(self, t1, t2, sparse):
t3 = t1 * t2
t4 = t1 + t2
res = t3 + t4
- loss = torch.linalg.multi_dot([t1, t2, t3, t4, res]).sum()
+ loss = t1 * t2 * t3 * t4 * res
+ if sparse:
+ loss = torch.sparse.sum(loss)
+ else:
+ loss = loss.sum()
torch.autograd.backward([loss])
# Now run distributed autograd.
DistAutogradTest._nested_python_udf,
args=(t1, t2, self._next_rank()),
)
- dist_autograd.backward(context_id, [loss.sum()])
-
+ if sparse:
+ loss = torch.sparse.sum(loss)
+ else:
+ loss = loss.sum()
+ dist_autograd.backward(context_id, [loss])
grads = dist_autograd.get_gradients(context_id)
self.assertEqual(t1.grad, grads[t1])
self.assertEqual(t2.grad, grads[t2])
+ @dist_init
+ def test_backwards_nested_python_udf(self):
+ # Run equivalent of _nested_python_udf locally.
+ self._backwards_nested_python_udf(
+ torch.rand(3, 3, requires_grad=True),
+ torch.rand(3, 3, requires_grad=True),
+ False
+ )
+
+ @dist_init
+ def test_backwards_nested_python_udf_sparse(self):
+ # Run equivalent of _nested_python_udf locally.
+ self._backwards_nested_python_udf(
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ True
+ )
+
_test_clean_context_backward_context_id = None
class MyBackwardFunc(Function):
def _get_grad(cls, embedding_rref, context_id):
embedding = embedding_rref.local_value()
grad_map = dist_autograd.get_gradients(context_id)
- # Can't send sparse tensors over RPC: https://github.com/pytorch/pytorch/issues/30807
- return grad_map[embedding.weight].to_dense()
+ return grad_map[embedding.weight]
@dist_init
def test_embedding_bag_with_no_grad_tensors(self):
args=(remote_embedding, context_id),
)
- self.assertEqual(local_grad.to_dense(), remote_grad)
+ self.assertEqual(local_grad, remote_grad)
@classmethod
- def _mixed_requires_grad(cls, t1, t2):
+ def _mixed_requires_grad_operaton(cls, t1, t2):
if t2.requires_grad:
return t1 - t2
else:
return t1 * t2
- @dist_init
- def test_mixed_requires_grad(self):
+ def _mixed_requires_grad(self, t1, t2, sparse):
for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]:
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=False)
with dist_autograd.context() as context_id:
ret = self._exec_func(
- exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2
+ exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2
)
self.assertEqual(t1 * t2, ret)
- dist_autograd.backward(context_id, [ret.sum()])
+ if sparse:
+ loss = torch.sparse.sum(ret)
+ else:
+ loss = ret.sum()
+ dist_autograd.backward(context_id, [loss])
self.assertTrue(t1.requires_grad)
self.assertFalse(t2.requires_grad)
grads = dist_autograd.get_gradients(context_id)
self.assertNotIn(t2, grads)
self.assertEqual(t2, grads[t1])
+ @dist_init
+ def test_mixed_requires_grad(self):
+ self._mixed_requires_grad(
+ torch.rand(3, 3, requires_grad=True),
+ torch.rand(3, 3, requires_grad=False),
+ False
+ )
+
+ @dist_init
+ def test_mixed_requires_grad_sparse(self):
+ self._mixed_requires_grad(
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=False),
+ True
+ )
+
class TestDebugInfoFunc(Function):
@staticmethod
def forward(ctx, input):
@staticmethod
def _test_nested_backward_accumulate_grads(t1, t2, dst_rank):
- return rpc.rpc_sync(worker_name(dst_rank), torch.matmul, args=(t1, t2))
+ return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
- @dist_init
- def test_nested_backward_accumulate_grads(self):
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=True)
+ def _nested_backward_accumulate_grads(self, t1, t2, sparse):
with dist_autograd.context() as context_id:
- loss = rpc.rpc_sync(
+ ret = rpc.rpc_sync(
worker_name(self._next_rank()),
DistAutogradTest._test_nested_backward_accumulate_grads,
args=(t1, t2, self._next_rank()),
- ).sum()
-
+ )
+ if sparse:
+ loss = torch.sparse.sum(ret)
+ else:
+ loss = ret.sum()
# Run backward twice.
dist_autograd.backward(context_id, [loss], retain_graph=True)
dist_autograd.backward(context_id, [loss])
@dist_init
- def test_multiple_backward(self):
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=True)
+ def test_nested_backward_accumulate_grads(self):
+ self._nested_backward_accumulate_grads(
+ torch.rand(3, 3, requires_grad=True),
+ torch.rand(3, 3, requires_grad=True),
+ False
+ )
+
+ @dist_init
+ def test_nested_backward_accumulate_grads_sparse(self):
+ self._nested_backward_accumulate_grads(
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ True
+ )
+
+ def _multiple_backward(self, t1, t2, sparse):
with dist_autograd.context() as context_id:
loss = rpc.rpc_sync(
worker_name(self._next_rank()),
torch.add,
- args=(t1, t2)).sum()
-
+ args=(t1, t2))
+ if sparse:
+ loss = torch.sparse.sum(loss)
+ else:
+ loss = loss.sum()
# Run backward in a loop multiple times.
for i in range(1000):
dist_autograd.backward(context_id, [loss], retain_graph=True)
+ @dist_init
+ def test_multiple_backward(self):
+ self._multiple_backward(
+ torch.rand(3, 3, requires_grad=True),
+ torch.rand(3, 3, requires_grad=True),
+ False
+ )
+
+ @dist_init
+ def test_multiple_backward_sparse(self):
+ self._multiple_backward(
+ build_sparse_tensor(requires_grad=True),
+ build_sparse_tensor(requires_grad=True),
+ True
+ )
+
@dist_init(clean_shutdown=False)
def test_multiple_backward_with_errors(self):
initialize_pg(self.file_init_method, self.rank, self.world_size)
def run_nested_pickle(pickle_cls_instance, tensor):
return pickle_cls_instance.t + tensor
-def build_sparse_tensor():
+def build_sparse_tensor(coalesce=False):
i = [[0, 1, 1], [2, 0, 2]]
v = [3, 4, 5]
- return torch.sparse_coo_tensor(i, v, (2, 3))
+ tensor = torch.sparse_coo_tensor(i, v, (2, 3))
+ if coalesce:
+ tensor = tensor.coalesce()
+ return tensor
def build_complex_tensors():
a = torch.ones(3, 3)
def my_tensor_function(a, b):
return a + b
+def my_container_sum(a):
+ result = a[0]
+ for tensor in a[1:]:
+ result += tensor
+ return result
+
def my_sleep_func(seconds=1):
time.sleep(seconds)
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
+def nested_rpc_sparse(dst):
+ return rpc.rpc_sync(
+ dst,
+ torch.add,
+ args=(build_sparse_tensor(), build_sparse_tensor())
+ )
+
+
def multi_layer_nested_async_rpc(dst, world_size, ttl):
# this method returns immediately without blocking the callee, but will
# generate additional requests.
)
+def nested_rref_sparse(dst):
+ return (
+ rpc.remote(
+ dst,
+ torch.add,
+ args=(build_sparse_tensor(), build_sparse_tensor())
+ ),
+ rpc.remote(
+ dst,
+ torch.add,
+ args=(build_sparse_tensor(), build_sparse_tensor())
+ ),
+ )
+
+
def nested_remote(dst):
rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
return rref.to_here()
+def nested_remote_sparse(dst):
+ rref = rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor()))
+ return rref.to_here()
+
def rref_forward_chain(dst, world_size, rref, ttl):
if ttl > 0:
return 0
+def heavy_rpc_sparse(tensor):
+ for i in range(1, 100):
+ tensor *= i
+ tensor = tensor / (i + 1)
+ return 0
+
@torch.jit.script
def heavy_rpc_torchscript(tensor):
for i in range(1, 100):
load_tests = load_tests
+class MyEmbeddingBagModel(torch.nn.Module):
+ def __init__(self, sparse):
+ super().__init__()
+ self.eb = torch.nn.EmbeddingBag(
+ 10,
+ 10,
+ sparse=sparse
+ )
+
+ def forward(self, x):
+ return self.eb(x)
+
+
+class MyParameterServer:
+ def __init__(self, trainers):
+ self.lock = Lock()
+ self.trainers = trainers
+ self.iteration = 0
+ self.updates = 0
+ self.futures = []
+ self.total = None
+ self.gradient = None
+
+ @staticmethod
+ def get_gradient(rref):
+ return rref.local_value().gradient
+
+ @staticmethod
+ @rpc.functions.async_execution
+ def average(rref, riteration, tensor):
+ self = rref.local_value()
+ fut = torch.futures.Future()
+ with self.lock:
+ if riteration > self.iteration:
+ self.iteration = riteration
+ self.updates = 0
+ self.futures.clear()
+ self.futures.append(fut)
+ if self.total is None:
+ self.total = tensor
+ else:
+ self.total += tensor
+ self.updates += 1
+ if self.trainers == self.updates:
+ self.gradient = self.total / float(self.trainers)
+ for fut in self.futures:
+ result = self.total / float(self.trainers)
+ fut.set_result(result)
+ return fut
+
+
class RpcTest(RpcAgentTestFixture):
@dist_init
def test_worker_id(self):
def test_send_to_rank(self):
dst_rank = (self.rank + 1) % self.world_size
+ # Test dense tensor
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
self.assertEqual(ret, torch.ones(2, 2) + 1)
+ # Test sparse tensor
+ for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+ x = build_sparse_tensor()
+ y = build_sparse_tensor()
+ expected_tensor = (x + y)
+ ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+ self.assertEqual(expected_tensor, ret)
+
+ for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+ x = build_sparse_tensor(coalesce=True)
+ y = build_sparse_tensor(coalesce=True)
+ expected_tensor = (x + y)
+ ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+ self.assertEqual(expected_tensor, ret)
+
# Test invalid ranks
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
with self.assertRaises(RuntimeError):
with self.assertRaises(ValueError):
self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
+ def _self_py_udf_remote(self, worker_info, x, y, z):
+ rref = rpc.remote(worker_info, my_function, args=(x, y, z))
+ self.assertEqual(rref.to_here(), x + y + z)
+
@dist_init
def test_self_py_udf_remote(self):
- self_worker_info = rpc.get_worker_info()
- rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
- self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1 + 3)
+ self._self_py_udf_remote(
+ rpc.get_worker_info(),
+ torch.ones(2, 2),
+ 1,
+ 3
+ )
+
+ @dist_init
+ def test_self_py_udf_remote_sparse(self):
+ self._self_py_udf_remote(
+ rpc.get_worker_info(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor()
+ )
+
- def _test_self_remote_rref_as_rpc_arg(self, dst):
+ def _self_remote_rref_as_rpc_arg(self, dst, x, y, z):
self_worker_info = rpc.get_worker_info()
- rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
- fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
- ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1))
- self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1)
- self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
+ rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+ fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x))
+ ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y))
+ self.assertEqual(ret, x + y + z + x + y)
+ self.assertEqual(fut.wait(), x + y + z + x)
@dist_init
def test_self_remote_rref_as_rpc_arg(self):
dst = worker_name((self.rank + 1) % self.world_size)
- self._test_self_remote_rref_as_rpc_arg(dst)
+ self._self_remote_rref_as_rpc_arg(
+ dst,
+ torch.ones(2, 2),
+ 1,
+ 3
+ )
+
+ @dist_init
+ def test_self_remote_rref_as_rpc_arg_sparse(self):
+ dst = worker_name((self.rank + 1) % self.world_size)
+ self._self_remote_rref_as_rpc_arg(
+ dst,
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor()
+ )
@dist_init
def test_self_remote_rref_as_self_rpc_arg(self):
- self._test_self_remote_rref_as_rpc_arg(rpc.get_worker_info())
+ self._self_remote_rref_as_rpc_arg(
+ rpc.get_worker_info(),
+ torch.ones(2, 2),
+ 1,
+ 3
+ )
- def _test_self_remote_rref_as_remote_arg(self, dst):
+ @dist_init
+ def test_self_remote_rref_as_self_rpc_arg_sparse(self):
+ self._self_remote_rref_as_rpc_arg(
+ rpc.get_worker_info(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor()
+ )
+
+ def _self_remote_rref_as_remote_arg(self, dst, x, y, z):
self_worker_info = rpc.get_worker_info()
- rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
- ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
+ rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+ ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x))
self.assertEqual(
- ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2)
+ ret_rref.to_here(), x + y + z + x
)
@dist_init
def test_self_remote_rref_as_remote_arg(self):
dst = worker_name((self.rank + 1) % self.world_size)
- self._test_self_remote_rref_as_remote_arg(dst)
+ self._self_remote_rref_as_remote_arg(
+ dst,
+ torch.ones(2, 2),
+ 1,
+ 3
+ )
+
+ @dist_init
+ def test_self_remote_rref_as_remote_arg_sparse(self):
+ dst = worker_name((self.rank + 1) % self.world_size)
+ self._self_remote_rref_as_remote_arg(
+ dst,
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor()
+ )
+
+ @dist_init
+ def test_self_remote_rref_as_self_remote_arg(self):
+ self._self_remote_rref_as_remote_arg(
+ rpc.get_worker_info(),
+ torch.ones(2, 2),
+ 1,
+ 3
+ )
+
+ @dist_init
+ def test_self_remote_rref_as_self_remote_arg_sparse(self):
+ self._self_remote_rref_as_remote_arg(
+ rpc.get_worker_info(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor()
+ )
@dist_init
def test_rref_proxy_non_exist(self):
def test_rref_proxy_class_self(self):
self._test_rref_proxy_class(rpc.get_worker_info())
- @dist_init
- def test_self_remote_rref_as_self_remote_arg(self):
- self._test_self_remote_rref_as_remote_arg(rpc.get_worker_info())
-
@mock.patch.object(torch.distributed.autograd, "_init")
@mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent")
@dist_init(setup_rpc=False)
)
rpc.shutdown()
- def test_world_size_one(self):
+ def _world_size_one(self, a, b):
if self.rank == 0:
rpc.init_rpc(
name="me",
rpc_backend_options=self.rpc_backend_options,
)
- expect = torch.ones(2, 2) * 2
- result = rpc.rpc_sync(
- "me",
- my_tensor_function,
- args=(torch.ones(2, 2), torch.ones(2, 2))
- )
- self.assertEqual(expect, result)
-
- expect = torch.ones(3, 3) * 2
- result = rpc.rpc_async(
- "me",
- my_tensor_function,
- args=(torch.ones(3, 3), torch.ones(3, 3))
- ).wait()
- self.assertEqual(expect, result)
+ def _rpc_sync(x, y):
+ expect = x * 2
+ result = rpc.rpc_sync(
+ "me",
+ my_tensor_function,
+ args=(x, y)
+ )
+ self.assertEqual(expect, result)
+
+ def _rpc_async(x, y):
+ expect = x * 2
+ result = rpc.rpc_async(
+ "me",
+ my_tensor_function,
+ args=(x, y)
+ ).wait()
+ self.assertEqual(expect, result)
+
+ def _remote(x, y):
+ expect = x * 2
+ result = rpc.remote(
+ "me",
+ my_tensor_function,
+ args=(x, y)
+ ).to_here()
+ self.assertEqual(expect, result)
- expect = torch.ones(4, 4) * 2
- result = rpc.remote(
- "me",
- my_tensor_function,
- args=(torch.ones(4, 4), torch.ones(4, 4))
- ).to_here()
- self.assertEqual(expect, result)
+ _rpc_sync(a, b)
+ _rpc_async(a, b)
+ _remote(a, b)
rpc.shutdown()
+ def test_world_size_one(self):
+ self._world_size_one(
+ torch.ones(2, 2),
+ torch.ones(2, 2)
+ )
+
+ def test_world_size_one_sparse(self):
+ self._world_size_one(
+ build_sparse_tensor(),
+ build_sparse_tensor()
+ )
+
@dist_init(setup_rpc=False)
def test_invalid_names(self):
from torch.distributed.rpc import WorkerInfo
ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,))
self.assertEqual(ret, x.nonzero())
- @dist_init
- def test_multi_rpc(self):
+ def _multi_rpc(self, sparse):
dst_rank = (self.rank + 1) % self.world_size
for i in range(20):
n = i + self.rank + 1
+ if sparse:
+ x = build_sparse_tensor() * n
+ y = build_sparse_tensor() * n
+ else:
+ x = torch.ones(2, 2)
+ y = torch.ones(2, 2)
ret = rpc.rpc_sync(
worker_name(dst_rank),
torch.add,
- args=(torch.ones(n, n), torch.ones(n, n)),
+ args=(x, y),
)
- self.assertEqual(ret, torch.ones(n, n) * 2)
+ self.assertEqual(ret, x * 2)
+
+ @dist_init
+ def test_multi_rpc(self):
+ self._multi_rpc(False)
+
+ @dist_init
+ def test_multi_rpc_sparse(self):
+ self._multi_rpc(True)
@dist_init
def test_future_wait_twice(self):
with self.assertRaisesRegex(ValueError, "Expected error"):
fut.wait()
- def _run_uneven_workload(self, num_repeat=30):
+ def _run_uneven_workload(self, f, x, num_repeat=30):
# worker0 drives and waits for worker1 and worker2
# throughout the test.
if self.rank == 0:
dst = "worker1"
futs = []
for _ in range(num_repeat):
- fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
+ fut = rpc.rpc_async(dst, f, args=(x,))
futs.append(fut)
for fut in torch.futures.collect_all(futs).wait():
dst = "worker2"
futs = []
for _ in range(num_repeat):
- fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
+ fut = rpc.rpc_async(dst, f, args=(x,))
futs.append(fut)
for val in torch.futures.wait_all(futs):
self.assertEqual(val, 0)
- def test_wait_all_workers(self):
+ def _wait_all_workers(self, f, x):
initialize_pg(self.file_init_method, self.rank, self.world_size)
rpc.init_rpc(
name="worker%d" % self.rank,
rpc_backend_options=self.rpc_backend_options,
)
- self._run_uneven_workload()
+ self._run_uneven_workload(f, x)
# worker0 calls this at the end after waiting for RPC responses.
# worker1/2 calls this immediately and has some works after it.
dist.barrier()
rpc.shutdown(graceful=False)
- def test_wait_all_workers_twice(self):
+ def test_wait_all_workers_dense(self):
+ self._wait_all_workers(heavy_rpc, torch.ones(100, 100))
+
+ def test_wait_all_workers_sparse(self):
+ self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor())
+
+ def _wait_all_workers_twice(self, f, x):
initialize_pg(self.file_init_method, self.rank, self.world_size)
rpc.init_rpc(
name="worker%d" % self.rank,
rpc_backend_options=self.rpc_backend_options,
)
- self._run_uneven_workload()
+ self._run_uneven_workload(f, x)
# worker0 calls this at the end after waiting for RPC responses.
# worker1/2 calls this immediately and has some works after it.
dist.barrier()
rpc.shutdown(graceful=False)
+ def test_wait_all_workers_twice_dense(self):
+ self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100))
+
+ def test_wait_all_workers_twice_sparse(self):
+ self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor())
+
@dist_init
def test_all_gather(self):
info = rpc.get_worker_info()
@dist_init
def test_graceful_shutdown_with_uneven_workload(self):
"""Test graceful termination."""
- self._run_uneven_workload()
+ self._run_uneven_workload(heavy_rpc, torch.ones(100, 100))
@dist_init(setup_rpc=False)
def test_shutdown_followed_by_rpc(self):
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
@dist_init
+ def test_py_sparse_tensors_in_container(self):
+ n = self.rank + 1
+ dst_rank = n % self.world_size
+ a = [build_sparse_tensor(), build_sparse_tensor()]
+ ret = rpc.rpc_sync(
+ worker_name(dst_rank), my_container_sum, args=(a,)
+ )
+ self.assertEqual(ret, my_container_sum(a))
+
+ @dist_init
def test_py_nested_pickle(self):
n = self.rank + 1
dst_rank = n % self.world_size
else:
self.assertTrue(False, "expected raise_func_escape to raise ValueError.")
- @dist_init
- def test_nested_rpc(self):
+ def _nested_rpc(self, f, expected):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
worker_name(dst_rank),
- nested_rpc,
+ f,
args=(worker_name(self.rank),),
)
- self.assertEqual(ret, torch.ones(2, 2) + 1)
+ self.assertEqual(ret, expected)
+
+ @dist_init
+ def test_nested_rpc(self):
+ self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1)
+
+ @dist_init
+ def test_nested_rpc_sparse(self):
+ self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2)
def _stress_test_rpc(self, f, repeat=1000, args=()):
n = self.rank + 1
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
@dist_init
+ def test_stress_heavy_rpc_sparse(self):
+ self._stress_test_rpc(heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),))
+
+ @dist_init
def test_stress_heavy_rpc_torchscript(self):
self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),))
- @dist_init
- def test_builtin_remote_ret(self):
+ def _builtin_remote_ret(self, x, y, expected):
n = self.rank + 1
dst_rank = n % self.world_size
rref = rpc.remote(
worker_name(dst_rank),
torch.add,
- args=(torch.ones(n, n), torch.ones(n, n)),
+ args=(x, y),
)
- self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
+ self.assertEqual(rref.to_here(), expected)
@dist_init
- def test_builtin_remote_self(self):
+ def test_builtin_remote_ret(self):
+ self._builtin_remote_ret(
+ torch.ones(2, 2),
+ torch.ones(2, 2),
+ torch.ones(2, 2) * 2
+ )
+
+ @dist_init
+ def test_builtin_remote_ret_sparse(self):
+ self._builtin_remote_ret(
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor() * 2
+ )
+
+ def _builtin_remote_self(self, x, y, expected):
rref = rpc.remote(
worker_name(self.rank),
torch.add,
- args=(torch.ones(2, 2), torch.ones(2, 2)),
+ args=(x, y),
+ )
+ self.assertEqual(rref.local_value(), expected)
+
+ @dist_init
+ def test_builtin_remote_self(self):
+ self._builtin_remote_self(
+ torch.ones(2, 2),
+ torch.ones(2, 2),
+ torch.ones(2, 2) * 2
+ )
+
+ @dist_init
+ def test_builtin_remote_self_sparse(self):
+ self._builtin_remote_self(
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor() * 2
)
- self.assertEqual(rref.local_value(), torch.ones(2, 2) * 2)
- def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}):
+ def _test_multi_remote_call(self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}):
m = 10
n = self.rank + 1
dst_rank = n % self.world_size
rpc.remote(
worker_name(dst_rank),
fn,
- args=args_fn(n),
- kwargs=kwargs_fn(n),
+ args=args_fn(n, sparse),
+ kwargs=kwargs_fn(n, sparse),
)
)
- expected.append(fn(*args_fn(n), **kwargs_fn(n)))
+ expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse)))
for i in range(m):
self.assertEqual(rrefs[i].to_here(), expected[i])
+ @staticmethod
+ def _multi_args_fn(n, sparse=False):
+ if sparse:
+ return (build_sparse_tensor(), build_sparse_tensor())
+ else:
+ return (torch.ones(n, n), torch.ones(n, n))
+
@dist_init
def test_multi_builtin_remote_ret(self):
- def args_fn(n):
- return (torch.ones(n, n), torch.ones(n, n))
+ self._test_multi_remote_call(
+ torch.add, False,
+ args_fn=RpcTest._multi_args_fn
+ )
- self._test_multi_remote_call(torch.add, args_fn=args_fn)
+ @dist_init
+ def test_multi_builtin_remote_ret_sparse(self):
+ self._test_multi_remote_call(
+ torch.add, True,
+ args_fn=RpcTest._multi_args_fn
+ )
@dist_init
def test_py_udf_remote(self):
)
self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
- @dist_init
- def test_multi_py_udf_remote(self):
- def kwargs_fn(n):
+ @staticmethod
+ def _multi_kwargs_fn(n, sparse=False):
+ if sparse:
+ return {
+ "a": build_sparse_tensor(),
+ "b": build_sparse_tensor(),
+ "c": build_sparse_tensor()
+ }
+ else:
return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
- self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn)
+ @dist_init
+ def test_multi_py_udf_remote(self):
+ self._test_multi_remote_call(
+ my_function,
+ False,
+ kwargs_fn=RpcTest._multi_kwargs_fn
+ )
@dist_init
- def test_py_rref_args(self):
+ def test_multi_py_udf_remote_sparse(self):
+ self._test_multi_remote_call(
+ my_function,
+ True,
+ kwargs_fn=RpcTest._multi_kwargs_fn
+ )
+
+ def _py_rref_args(self, a, b, x, y, expected):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = rpc.remote(
- worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2)
+ worker_name(dst_rank), torch.add, args=(a, b)
)
rref_b = rpc.remote(
- worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1)
+ worker_name(dst_rank), torch.add, args=(x, y)
)
rref_c = rpc.remote(
worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
)
- self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
+ self.assertEqual(rref_c.to_here(), expected)
@dist_init
- def test_py_rref_args_user_share(self):
+ def test_py_rref_args(self):
+ self._py_rref_args(
+ torch.ones(2, 2),
+ 1,
+ torch.ones(2, 2),
+ 2,
+ torch.ones(2, 2) * 2 + 3)
+
+ @dist_init
+ def test_py_rref_args_sparse(self):
+ self._py_rref_args(
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor() * 4
+ )
+
+ def _py_rref_args_user_share(self, a, b, c, x, y, z, expected):
n = self.rank + 1
owner_rank = n % self.world_size
user_rank = (n + 1) % self.world_size
rref_a = rpc.remote(
- worker_name(owner_rank), my_function, args=(torch.ones(n, n), 2, 0)
+ worker_name(owner_rank), my_function, args=(a, b, c)
)
rref_b = rpc.remote(
- worker_name(owner_rank), my_function, args=(torch.ones(n, n), 1, 0)
+ worker_name(owner_rank), my_function, args=(x, y, z)
)
rref_c = rpc.remote(
worker_name(user_rank), my_rref_function, args=(rref_a, rref_b)
)
- self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
+ self.assertEqual(rref_c.to_here(), expected)
@dist_init
- def test_py_rpc_rref_args(self):
+ def test_py_rref_args_user_share(self):
+ self._py_rref_args_user_share(
+ torch.ones(2, 2),
+ 1,
+ 2,
+ torch.ones(2, 2),
+ 3,
+ 4,
+ torch.ones(2, 2) * 2 + 10
+ )
+
+ @dist_init
+ def test_py_rref_args_user_share_sparse(self):
+ self._py_rref_args_user_share(
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor() * 6
+ )
+
+ def _py_rpc_rref_args(self, a, b, c, x, y, z, expected):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = rpc.remote(
- worker_name(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)
+ worker_name(dst_rank), my_function, args=(a, b, c)
)
rref_b = rpc.remote(
- worker_name(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)
+ worker_name(dst_rank), my_function, args=(x, y, z)
)
c = rpc.rpc_sync(
worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
)
+ self.assertEqual(c, expected)
- self.assertEqual(c, torch.ones(n, n) + 4)
+ @dist_init
+ def test_py_rpc_rref_args(self):
+ self._py_rpc_rref_args(
+ torch.ones(2, 2),
+ 1,
+ 2,
+ torch.ones(2, 2),
+ 3,
+ 4,
+ torch.ones(2, 2) * 2 + 10
+ )
@dist_init
- def test_nested_remote(self):
+ def test_py_rpc_rref_args_sparse(self):
+ self._py_rpc_rref_args(
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor(),
+ build_sparse_tensor() * 6
+ )
+
+ def _nested_remote(self, f, expected):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref = rpc.remote(
worker_name(dst_rank1),
- nested_remote,
+ f,
args=(worker_name(dst_rank2),),
)
- self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
+ self.assertEqual(rref.to_here(), expected)
@dist_init
- def test_nested_rref(self):
+ def test_nested_remote(self):
+ self._nested_remote(
+ nested_remote,
+ torch.ones(2, 2) + 3
+ )
+
+ @dist_init
+ def test_nested_remote_sparse(self):
+ self._nested_remote(
+ nested_remote_sparse,
+ build_sparse_tensor() + build_sparse_tensor()
+ )
+
+ def _nested_rref(self, f, expected1, expected2):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref_of_rrefs = rpc.remote(
worker_name(dst_rank1),
- nested_rref,
+ f,
args=(worker_name(dst_rank2),),
)
rrefs = rref_of_rrefs.to_here()
self.assertEqual(len(rrefs), 2)
- self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
- self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
+ self.assertEqual(rrefs[0].to_here(), expected1)
+ self.assertEqual(rrefs[1].to_here(), expected2)
@dist_init
- def test_nested_rref_stress(self):
+ def test_nested_rref(self):
+ self._nested_rref(
+ nested_rref,
+ torch.ones(2, 2) + 1,
+ torch.ones(2, 2) + 2
+ )
+
+ @dist_init
+ def test_nested_rref_sparse(self):
+ self._nested_rref(
+ nested_rref_sparse,
+ build_sparse_tensor() * 2,
+ build_sparse_tensor() * 2
+ )
+
+ def _nested_rref_stress(self, f, expected1, expected2):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
all_rrefs.append(
rpc.remote(
worker_name(dst_rank1),
- nested_rref,
+ f,
args=(worker_name(dst_rank2),),
)
)
rref_of_rrefs = all_rrefs[i]
rrefs = rref_of_rrefs.to_here()
self.assertEqual(len(rrefs), 2)
- self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
- self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
+ self.assertEqual(rrefs[0].to_here(), expected1)
+ self.assertEqual(rrefs[1].to_here(), expected2)
+
+ @dist_init
+ def test_nested_rref_stress(self):
+ self._nested_rref_stress(
+ nested_rref,
+ torch.ones(2, 2) + 1,
+ torch.ones(2, 2) + 2
+ )
+
+ @dist_init
+ def test_nested_rref_stress_sparse(self):
+ self._nested_rref_stress(
+ nested_rref_sparse,
+ build_sparse_tensor() * 2,
+ build_sparse_tensor() * 2
+ )
@dist_init
def test_multi_layer_nested_async_rpc(self):
dist.barrier()
+ def _trainer_func(self, rref, sparse):
+ m = MyEmbeddingBagModel(sparse=sparse)
+ loss_fn = nn.MSELoss()
+ for i in range(10):
+ outputs = m(torch.rand(10, 10).long())
+ loss_fn(outputs, torch.rand(10, 10)).backward()
+ gradient = list(m.parameters())[0].grad
+ fut = rref.rpc_async().average(rref, i, gradient)
+ gradient = fut.wait()
+ if gradient.is_sparse:
+ gradient = gradient.to_dense().double()
+ ps_gradient = rref.rpc_sync().get_gradient(rref)
+ if ps_gradient.is_sparse:
+ ps_gradient = ps_gradient.to_dense().double()
+ self.assertTrue(torch.equal(gradient, ps_gradient))
+
+ def _my_parameter_server(self, sparse):
+ ps_rref = RRef(MyParameterServer(self.world_size - 1))
+ futures = []
+ for index in range(1, self.world_size):
+ futures.append(
+ rpc.rpc_async(
+ worker_name((self.rank + index) % self.world_size),
+ self._trainer_func,
+ args=(
+ ps_rref,
+ sparse
+ ),
+ )
+ )
+ torch.futures.wait_all(futures)
+
+ @dist_init
+ def test_my_parameter_server(self):
+ self._my_parameter_server(False)
+
+ @dist_init
+ def test_my_parameter_server_sparse(self):
+ self._my_parameter_server(True)
+
class CudaRpcTest(RpcAgentTestFixture):