from __future__ import print_function
import contextlib
+import weakref
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
# pylint: enable=protected-access
-# pylint: disable=protected-access
-def _eager_write_no_copy(ta, index, value):
- """Writes value into an _EagerTensorArray without creating a new TensorArray.
-
- Args:
- ta: _EagerTensorArray into which to write value.
- index: 0-D. int32 scalar with the index to write to.
- value: N-D. Tensor of type `dtype`. The Tensor to write to this index.
-
- Raises:
- errors_impl.AlreadyExistsError: attempting to overwrite an entry.
- errors_impl.InvalidArgumentError: value dtype does not match `ta`'s dtype.
- errors_impl.OutOfRangeError: `index` is out of bounds.
- ValueError: shape of `value` is not consistent with inferred shape.
- """
-
- if isinstance(index, ops.EagerTensor):
- index = index.numpy()
-
- if index < 0:
- raise errors_impl.OutOfRangeError(
- None, None,
- "Writing to negative indices (index %d) is not allowed." % index)
-
- tensor_array = ta._tensor_array
- size = len(tensor_array)
- if index >= size:
- if not ta._dynamic_size:
- raise errors_impl.OutOfRangeError(
- None, None,
- "Tried to write to index %d but array is not resizeable and size "
- "is: %d" % (index, size))
- tensor_array.extend([None for _ in range(index - size + 1)])
-
- if not isinstance(value, ops.EagerTensor):
- value = constant_op.constant(value)
-
- if ta._infer_shape:
- if ta._element_shape is None:
- ta._element_shape = value.shape
- elif ta._element_shape != value.shape:
- raise ValueError("Incompatible shape for value (%s), expected (%s)" %
- (value.shape.as_list(), ta._element_shape.as_list()))
-
- if ta._dtype != value.dtype:
- raise errors_impl.InvalidArgumentError(
- None, None,
- "TensorArray dtype is %s but Op is trying to write dtype %s" %
- (ta._dtype.name, value.dtype.name))
-
- if ta._tensor_array[index] is not None:
- raise errors_impl.AlreadyExistsError(
- None, None,
- "Could not write to TensorArray index %d because it has already been "
- "written to." % index)
-
- tensor_array[index] = value
-
-# pylint: enable=protected-access
-
-
class _EagerTensorArray(object):
- """Eager-mode implementation of TensorArray.
+ """Eager-compatible implementation of TensorArray.
"""
def __init__(self,
element_shape=None,
colocate_with_first_write_call=True,
name=None):
- """Constructs an Eager mode TensorArray.
+ """Constructs a TensorArray compatible with eager execution.
Args:
dtype: (required) data type of the TensorArray.
ValueError: handle or flow are supplied, or if size is not supplied.
"""
- del (flow, tensor_array_name, name) # not meaningful in Eager
+ del (flow, tensor_array_name, name) # Unused.
if handle is not None:
- raise ValueError("TensorArray handles are not supported in Eager mode.")
+ raise ValueError("TensorArray handles are not supported when eager "
+ "execution is enabled.")
if size is None:
- raise ValueError("Size must be declared for TensorArrays in Eager mode.")
+ raise ValueError("Size must be declared for TensorArrays when eager "
+ "execution is enabled.")
- # These attributes are not meaningful in Eager, but some library functions
- # (e.g., those in control_flow_ops.py) access them to create new tensor
- # arrays; as such, we define them for the sake of compatibility.
+ # These attributes are not meaningful when eager is enabled, but some
+ # library functions (e.g., those in control_flow_ops.py) access them to
+ # create new tensor arrays; as such, we define them for the sake of
+ # compatibility.
self._handle = None
# we assign a dummy value to _flow in case other code assumes it to be
# a Tensor
@property
def flow(self):
- """Flows are not meaningful in Eager; this exists for compatibility."""
+ """For compatibility; flows are not meaningful when eager is enabled."""
return self._flow
@property
@property
def handle(self):
- """Handles are not meaningful in Eager; this exists for compatibility."""
+ """For compatibility; handles are not meaningful when eager is enabled."""
return self._handle
- def _identity_without_array(self):
- """Returns a new TensorArray with the same properties as this Eager one.
-
- NB: Does not set the underlying _tensor_array attribute.
- """
- ta = TensorArray(
- dtype=self._dtype,
- size=len(self._tensor_array),
- dynamic_size=self._dynamic_size,
- clear_after_read=self._clear_after_read,
- handle=self._handle,
- flow=self._flow,
- infer_shape=self._infer_shape,
- element_shape=self._element_shape,
- colocate_with_first_write_call=self._colocate_with_first_write_call)
- ta._implementation._previously_read_indices = self._previously_read_indices # pylint: disable=protected-access
- return ta
-
def identity(self):
"""See TensorArray."""
- ta = self._identity_without_array()
- ta._implementation._tensor_array = [t for t in self._tensor_array] # pylint: disable=protected-access
- return ta
+ return self.parent()
def grad(self, source, flow=None, name=None):
raise NotImplementedError(
- "TensorArray.grad is not supported in Eager mode; Eager's gradient "
- "implementation does not use/need this function to compute gradients "
- "of operations that use TensorArrays.")
+ "TensorArray.grad is not supported when executing eagerly; eager's "
+ "gradient implementation does not use/need this function to compute "
+ "gradients of operations that use TensorArrays.")
def read(self, index, name=None):
"""See TensorArray."""
- del name # not meaningful in Eager mode
+ del name # not meaningful when executing eagerly.
if isinstance(index, ops.EagerTensor):
index = index.numpy()
self._previously_read_indices.append(index)
return tensor
+ def _write(self, index, value):
+ """Writes `value` into index named by `index`.
+
+ Args:
+ index: 0-D. int32 scalar with the index to write to.
+ value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`.
+
+ Raises:
+ errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
+ errors_impl.OutOfRangeError: `index` is out of bounds.
+ ValueError: shape of `value` is not consistent with inferred shape.
+ """
+
+ if isinstance(index, ops.EagerTensor):
+ index = index.numpy()
+
+ if index < 0:
+ raise errors_impl.OutOfRangeError(
+ None, None,
+ "Writing to negative indices (index %d) is not allowed." % index)
+
+ size = len(self._tensor_array)
+ if index >= size:
+ if not self._dynamic_size:
+ raise errors_impl.OutOfRangeError(
+ None, None,
+ "Tried to write to index %d but array is not resizeable and size "
+ "is: %d" % (index, size))
+ self._tensor_array.extend([None for _ in range(index - size + 1)])
+
+ if not isinstance(value, ops.EagerTensor):
+ value = constant_op.constant(value)
+
+ if self._infer_shape:
+ if self._element_shape is None:
+ self._element_shape = value.shape
+ elif self._element_shape != value.shape:
+ raise ValueError("Incompatible shape for value (%s), expected (%s)" %
+ (value.shape.as_list(), self._element_shape.as_list()))
+
+ if self._dtype != value.dtype:
+ raise errors_impl.InvalidArgumentError(
+ None, None,
+ "TensorArray dtype is %s but Op is trying to write dtype %s" %
+ (self._dtype.name, value.dtype.name))
+ self._tensor_array[index] = value
+
def write(self, index, value, name=None):
"""See TensorArray."""
- del name # not meaningful in Eager mode
- ta = self.identity()
- _eager_write_no_copy(ta._implementation, index, value) # pylint: disable=protected-access
- return ta
+ del name # not meaningful when executing eagerly.
+ self._write(index, value)
+ return self.parent()
def _maybe_zero(self, ix):
val = self._tensor_array[ix]
def gather(self, indices, name=None):
"""See TensorArray."""
- del name # not meaningful in Eager mode
+ del name # not meaningful when executing eagerly.
return array_ops.stack([self._maybe_zero(i) for i in indices.numpy()])
def concat(self, name=None):
raise ValueError(
"Cannot unstack %d tensors into a TensorArray of static size %d" %
(len(tensors), len(self._tensor_array)))
- ta = self._identity_without_array()
- ta._implementation._tensor_array = tensors # pylint: disable=protected-access
- return ta
+ self._tensor_array = tensors
+ return self.parent()
def scatter(self, indices, value, name=None):
"""See TensorArray."""
- del name # unused in Eager
- ta = self.identity()
+ del name # not meaningful when executing eagerly.
for index, val in zip(indices.numpy(), array_ops.unstack(value)):
- _eager_write_no_copy(ta._implementation, index, val) # pylint: disable=protected-access
- return ta
+ self._write(index, val) # pylint: disable=protected-access
+ return self.parent()
def split(self, value, lengths, name=None):
"""See TensorArray."""
"dynamically resizeable" % (len(self._tensor_array),
lengths.shape[0]))
else:
- ta = self._identity_without_array()
- tensor_array = array_ops.split(value, lengths, name=name)
- ta._implementation._tensor_array = tensor_array # pylint: disable=protected-access
- return ta
+ self._tensor_array = array_ops.split(value, lengths, name=name)
+ return self.parent()
def size(self, name=None):
"""See TensorArray."""
- del name # not meaningful in Eager mode
+ del name # not meaningful when executing eagerly.
return constant_op.constant(len(self._tensor_array))
def close(self, name=None):
- del name # not meaningful in Eager mode
+ del name # not meaningful when executing eagerly.
del self._tensor_array[:]
- return
# TensorArray is designed to hide an underlying implementation object
colocate_with_first_write_call=colocate_with_first_write_call,
name=name)
+ self._implementation.parent = weakref.ref(self)
+
@property
def flow(self):
"""The flow `Tensor` forcing ops leading to this TensorArray state."""