from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops as framework_ops
from tensorflow.python.layers import base
if x.get_shape().ndims is None:
raise ValueError("Rank of Tensor %s must be known" % x)
ndims = x.get_shape().ndims
- return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), [])
+ begin = framework_ops.convert_to_tensor([0] * ndims, dtype=dtypes.int32)
+ size = framework_ops.convert_to_tensor([1] * ndims, dtype=dtypes.int32)
+ return array_ops.reshape(array_ops.slice(x, begin, size), [])
first_compute_sum = math_ops.add_n(
[_first_element(x) for x in first_compute if x is not None])