if shape.ndims is None:
raise ValueError("At least one of the tensors in 'tensors' must have "
"statically known rank.")
- if len(shape) > 1:
+ if len(shape) != 1:
reshaped = []
for t in tensors:
with ops.colocate_with(t):
chunks_by_dev)
if pad_len > 0:
output_tensors = _strip_padding(output_tensors, pad_len)
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
if un_op:
reduced_shards = [un_op(t) for t in reduced_shards]
output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
red_op, un_op)
output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
dst_tensors.append(array_ops.identity(broadcast_src))
down_values[w] = dst_tensors
output_tensors = [v for sublist in down_values for v in sublist]
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
for w in range(0, num_workers):
output_tensors += _build_shuffle_scatter(
[level_2_output[w]], per_worker_devices[w])
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
def _buildInitialVars(self, shape, dev_list):
values = []
num_devices = len(dev_list)
- dim = np.prod(shape)
+ dim = np.prod(shape) if shape else 1
for d in range(0, num_devices):
with ops.device(dev_list[d]):
npt = np.zeros(shape).astype(np.float32)
(num_workers, num_gpus, shape, subdiv, elapsed))
def testRingAllReduce(self):
+ self._testRingAllReduce(1, 2, [], 1)
self._testRingAllReduce(1, 2, [8], 1)
self._testRingAllReduce(1, 2, [4, 4], 1)
self._testRingAllReduce(6, 1, [8], 1)
"elapsed=%f" % (num_workers, num_gpus, shape, elapsed))
def testShuffleAllReduce(self):
+ self._testShuffleAllReduce(1, 2, [], 1)
self._testShuffleAllReduce(1, 2, [8], 1)
self._testShuffleAllReduce(1, 2, [4, 4], 1)
self._testShuffleAllReduce(1, 8, [32], 1)