Tensor type checking and informative error messages for torch.distributed (#14204)
authorTeng Li <tengli@fb.com>
Tue, 20 Nov 2018 02:25:00 +0000 (18:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 20 Nov 2018 02:30:54 +0000 (18:30 -0800)
commitb807970aeaf3bd26068f347a01196b66c0a21bb5
treebe9b6e89d342be91a6d53342448ce7a56ddcf124
parent7d1db89ef9edd3ea844c08d682fd634e0c34d813
Tensor type checking and informative error messages for torch.distributed (#14204)

Summary:
This will address https://github.com/pytorch/pytorch/issues/13574

This error message should be more informative to the user for all the non-multiGPU ops, since we python binding to multi-gpu ops always.

test_distributed should cover all. Also tested both RunTime errors.

```
>>> a = torch.ByteTensor([])
>>> b = [a, a]
>>> dist.all_reduce(b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/private/home/tengli/pytorch/torch/distributed/distributed_c10d.py", line 809, in all_reduce
    _check_single_tensor(tensor, "tensor")
  File "/private/home/tengli/pytorch/torch/distributed/distributed_c10d.py", line 207, in _check_single_tensor
    "to be a torch.Tensor type".format(param_name))
RuntimeError: Invalid function argument. Expecting parameter: tensor to be a torch.Tensor type

>>> b = ["b"]
>>> dist.all_gather(b, a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/private/home/tengli/pytorch/torch/distributed/distributed_c10d.py", line 1006, in all_gather
    _check_tensor_list(tensor_list, "tensor_list")
  File "/private/home/tengli/pytorch/torch/distributed/distributed_c10d.py", line 225, in _check_tensor_list
    "to be a List[torch.Tensor] type".format(param_name))
RuntimeError: Invalid function argument. Expecting parameter: tensor_list to be a List[torch.Tensor] type
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14204

Differential Revision: D13131526

Pulled By: teng-li

fbshipit-source-id: bca3d881e41044a013a6b90fa187e722b9dd45f2
torch/distributed/distributed_c10d.py