Fix bug in @custom_gradient in Eager mode with numpy inputs
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 27 Apr 2018 14:21:37 +0000 (07:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 27 Apr 2018 14:23:51 +0000 (07:23 -0700)
PiperOrigin-RevId: 194538828

tensorflow/python/BUILD
tensorflow/python/ops/custom_gradient.py
tensorflow/python/ops/gradients_test.py

index 105fcba..44d9147 100644 (file)
@@ -1878,6 +1878,7 @@ py_library(
         ":math_grad",
         ":math_ops",
         ":platform",
+        ":resource_variable_ops",
         ":spectral_grad",
         ":util",
         ":variable_scope",
index c07c669..446ad1b 100644 (file)
@@ -24,6 +24,7 @@ from tensorflow.python.eager import tape as tape_lib
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_decorator
@@ -166,7 +167,11 @@ def _eager_mode_decorator(f, *args, **kwargs):
   all_inputs = list(args) + list(kwargs.values())
   # The variables that grad_fn needs to return gradients for are the set of
   # variables used that are *not* part of the inputs.
-  variables = list(set(tape.watched_variables()) - set(all_inputs))
+  variable_inputs = [
+      arg for arg in all_inputs
+      if isinstance(arg, resource_variable_ops.ResourceVariable)
+  ]
+  variables = list(set(tape.watched_variables()) - set(variable_inputs))
   flat_result = nest.flatten(result)
   # TODO(apassos) consider removing the identity below.
   flat_result = [gen_array_ops.identity(x) for x in flat_result]
index f336372..9d29617 100644 (file)
@@ -894,6 +894,22 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
       self.assertEqual(6., math_ops.reduce_sum(dx).numpy())
       self.assertEqual(8., math_ops.reduce_sum(dw).numpy())
 
+  def testWithNumpyInputs(self):
+    with context.eager_mode():
+
+      @custom_gradient.custom_gradient
+      def F(x):
+        out = x
+
+        def Grad(_):
+          return (None, None)
+
+        return out, Grad
+
+      x = np.ones((3, 2), dtype=np.float32)
+      # Smoke test to ensure numpy inputs are accepted
+      F(x)
+
 
 if __name__ == "__main__":
   googletest.main()