From a4dbc33512adb3705345b093a0aafec151e7e32d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Apr 2018 14:28:12 -0700 Subject: [PATCH] If two identical functions are given different grad func, they should be named differently. Otherwise, tf.gradients gets confused. PiperOrigin-RevId: 194593519 --- tensorflow/python/framework/function.py | 37 ++++++--- tensorflow/python/framework/function_test.py | 114 ++++++++++++++++++++------- tensorflow/python/ops/gradients_impl.py | 32 ++++---- 3 files changed, 129 insertions(+), 54 deletions(-) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 2432ab3..e7f9e59 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -353,8 +353,10 @@ class _DefinedFunction(object): raise ValueError("Function can not return None.") # Ensures each output is a Tensor in the function graph. outputs = [ops.convert_to_tensor(t) for t in outputs] - outputs = [temp_graph.capture(t) if t.graph is not temp_graph else t - for t in outputs] + outputs = [ + temp_graph.capture(t) if t.graph is not temp_graph else t + for t in outputs + ] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access @@ -362,9 +364,13 @@ class _DefinedFunction(object): # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. - base_func_name = self._func_name or _get_func_name(self._func) - kwargs_attr = _parse_kwargs_as_attrs(base_func_name, - **self._extra_kwargs) + if self._func_name: + base_func_name = self._func_name + else: + base_func_name = _get_func_name(self._func) + if self._grad_func: + base_func_name += ("_%s" % self._grad_func.name) + kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef @@ -503,6 +509,12 @@ class _DefinedFunction(object): self.add_to_graph(ops.get_default_graph()) args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs ret, op = _call(self._signature, *args, **kwargs) + + # Set a hidden attr in 'op' so that gradients_impl can refer back + # to this _DefinedFunction instance to access python_grad_func. + assert isinstance(op, ops.Operation) + setattr(op, "__defun", self) + if self._shape_func is not None: shapes = self._shape_func(op) if len(shapes) != len(op.outputs): @@ -591,12 +603,11 @@ class _OverloadedFunction(object): # _OverloadedFunction. We need to instantiate it with the # right input types. output_types = [ - dtypes.DType(_.type) - for _ in defined._signature.output_arg # pylint: disable=protected-access + dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access ] # pylint: disable=protected-access - defined._grad_func = self._grad_func.instantiate( - input_types + output_types) + defined._grad_func = self._grad_func.instantiate(input_types + + output_types) # pylint: enable=protected-access self._overload[key] = defined return defined @@ -833,8 +844,8 @@ def _call(sig, *inputs, **kwargs): ValueError: if the arguments are invalid. """ if len(inputs) != len(sig.input_arg): - raise ValueError("Expected number of arguments: %d, received: %d" % - (len(sig.input_arg), len(inputs))) + raise ValueError("Expected number of arguments: %d, received: %d" % (len( + sig.input_arg), len(inputs))) name = kwargs.pop("name", None) g = ops.get_default_graph() func_name = sig.name @@ -950,8 +961,8 @@ def _from_library(lib): fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None ] if not ready: - raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n" - + str(lib)) + raise ValueError( + "FunctionDefLibrary contains cyclic gradient functions!\n" + str(lib)) # function name -> _DefinedFunction initialized = {} diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 594596e..a5c19f1 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -136,7 +136,8 @@ class FunctionTest(test.TestCase): def testTooManyOutputNames(self): @function.Defun( - dtypes.float32, func_name="MyIdentity", + dtypes.float32, + func_name="MyIdentity", out_names=["my_result1", "my_result2"]) def MyIdentityFunc(a): return a @@ -239,10 +240,11 @@ class FunctionTest(test.TestCase): inp = np.array([-1, 1, 2, -2], dtype=np.float32) feed = {x: inp} - cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( - optimizer_options=config_pb2.OptimizerOptions( - opt_level=config_pb2.OptimizerOptions.L1, - do_function_inlining=True))) + cfg = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L1, + do_function_inlining=True))) with session.Session(graph=g, config=cfg) as sess: out, = sess.run(dx, feed) self.assertAllClose(1 - np.square(np.tanh(inp)), out) @@ -334,18 +336,20 @@ class FunctionTest(test.TestCase): y = Foo(x) dx, = gradients_impl.gradients(y, [x]) - cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( - optimizer_options=config_pb2.OptimizerOptions( - opt_level=config_pb2.OptimizerOptions.L0, - do_common_subexpression_elimination=True, - do_function_inlining=True, - do_constant_folding=True))) + cfg = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L0, + do_common_subexpression_elimination=True, + do_function_inlining=True, + do_constant_folding=True))) with self.test_session(graph=g, config=cfg): self.assertAllClose(y.eval(), 6.) self.assertAllClose(dx.eval(), 2.) def _testZNoDepOnY(self, use_const_grad_ys): + @function.Defun(dtypes.float32, dtypes.float32) def Foo(x, y): # pylint: disable=unused-argument return x * 2 @@ -775,9 +779,9 @@ class FunctionTest(test.TestCase): @function.Defun() def Foo(): - return control_flow_ops.while_loop(lambda i: i < 10, - lambda i: i + x, + return control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + x, [0]) + y = Foo() with self.test_session(graph=g) as sess: @@ -790,9 +794,8 @@ class FunctionTest(test.TestCase): @function.Defun(dtypes.bool) def Foo(pred): - return control_flow_ops.cond(pred, - lambda: x, - lambda: x + 1) + return control_flow_ops.cond(pred, lambda: x, lambda: x + 1) + y = Foo(True) z = Foo(False) @@ -945,6 +948,7 @@ class FunctionTest(test.TestCase): self.assertEqual(len(f.signature.input_arg), 3) def testGradientWithIntegerFunctionArgument(self): + @function.Defun(dtypes.int32, dtypes.float32) def Foo(t, x): return x[t] @@ -959,8 +963,7 @@ class FunctionTest(test.TestCase): x = np.zeros((2,)).astype(np.float32) with session.Session(graph=g) as sess: self.assertAllClose( - np.array([1.0, 0.0]).astype(np.float32), - sess.run(dinp, {inp: x})) + np.array([1.0, 0.0]).astype(np.float32), sess.run(dinp, {inp: x})) def testFunctionMarkedStateful(self): @@ -1073,6 +1076,60 @@ class FunctionTest(test.TestCase): sess.run(var.initializer) _ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0}) + def testSameFunctionDifferentGrads(self): + + def PartOne(x): + + # Default grad is dx = dy * 2 + @function.Defun(dtypes.float32) + def Foo(x): + return x * 2 + + return Foo(x) + + def PartTwo(x): + + @function.Defun(dtypes.float32, dtypes.float32) + def Bar(x, dy): + return x + dy # crazy backprop + + @function.Defun(dtypes.float32, grad_func=Bar) + def Foo(x): + return x * 2 + + return Foo(x) + + def PartThree(x): + + def Bar(op, dy): + return op.inputs[0] * dy / 2 # crazy backprop + + @function.Defun(dtypes.float32, python_grad_func=Bar) + def Foo(x): + return x * 2 + + return Foo(x) + + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(100.) + x0 = x + y0 = PartOne(x0) + dx0, = gradients_impl.gradients(ys=[y0], xs=[x0]) + x1 = x + y1 = PartTwo(x1) + dx1, = gradients_impl.gradients(ys=[y1], xs=[x1]) + x2 = x + y2 = PartThree(x2) + dx2, = gradients_impl.gradients(ys=[y2], xs=[x2]) + + with self.test_session(graph=g) as sess: + v0, v1, v2 = sess.run([dx0, dx1, dx2]) + + self.assertAllEqual(v0, 2.) + self.assertAllEqual(v1, 101.) + self.assertAllEqual(v2, 50.) + @test_util.with_c_shapes class FunctionsFromProtos(test.TestCase): @@ -1271,9 +1328,10 @@ class FunctionsFromProtos(test.TestCase): @function.Defun(dtypes.int32, experimental_tag="tag_value") def FunctionWithAttr(i): return array_ops.identity(i) + self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr) - self.assertEqual( - FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value") + self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s, + b"tag_value") @test_util.with_c_shapes @@ -1401,7 +1459,8 @@ class UnrollLSTMTest(test.TestCase): return Loop(cell, weights, inp) cell = function.Defun(dtypes.float32, dtypes.float32, dtypes.float32, - dtypes.float32)(cell) + dtypes.float32)( + cell) if mode == "cell": # Just represent the LSTM as a function. return Loop(cell, weights, inp) @@ -1500,12 +1559,13 @@ class FunctionInlineControlTest(test.TestCase): def testFoo(self): dtype = dtypes.float32 - cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( - optimizer_options=config_pb2.OptimizerOptions( - opt_level=config_pb2.OptimizerOptions.L0, - do_common_subexpression_elimination=True, - do_function_inlining=True, - do_constant_folding=True))) + cfg = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L0, + do_common_subexpression_elimination=True, + do_function_inlining=True, + do_constant_folding=True))) cell_func_call_pattern = re.compile(r"Cell[^/]*\(") for noinline in [False, True]: diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 581ba7d..1448151 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -256,21 +256,21 @@ def _DefaultGradYs(grad_ys, continue if y.dtype.is_floating or y.dtype.is_integer: if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: - raise TypeError("Gradient type %s generated for real or " - "integer-valued tensor %s with type %s must be " - "real or integer" % - (dtypes.as_dtype(grad_y.dtype).name, y, - dtypes.as_dtype(y.dtype).name)) + raise TypeError( + "Gradient type %s generated for real or " + "integer-valued tensor %s with type %s must be " + "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y, + dtypes.as_dtype(y.dtype).name)) elif y.dtype.is_complex: if not grad_y.dtype.is_complex: - raise TypeError("Gradient type %s generated for complex-valued " - "tensor %s with type %s must be real" % - (dtypes.as_dtype(grad_y.dtype).name, y, - dtypes.as_dtype(y.dtype).name)) + raise TypeError( + "Gradient type %s generated for complex-valued " + "tensor %s with type %s must be real" % (dtypes.as_dtype( + grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) else: - raise TypeError("Tensor %s with type %s must be numeric " - "to obtain a default gradient" % - (y, dtypes.as_dtype(y.dtype).name)) + raise TypeError( + "Tensor %s with type %s must be numeric " + "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name)) # Create a grad_y tensor in the name scope of the gradient. # Required for TensorArrays to identify which gradient call a # grad_y value is coming from. @@ -605,15 +605,19 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, loop_state.ExitGradWhileContext(op, before=True) grad_fn = None - # pylint: disable=protected-access func_call = None + # pylint: disable=protected-access is_func_call = ops.get_default_graph()._is_function(op.type) + # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) if has_out_grads and (op._id not in stop_ops): if is_func_call: func_call = ops.get_default_graph()._get_function(op.type) + # Note that __defun is not set if the graph is + # imported. If it's set, we prefer to access the original + # defun. + func_call = getattr(op, "__defun", func_call) grad_fn = func_call.python_grad_func - # pylint: enable=protected-access else: # A grad_fn must be defined, either as a function or as None # for ops that do not have gradients. -- 2.7.4