Refactors the eager guide to be more researcher-friendly.
authorAlexandre Passos <apassos@google.com>
Wed, 4 Apr 2018 00:35:54 +0000 (17:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 00:38:30 +0000 (17:38 -0700)
 * Shows how to build layers
 * Orders the topics as they'd be ordered in a normal model

PiperOrigin-RevId: 191526275

tensorflow/docs_src/programmers_guide/eager.md

index 414653c..dc5b403 100644 (file)
@@ -109,9 +109,106 @@ environments and is useful for writing code to [work with graphs](#work_with_gra
 import tensorflow.contrib.eager as tfe
 ```
 
-## Updating model parameters
+## Dynamic control flow
 
-### Automatic differentiation
+A major benefit of eager execution is that all the functionality of the host
+language is available while your model is executing. So, for example,
+it is easy to write [fizzbuzz](https://en.wikipedia.org/wiki/Fizz_buzz):
+
+```py
+def fizzbuzz(max_num):
+  counter = tf.constant(0)
+  for num in range(max_num):
+    num = tf.constant(num)
+    if num % 3 == 0 and num % 5 == 0:
+      print('FizzBuzz')
+    elif num % 3 == 0:
+      print('Fizz')
+    elif num % 5 == 0:
+      print('Buzz')
+    else:
+      print(num)
+    counter += 1
+  return counter
+```
+
+This has conditionals that depend on tensor values and it prints these values
+at runtime.
+
+## Build a model
+
+Many machine learning models are represented by composing layers. When
+using TensorFlow with eager execution you can either write your own layers or
+use a layer provided in the `tf.keras.layers` package.
+
+While you can use any Python object to represent a layer,
+TensorFlow has `tf.keras.layers.Layer` as a convenient base class. Inherit from
+it to implement your own layer:
+
+```py
+class MySimpleLayer(tf.keras.layers.Layer):
+  def __init__(self, output_units):
+    self.output_units = output_units
+
+  def build(self, input):
+    # The build method gets called the first time your layer is used.
+    # Creating variables on build() allows you to make their shape depend
+    # on the input shape and hence remove the need for the user to specify
+    # full shapes. It is possible to create variables during __init__() if
+    # you already know their full shapes.
+    self.kernel = self.add_variable(
+      "kernel", [input.shape[-1], self.output_units])
+
+  def call(self, input):
+    # Override call() instead of __call__ so we can perform some bookkeeping.
+    return tf.matmul(input, self.kernel)
+```
+
+Use `tf.keras.layers.Dense` layer instead  of `MySimpleLayer` above as it has
+a superset of its functionality (it can also add a bias).
+
+When composing layers into models you can use `tf.keras.Sequential` to represent
+models which are a linear stack of layers. It is easy to use for basic models:
+
+```py
+model = tf.keras.Sequential([
+  tf.keras.layers.Dense(10, input_shape=(784,)),  # must declare input shape
+  tf.keras.layers.Dense(10)
+])
+```
+
+Alternatively, organize models in classes by inheriting from `tf.keras.Model`.
+This is a container for layers that is a layer itself, allowing `tf.keras.Model`
+objects to contain other `tf.keras.Model` objects.
+
+```py
+class MNISTModel(tf.keras.Model):
+  def __init__(self):
+    super(MNISTModel, self).__init__()
+    self.dense1 = tf.keras.layers.Dense(units=10)
+    self.dense2 = tf.keras.layers.Dense(units=10)
+
+  def call(self, input):
+    """Run the model."""
+    result = self.dense1(input)
+    result = self.dense2(result)
+    result = self.dense2(result)  # reuse variables from dense2 layer
+    return result
+
+model = MNISTModel()
+```
+
+It's not required to set an input shape for the `tf.keras.Model` class since
+the parameters are set the first time input is passed to the layer.
+
+`tf.keras.layers` classes create and contain their own model variables that
+are tied to the lifetime of their layer objects. To share layer variables, share
+their objects.
+
+
+## Eager training
+
+### Computing gradients
 
 [Automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
 is useful for implementing machine learning algorithms such as
@@ -215,189 +312,12 @@ for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
                             global_step=tf.train.get_or_create_global_step())
 ```
 
-#### Dynamic models
-
-`tfe.GradientTape` can also be used in dynamic models. This example for a
-[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search)
-algorithm looks like normal NumPy code, except there are gradients and is
-differentiable, despite the complex control flow:
-
-```py
-def line_search_step(fn, init_x, rate=1.0):
-  with tfe.GradientTape() as tape:
-    # Variables are automatically recorded, but manually watch a tensor
-    tape.watch(init_x)
-    value = fn(init_x)
-  grad, = tape.gradient(value, [init_x])
-  grad_norm = tf.reduce_sum(grad * grad)
-  init_value = value
-  while value > init_value - rate * grad_norm:
-    x = init_x - rate * grad
-    value = fn(x)
-    rate /= 2.0
-  return x, value
-```
-
-#### Additional functions to compute gradients
-
-`tfe.GradientTape` is a powerful interface for computing gradients, but there
-is another [Autograd](https://github.com/HIPS/autograd)-style API available for
-automatic differentiation. These functions are useful if writing math code with
-only tensors and gradient functions, and without `tfe.Variables`:
-
-* `tfe.gradients_function` —Returns a function that computes the derivatives
-  of its input function parameter with respect to its arguments. The input
-  function parameter must return a scalar value. When the returned function is
-  invoked, it returns a list of `tf.Tensor` objects: one element for each
-  argument of the input function. Since anything of interest must be passed as a
-  function parameter, this becomes unwieldy if there's a dependency on many
-  trainable parameters.
-* `tfe.value_and_gradients_function` —Similar to
-  `tfe.gradients_function`, but when the returned function is invoked, it
-  returns the value from the input function in addition to the list of
-  derivatives of the input function with respect to its arguments.
-
-In the following example, `tfe.gradients_function` takes the `square`
-function as an argument and returns a function that computes the partial
-derivatives of `square` with respect to its inputs. To calculate the derivative
-of `square` at `3`, `grad(3.0)` returns `6`.
-
-```py
-def square(x):
-  return tf.multiply(x, x)
-
-grad = tfe.gradients_function(square)
-
-square(3.)  # => 9.0
-grad(3.)    # => [6.0]
-
-# The second-order derivative of square:
-gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
-gradgrad(3.)  # => [2.0]
-
-# The third-order derivative is None:
-gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0])
-gradgradgrad(3.)  # => [None]
-
-
-# With flow control:
-def abs(x):
-  return x if x > 0. else -x
-
-grad = tfe.gradients_function(abs)
-
-grad(3.)   # => [1.0]
-grad(-3.)  # => [-1.0]
-```
-
-### Custom gradients
-
-Custom gradients are an easy way to override gradients in eager and graph
-execution. Within the forward function, define the gradient with respect to the
-inputs, outputs, or intermediate results. For example, here's an easy way to clip
-the norm of the gradients in the backward pass:
-
-```py
-@tf.custom_gradient
-def clip_gradient_by_norm(x, norm):
-  y = tf.identity(x)
-  def grad_fn(dresult):
-    return [tf.clip_by_norm(dresult, norm), None]
-  return y, grad_fn
-```
-
-Custom gradients are commonly used to provide a numerically stable gradient for a
-sequence of operations:
-
-```py
-def log1pexp(x):
-  return tf.log(1 + tf.exp(x))
-grad_log1pexp = tfe.gradients_function(log1pexp)
-
-# The gradient computation works fine at x = 0.
-grad_log1pexp(0.)  # => [0.5]
-
-# However, x = 100 fails because of numerical instability.
-grad_log1pexp(100.)  # => [nan]
-```
-
-Here, the `log1pexp` function can be analytically simplified with a custom
-gradient. The implementation below reuses the value for `tf.exp(x)` that is
-computed during the forward pass—making it more efficient by eliminating
-redundant calculations:
-
-```py
-@tf.custom_gradient
-def log1pexp(x):
-  e = tf.exp(x)
-  def grad(dy):
-    return dy * (1 - 1 / (1 + e))
-  return tf.log(1 + e), grad
-
-grad_log1pexp = tfe.gradients_function(log1pexp)
-
-# As before, the gradient computation works fine at x = 0.
-grad_log1pexp(0.)  # => [0.5]
-
-# And the gradient computation also works at x = 100.
-grad_log1pexp(100.)  # => [1.0]
-```
-
-
-## Build and train models
-
-There are many parameters to optimize when calculating derivatives. TensorFlow
-code is easier to read when structured into reusable classes and objects instead
-of a single top-level function. Eager execution encourages the use of the
-Keras-style layer classes in the `tf.keras.layers` module. Additionally, the
-`tf.train.Optimizer` classes provide sophisticated techniques to calculate
-parameter updates.
 
 The following example creates a multi-layer model that classifies the standard
 [MNIST handwritten digits](https://www.tensorflow.org/tutorials/layers). It
 demonstrates the optimizer and layer APIs to build trainable graphs in an eager
 execution environment.
 
-### Build a model
-
-The `tf.keras.Sequential` model is a linear stack of layers. It is easy to
-use for basic models:
-
-```py
-model = tf.keras.Sequential([
-  tf.keras.layers.Dense(10, input_shape=(784,)),  # must declare input shape
-  tf.keras.layers.Dense(10)
-])
-```
-
-Alternatively, organize models in classes by inheriting from `tf.keras.Model`.
-This is a container for layers that is a layer itself, allowing `tf.keras.Model`
-objects to contain other `tf.keras.Model` objects.
-
-```py
-class MNISTModel(tf.keras.Model):
-  def __init__(self):
-    super(MNISTModel, self).__init__()
-    self.dense1 = tf.keras.layers.Dense(units=10)
-    self.dense2 = tf.keras.layers.Dense(units=10)
-
-  def call(self, input):
-    """Run the model."""
-    result = self.dense1(input)
-    result = self.dense2(result)
-    result = self.dense2(result)  # reuse variables from dense2 layer
-    return result
-
-model = MNISTModel()
-```
-
-It's not required to set an input shape for the `tf.keras.Model` class since
-the parameters are set the first time input is passed to the layer.
-
-`tf.keras.layers` classes create and contain their own model variables that
-are tied to the lifetime of their layer objects. To share layer variables, share
-their objects.
-
 ### Train a model
 
 Even without training, call the model and inspect the output in eager execution:
@@ -661,11 +581,141 @@ for _ in range(iterations):
      ...
 ```
 
+## Advanced automatic differentiation topics
+
+### Dynamic models
+
+`tfe.GradientTape` can also be used in dynamic models. This example for a
+[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search)
+algorithm looks like normal NumPy code, except there are gradients and is
+differentiable, despite the complex control flow:
+
+```py
+def line_search_step(fn, init_x, rate=1.0):
+  with tfe.GradientTape() as tape:
+    # Variables are automatically recorded, but manually watch a tensor
+    tape.watch(init_x)
+    value = fn(init_x)
+  grad, = tape.gradient(value, [init_x])
+  grad_norm = tf.reduce_sum(grad * grad)
+  init_value = value
+  while value > init_value - rate * grad_norm:
+    x = init_x - rate * grad
+    value = fn(x)
+    rate /= 2.0
+  return x, value
+```
+
+### Additional functions to compute gradients
+
+`tfe.GradientTape` is a powerful interface for computing gradients, but there
+is another [Autograd](https://github.com/HIPS/autograd)-style API available for
+automatic differentiation. These functions are useful if writing math code with
+only tensors and gradient functions, and without `tfe.Variables`:
+
+* `tfe.gradients_function` —Returns a function that computes the derivatives
+  of its input function parameter with respect to its arguments. The input
+  function parameter must return a scalar value. When the returned function is
+  invoked, it returns a list of `tf.Tensor` objects: one element for each
+  argument of the input function. Since anything of interest must be passed as a
+  function parameter, this becomes unwieldy if there's a dependency on many
+  trainable parameters.
+* `tfe.value_and_gradients_function` —Similar to
+  `tfe.gradients_function`, but when the returned function is invoked, it
+  returns the value from the input function in addition to the list of
+  derivatives of the input function with respect to its arguments.
+
+In the following example, `tfe.gradients_function` takes the `square`
+function as an argument and returns a function that computes the partial
+derivatives of `square` with respect to its inputs. To calculate the derivative
+of `square` at `3`, `grad(3.0)` returns `6`.
+
+```py
+def square(x):
+  return tf.multiply(x, x)
+
+grad = tfe.gradients_function(square)
+
+square(3.)  # => 9.0
+grad(3.)    # => [6.0]
+
+# The second-order derivative of square:
+gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
+gradgrad(3.)  # => [2.0]
+
+# The third-order derivative is None:
+gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0])
+gradgradgrad(3.)  # => [None]
+
+
+# With flow control:
+def abs(x):
+  return x if x > 0. else -x
+
+grad = tfe.gradients_function(abs)
+
+grad(3.)   # => [1.0]
+grad(-3.)  # => [-1.0]
+```
+
+### Custom gradients
+
+Custom gradients are an easy way to override gradients in eager and graph
+execution. Within the forward function, define the gradient with respect to the
+inputs, outputs, or intermediate results. For example, here's an easy way to clip
+the norm of the gradients in the backward pass:
+
+```py
+@tf.custom_gradient
+def clip_gradient_by_norm(x, norm):
+  y = tf.identity(x)
+  def grad_fn(dresult):
+    return [tf.clip_by_norm(dresult, norm), None]
+  return y, grad_fn
+```
+
+Custom gradients are commonly used to provide a numerically stable gradient for a
+sequence of operations:
+
+```py
+def log1pexp(x):
+  return tf.log(1 + tf.exp(x))
+grad_log1pexp = tfe.gradients_function(log1pexp)
+
+# The gradient computation works fine at x = 0.
+grad_log1pexp(0.)  # => [0.5]
+
+# However, x = 100 fails because of numerical instability.
+grad_log1pexp(100.)  # => [nan]
+```
+
+Here, the `log1pexp` function can be analytically simplified with a custom
+gradient. The implementation below reuses the value for `tf.exp(x)` that is
+computed during the forward pass—making it more efficient by eliminating
+redundant calculations:
+
+```py
+@tf.custom_gradient
+def log1pexp(x):
+  e = tf.exp(x)
+  def grad(dy):
+    return dy * (1 - 1 / (1 + e))
+  return tf.log(1 + e), grad
+
+grad_log1pexp = tfe.gradients_function(log1pexp)
+
+# As before, the gradient computation works fine at x = 0.
+grad_log1pexp(0.)  # => [0.5]
+
+# And the gradient computation also works at x = 100.
+grad_log1pexp(100.)  # => [1.0]
+```
+
 ## Performance
 
-Computation is not automatically offloaded to GPUs during eager execution. To
-explicitly direct a computation to a GPU, enclose it in a
-`tf.device('/gpu:0')` block:
+Computation is automatically offloaded to GPUs during eager execution. If you
+want control over where a computation runs you can enclose it in a
+`tf.device('/gpu:0')` block (or the CPU equivalent):
 
 ```py
 import time