From 68a2716ba9c553b0442bb6ecca8742c0ad5d37f8 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Tue, 3 Apr 2018 17:35:54 -0700 Subject: [PATCH] Refactors the eager guide to be more researcher-friendly. * Shows how to build layers * Orders the topics as they'd be ordered in a normal model PiperOrigin-RevId: 191526275 --- tensorflow/docs_src/programmers_guide/eager.md | 414 ++++++++++++++----------- 1 file changed, 232 insertions(+), 182 deletions(-) diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/programmers_guide/eager.md index 414653c..dc5b403 100644 --- a/tensorflow/docs_src/programmers_guide/eager.md +++ b/tensorflow/docs_src/programmers_guide/eager.md @@ -109,9 +109,106 @@ environments and is useful for writing code to [work with graphs](#work_with_gra import tensorflow.contrib.eager as tfe ``` -## Updating model parameters +## Dynamic control flow -### Automatic differentiation +A major benefit of eager execution is that all the functionality of the host +language is available while your model is executing. So, for example, +it is easy to write [fizzbuzz](https://en.wikipedia.org/wiki/Fizz_buzz): + +```py +def fizzbuzz(max_num): + counter = tf.constant(0) + for num in range(max_num): + num = tf.constant(num) + if num % 3 == 0 and num % 5 == 0: + print('FizzBuzz') + elif num % 3 == 0: + print('Fizz') + elif num % 5 == 0: + print('Buzz') + else: + print(num) + counter += 1 + return counter +``` + +This has conditionals that depend on tensor values and it prints these values +at runtime. + +## Build a model + +Many machine learning models are represented by composing layers. When +using TensorFlow with eager execution you can either write your own layers or +use a layer provided in the `tf.keras.layers` package. + +While you can use any Python object to represent a layer, +TensorFlow has `tf.keras.layers.Layer` as a convenient base class. Inherit from +it to implement your own layer: + +```py +class MySimpleLayer(tf.keras.layers.Layer): + def __init__(self, output_units): + self.output_units = output_units + + def build(self, input): + # The build method gets called the first time your layer is used. + # Creating variables on build() allows you to make their shape depend + # on the input shape and hence remove the need for the user to specify + # full shapes. It is possible to create variables during __init__() if + # you already know their full shapes. + self.kernel = self.add_variable( + "kernel", [input.shape[-1], self.output_units]) + + def call(self, input): + # Override call() instead of __call__ so we can perform some bookkeeping. + return tf.matmul(input, self.kernel) +``` + +Use `tf.keras.layers.Dense` layer instead of `MySimpleLayer` above as it has +a superset of its functionality (it can also add a bias). + +When composing layers into models you can use `tf.keras.Sequential` to represent +models which are a linear stack of layers. It is easy to use for basic models: + +```py +model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(784,)), # must declare input shape + tf.keras.layers.Dense(10) +]) +``` + +Alternatively, organize models in classes by inheriting from `tf.keras.Model`. +This is a container for layers that is a layer itself, allowing `tf.keras.Model` +objects to contain other `tf.keras.Model` objects. + +```py +class MNISTModel(tf.keras.Model): + def __init__(self): + super(MNISTModel, self).__init__() + self.dense1 = tf.keras.layers.Dense(units=10) + self.dense2 = tf.keras.layers.Dense(units=10) + + def call(self, input): + """Run the model.""" + result = self.dense1(input) + result = self.dense2(result) + result = self.dense2(result) # reuse variables from dense2 layer + return result + +model = MNISTModel() +``` + +It's not required to set an input shape for the `tf.keras.Model` class since +the parameters are set the first time input is passed to the layer. + +`tf.keras.layers` classes create and contain their own model variables that +are tied to the lifetime of their layer objects. To share layer variables, share +their objects. + + +## Eager training + +### Computing gradients [Automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is useful for implementing machine learning algorithms such as @@ -215,189 +312,12 @@ for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): global_step=tf.train.get_or_create_global_step()) ``` -#### Dynamic models - -`tfe.GradientTape` can also be used in dynamic models. This example for a -[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search) -algorithm looks like normal NumPy code, except there are gradients and is -differentiable, despite the complex control flow: - -```py -def line_search_step(fn, init_x, rate=1.0): - with tfe.GradientTape() as tape: - # Variables are automatically recorded, but manually watch a tensor - tape.watch(init_x) - value = fn(init_x) - grad, = tape.gradient(value, [init_x]) - grad_norm = tf.reduce_sum(grad * grad) - init_value = value - while value > init_value - rate * grad_norm: - x = init_x - rate * grad - value = fn(x) - rate /= 2.0 - return x, value -``` - -#### Additional functions to compute gradients - -`tfe.GradientTape` is a powerful interface for computing gradients, but there -is another [Autograd](https://github.com/HIPS/autograd)-style API available for -automatic differentiation. These functions are useful if writing math code with -only tensors and gradient functions, and without `tfe.Variables`: - -* `tfe.gradients_function` —Returns a function that computes the derivatives - of its input function parameter with respect to its arguments. The input - function parameter must return a scalar value. When the returned function is - invoked, it returns a list of `tf.Tensor` objects: one element for each - argument of the input function. Since anything of interest must be passed as a - function parameter, this becomes unwieldy if there's a dependency on many - trainable parameters. -* `tfe.value_and_gradients_function` —Similar to - `tfe.gradients_function`, but when the returned function is invoked, it - returns the value from the input function in addition to the list of - derivatives of the input function with respect to its arguments. - -In the following example, `tfe.gradients_function` takes the `square` -function as an argument and returns a function that computes the partial -derivatives of `square` with respect to its inputs. To calculate the derivative -of `square` at `3`, `grad(3.0)` returns `6`. - -```py -def square(x): - return tf.multiply(x, x) - -grad = tfe.gradients_function(square) - -square(3.) # => 9.0 -grad(3.) # => [6.0] - -# The second-order derivative of square: -gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) -gradgrad(3.) # => [2.0] - -# The third-order derivative is None: -gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0]) -gradgradgrad(3.) # => [None] - - -# With flow control: -def abs(x): - return x if x > 0. else -x - -grad = tfe.gradients_function(abs) - -grad(3.) # => [1.0] -grad(-3.) # => [-1.0] -``` - -### Custom gradients - -Custom gradients are an easy way to override gradients in eager and graph -execution. Within the forward function, define the gradient with respect to the -inputs, outputs, or intermediate results. For example, here's an easy way to clip -the norm of the gradients in the backward pass: - -```py -@tf.custom_gradient -def clip_gradient_by_norm(x, norm): - y = tf.identity(x) - def grad_fn(dresult): - return [tf.clip_by_norm(dresult, norm), None] - return y, grad_fn -``` - -Custom gradients are commonly used to provide a numerically stable gradient for a -sequence of operations: - -```py -def log1pexp(x): - return tf.log(1 + tf.exp(x)) -grad_log1pexp = tfe.gradients_function(log1pexp) - -# The gradient computation works fine at x = 0. -grad_log1pexp(0.) # => [0.5] - -# However, x = 100 fails because of numerical instability. -grad_log1pexp(100.) # => [nan] -``` - -Here, the `log1pexp` function can be analytically simplified with a custom -gradient. The implementation below reuses the value for `tf.exp(x)` that is -computed during the forward pass—making it more efficient by eliminating -redundant calculations: - -```py -@tf.custom_gradient -def log1pexp(x): - e = tf.exp(x) - def grad(dy): - return dy * (1 - 1 / (1 + e)) - return tf.log(1 + e), grad - -grad_log1pexp = tfe.gradients_function(log1pexp) - -# As before, the gradient computation works fine at x = 0. -grad_log1pexp(0.) # => [0.5] - -# And the gradient computation also works at x = 100. -grad_log1pexp(100.) # => [1.0] -``` - - -## Build and train models - -There are many parameters to optimize when calculating derivatives. TensorFlow -code is easier to read when structured into reusable classes and objects instead -of a single top-level function. Eager execution encourages the use of the -Keras-style layer classes in the `tf.keras.layers` module. Additionally, the -`tf.train.Optimizer` classes provide sophisticated techniques to calculate -parameter updates. The following example creates a multi-layer model that classifies the standard [MNIST handwritten digits](https://www.tensorflow.org/tutorials/layers). It demonstrates the optimizer and layer APIs to build trainable graphs in an eager execution environment. -### Build a model - -The `tf.keras.Sequential` model is a linear stack of layers. It is easy to -use for basic models: - -```py -model = tf.keras.Sequential([ - tf.keras.layers.Dense(10, input_shape=(784,)), # must declare input shape - tf.keras.layers.Dense(10) -]) -``` - -Alternatively, organize models in classes by inheriting from `tf.keras.Model`. -This is a container for layers that is a layer itself, allowing `tf.keras.Model` -objects to contain other `tf.keras.Model` objects. - -```py -class MNISTModel(tf.keras.Model): - def __init__(self): - super(MNISTModel, self).__init__() - self.dense1 = tf.keras.layers.Dense(units=10) - self.dense2 = tf.keras.layers.Dense(units=10) - - def call(self, input): - """Run the model.""" - result = self.dense1(input) - result = self.dense2(result) - result = self.dense2(result) # reuse variables from dense2 layer - return result - -model = MNISTModel() -``` - -It's not required to set an input shape for the `tf.keras.Model` class since -the parameters are set the first time input is passed to the layer. - -`tf.keras.layers` classes create and contain their own model variables that -are tied to the lifetime of their layer objects. To share layer variables, share -their objects. - ### Train a model Even without training, call the model and inspect the output in eager execution: @@ -661,11 +581,141 @@ for _ in range(iterations): ... ``` +## Advanced automatic differentiation topics + +### Dynamic models + +`tfe.GradientTape` can also be used in dynamic models. This example for a +[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search) +algorithm looks like normal NumPy code, except there are gradients and is +differentiable, despite the complex control flow: + +```py +def line_search_step(fn, init_x, rate=1.0): + with tfe.GradientTape() as tape: + # Variables are automatically recorded, but manually watch a tensor + tape.watch(init_x) + value = fn(init_x) + grad, = tape.gradient(value, [init_x]) + grad_norm = tf.reduce_sum(grad * grad) + init_value = value + while value > init_value - rate * grad_norm: + x = init_x - rate * grad + value = fn(x) + rate /= 2.0 + return x, value +``` + +### Additional functions to compute gradients + +`tfe.GradientTape` is a powerful interface for computing gradients, but there +is another [Autograd](https://github.com/HIPS/autograd)-style API available for +automatic differentiation. These functions are useful if writing math code with +only tensors and gradient functions, and without `tfe.Variables`: + +* `tfe.gradients_function` —Returns a function that computes the derivatives + of its input function parameter with respect to its arguments. The input + function parameter must return a scalar value. When the returned function is + invoked, it returns a list of `tf.Tensor` objects: one element for each + argument of the input function. Since anything of interest must be passed as a + function parameter, this becomes unwieldy if there's a dependency on many + trainable parameters. +* `tfe.value_and_gradients_function` —Similar to + `tfe.gradients_function`, but when the returned function is invoked, it + returns the value from the input function in addition to the list of + derivatives of the input function with respect to its arguments. + +In the following example, `tfe.gradients_function` takes the `square` +function as an argument and returns a function that computes the partial +derivatives of `square` with respect to its inputs. To calculate the derivative +of `square` at `3`, `grad(3.0)` returns `6`. + +```py +def square(x): + return tf.multiply(x, x) + +grad = tfe.gradients_function(square) + +square(3.) # => 9.0 +grad(3.) # => [6.0] + +# The second-order derivative of square: +gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) +gradgrad(3.) # => [2.0] + +# The third-order derivative is None: +gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0]) +gradgradgrad(3.) # => [None] + + +# With flow control: +def abs(x): + return x if x > 0. else -x + +grad = tfe.gradients_function(abs) + +grad(3.) # => [1.0] +grad(-3.) # => [-1.0] +``` + +### Custom gradients + +Custom gradients are an easy way to override gradients in eager and graph +execution. Within the forward function, define the gradient with respect to the +inputs, outputs, or intermediate results. For example, here's an easy way to clip +the norm of the gradients in the backward pass: + +```py +@tf.custom_gradient +def clip_gradient_by_norm(x, norm): + y = tf.identity(x) + def grad_fn(dresult): + return [tf.clip_by_norm(dresult, norm), None] + return y, grad_fn +``` + +Custom gradients are commonly used to provide a numerically stable gradient for a +sequence of operations: + +```py +def log1pexp(x): + return tf.log(1 + tf.exp(x)) +grad_log1pexp = tfe.gradients_function(log1pexp) + +# The gradient computation works fine at x = 0. +grad_log1pexp(0.) # => [0.5] + +# However, x = 100 fails because of numerical instability. +grad_log1pexp(100.) # => [nan] +``` + +Here, the `log1pexp` function can be analytically simplified with a custom +gradient. The implementation below reuses the value for `tf.exp(x)` that is +computed during the forward pass—making it more efficient by eliminating +redundant calculations: + +```py +@tf.custom_gradient +def log1pexp(x): + e = tf.exp(x) + def grad(dy): + return dy * (1 - 1 / (1 + e)) + return tf.log(1 + e), grad + +grad_log1pexp = tfe.gradients_function(log1pexp) + +# As before, the gradient computation works fine at x = 0. +grad_log1pexp(0.) # => [0.5] + +# And the gradient computation also works at x = 100. +grad_log1pexp(100.) # => [1.0] +``` + ## Performance -Computation is not automatically offloaded to GPUs during eager execution. To -explicitly direct a computation to a GPU, enclose it in a -`tf.device('/gpu:0')` block: +Computation is automatically offloaded to GPUs during eager execution. If you +want control over where a computation runs you can enclose it in a +`tf.device('/gpu:0')` block (or the CPU equivalent): ```py import time -- 2.7.4