Refactor TensorArray to avoid copies and memory allocations when executing eagerly.
authorAkshay Agrawal <akshayka@google.com>
Mon, 7 May 2018 23:16:24 +0000 (16:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 00:50:01 +0000 (17:50 -0700)
With this change, writes to TensorArrays when eager execution is enabled take O(1) time instead of O(n). Additionally, whereas writing to a TensorArray when constructing a graph results in allocating a new Python TensorArray object, writing to a TensorArray with eager enabled no longer performs that allocation (graph construction uses these allocations to ensure correctness of control flow and gradients, but this isn't necessary when executing eagerly). Finally, this change also removes the artificial write-once semantics of TensorArrays when executing eagerly.

PiperOrigin-RevId: 195739572

tensorflow/python/kernel_tests/tensor_array_ops_test.py
tensorflow/python/ops/tensor_array_ops.py

index 918bbd3..c0b36f1 100644 (file)
@@ -438,7 +438,6 @@ class TensorArrayTest(test.TestCase):
           "Tried to read from index 3 but array size is: 3"):
         self.evaluate(ta.read(3))
 
-  @test_util.run_in_graph_and_eager_modes()
   def testTensorArrayWriteMultipleFails(self):
     with self.test_session(use_gpu=True):
       ta = tensor_array_ops.TensorArray(
index d2f45ce..cc92da4 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import division
 from __future__ import print_function
 
 import contextlib
+import weakref
 
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
@@ -395,69 +396,8 @@ class _GraphTensorArray(object):
 # pylint: enable=protected-access
 
 
-# pylint: disable=protected-access
-def _eager_write_no_copy(ta, index, value):
-  """Writes value into an _EagerTensorArray without creating a new TensorArray.
-
-  Args:
-    ta: _EagerTensorArray into which to write value.
-    index: 0-D.  int32 scalar with the index to write to.
-    value: N-D.  Tensor of type `dtype`.  The Tensor to write to this index.
-
-  Raises:
-    errors_impl.AlreadyExistsError: attempting to overwrite an entry.
-    errors_impl.InvalidArgumentError: value dtype does not match `ta`'s dtype.
-    errors_impl.OutOfRangeError: `index` is out of bounds.
-    ValueError: shape of `value` is not consistent with inferred shape.
-  """
-
-  if isinstance(index, ops.EagerTensor):
-    index = index.numpy()
-
-  if index < 0:
-    raise errors_impl.OutOfRangeError(
-        None, None,
-        "Writing to negative indices (index %d) is not allowed." % index)
-
-  tensor_array = ta._tensor_array
-  size = len(tensor_array)
-  if index >= size:
-    if not ta._dynamic_size:
-      raise errors_impl.OutOfRangeError(
-          None, None,
-          "Tried to write to index %d but array is not resizeable and size "
-          "is: %d" % (index, size))
-    tensor_array.extend([None for _ in range(index - size + 1)])
-
-  if not isinstance(value, ops.EagerTensor):
-    value = constant_op.constant(value)
-
-  if ta._infer_shape:
-    if ta._element_shape is None:
-      ta._element_shape = value.shape
-    elif ta._element_shape != value.shape:
-      raise ValueError("Incompatible shape for value (%s), expected (%s)" %
-                       (value.shape.as_list(), ta._element_shape.as_list()))
-
-  if ta._dtype != value.dtype:
-    raise errors_impl.InvalidArgumentError(
-        None, None,
-        "TensorArray dtype is %s but Op is trying to write dtype %s" %
-        (ta._dtype.name, value.dtype.name))
-
-  if ta._tensor_array[index] is not None:
-    raise errors_impl.AlreadyExistsError(
-        None, None,
-        "Could not write to TensorArray index %d because it has already been "
-        "written to." % index)
-
-  tensor_array[index] = value
-
-# pylint: enable=protected-access
-
-
 class _EagerTensorArray(object):
-  """Eager-mode implementation of TensorArray.
+  """Eager-compatible implementation of TensorArray.
   """
 
   def __init__(self,
@@ -472,7 +412,7 @@ class _EagerTensorArray(object):
                element_shape=None,
                colocate_with_first_write_call=True,
                name=None):
-    """Constructs an Eager mode TensorArray.
+    """Constructs a TensorArray compatible with eager execution.
 
     Args:
       dtype: (required) data type of the TensorArray.
@@ -495,16 +435,19 @@ class _EagerTensorArray(object):
       ValueError: handle or flow are supplied, or if size is not supplied.
     """
 
-    del (flow, tensor_array_name, name)  # not meaningful in Eager
+    del (flow, tensor_array_name, name)  # Unused.
 
     if handle is not None:
-      raise ValueError("TensorArray handles are not supported in Eager mode.")
+      raise ValueError("TensorArray handles are not supported when eager "
+                       "execution is enabled.")
     if size is None:
-      raise ValueError("Size must be declared for TensorArrays in Eager mode.")
+      raise ValueError("Size must be declared for TensorArrays when eager "
+                       "execution is enabled.")
 
-    # These attributes are not meaningful in Eager, but some library functions
-    # (e.g., those in control_flow_ops.py) access them to create new tensor
-    # arrays; as such, we define them for the sake of compatibility.
+    # These attributes are not meaningful when eager is enabled, but some
+    # library functions (e.g., those in control_flow_ops.py) access them to
+    # create new tensor arrays; as such, we define them for the sake of
+    # compatibility.
     self._handle = None
     # we assign a dummy value to _flow in case other code assumes it to be
     # a Tensor
@@ -525,7 +468,7 @@ class _EagerTensorArray(object):
 
   @property
   def flow(self):
-    """Flows are not meaningful in Eager; this exists for compatibility."""
+    """For compatibility; flows are not meaningful when eager is enabled."""
     return self._flow
 
   @property
@@ -534,42 +477,22 @@ class _EagerTensorArray(object):
 
   @property
   def handle(self):
-    """Handles are not meaningful in Eager; this exists for compatibility."""
+    """For compatibility; handles are not meaningful when eager is enabled."""
     return self._handle
 
-  def _identity_without_array(self):
-    """Returns a new TensorArray with the same properties as this Eager one.
-
-    NB: Does not set the underlying _tensor_array attribute.
-    """
-    ta = TensorArray(
-        dtype=self._dtype,
-        size=len(self._tensor_array),
-        dynamic_size=self._dynamic_size,
-        clear_after_read=self._clear_after_read,
-        handle=self._handle,
-        flow=self._flow,
-        infer_shape=self._infer_shape,
-        element_shape=self._element_shape,
-        colocate_with_first_write_call=self._colocate_with_first_write_call)
-    ta._implementation._previously_read_indices = self._previously_read_indices  # pylint: disable=protected-access
-    return ta
-
   def identity(self):
     """See TensorArray."""
-    ta = self._identity_without_array()
-    ta._implementation._tensor_array = [t for t in self._tensor_array]  # pylint: disable=protected-access
-    return ta
+    return self.parent()
 
   def grad(self, source, flow=None, name=None):
     raise NotImplementedError(
-        "TensorArray.grad is not supported in Eager mode; Eager's gradient "
-        "implementation does not use/need this function to compute gradients "
-        "of operations that use TensorArrays.")
+        "TensorArray.grad is not supported when executing eagerly; eager's "
+        "gradient implementation does not use/need this function to compute "
+        "gradients of operations that use TensorArrays.")
 
   def read(self, index, name=None):
     """See TensorArray."""
-    del name  # not meaningful in Eager mode
+    del name  # not meaningful when executing eagerly.
 
     if isinstance(index, ops.EagerTensor):
       index = index.numpy()
@@ -600,12 +523,58 @@ class _EagerTensorArray(object):
       self._previously_read_indices.append(index)
     return tensor
 
+  def _write(self, index, value):
+    """Writes `value` into index named by `index`.
+
+    Args:
+      index: 0-D.  int32 scalar with the index to write to.
+      value: N-D.  Tensor of type `dtype`.  The `Tensor` to write to `index`.
+
+    Raises:
+      errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
+      errors_impl.OutOfRangeError: `index` is out of bounds.
+      ValueError: shape of `value` is not consistent with inferred shape.
+    """
+
+    if isinstance(index, ops.EagerTensor):
+      index = index.numpy()
+
+    if index < 0:
+      raise errors_impl.OutOfRangeError(
+          None, None,
+          "Writing to negative indices (index %d) is not allowed." % index)
+
+    size = len(self._tensor_array)
+    if index >= size:
+      if not self._dynamic_size:
+        raise errors_impl.OutOfRangeError(
+            None, None,
+            "Tried to write to index %d but array is not resizeable and size "
+            "is: %d" % (index, size))
+      self._tensor_array.extend([None for _ in range(index - size + 1)])
+
+    if not isinstance(value, ops.EagerTensor):
+      value = constant_op.constant(value)
+
+    if self._infer_shape:
+      if self._element_shape is None:
+        self._element_shape = value.shape
+      elif self._element_shape != value.shape:
+        raise ValueError("Incompatible shape for value (%s), expected (%s)" %
+                         (value.shape.as_list(), self._element_shape.as_list()))
+
+    if self._dtype != value.dtype:
+      raise errors_impl.InvalidArgumentError(
+          None, None,
+          "TensorArray dtype is %s but Op is trying to write dtype %s" %
+          (self._dtype.name, value.dtype.name))
+    self._tensor_array[index] = value
+
   def write(self, index, value, name=None):
     """See TensorArray."""
-    del name  # not meaningful in Eager mode
-    ta = self.identity()
-    _eager_write_no_copy(ta._implementation, index, value)  # pylint: disable=protected-access
-    return ta
+    del name  # not meaningful when executing eagerly.
+    self._write(index, value)
+    return self.parent()
 
   def _maybe_zero(self, ix):
     val = self._tensor_array[ix]
@@ -623,7 +592,7 @@ class _EagerTensorArray(object):
 
   def gather(self, indices, name=None):
     """See TensorArray."""
-    del name  # not meaningful in Eager mode
+    del name  # not meaningful when executing eagerly.
     return array_ops.stack([self._maybe_zero(i) for i in indices.numpy()])
 
   def concat(self, name=None):
@@ -651,17 +620,15 @@ class _EagerTensorArray(object):
       raise ValueError(
           "Cannot unstack %d tensors into a TensorArray of static size %d" %
           (len(tensors), len(self._tensor_array)))
-    ta = self._identity_without_array()
-    ta._implementation._tensor_array = tensors  # pylint: disable=protected-access
-    return ta
+    self._tensor_array = tensors
+    return self.parent()
 
   def scatter(self, indices, value, name=None):
     """See TensorArray."""
-    del name  # unused in Eager
-    ta = self.identity()
+    del name  # not meaningful when executing eagerly.
     for index, val in zip(indices.numpy(), array_ops.unstack(value)):
-      _eager_write_no_copy(ta._implementation, index, val)  # pylint: disable=protected-access
-    return ta
+      self._write(index, val)  # pylint: disable=protected-access
+    return self.parent()
 
   def split(self, value, lengths, name=None):
     """See TensorArray."""
@@ -690,20 +657,17 @@ class _EagerTensorArray(object):
           "dynamically resizeable" % (len(self._tensor_array),
                                       lengths.shape[0]))
     else:
-      ta = self._identity_without_array()
-      tensor_array = array_ops.split(value, lengths, name=name)
-      ta._implementation._tensor_array = tensor_array  # pylint: disable=protected-access
-      return ta
+      self._tensor_array = array_ops.split(value, lengths, name=name)
+      return self.parent()
 
   def size(self, name=None):
     """See TensorArray."""
-    del name  # not meaningful in Eager mode
+    del name  # not meaningful when executing eagerly.
     return constant_op.constant(len(self._tensor_array))
 
   def close(self, name=None):
-    del name  # not meaningful in Eager mode
+    del name  # not meaningful when executing eagerly.
     del self._tensor_array[:]
-    return
 
 
 # TensorArray is designed to hide an underlying implementation object
@@ -789,6 +753,8 @@ class TensorArray(object):
         colocate_with_first_write_call=colocate_with_first_write_call,
         name=name)
 
+    self._implementation.parent = weakref.ref(self)
+
   @property
   def flow(self):
     """The flow `Tensor` forcing ops leading to this TensorArray state."""