Fix _force_data_dependency for scalar inputs
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 02:58:53 +0000 (19:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 03:01:36 +0000 (20:01 -0700)
PiperOrigin-RevId: 190715033

tensorflow/contrib/layers/python/layers/rev_block_lib.py

index 0b38c0c..e49589d 100644 (file)
@@ -33,6 +33,7 @@ import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.contrib.framework.python import ops as contrib_framework_ops
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops as framework_ops
 from tensorflow.python.layers import base
@@ -660,7 +661,9 @@ def _force_data_dependency(first_compute, then_compute):
     if x.get_shape().ndims is None:
       raise ValueError("Rank of Tensor %s must be known" % x)
     ndims = x.get_shape().ndims
-    return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), [])
+    begin = framework_ops.convert_to_tensor([0] * ndims, dtype=dtypes.int32)
+    size = framework_ops.convert_to_tensor([1] * ndims, dtype=dtypes.int32)
+    return array_ops.reshape(array_ops.slice(x, begin, size), [])
 
   first_compute_sum = math_ops.add_n(
       [_first_element(x) for x in first_compute if x is not None])