From e47f970f8395be1428c2d7fe0bf42aa4119803fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 8 Feb 2018 14:55:09 -0800 Subject: [PATCH] Add a _TensorProcessor for computing gradients but not applying them. PiperOrigin-RevId: 185056764 --- tensorflow/python/training/optimizer.py | 25 ++++++++++++++++++++++++- tensorflow/python/training/optimizer_test.py | 16 ++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 9ec588b..b806251 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -175,16 +175,39 @@ class _StreamingModelPortProcessor(_OptimizableVariable): return g +class _TensorProcessor(_OptimizableVariable): + """Processor for ordinary Tensors. + + Even though a Tensor can't really be updated, sometimes it is useful to + compute the gradients with respect to a Tensor using the optimizer. Updating + the Tensor is, of course, unsupported. + """ + + def __init__(self, v): + self._v = v + + def target(self): + return self._v + + def update_op(self, optimizer, g): + raise NotImplementedError("Trying to update a Tensor ", self._v) + + def _get_processor(v): """The processor of v.""" if context.in_eager_mode(): - return _DenseResourceVariableProcessor(v) + if isinstance(v, ops.Tensor): + return _TensorProcessor(v) + else: + return _DenseResourceVariableProcessor(v) if v.op.type == "VarHandleOp": return _DenseResourceVariableProcessor(v) if isinstance(v, variables.Variable): return _RefVariableProcessor(v) if v.op.type == "SubmodelPort": return _StreamingModelPortProcessor(v) + if isinstance(v, ops.Tensor): + return _TensorProcessor(v) raise NotImplementedError("Trying to optimize unsupported type ", v) diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 6bdae39..8652c61 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -221,6 +221,22 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) + @test_util.run_in_graph_and_eager_modes() + def testComputeGradientsWithTensors(self): + def f(x): + return x * x + x = ops.convert_to_tensor(1.0) + y = f if context.in_eager_mode() else f(x) + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + grads_and_vars = sgd_op.compute_gradients(y, [x]) + self.assertEqual(1, len(grads_and_vars)) + grad, x_as_var = grads_and_vars[0] + self.assertIs(x, x_as_var) + self.assertEqual(2.0, self.evaluate(grad)) + + with self.assertRaises(NotImplementedError): + sgd_op.apply_gradients(grads_and_vars) + def testTrainOp(self): with self.test_session(): var0 = variables.Variable([1.0, 2.0]) -- 2.7.4