return var.scatter_sub(delta, use_locking=self._use_locking)
def _prepare(self):
- if not context.executing_eagerly() or self._learning_rate_tensor is None:
+ if not context.executing_eagerly() or not isinstance(
+ self._learning_rate_tensor, ops.EagerTensor):
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
name="learning_rate")
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
var1.eval())
+ def testCapturingInDefunWhileExecutingEagerly(self):
+ with context.eager_mode():
+ optimizer = gradient_descent.GradientDescentOptimizer(1.0)
+
+ def step():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ with backprop.GradientTape() as tape:
+ loss = v ** 2
+ grad = tape.gradient(loss, v)
+ optimizer.apply_gradients([(grad, v)])
+ return v.read_value()
+
+ compiled_step = function.defun(step)
+
+ self.assertEqual(float(step()), -1.0)
+ self.assertEqual(float(compiled_step()), -1.0)
+ # This shouldn't fail; in particular, the learning rate tensor should
+ # be an EagerTensor once again, not a graph Tensor.
+ self.assertEqual(float(step()), -1.0)
+
if __name__ == "__main__":
test.main()