From a87808de931a31c242bca0c2305ec4af67f08ef2 Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Tue, 31 Aug 2021 20:19:55 -0700 Subject: [PATCH] Fix bug in ShardedTensorMetadata serde. (#63902) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63902 The 'memory_format' field was not being serialized correctly and used the same encoding for different fields. ghstack-source-id: 137142406 Test Plan: waitforbuildbot Reviewed By: bowangbj Differential Revision: D30527324 fbshipit-source-id: f0f223e2d660ef6e4abae9649d9992acc36e1278 --- .../_sharded_tensor/test_sharded_tensor.py | 50 ++++++++++++++++++++++ torch/distributed/_sharded_tensor/api.py | 19 +++++--- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/test/distributed/_sharded_tensor/test_sharded_tensor.py b/test/distributed/_sharded_tensor/test_sharded_tensor.py index 718b594..77e35b7 100644 --- a/test/distributed/_sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_sharded_tensor/test_sharded_tensor.py @@ -1,6 +1,8 @@ from functools import wraps import math import io +import itertools +import pickle import sys import torch import torch.distributed as dist @@ -123,6 +125,54 @@ def with_comms(func): self.destroy_comms() return wrapper +class TestShardedTensorMetadata(TestCase): + def test_serialize_and_deserialize(self): + shard_metadatas = [ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[5, 5], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 5], + shard_lengths=[5, 5], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_lengths=[5, 5], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_lengths=[5, 5], + placement="rank:3/cuda:3", + ) + ] + + dtypes = [ + torch.float, torch.double, torch.cfloat, torch.cdouble, torch.half, + torch.bfloat16, torch.uint8, torch.int8, torch.short, torch.int, + torch.long, torch.bool] + + layouts = [torch.strided, torch.sparse_coo] + requires_grads = [True, False] + memory_formats = [torch.contiguous_format, torch.channels_last, torch.preserve_format] + pin_memories = [True, False] + + for tensor_properties_input in itertools.product(dtypes, layouts, requires_grads, memory_formats, pin_memories): + dtype, layout, requires_grad, memory_format, pin_memory = tensor_properties_input + + expected_st_metadata = _sharded_tensor.ShardedTensorMetadata( + shard_metadatas, + (10, 10), + _sharded_tensor.TensorProperties(dtype, layout, requires_grad, memory_format, pin_memory) + ) + + pickled_obj = pickle.dumps(expected_st_metadata) + st_metadata = pickle.loads(pickled_obj) + self.assertEqual(expected_st_metadata, st_metadata) + class TestCreateTensorFromParams(TestCase): @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed') def test_empty(self): diff --git a/torch/distributed/_sharded_tensor/api.py b/torch/distributed/_sharded_tensor/api.py index 3b7476d..d6b7a54 100644 --- a/torch/distributed/_sharded_tensor/api.py +++ b/torch/distributed/_sharded_tensor/api.py @@ -70,6 +70,13 @@ class TensorProperties(object): memory_format: torch.memory_format = field(default=torch.contiguous_format) pin_memory: bool = False + +class MEM_FORMAT_ENCODING(Enum): + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + @dataclass class ShardedTensorMetadata(object): """ @@ -93,11 +100,11 @@ class ShardedTensorMetadata(object): # Since torch.memory_format cannot be pickled! memory_format = self.tensor_properties.memory_format if memory_format == torch.contiguous_format: - mem_format_encoding = 0 + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT elif memory_format == torch.channels_last: - mem_format_encoding = 1 + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST elif memory_format == torch.preserve_format: - mem_format_encoding = 1 + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT else: raise RuntimeError(f'Invalid torch.memory_format: {memory_format}') @@ -118,11 +125,11 @@ class ShardedTensorMetadata(object): ): (self.shards_metadata, self.size, dtype, layout, requires_grad, mem_format_encoding, pin_memory) = state - if mem_format_encoding == 0: + if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: memory_format = torch.contiguous_format - elif mem_format_encoding == 1: + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: memory_format = torch.channels_last - elif mem_format_encoding == 2: + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: memory_format = torch.preserve_format else: raise RuntimeError(f'Invalid torch.memory_format encoding: {mem_format_encoding}') -- 2.7.4