Avoid unnecessary CPU-to-GPU copy of torch.load with CUDA (#17297)
authorLuca Wehrstedt <lcw@fb.com>
Thu, 21 Feb 2019 09:24:56 +0000 (01:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Feb 2019 09:32:19 +0000 (01:32 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17297

When `torch.load` needs to load a tensor, no matter which device it will be end up being loaded on, it first creates a CPU storage for it of the necessary size. This storage is allocated but it's not "set" yet, hence no data is written to it: it exists in the kernel's memory map, but it's not resident and doesn't take up physical pages. Then, this storage is passed to the `map_location` function (if the parameter is a string, a device or a map, PyTorch builds that function automatically). The default map for CUDA consists effectively in `lambda storage, _: storage.cuda()` (I omitted the code needed to pick the correct device). This creates a GPU storage and copies over the data of the CPU storage. *This step is unnecessary as we're copying uninitialized memory*. (Surprisingly enough, though, it appears the kernel is smart enough that reading from the unpaged CPU memory doesn't cause it to become paged.) Once `map_location` returns a storage residing on the correct target device, `torch.load` resumes reading the file and copying the tensor's content over into the storage. This will overwrite the content that had previously been written to it, which confirms that the above copy was pointless.

A way to avoid this useless copy is to just create and return a new empty storage on the target GPU, instead of "transforming" the original one.

This does indeed increase the performance:
```
In [5]: torch.save(torch.rand(100, 100, 100), "/tmp/tensor")

In [6]: %timeit torch.load("/tmp/tensor", map_location="cuda")
1.55 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [7]: %timeit torch.load("/tmp/tensor", map_location=lambda storage, _: torch.cuda.FloatStorage(storage.size()))
1.03 ms ± 44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

Credit for this diff is shared with adamlerer and fmassa.

Differential Revision: D14147673

fbshipit-source-id: a58d4bc0d894ca03a008499334fc2cdd4cc91e9f

torch/serialization.py

index ef7d415..ee92d8d 100644 (file)
@@ -93,7 +93,12 @@ def validate_cuda_device(location):
 def _cuda_deserialize(obj, location):
     if location.startswith('cuda'):
         device = validate_cuda_device(location)
-        return obj.cuda(device)
+        if getattr(obj, "_torch_load_uninitialized", False):
+            storage_type = getattr(torch.cuda, type(obj).__name__)
+            with torch.cuda.device(device):
+                return storage_type(obj.size())
+        else:
+            return obj.cuda(device)
 
 
 register_package(10, _cpu_tag, _cpu_deserialize)
@@ -525,8 +530,9 @@ def _load(f, map_location, pickle_module, **pickle_load_args):
             data_type, root_key, location, size, view_metadata = data
             location = maybe_decode_ascii(location)
             if root_key not in deserialized_objects:
-                deserialized_objects[root_key] = restore_location(
-                    data_type(size), location)
+                obj = data_type(size)
+                obj._torch_load_uninitialized = True
+                deserialized_objects[root_key] = restore_location(obj, location)
             storage = deserialized_objects[root_key]
             if view_metadata is not None:
                 view_key, offset, view_size = view_metadata