In contrib/all_reduce raise a ValueError if the input tensors
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 18:34:32 +0000 (11:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 18:36:34 +0000 (11:36 -0700)
do not have fully-defined shapes.

PiperOrigin-RevId: 190804146

tensorflow/contrib/all_reduce/python/all_reduce.py
tensorflow/contrib/all_reduce/python/all_reduce_test.py

index 6658f0d..8add2aa 100644 (file)
@@ -38,16 +38,15 @@ def _flatten_tensors(tensors):
     shape: the original shape of each element of input tensors
 
   Raises:
-    ValueError: tensors are empty or non-isomorphic.
+    ValueError: tensors are empty or non-isomorphic or have unknown shape.
   """
   if not tensors:
     raise ValueError("tensors cannot be empty")
   shape = tensors[0].shape
   for tensor in tensors:
     shape = shape.merge_with(tensor.shape)
-  if shape.ndims is None:
-    raise ValueError("At least one of the tensors in 'tensors' must have "
-                     "statically known rank.")
+  if not shape.is_fully_defined():
+    raise ValueError("Tensors must have statically known shape.")
   if len(shape) != 1:
     reshaped = []
     for t in tensors:
index 47bab0a..b3f5d92 100644 (file)
@@ -36,6 +36,12 @@ from tensorflow.python.platform import tf_logging
 
 class AllReduceTest(test_util.TensorFlowTestCase):
 
+  def testFlattenTensorsShapesDefined(self):
+    x = array_ops.placeholder(types_pb2.DT_FLOAT, [None])
+    with self.assertRaisesRegexp(ValueError,
+                                 "must have statically known shape"):
+      ar._flatten_tensors([x, x])
+
   def testRingPermutations(self):
     # 0 devices
     pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, [])