From 620348fb6d045dc1f644925a3828ebb12de944d7 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 26 Feb 2018 10:24:56 -0800 Subject: [PATCH] Move accumulate_n_v2 to core. PiperOrigin-RevId: 187042001 --- tensorflow/contrib/framework/BUILD | 38 ------- .../framework/python/ops/accumulate_n_v2.py | 111 --------------------- tensorflow/python/kernel_tests/BUILD | 34 +++++++ .../kernel_tests/accumulate_n_eager_test.py} | 27 ++--- .../kernel_tests/accumulate_n_test.py} | 34 +++---- tensorflow/python/ops/math_ops.py | 81 ++++++++------- 6 files changed, 99 insertions(+), 226 deletions(-) delete mode 100644 tensorflow/contrib/framework/python/ops/accumulate_n_v2.py rename tensorflow/{contrib/framework/python/ops/accumulate_n_v2_eager_test.py => python/kernel_tests/accumulate_n_eager_test.py} (72%) rename tensorflow/{contrib/framework/python/ops/accumulate_n_v2_test.py => python/kernel_tests/accumulate_n_test.py} (79%) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index dbdb5cf..1accb31 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -28,7 +28,6 @@ tf_custom_op_py_library( "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", - "python/ops/accumulate_n_v2.py", "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", @@ -161,23 +160,6 @@ py_test( ], ) -py_test( - name = "accumulate_n_v2_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - cuda_py_test( name = "critical_section_test", size = "medium", @@ -197,26 +179,6 @@ cuda_py_test( ) py_test( - name = "accumulate_n_v2_eager_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_eager_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/eager:backprop", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", - "//third_party/py/numpy", - ], -) - -py_test( name = "ops_test", size = "small", srcs = ["python/ops/ops_test.py"], diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py deleted file mode 100644 index 476528b..0000000 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Ops that will eventually be folded into tensorflow/python/ops/math_ops.py -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import math_ops - - - -def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): - """Returns the element-wise sum of a list of tensors. - - Optionally, pass `shape` and `tensor_dtype` for shape and type checking, - otherwise, these are inferred. - - `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not - wait for all of its inputs to be ready before beginning to sum. This can - save memory if inputs are ready at different times, since minimum temporary - storage is proportional to the output size rather than the inputs size. - - Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. - - For example: - - ```python - a = tf.constant([[1, 2], [3, 4]]) - b = tf.constant([[5, 0], [0, 6]]) - tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] - - # Explicitly pass shape and type - tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) - # [[7, 4], - # [6, 14]] - ``` - - Args: - inputs: A list of `Tensor` objects, each with same shape and type. - shape: Shape of elements of `inputs`. - tensor_dtype: The type of `inputs`. - name: A name for the operation (optional). - - Returns: - A `Tensor` of same shape and type as the elements of `inputs`. - - Raises: - ValueError: If `inputs` don't all have same shape and dtype or the shape - cannot be inferred. - """ - _INPUTS_ERR_MSG = ValueError("inputs must be a list of at least one Tensor" - "with the same dtype and shape") - if not inputs or not isinstance(inputs, (list, tuple)): - raise _INPUTS_ERR_MSG - inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): - raise _INPUTS_ERR_MSG - if not all(x.dtype == inputs[0].dtype for x in inputs): - raise _INPUTS_ERR_MSG - if shape is not None: - shape = tensor_shape.as_shape(shape) - else: - shape = tensor_shape.unknown_shape() - for input_tensor in inputs: - if isinstance(input_tensor, ops.Tensor): - shape = shape.merge_with(input_tensor.get_shape()) - - # tensor_dtype is for safety only; operator's output type computed in C++ - if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}" - .format(tensor_dtype, inputs[0].dtype)) - - if len(inputs) == 1 and name is None: - return inputs[0] - elif len(inputs) == 1 and name is not None: - return array_ops.identity(inputs[0], name=name) - elif context.in_eager_mode(): - # TemporaryVariable not currently supported in eager mode; fall back - # onto AddN for now. - # TODO(frreiss) remove this once the lifetime of eager variables gets - # addressed - return math_ops.add_n(inputs, name=name) - else: - return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) - -# The following code should eventually be merged into -# tensorflow/python/ops/math_grad.py -@ops.RegisterGradient("AccumulateNV2") -def _AddNGrad(op, grad): - """Same as gradient for AddN. Copies the gradient to all inputs.""" - # Not broadcasting. - return [grad] * len(op.inputs) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index d4ceb2e..c9aa4a2 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2892,6 +2892,40 @@ tf_py_test( ], ) +tf_py_test( + name = "accumulate_n_test", + size = "small", + srcs = ["accumulate_n_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "accumulate_n_eager_test", + size = "small", + srcs = ["accumulate_n_eager_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py similarity index 72% rename from tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py rename to tensorflow/python/kernel_tests/accumulate_n_eager_test.py index 35974b9..dc11b7d 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py @@ -12,48 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into -`ops.math_ops`. - -These test cases spefically exercise the `eager` APIs. They need to be in a -separate file from the remaining tests because eager mode is currently something -you can turn on but can't turn off for the lifetime of the current process.""" +"""Tests for new version of accumulate_n op.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 - from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test - class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): - """Tests of the new, differentiable version of accumulate_n""" + """Tests of the new, differentiable version of accumulate_n.""" def testMinimalEagerMode(self): forty = constant_op.constant(40) two = constant_op.constant(2) - answer = av2.accumulate_n_v2([forty, two]) + answer = math_ops.accumulate_n([forty, two]) self.assertEqual(42, answer.numpy()) - def testFloat(self): np.random.seed(12345) x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).numpy()) - self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).numpy()) + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).numpy()) + self.assertAllClose(x[0] * 5, + math_ops.accumulate_n([tf_x[0]] * 5).numpy()) def testGrad(self): np.random.seed(42) @@ -65,16 +58,14 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): ] def fn(first, second, third): - return av2.accumulate_n_v2([first, second, third]) + return math_ops.accumulate_n([first, second, third]) grad_fn = backprop.gradients_function(fn) grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) - self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 [elem.numpy() for elem in grad]) - if __name__ == "__main__": ops.enable_eager_execution() test.main() - diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/python/kernel_tests/accumulate_n_test.py similarity index 79% rename from tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py rename to tensorflow/python/kernel_tests/accumulate_n_test.py index 4596209..0a6d4ae 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/python/kernel_tests/accumulate_n_test.py @@ -12,42 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into -`ops.math_ops`.""" +"""Tests for new version of accumulate_n op.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 - from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest class AccumulateNV2Test(test_util.TensorFlowTestCase): - """Tests of the new, differentiable version of accumulate_n""" + """Tests of the new, differentiable version of accumulate_n.""" def testFloat(self): np.random.seed(12345) x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).eval()) - self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).eval()) + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).eval()) + self.assertAllClose(x[0] * 5, + math_ops.accumulate_n([tf_x[0]] * 5).eval()) def testInt(self): np.random.seed(54321) x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllEqual(sum(x), av2.accumulate_n_v2(tf_x).eval()) - self.assertAllEqual(x[0] * 6, av2.accumulate_n_v2([tf_x[0]] * 6).eval()) + self.assertAllEqual(sum(x), math_ops.accumulate_n(tf_x).eval()) + self.assertAllEqual(x[0] * 6, + math_ops.accumulate_n([tf_x[0]] * 6).eval()) def testGrad(self): np.random.seed(42) @@ -55,9 +55,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True) as sess: input_vars = [ variables.Variable(10.0 * np.random.random()) - for i in range(0, num_inputs) + for _ in range(0, num_inputs) ] - accum_n = av2.accumulate_n_v2(input_vars) + accum_n = math_ops.accumulate_n(input_vars) sess.run(variables.global_variables_initializer()) accum_n_grad = gradients.gradients(accum_n, input_vars) self.assertAllEqual( @@ -77,7 +77,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): ops.convert_to_tensor(x, dtype=dtypes_lib.float32) for x in random_arrays ] - tf_val = av2.accumulate_n_v2(random_tensors) + tf_val = math_ops.accumulate_n(random_tensors) np_val = random_arrays[0] for random_array in random_arrays[1:]: np_val += random_array @@ -86,7 +86,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): def testZeroArgs(self): with self.test_session(): with self.assertRaises(ValueError): - tf_val = av2.accumulate_n_v2([]) + tf_val = math_ops.accumulate_n([]) tf_val.eval() def testWrongShape(self): @@ -94,28 +94,28 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(0.2) b = variables.Variable(0.1) - tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[] + math_ops.accumulate_n([a, b], shape=[2, 2]) # Should be shape=[] def testIncompatibleShapes(self): with self.test_session(): with self.assertRaises(ValueError): a = variables.Variable(np.array([0.1, 0.2])) b = variables.Variable(np.array([[0.3], [0.4]])) - tf_val = av2.accumulate_n_v2([a, b]) + math_ops.accumulate_n([a, b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32) + math_ops.accumulate_n([a, b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + math_ops.accumulate_n([a], tensor_dtype=np.int32) if __name__ == "__main__": diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 2ae8b61..ed11fe5 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -161,14 +161,11 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import gen_spectral_ops -from tensorflow.python.ops import gen_state_ops -from tensorflow.python.ops import state_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_math_ops import * @@ -2218,14 +2215,12 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): Optionally, pass `shape` and `tensor_dtype` for shape and type checking, otherwise, these are inferred. - NOTE: This operation is not differentiable and cannot be used if inputs depend - on trainable variables. Please use `tf.add_n` for such cases. + `tf.accumulate_n` performs the same operation as `tf.add_n`, but does not + wait for all of its inputs to be ready before beginning to sum. This can + save memory if inputs are ready at different times, since minimum temporary + storage is proportional to the output size rather than the inputs size. - Aside from differentiability, `tf.accumulate_n` performs the same operation as - `tf.add_n`, but does not wait for all of its inputs to be ready before - beginning to sum. This can save memory if inputs are ready at different times, - since minimum temporary storage is proportional to the output size rather than - the inputs size. + `accumulate_n` is differentiable (but wasn't previous to TensorFlow 1.7). For example: @@ -2235,8 +2230,9 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): tf.accumulate_n([a, b, a]) # [[7, 4], [6, 14]] # Explicitly pass shape and type - tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) # [[7, 4], - # [6, 14]] + tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + # [[7, 4], + # [6, 14]] ``` Args: @@ -2252,20 +2248,17 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): ValueError: If `inputs` don't all have same shape and dtype or the shape cannot be inferred. """ - if context.in_eager_mode(): - # TODO(apassos) remove this once the lifetime of eager variables gets - # addressed. - raise ValueError("accumulate_n not supported in eager mode") + def _input_error(): + return ValueError( + "inputs must be a list of at least one Tensor with the " + "same dtype and shape") if not inputs or not isinstance(inputs, (list, tuple)): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) if not all(isinstance(x, ops.Tensor) for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() if not all(x.dtype == inputs[0].dtype for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() if shape is not None: shape = tensor_shape.as_shape(shape) else: @@ -2273,27 +2266,31 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): for input_tensor in inputs: if isinstance(input_tensor, ops.Tensor): shape = shape.merge_with(input_tensor.get_shape()) - if tensor_dtype is None: - tensor_dtype = inputs[0].dtype - if tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}".format( - tensor_dtype, inputs[0].dtype)) - if len(inputs) == 1: + + # tensor_dtype is for safety only; operator's output type computed in C++ + if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: + raise TypeError("tensor_dtype is {}, but input is of type {}" + .format(tensor_dtype, inputs[0].dtype)) + + if len(inputs) == 1 and name is None: return inputs[0] - with ops.name_scope(name, "AccumulateN", inputs) as name: - var = gen_state_ops._temporary_variable( - shape=tensor_shape.vector(0), dtype=tensor_dtype) - with ops.colocate_with(var): - zeros = array_ops.zeros_like(gen_control_flow_ops._merge(inputs)[0]) - zeros.set_shape(shape) - ref = state_ops.assign(var, zeros, validate_shape=False) - update_ops = [ - state_ops.assign_add(ref, input_tensor, use_locking=True) - for input_tensor in inputs - ] - with ops.control_dependencies(update_ops): - return gen_state_ops._destroy_temporary_variable( - ref, var_name=var.op.name, name=name) + elif len(inputs) == 1 and name is not None: + return array_ops.identity(inputs[0], name=name) + elif context.in_eager_mode(): + # TemporaryVariable not currently supported in eager mode; fall back + # onto AddN for now. + # TODO(frreiss) remove this once the lifetime of eager variables gets + # addressed + return add_n(inputs, name=name) + else: + return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) # pylint: disable=protected-access + + +@ops.RegisterGradient("AccumulateNV2") +def _accumulate_n_grad(op, grad): + """Same as gradient for AddN. Copies the gradient to all inputs.""" + # Not broadcasting. + return [grad] * len(op.inputs) @tf_export("nn.sigmoid", "sigmoid") -- 2.7.4