Make DType, TensorShape, and Dimension "reducable" for pickling purposes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 12 Apr 2018 11:40:42 +0000 (04:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 12 Apr 2018 11:43:24 +0000 (04:43 -0700)
PiperOrigin-RevId: 192591402

tensorflow/python/framework/dtypes.py
tensorflow/python/framework/dtypes_test.py
tensorflow/python/framework/tensor_shape.py
tensorflow/python/framework/tensor_shape_test.py

index a31c424..51ff517 100644 (file)
@@ -297,6 +297,9 @@ class DType(object):
   def __hash__(self):
     return self._type_enum
 
+  def __reduce__(self):
+    return as_dtype, (self.name,)
+
   @property
   def size(self):
     if (self._type_enum == types_pb2.DT_VARIANT or
index e49e2fd..e55783b 100644 (file)
@@ -295,6 +295,15 @@ class TypesTest(test_util.TensorFlowTestCase):
     self.assertNotEqual(dtypes.int32, int)
     self.assertNotEqual(dtypes.float64, 2.1)
 
+  def testReduce(self):
+    for enum in dtypes._TYPE_TO_STRING:
+      dtype = dtypes.DType(enum)
+      ctor, args = dtype.__reduce__()
+      self.assertEquals(ctor, dtypes.as_dtype)
+      self.assertEquals(args, (dtype.name,))
+      reconstructed = ctor(*args)
+      self.assertEquals(reconstructed, dtype)
+
 
 if __name__ == "__main__":
   googletest.main()
index af2a5b1..00f256c 100644 (file)
@@ -456,6 +456,9 @@ class Dimension(object):
     else:
       return self._value >= other.value
 
+  def __reduce__(self):
+    return Dimension, (self._value,)
+
 
 def as_dimension(value):
   """Converts the given value to a Dimension.
@@ -928,6 +931,9 @@ class TensorShape(object):
       return True
     return self._dims != other.dims
 
+  def __reduce__(self):
+    return TensorShape, (self._dims,)
+
 
 def as_shape(shape):
   """Converts the given object to a TensorShape."""
index 4e8ce4d..498574e 100644 (file)
@@ -192,6 +192,14 @@ class DimensionTest(test_util.TensorFlowTestCase):
     self.assertEqual(nine % 4, 1)
     self.assertEqual(4 % nine, 4)
 
+  def testReduce(self):
+    dim = tensor_shape.Dimension(5)
+    ctor, args = dim.__reduce__()
+    self.assertEquals(ctor, tensor_shape.Dimension)
+    self.assertEquals(args, (5,))
+    reconstructed = ctor(*args)
+    self.assertEquals(reconstructed, dim)
+
 
 class ShapeTest(test_util.TensorFlowTestCase):
 
@@ -417,5 +425,15 @@ class ShapeTest(test_util.TensorFlowTestCase):
     self.assertAllEqual([2, None, 4], tensor_shape.TensorShape(
         (2, None, 4)).as_list())
 
+  def testReduce(self):
+    shape = tensor_shape.TensorShape([2, 3])
+    ctor, args = shape.__reduce__()
+    self.assertEquals(ctor, tensor_shape.TensorShape)
+    self.assertEquals(args, ([tensor_shape.Dimension(2),
+                              tensor_shape.Dimension(3)],))
+    reconstructed = ctor(*args)
+    self.assertEquals(reconstructed, shape)
+
+
 if __name__ == "__main__":
   googletest.main()