Fix half tensor printing plus speedup large tensor printing (#14418)
authorFrancisco Massa <fvsmassa@gmail.com>
Wed, 28 Nov 2018 14:11:08 +0000 (06:11 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 28 Nov 2018 14:13:06 +0000 (06:13 -0800)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/14344 and https://github.com/pytorch/pytorch/issues/6863

The slowdown was due to the fact that we were only summarizing the tensor (for computing the number of digits to print) if its first dimension was larger than the threshold. It now goes over all the dimensions.

Some quick runtime analysis:

Before this PR:
```python
In [1]: import torch; a = torch.rand(1, 1700, 34, 50)

In [2]: %timeit str(a)
13.6 s ± 84.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

After this PR

```python
In [1]: import torch; a = torch.rand(1, 1700, 34, 50)

In [2]: %timeit str(a)
2.08 ms ± 395 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [3]: b = a.cuda()

In [4]: %timeit str(b)
8.39 ms ± 45.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14418

Reviewed By: weiyangfb

Differential Revision: D13226950

Pulled By: soumith

fbshipit-source-id: 19eb4b855db4c8f891d0925a9c56ae8a2824bb23

test/test_torch.py
torch/_tensor_str.py

index a7e6fb3..69bafcb 100644 (file)
@@ -8335,6 +8335,10 @@ class _TestTorchMixin(object):
             obj = t(100, 100).fill_(1)
             obj.__repr__()
             str(obj)
+        # test half tensor
+        obj = torch.rand(100, 100, device='cpu').half()
+        obj.__repr__()
+        str(obj)
         for t in torch._storage_classes:
             if t.is_cuda and not torch.cuda.is_available():
                 continue
@@ -8389,6 +8393,13 @@ tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
         self.assertEqual(x.__repr__(), str(x))
         self.assertExpectedInline(str(x), '''tensor([0., 0., 0.,  ..., 0., 0., 0.])''')
 
+        # test internal summary function
+        x = torch.rand(1, 20, 5, 30)
+        summary = torch._tensor_str.get_summarized_data(x)
+        self.assertEqual(summary.shape, (1, 6, 5, 6))
+        first_and_last = [0, 1, 2, -3, -2, -1]
+        self.assertEqual(summary, x[:, first_and_last][..., first_and_last])
+
         # test device
         if torch.cuda.is_available():
             x = torch.tensor([123], device='cuda:0')
index e0c6612..a00c32e 100644 (file)
@@ -190,6 +190,8 @@ def _tensor_str(self, indent):
         return '[]'
 
     summarize = self.numel() > PRINT_OPTS.threshold
+    if self.dtype is torch.float16:
+        self = self.float()
     formatter = _Formatter(get_summarized_data(self) if summarize else self)
     return _tensor_str_with_formatter(self, indent, formatter, summarize)
 
@@ -220,12 +222,12 @@ def get_summarized_data(self):
         else:
             return self
     if self.size(0) > 2 * PRINT_OPTS.edgeitems:
-        start = [get_summarized_data(self[i]).reshape(-1) for i in range(0, PRINT_OPTS.edgeitems)]
-        end = ([get_summarized_data(self[i]).reshape(-1)
+        start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
+        end = ([self[i]
                for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
-        return torch.cat((start + end))
+        return torch.stack([get_summarized_data(x) for x in (start + end)])
     else:
-        return self
+        return torch.stack([get_summarized_data(x) for x in self])
 
 
 def _str(self):