Makes tf.gather not silently snapshot resource variables.
authorAlexandre Passos <apassos@google.com>
Mon, 26 Mar 2018 22:39:54 +0000 (15:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 26 Mar 2018 22:42:46 +0000 (15:42 -0700)
PiperOrigin-RevId: 190537320

tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
tensorflow/python/ops/array_ops.py
tensorflow/python/ops/embedding_ops.py

index c4139dd..07b3ad7 100644 (file)
@@ -785,26 +785,31 @@ class AttentionWrapperTest(test.TestCase):
         wrapper.BahdanauAttention, wrapper.LuongAttention)
 
     expected_final_output = BasicDecoderOutput(
-        rnn_output=ResultSummary(
-            shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11798714846372604),
-        sample_id=ResultSummary(
-            shape=(5, 3), dtype=dtype('int32'), mean=7.933333333333334))
+        rnn_output=ResultSummary(shape=(5, 3, 20),
+                                 dtype=dtype('float32'),
+                                 mean=0.11723966),
+        sample_id=ResultSummary(shape=(5, 3),
+                                dtype=dtype('int32'),
+                                mean=9.2666666666666675))
     expected_final_state = AttentionWrapperState(
         cell_state=LSTMStateTuple(
-            c=ResultSummary(
-                shape=(5, 9), dtype=dtype('float32'), mean=-0.0036486709),
-            h=ResultSummary(
-                shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)),
-        attention=ResultSummary(
-            shape=(5, 20), dtype=dtype('float32'), mean=0.11798714846372604),
+            c=ResultSummary(shape=(5, 9),
+                            dtype=dtype('float32'),
+                            mean=-0.003545674),
+            h=ResultSummary(shape=(5, 9),
+                            dtype=dtype('float32'),
+                            mean=-0.0018327223)),
+        attention=ResultSummary(shape=(5, 20),
+                                dtype=dtype('float32'),
+                                mean=0.11728073),
         time=3,
         alignments=(
             ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
             ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+        alignment_history=(),
         attention_state=(
             ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
-            ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
-        alignment_history=())
+            ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
     expected_final_alignment_history = (
         ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
         ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
index ec7c14f..9106461 100644 (file)
@@ -2691,12 +2691,17 @@ reverse_sequence.__doc__ = deprecation.rewrite_argument_docstring(
 
 @tf_export("gather")
 def gather(params, indices, validate_indices=None, name=None, axis=0):
-  # TODO(rjryan): Remove "Gather" creation in favor of GatherV2 once the forward
-  # compatibility 3 week period has passed.
-  if axis == 0:
-    return gen_array_ops.gather(
-        params, indices, validate_indices=validate_indices, name=name)
-  else:
+  del validate_indices
+  if axis != 0:
+    # Note that we do a sparse_read here to avoid snapshotting the entire
+    # resource variable and doing a gather, which can be inefficient and lead to
+    # subtle race conditions. TODO(apassos) implement axis != 0 on sparse_read
+    return gen_array_ops.gather_v2(params, indices, axis, name=name)
+  try:
+    # TODO(apassos) find a less bad way of detecting resource variables without
+    # introducing a circular dependency.
+    return params.sparse_read(indices, name=name)
+  except AttributeError:
     return gen_array_ops.gather_v2(params, indices, axis, name=name)
 
 
index 20e4a28..f0120f2 100644 (file)
@@ -35,34 +35,14 @@ from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util.tf_export import tf_export
 
 
-def _gather(params, ids, name=None):
-  """Helper function for _embedding_lookup_and_transform.
-
-  This function gathers embeddings from a single tensor. The gather deals with
-  resource variables specially.
-
-  Args:
-    params: A `Tensor` of embeddings.
-    ids: A `Tensor` indexing the embeddings to be retrieved from `params`.
-    name: A name for the operation (optional).
-
-  Returns:
-    A `Tensor` with the same type as `params`.
-  """
-  if isinstance(params, resource_variable_ops.ResourceVariable):
-    return params.sparse_read(ids, name=name)
-  else:
-    return array_ops.gather(params, ids, name=name)
-
-
 def _clip(params, ids, max_norm):
   """Helper function for _embedding_lookup_and_transform.
 
   This function optionally clips embeddings to an l2-norm of max_norm.
 
   Args:
-    params: A `Tensor` of embeddings retrieved by `_gather`.
-    ids: The `ids` argument that was passed to `_gather`.
+    params: A `Tensor` of embeddings retrieved by `gather`.
+    ids: The `ids` argument that was passed to `gather`.
     max_norm: If provided, the embeddings are l2-normalized to the value of
       max_norm.
 
@@ -148,7 +128,8 @@ def _embedding_lookup_and_transform(params,
     ids = ops.convert_to_tensor(ids, name="ids")
     if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
       with ops.colocate_with(params[0]):
-        result = _clip(_gather(params[0], ids, name=name), ids, max_norm)
+        result = _clip(array_ops.gather(params[0], ids, name=name),
+                       ids, max_norm)
         if transform_fn:
           result = transform_fn(result)
         return result
@@ -212,7 +193,7 @@ def _embedding_lookup_and_transform(params,
       for p in xrange(np):
         pids = gather_ids[p]
         with ops.colocate_with(params[p]):
-          result = _gather(params[p], pids)
+          result = array_ops.gather(params[p], pids)
           if transform_fn:
             # If transform_fn is provided, the clip_by_norm precedes
             # the transform and hence must be co-located. See below