Avoid unnecessary CPU-to-GPU copy of torch.load with CUDA (#17297)
authorLuca Wehrstedt <lcw@fb.com>
Thu, 21 Feb 2019 09:24:56 +0000 (01:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Feb 2019 09:32:19 +0000 (01:32 -0800)
commit29f4f8f048b03808bd1c7b6917b1825997b4abc3
tree00d01fea6c7486319c982e9c8a439910b1a0ce89
parent2c302b6ea68eb9024f5f1b9cdc41e3bf39556fa2
Avoid unnecessary CPU-to-GPU copy of torch.load with CUDA (#17297)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17297

When `torch.load` needs to load a tensor, no matter which device it will be end up being loaded on, it first creates a CPU storage for it of the necessary size. This storage is allocated but it's not "set" yet, hence no data is written to it: it exists in the kernel's memory map, but it's not resident and doesn't take up physical pages. Then, this storage is passed to the `map_location` function (if the parameter is a string, a device or a map, PyTorch builds that function automatically). The default map for CUDA consists effectively in `lambda storage, _: storage.cuda()` (I omitted the code needed to pick the correct device). This creates a GPU storage and copies over the data of the CPU storage. *This step is unnecessary as we're copying uninitialized memory*. (Surprisingly enough, though, it appears the kernel is smart enough that reading from the unpaged CPU memory doesn't cause it to become paged.) Once `map_location` returns a storage residing on the correct target device, `torch.load` resumes reading the file and copying the tensor's content over into the storage. This will overwrite the content that had previously been written to it, which confirms that the above copy was pointless.

A way to avoid this useless copy is to just create and return a new empty storage on the target GPU, instead of "transforming" the original one.

This does indeed increase the performance:
```
In [5]: torch.save(torch.rand(100, 100, 100), "/tmp/tensor")

In [6]: %timeit torch.load("/tmp/tensor", map_location="cuda")
1.55 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [7]: %timeit torch.load("/tmp/tensor", map_location=lambda storage, _: torch.cuda.FloatStorage(storage.size()))
1.03 ms ± 44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

Credit for this diff is shared with adamlerer and fmassa.

Differential Revision: D14147673

fbshipit-source-id: a58d4bc0d894ca03a008499334fc2cdd4cc91e9f
torch/serialization.py