Fix bug in ShardedTensorMetadata serde. (#63902)
authorPritam Damania <pritam.damania@fb.com>
Wed, 1 Sep 2021 03:19:55 +0000 (20:19 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 03:31:14 +0000 (20:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63902

The 'memory_format' field was not being serialized correctly and used
the same encoding for different fields.
ghstack-source-id: 137142406

Test Plan: waitforbuildbot

Reviewed By: bowangbj

Differential Revision: D30527324

fbshipit-source-id: f0f223e2d660ef6e4abae9649d9992acc36e1278

test/distributed/_sharded_tensor/test_sharded_tensor.py
torch/distributed/_sharded_tensor/api.py

index 718b594..77e35b7 100644 (file)
@@ -1,6 +1,8 @@
 from functools import wraps
 import math
 import io
+import itertools
+import pickle
 import sys
 import torch
 import torch.distributed as dist
@@ -123,6 +125,54 @@ def with_comms(func):
         self.destroy_comms()
     return wrapper
 
+class TestShardedTensorMetadata(TestCase):
+    def test_serialize_and_deserialize(self):
+        shard_metadatas = [
+            ShardMetadata(
+                shard_offsets=[0, 0],
+                shard_lengths=[5, 5],
+                placement="rank:0/cuda:0",
+            ),
+            ShardMetadata(
+                shard_offsets=[0, 5],
+                shard_lengths=[5, 5],
+                placement="rank:1/cuda:1",
+            ),
+            ShardMetadata(
+                shard_offsets=[5, 0],
+                shard_lengths=[5, 5],
+                placement="rank:2/cuda:2",
+            ),
+            ShardMetadata(
+                shard_offsets=[5, 5],
+                shard_lengths=[5, 5],
+                placement="rank:3/cuda:3",
+            )
+        ]
+
+        dtypes = [
+            torch.float, torch.double, torch.cfloat, torch.cdouble, torch.half,
+            torch.bfloat16, torch.uint8, torch.int8, torch.short, torch.int,
+            torch.long, torch.bool]
+
+        layouts = [torch.strided, torch.sparse_coo]
+        requires_grads = [True, False]
+        memory_formats = [torch.contiguous_format, torch.channels_last, torch.preserve_format]
+        pin_memories = [True, False]
+
+        for tensor_properties_input in itertools.product(dtypes, layouts, requires_grads, memory_formats, pin_memories):
+            dtype, layout, requires_grad, memory_format, pin_memory = tensor_properties_input
+
+            expected_st_metadata = _sharded_tensor.ShardedTensorMetadata(
+                shard_metadatas,
+                (10, 10),
+                _sharded_tensor.TensorProperties(dtype, layout, requires_grad, memory_format, pin_memory)
+            )
+
+            pickled_obj = pickle.dumps(expected_st_metadata)
+            st_metadata = pickle.loads(pickled_obj)
+            self.assertEqual(expected_st_metadata, st_metadata)
+
 class TestCreateTensorFromParams(TestCase):
     @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
     def test_empty(self):
index 3b7476d..d6b7a54 100644 (file)
@@ -70,6 +70,13 @@ class TensorProperties(object):
     memory_format: torch.memory_format = field(default=torch.contiguous_format)
     pin_memory: bool = False
 
+
+class MEM_FORMAT_ENCODING(Enum):
+    TORCH_CONTIGUOUS_FORMAT = 0
+    TORCH_CHANNELS_LAST = 1
+    TORCH_PRESERVE_FORMAT = 2
+
+
 @dataclass
 class ShardedTensorMetadata(object):
     """
@@ -93,11 +100,11 @@ class ShardedTensorMetadata(object):
         # Since torch.memory_format cannot be pickled!
         memory_format = self.tensor_properties.memory_format
         if memory_format == torch.contiguous_format:
-            mem_format_encoding = 0
+            mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
         elif memory_format == torch.channels_last:
-            mem_format_encoding = 1
+            mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
         elif memory_format == torch.preserve_format:
-            mem_format_encoding = 1
+            mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
         else:
             raise RuntimeError(f'Invalid torch.memory_format: {memory_format}')
 
@@ -118,11 +125,11 @@ class ShardedTensorMetadata(object):
     ):
         (self.shards_metadata, self.size, dtype, layout, requires_grad, mem_format_encoding, pin_memory) = state
 
-        if mem_format_encoding == 0:
+        if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
             memory_format = torch.contiguous_format
-        elif mem_format_encoding == 1:
+        elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
             memory_format = torch.channels_last
-        elif mem_format_encoding == 2:
+        elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
             memory_format = torch.preserve_format
         else:
             raise RuntimeError(f'Invalid torch.memory_format encoding: {mem_format_encoding}')