Add support for scalars in `tf.contrib.all_reduce`.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Feb 2018 18:34:20 +0000 (10:34 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Feb 2018 18:41:59 +0000 (10:41 -0800)
PiperOrigin-RevId: 185398372

tensorflow/contrib/all_reduce/python/all_reduce.py
tensorflow/contrib/all_reduce/python/all_reduce_test.py

index 28f60b3..16617b7 100644 (file)
@@ -48,7 +48,7 @@ def _flatten_tensors(tensors):
   if shape.ndims is None:
     raise ValueError("At least one of the tensors in 'tensors' must have "
                      "statically known rank.")
-  if len(shape) > 1:
+  if len(shape) != 1:
     reshaped = []
     for t in tensors:
       with ops.colocate_with(t):
@@ -289,7 +289,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
                                        chunks_by_dev)
   if pad_len > 0:
     output_tensors = _strip_padding(output_tensors, pad_len)
-  if len(shape) > 1:
+  if len(shape) != 1:
     output_tensors = _reshape_tensors(output_tensors, shape)
   return output_tensors
 
@@ -466,7 +466,7 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
   if un_op:
     reduced_shards = [un_op(t) for t in reduced_shards]
   output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
-  if len(shape) > 1:
+  if len(shape) != 1:
     output_tensors = _reshape_tensors(output_tensors, shape)
   return output_tensors
 
@@ -578,7 +578,7 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
   reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
                                          red_op, un_op)
   output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
-  if len(shape) > 1:
+  if len(shape) != 1:
     output_tensors = _reshape_tensors(output_tensors, shape)
   return output_tensors
 
@@ -752,7 +752,7 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
         dst_tensors.append(array_ops.identity(broadcast_src))
     down_values[w] = dst_tensors
   output_tensors = [v for sublist in down_values for v in sublist]
-  if len(shape) > 1:
+  if len(shape) != 1:
     output_tensors = _reshape_tensors(output_tensors, shape)
   return output_tensors
 
@@ -831,7 +831,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
   for w in range(0, num_workers):
     output_tensors += _build_shuffle_scatter(
         [level_2_output[w]], per_worker_devices[w])
-  if len(shape) > 1:
+  if len(shape) != 1:
     output_tensors = _reshape_tensors(output_tensors, shape)
   return output_tensors
 
index 0802b27..47bab0a 100644 (file)
@@ -119,7 +119,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
   def _buildInitialVars(self, shape, dev_list):
     values = []
     num_devices = len(dev_list)
-    dim = np.prod(shape)
+    dim = np.prod(shape) if shape else 1
     for d in range(0, num_devices):
       with ops.device(dev_list[d]):
         npt = np.zeros(shape).astype(np.float32)
@@ -164,6 +164,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
                     (num_workers, num_gpus, shape, subdiv, elapsed))
 
   def testRingAllReduce(self):
+    self._testRingAllReduce(1, 2, [], 1)
     self._testRingAllReduce(1, 2, [8], 1)
     self._testRingAllReduce(1, 2, [4, 4], 1)
     self._testRingAllReduce(6, 1, [8], 1)
@@ -192,6 +193,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
                     "elapsed=%f" % (num_workers, num_gpus, shape, elapsed))
 
   def testShuffleAllReduce(self):
+    self._testShuffleAllReduce(1, 2, [], 1)
     self._testShuffleAllReduce(1, 2, [8], 1)
     self._testShuffleAllReduce(1, 2, [4, 4], 1)
     self._testShuffleAllReduce(1, 8, [32], 1)