From c83a54adcface7d4bb666d7c4fd3968ba980a50d Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 26 Mar 2018 15:39:54 -0700 Subject: [PATCH] Makes tf.gather not silently snapshot resource variables. PiperOrigin-RevId: 190537320 --- .../python/kernel_tests/attention_wrapper_test.py | 29 +++++++++++++--------- tensorflow/python/ops/array_ops.py | 17 ++++++++----- tensorflow/python/ops/embedding_ops.py | 29 ++++------------------ 3 files changed, 33 insertions(+), 42 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index c4139dd..07b3ad7 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -785,26 +785,31 @@ class AttentionWrapperTest(test.TestCase): wrapper.BahdanauAttention, wrapper.LuongAttention) expected_final_output = BasicDecoderOutput( - rnn_output=ResultSummary( - shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11798714846372604), - sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=7.933333333333334)) + rnn_output=ResultSummary(shape=(5, 3, 20), + dtype=dtype('float32'), + mean=0.11723966), + sample_id=ResultSummary(shape=(5, 3), + dtype=dtype('int32'), + mean=9.2666666666666675)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=dtype('float32'), mean=-0.0036486709), - h=ResultSummary( - shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)), - attention=ResultSummary( - shape=(5, 20), dtype=dtype('float32'), mean=0.11798714846372604), + c=ResultSummary(shape=(5, 9), + dtype=dtype('float32'), + mean=-0.003545674), + h=ResultSummary(shape=(5, 9), + dtype=dtype('float32'), + mean=-0.0018327223)), + attention=ResultSummary(shape=(5, 20), + dtype=dtype('float32'), + mean=0.11728073), time=3, alignments=( ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + alignment_history=(), attention_state=( ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), - ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), - alignment_history=()) + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125))) expected_final_alignment_history = ( ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125), ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125)) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ec7c14f..9106461 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2691,12 +2691,17 @@ reverse_sequence.__doc__ = deprecation.rewrite_argument_docstring( @tf_export("gather") def gather(params, indices, validate_indices=None, name=None, axis=0): - # TODO(rjryan): Remove "Gather" creation in favor of GatherV2 once the forward - # compatibility 3 week period has passed. - if axis == 0: - return gen_array_ops.gather( - params, indices, validate_indices=validate_indices, name=name) - else: + del validate_indices + if axis != 0: + # Note that we do a sparse_read here to avoid snapshotting the entire + # resource variable and doing a gather, which can be inefficient and lead to + # subtle race conditions. TODO(apassos) implement axis != 0 on sparse_read + return gen_array_ops.gather_v2(params, indices, axis, name=name) + try: + # TODO(apassos) find a less bad way of detecting resource variables without + # introducing a circular dependency. + return params.sparse_read(indices, name=name) + except AttributeError: return gen_array_ops.gather_v2(params, indices, axis, name=name) diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 20e4a28..f0120f2 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -35,34 +35,14 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export -def _gather(params, ids, name=None): - """Helper function for _embedding_lookup_and_transform. - - This function gathers embeddings from a single tensor. The gather deals with - resource variables specially. - - Args: - params: A `Tensor` of embeddings. - ids: A `Tensor` indexing the embeddings to be retrieved from `params`. - name: A name for the operation (optional). - - Returns: - A `Tensor` with the same type as `params`. - """ - if isinstance(params, resource_variable_ops.ResourceVariable): - return params.sparse_read(ids, name=name) - else: - return array_ops.gather(params, ids, name=name) - - def _clip(params, ids, max_norm): """Helper function for _embedding_lookup_and_transform. This function optionally clips embeddings to an l2-norm of max_norm. Args: - params: A `Tensor` of embeddings retrieved by `_gather`. - ids: The `ids` argument that was passed to `_gather`. + params: A `Tensor` of embeddings retrieved by `gather`. + ids: The `ids` argument that was passed to `gather`. max_norm: If provided, the embeddings are l2-normalized to the value of max_norm. @@ -148,7 +128,8 @@ def _embedding_lookup_and_transform(params, ids = ops.convert_to_tensor(ids, name="ids") if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): with ops.colocate_with(params[0]): - result = _clip(_gather(params[0], ids, name=name), ids, max_norm) + result = _clip(array_ops.gather(params[0], ids, name=name), + ids, max_norm) if transform_fn: result = transform_fn(result) return result @@ -212,7 +193,7 @@ def _embedding_lookup_and_transform(params, for p in xrange(np): pids = gather_ids[p] with ops.colocate_with(params[p]): - result = _gather(params[p], pids) + result = array_ops.gather(params[p], pids) if transform_fn: # If transform_fn is provided, the clip_by_norm precedes # the transform and hence must be co-located. See below -- 2.7.4