If two identical functions are given different grad func,
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 27 Apr 2018 21:28:12 +0000 (14:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 27 Apr 2018 21:31:07 +0000 (14:31 -0700)
they should be named differently. Otherwise, tf.gradients
gets confused.

PiperOrigin-RevId: 194593519

tensorflow/python/framework/function.py
tensorflow/python/framework/function_test.py
tensorflow/python/ops/gradients_impl.py

index 2432ab3..e7f9e59 100644 (file)
@@ -353,8 +353,10 @@ class _DefinedFunction(object):
           raise ValueError("Function can not return None.")
       # Ensures each output is a Tensor in the function graph.
       outputs = [ops.convert_to_tensor(t) for t in outputs]
-      outputs = [temp_graph.capture(t) if t.graph is not temp_graph else t
-                 for t in outputs]
+      outputs = [
+          temp_graph.capture(t) if t.graph is not temp_graph else t
+          for t in outputs
+      ]
     self._extra_inputs = temp_graph.extra_inputs
     inputs.extend(temp_graph.extra_args)
     # pylint: disable=protected-access
@@ -362,9 +364,13 @@ class _DefinedFunction(object):
     # pylint: enable=protected-access
 
     # Extra kwargs are treated as attrs on the function def.
-    base_func_name = self._func_name or _get_func_name(self._func)
-    kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
-                                         **self._extra_kwargs)
+    if self._func_name:
+      base_func_name = self._func_name
+    else:
+      base_func_name = _get_func_name(self._func)
+      if self._grad_func:
+        base_func_name += ("_%s" % self._grad_func.name)
+    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
 
     if not temp_graph._c_graph:  # pylint: disable=protected-access
       # Build the FunctionDef
@@ -503,6 +509,12 @@ class _DefinedFunction(object):
     self.add_to_graph(ops.get_default_graph())
     args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
     ret, op = _call(self._signature, *args, **kwargs)
+
+    # Set a hidden attr in 'op' so that gradients_impl can refer back
+    # to this _DefinedFunction instance to access python_grad_func.
+    assert isinstance(op, ops.Operation)
+    setattr(op, "__defun", self)
+
     if self._shape_func is not None:
       shapes = self._shape_func(op)
       if len(shapes) != len(op.outputs):
@@ -591,12 +603,11 @@ class _OverloadedFunction(object):
         # _OverloadedFunction. We need to instantiate it with the
         # right input types.
         output_types = [
-            dtypes.DType(_.type)
-            for _ in defined._signature.output_arg  # pylint: disable=protected-access
+            dtypes.DType(_.type) for _ in defined._signature.output_arg  # pylint: disable=protected-access
         ]
         # pylint: disable=protected-access
-        defined._grad_func = self._grad_func.instantiate(
-            input_types + output_types)
+        defined._grad_func = self._grad_func.instantiate(input_types +
+                                                         output_types)
         # pylint: enable=protected-access
       self._overload[key] = defined
     return defined
@@ -833,8 +844,8 @@ def _call(sig, *inputs, **kwargs):
     ValueError: if the arguments are invalid.
   """
   if len(inputs) != len(sig.input_arg):
-    raise ValueError("Expected number of arguments: %d, received: %d" %
-                     (len(sig.input_arg), len(inputs)))
+    raise ValueError("Expected number of arguments: %d, received: %d" % (len(
+        sig.input_arg), len(inputs)))
   name = kwargs.pop("name", None)
   g = ops.get_default_graph()
   func_name = sig.name
@@ -950,8 +961,8 @@ def _from_library(lib):
       fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
   ]
   if not ready:
-    raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
-                     + str(lib))
+    raise ValueError(
+        "FunctionDefLibrary contains cyclic gradient functions!\n" + str(lib))
   # function name -> _DefinedFunction
   initialized = {}
 
index 594596e..a5c19f1 100644 (file)
@@ -136,7 +136,8 @@ class FunctionTest(test.TestCase):
   def testTooManyOutputNames(self):
 
     @function.Defun(
-        dtypes.float32, func_name="MyIdentity",
+        dtypes.float32,
+        func_name="MyIdentity",
         out_names=["my_result1", "my_result2"])
     def MyIdentityFunc(a):
       return a
@@ -239,10 +240,11 @@ class FunctionTest(test.TestCase):
 
     inp = np.array([-1, 1, 2, -2], dtype=np.float32)
     feed = {x: inp}
-    cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
-        optimizer_options=config_pb2.OptimizerOptions(
-            opt_level=config_pb2.OptimizerOptions.L1,
-            do_function_inlining=True)))
+    cfg = config_pb2.ConfigProto(
+        graph_options=config_pb2.GraphOptions(
+            optimizer_options=config_pb2.OptimizerOptions(
+                opt_level=config_pb2.OptimizerOptions.L1,
+                do_function_inlining=True)))
     with session.Session(graph=g, config=cfg) as sess:
       out, = sess.run(dx, feed)
     self.assertAllClose(1 - np.square(np.tanh(inp)), out)
@@ -334,18 +336,20 @@ class FunctionTest(test.TestCase):
       y = Foo(x)
       dx, = gradients_impl.gradients(y, [x])
 
-    cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
-        optimizer_options=config_pb2.OptimizerOptions(
-            opt_level=config_pb2.OptimizerOptions.L0,
-            do_common_subexpression_elimination=True,
-            do_function_inlining=True,
-            do_constant_folding=True)))
+    cfg = config_pb2.ConfigProto(
+        graph_options=config_pb2.GraphOptions(
+            optimizer_options=config_pb2.OptimizerOptions(
+                opt_level=config_pb2.OptimizerOptions.L0,
+                do_common_subexpression_elimination=True,
+                do_function_inlining=True,
+                do_constant_folding=True)))
 
     with self.test_session(graph=g, config=cfg):
       self.assertAllClose(y.eval(), 6.)
       self.assertAllClose(dx.eval(), 2.)
 
   def _testZNoDepOnY(self, use_const_grad_ys):
+
     @function.Defun(dtypes.float32, dtypes.float32)
     def Foo(x, y):  # pylint: disable=unused-argument
       return x * 2
@@ -775,9 +779,9 @@ class FunctionTest(test.TestCase):
 
       @function.Defun()
       def Foo():
-        return control_flow_ops.while_loop(lambda i: i < 10,
-                                           lambda i: i + x,
+        return control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + x,
                                            [0])
+
       y = Foo()
 
     with self.test_session(graph=g) as sess:
@@ -790,9 +794,8 @@ class FunctionTest(test.TestCase):
 
       @function.Defun(dtypes.bool)
       def Foo(pred):
-        return control_flow_ops.cond(pred,
-                                     lambda: x,
-                                     lambda: x + 1)
+        return control_flow_ops.cond(pred, lambda: x, lambda: x + 1)
+
       y = Foo(True)
       z = Foo(False)
 
@@ -945,6 +948,7 @@ class FunctionTest(test.TestCase):
     self.assertEqual(len(f.signature.input_arg), 3)
 
   def testGradientWithIntegerFunctionArgument(self):
+
     @function.Defun(dtypes.int32, dtypes.float32)
     def Foo(t, x):
       return x[t]
@@ -959,8 +963,7 @@ class FunctionTest(test.TestCase):
     x = np.zeros((2,)).astype(np.float32)
     with session.Session(graph=g) as sess:
       self.assertAllClose(
-          np.array([1.0, 0.0]).astype(np.float32),
-          sess.run(dinp, {inp: x}))
+          np.array([1.0, 0.0]).astype(np.float32), sess.run(dinp, {inp: x}))
 
   def testFunctionMarkedStateful(self):
 
@@ -1073,6 +1076,60 @@ class FunctionTest(test.TestCase):
       sess.run(var.initializer)
       _ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
 
+  def testSameFunctionDifferentGrads(self):
+
+    def PartOne(x):
+
+      # Default grad is dx = dy * 2
+      @function.Defun(dtypes.float32)
+      def Foo(x):
+        return x * 2
+
+      return Foo(x)
+
+    def PartTwo(x):
+
+      @function.Defun(dtypes.float32, dtypes.float32)
+      def Bar(x, dy):
+        return x + dy  # crazy backprop
+
+      @function.Defun(dtypes.float32, grad_func=Bar)
+      def Foo(x):
+        return x * 2
+
+      return Foo(x)
+
+    def PartThree(x):
+
+      def Bar(op, dy):
+        return op.inputs[0] * dy / 2  # crazy backprop
+
+      @function.Defun(dtypes.float32, python_grad_func=Bar)
+      def Foo(x):
+        return x * 2
+
+      return Foo(x)
+
+    g = ops.Graph()
+    with g.as_default():
+      x = constant_op.constant(100.)
+      x0 = x
+      y0 = PartOne(x0)
+      dx0, = gradients_impl.gradients(ys=[y0], xs=[x0])
+      x1 = x
+      y1 = PartTwo(x1)
+      dx1, = gradients_impl.gradients(ys=[y1], xs=[x1])
+      x2 = x
+      y2 = PartThree(x2)
+      dx2, = gradients_impl.gradients(ys=[y2], xs=[x2])
+
+    with self.test_session(graph=g) as sess:
+      v0, v1, v2 = sess.run([dx0, dx1, dx2])
+
+    self.assertAllEqual(v0, 2.)
+    self.assertAllEqual(v1, 101.)
+    self.assertAllEqual(v2, 50.)
+
 
 @test_util.with_c_shapes
 class FunctionsFromProtos(test.TestCase):
@@ -1271,9 +1328,10 @@ class FunctionsFromProtos(test.TestCase):
     @function.Defun(dtypes.int32, experimental_tag="tag_value")
     def FunctionWithAttr(i):
       return array_ops.identity(i)
+
     self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
-    self.assertEqual(
-        FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value")
+    self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s,
+                     b"tag_value")
 
 
 @test_util.with_c_shapes
@@ -1401,7 +1459,8 @@ class UnrollLSTMTest(test.TestCase):
       return Loop(cell, weights, inp)
 
     cell = function.Defun(dtypes.float32, dtypes.float32, dtypes.float32,
-                          dtypes.float32)(cell)
+                          dtypes.float32)(
+                              cell)
     if mode == "cell":
       # Just represent the LSTM as a function.
       return Loop(cell, weights, inp)
@@ -1500,12 +1559,13 @@ class FunctionInlineControlTest(test.TestCase):
 
   def testFoo(self):
     dtype = dtypes.float32
-    cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
-        optimizer_options=config_pb2.OptimizerOptions(
-            opt_level=config_pb2.OptimizerOptions.L0,
-            do_common_subexpression_elimination=True,
-            do_function_inlining=True,
-            do_constant_folding=True)))
+    cfg = config_pb2.ConfigProto(
+        graph_options=config_pb2.GraphOptions(
+            optimizer_options=config_pb2.OptimizerOptions(
+                opt_level=config_pb2.OptimizerOptions.L0,
+                do_common_subexpression_elimination=True,
+                do_function_inlining=True,
+                do_constant_folding=True)))
     cell_func_call_pattern = re.compile(r"Cell[^/]*\(")
     for noinline in [False, True]:
 
index 581ba7d..1448151 100644 (file)
@@ -256,21 +256,21 @@ def _DefaultGradYs(grad_ys,
         continue
       if y.dtype.is_floating or y.dtype.is_integer:
         if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
-          raise TypeError("Gradient type %s generated for real or "
-                          "integer-valued tensor %s with type %s must be "
-                          "real or integer" %
-                          (dtypes.as_dtype(grad_y.dtype).name, y,
-                           dtypes.as_dtype(y.dtype).name))
+          raise TypeError(
+              "Gradient type %s generated for real or "
+              "integer-valued tensor %s with type %s must be "
+              "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y,
+                                   dtypes.as_dtype(y.dtype).name))
       elif y.dtype.is_complex:
         if not grad_y.dtype.is_complex:
-          raise TypeError("Gradient type %s generated for complex-valued "
-                          "tensor %s with type %s must be real" %
-                          (dtypes.as_dtype(grad_y.dtype).name, y,
-                           dtypes.as_dtype(y.dtype).name))
+          raise TypeError(
+              "Gradient type %s generated for complex-valued "
+              "tensor %s with type %s must be real" % (dtypes.as_dtype(
+                  grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
       else:
-        raise TypeError("Tensor %s with type %s must be numeric "
-                        "to obtain a default gradient" %
-                        (y, dtypes.as_dtype(y.dtype).name))
+        raise TypeError(
+            "Tensor %s with type %s must be numeric "
+            "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name))
       # Create a grad_y tensor in the name scope of the gradient.
       # Required for TensorArrays to identify which gradient call a
       # grad_y value is coming from.
@@ -605,15 +605,19 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
           loop_state.ExitGradWhileContext(op, before=True)
 
         grad_fn = None
-        # pylint: disable=protected-access
         func_call = None
+        # pylint: disable=protected-access
         is_func_call = ops.get_default_graph()._is_function(op.type)
+        # pylint: enable=protected-access
         has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
         if has_out_grads and (op._id not in stop_ops):
           if is_func_call:
             func_call = ops.get_default_graph()._get_function(op.type)
+            # Note that __defun is not set if the graph is
+            # imported. If it's set, we prefer to access the original
+            # defun.
+            func_call = getattr(op, "__defun", func_call)
             grad_fn = func_call.python_grad_func
-            # pylint: enable=protected-access
           else:
             # A grad_fn must be defined, either as a function or as None
             # for ops that do not have gradients.