Add a _TensorProcessor for computing gradients but not applying them.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Feb 2018 22:55:09 +0000 (14:55 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Feb 2018 23:00:01 +0000 (15:00 -0800)
PiperOrigin-RevId: 185056764

tensorflow/python/training/optimizer.py
tensorflow/python/training/optimizer_test.py

index 9ec588b..b806251 100644 (file)
@@ -175,16 +175,39 @@ class _StreamingModelPortProcessor(_OptimizableVariable):
     return g
 
 
+class _TensorProcessor(_OptimizableVariable):
+  """Processor for ordinary Tensors.
+
+  Even though a Tensor can't really be updated, sometimes it is useful to
+  compute the gradients with respect to a Tensor using the optimizer. Updating
+  the Tensor is, of course, unsupported.
+  """
+
+  def __init__(self, v):
+    self._v = v
+
+  def target(self):
+    return self._v
+
+  def update_op(self, optimizer, g):
+    raise NotImplementedError("Trying to update a Tensor ", self._v)
+
+
 def _get_processor(v):
   """The processor of v."""
   if context.in_eager_mode():
-    return _DenseResourceVariableProcessor(v)
+    if isinstance(v, ops.Tensor):
+      return _TensorProcessor(v)
+    else:
+      return _DenseResourceVariableProcessor(v)
   if v.op.type == "VarHandleOp":
     return _DenseResourceVariableProcessor(v)
   if isinstance(v, variables.Variable):
     return _RefVariableProcessor(v)
   if v.op.type == "SubmodelPort":
     return _StreamingModelPortProcessor(v)
+  if isinstance(v, ops.Tensor):
+    return _TensorProcessor(v)
   raise NotImplementedError("Trying to optimize unsupported type ", v)
 
 
index 6bdae39..8652c61 100644 (file)
@@ -221,6 +221,22 @@ class OptimizerTest(test.TestCase):
       self.assertAllClose([-14., -13.], self.evaluate(var0))
       self.assertAllClose([-6., -5.], self.evaluate(var1))
 
+  @test_util.run_in_graph_and_eager_modes()
+  def testComputeGradientsWithTensors(self):
+    def f(x):
+      return x * x
+    x = ops.convert_to_tensor(1.0)
+    y = f if context.in_eager_mode() else f(x)
+    sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
+    grads_and_vars = sgd_op.compute_gradients(y, [x])
+    self.assertEqual(1, len(grads_and_vars))
+    grad, x_as_var = grads_and_vars[0]
+    self.assertIs(x, x_as_var)
+    self.assertEqual(2.0, self.evaluate(grad))
+
+    with self.assertRaises(NotImplementedError):
+      sgd_op.apply_gradients(grads_and_vars)
+
   def testTrainOp(self):
     with self.test_session():
       var0 = variables.Variable([1.0, 2.0])