Add snippet illustrating discretized logistic mixture for WaveNet.
authorDustin Tran <trandustin@google.com>
Mon, 30 Apr 2018 18:14:51 +0000 (11:14 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 30 Apr 2018 18:17:51 +0000 (11:17 -0700)
Currently, the example manually centers the bins in order to capture ?rounding? intervals and not ?ceiling? intervals. In the future, we may simplify the example by expanding QuantizedDistribution with a binning argument.

PiperOrigin-RevId: 194814662

tensorflow/contrib/distributions/python/ops/quantized_distribution.py

index 1ef7651..eb94760 100644 (file)
@@ -128,7 +128,7 @@ The base distribution's `log_cdf` method must be defined on `y - 1`.
 class QuantizedDistribution(distributions.Distribution):
   """Distribution representing the quantization `Y = ceiling(X)`.
 
-  #### Definition in terms of sampling.
+  #### Definition in Terms of Sampling
 
   ```
   1. Draw X
@@ -138,7 +138,7 @@ class QuantizedDistribution(distributions.Distribution):
   5. Return Y
   ```
 
-  #### Definition in terms of the probability mass function.
+  #### Definition in Terms of the Probability Mass Function
 
   Given scalar random variable `X`, we define a discrete random variable `Y`
   supported on the integers as follows:
@@ -170,12 +170,62 @@ class QuantizedDistribution(distributions.Distribution):
 
   `P[Y = j]` is still the mass of `X` within the `jth` interval.
 
-  #### Caveats
+  #### Examples
+
+  We illustrate a mixture of discretized logistic distributions
+  [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit
+  audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in
+  a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures
+  `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints.
+  The lowest value has probability `P(X <= 0.5)` and the highest value has
+  probability `P(2**16 - 1.5 < X)`.
+
+  Below we assume a `wavenet` function. It takes as `input` right-shifted audio
+  samples of shape `[..., sequence_length]`. It returns a real-valued tensor of
+  shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and
+  `scale` parameter belonging to the logistic distribution, and a `logits`
+  parameter determining the unnormalized probability of that component.
+
+  ```python
+  tfd = tf.contrib.distributions
+  tfb = tfd.bijectors
+
+  net = wavenet(inputs)
+  loc, unconstrained_scale, logits = tf.split(net,
+                                              num_or_size_splits=3,
+                                              axis=-1)
+  scale = tf.nn.softplus(unconstrained_scale)
+
+  # Form mixture of discretized logistic distributions. Note we shift the
+  # logistic distribution by -0.5. This lets the quantization capture "rounding"
+  # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
+  discretized_logistic_dist = tfd.QuantizedDistribution(
+      distribution=tfd.TransformedDistribution(
+          distribution=tfd.Logistic(loc=loc, scale=scale),
+          bijector=tfb.AffineScalar(shift=-0.5)),
+      low=0.,
+      high=2**16 - 1.)
+  mixture_dist = tfd.MixtureSameFamily(
+      mixture_distribution=tfd.Categorical(logits=logits),
+      components_distribution=discretized_logistic_dist)
+
+  neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets))
+  train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood)
+  ```
+
+  After instantiating `mixture_dist`, we illustrate maximum likelihood by
+  calculating its log-probability of audio samples as `target` and optimizing.
+
+  #### References
 
-  Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than
-  a closed form function such as for a Poisson), computations such as mean and
-  entropy are better done with samples or approximations, and are not
-  implemented by this class.
+  [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
+       PixelCNN++: Improving the PixelCNN with discretized logistic mixture
+       likelihood and other modifications.
+       _International Conference on Learning Representations_, 2017.
+       https://arxiv.org/abs/1701.05517
+  [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech
+       Synthesis. _arXiv preprint arXiv:1711.10433_, 2017.
+       https://arxiv.org/abs/1711.10433
   """
 
   def __init__(self,