Cleanup gumbel_softmax (#13339)
authorEgil Martinsson <egil.martinsson@gmail.com>
Thu, 17 Jan 2019 20:14:39 +0000 (12:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 20:56:35 +0000 (12:56 -0800)
Summary:
Fixes #12643, amends to #3341.

- Allow multidimensional input ~~(but apply softmax over `dim=-1`)~~ with `dim` argument
- Cleaner: Less lines of code
- Faster (1.32x speedup vs original, 2x speedup vs using `torch.Distributions`)
- Small fixes in docstring
- Remove some references in docstring. Was the linked (excellent) ipynb the first to do the straight-through trick? Instead, I propose changing to reference to the two papers most known for it.
- Add deprecationwarning for `eps`. It's not needed anymore.
- Initial commit keeps some code alternatives commented to exploit CI

- As of discussion when `gumbel_softmax` was added (#3341), this was merged into `torch.nn.functional` before all the work with `Distributions` and `Pyro`, and there will probably be multiple other best practices for this in the future.
I've tested building using the `Distributions`-api, but it was too slow, see below.

I therefore propose not using `Distributions` to keep it fast and simple, but adding a comment in docstring that `gumbel_softmax` may be deprecated in the future.

```
dist = torch.distributions.RelaxedOneHotCategorical(temperature=tau, logits=logits, validate_args=False)
y_soft = dist.rsample()
```

Pros:
* Built using tricks like `logsumexp` etc
* Explicitly uses `torch.distributions.utils._finfo` to avoid overflow (old implementation had an `eps` flag)
* Maintained for this exact purpose.

Cons:
* Very slow. Construction of distribution adds overhead see timings below. May be solved in future with speedups of `TransformedDistribution` and `Distribution`.
* Assumes which `dim` to apply softmax over.

```
    y_soft = logits.new(logits.shape)
    y_soft = (logits - y_soft.exponential_().log()) / tau  # Gumbel noise
    y_soft = y_soft.softmax(dim)  # Gumbel softmax noise
```
Pros:
* Faster

```
    import time
    start = time.time()
    num_draws = 1000000
    logits = torch.randn(1,3)

    for draw in range(num_draws):
        y_draw = gumbel_softmax(logits, hard=True)
        counts = counts + y_draw
    print(end - start)

>> 12.995795965194702

>> 7.658372640609741

>> 20.3382670879364
````

Decide on which path to chose. I'll commit in changes to the unit tests in a while to show that it passes both old tests and new tests. I'll also remove the commented code about `RelaxedOneHotCategorical`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13339

Differential Revision: D13092434

Pulled By: ezyang

fbshipit-source-id: 4c21788df336f4e9c2ac289022e395b261227b4b

test/test_nn.py
torch/nn/functional.py

index 2722635..cc966d0 100644 (file)
@@ -2140,64 +2140,90 @@ class TestNN(NNTestCase):
         # should be bitwise equal
         self.assertEqual(input.grad, inputf.grad.to(dtype), prec=0)
 
-    def _test_gumbel_softmax_st(self, cuda, dtype=torch.float):
-        th = torch.cuda if cuda else torch
-        """
-        Things we might want to check:
-        - if we make various draws, do we get different one-hot values?
-        - is the proportion approximately in line with the softmax values?
-        - with hard, is it one-hot?
-        - with hard, is there still a gradient?
-        """
+    def _test_gumbel_softmax_st_shapes(self, cuda, dtype, shape, dim, count_expected):
+        logits = torch.randn(shape, dtype=torch.float)
+        logits = logits.to(dtype)
+        if cuda:
+            logits = logits.cuda()
+
+        y_draw = F.gumbel_softmax(logits, hard=True, dim=dim)
+
+        # All values positive
+        self.assertGreaterEqual(y_draw.min(), 0)
+        # Shape unchanged
+        self.assertTrue(y_draw.shape == logits.shape)
+        # One choice per draw
+        self.assertEqual(y_draw.sum(), count_expected, prec=torch.finfo(y_draw.dtype).eps)
+
+    def _test_gumbel_softmax_straight_through(self, cuda, dtype):
         num_draws = 100
-        K = 3
-        logits = torch.tensor([[0.2, 0.8, 0.1]])
-        if dtype != torch.half:
-            logits = logits.to(dtype)
-        logits_softmax = torch.nn.functional.softmax(logits, 1)
-        y_draws = torch.zeros(num_draws, K)
-        preds = torch.zeros(num_draws)
 
+        logits = torch.tensor([[0.2, 0.8, 0.1]])
+        logits = logits.reshape([1, 3])
+        logits = logits.to(dtype).requires_grad_()
         if cuda:
             logits = logits.cuda()
-            y_draws = y_draws.cuda()
-            preds = preds.cuda()
+        probs = logits.softmax(dim=-1)
 
-        exceed_limits = 0
+        counts = torch.zeros_like(logits)
         for draw in range(num_draws):
-            logits_var = logits.detach().requires_grad_()
-            y_draw = torch.nn.functional.gumbel_softmax(
-                logits_var,
-                hard=True)
-            assert y_draw.size() == logits.size()
-            # check we have a gradient
-            assert y_draw.requires_grad
-            err = y_draw - logits.new_tensor([[0, 0.5, 0.3]])
-            loss = (err * err).sum()
-            loss.backward()
-            if logits_var.grad.std() < 0.01 or logits_var.grad.std() > 1.0:
-                exceed_limits += 1
-            y_draws[draw] = y_draw.data
-            _, pred = y_draw.max(1)
-            preds[draw] = pred.data[0]
-        assert exceed_limits / num_draws < 0.05
-        # check it's approximately one-hot
-        num_ones = (y_draws == 1).int().sum()
-        num_zeros = (y_draws == 0).int().sum()
-        assert num_ones + num_zeros == num_draws * K
-        assert num_ones == num_draws
-        # check output classes approx in line with logits
-        num_class_one = (preds == 1).int().sum()
-        assert num_class_one < num_draws
-        assert num_class_one > num_draws / 3
-
-    def test_gumbel_softmax_st(self):
-        self._test_gumbel_softmax_st(False)
+            y_draw = F.gumbel_softmax(logits, hard=True)
+            counts = counts + y_draw
+
+        # All values positive
+        self.assertGreaterEqual(y_draw.min(), 0)
+        # Each experiment should result in 1 draw.
+        self.assertEqual(counts.sum(), num_draws, prec=torch.finfo(counts.dtype).eps)
+
+        # check results is asymptotically as expected.
+        expected = probs * num_draws
+        # ~z is approximately N(0,1) for unbiased count
+        z = (counts - expected) / (expected * (1 - probs)).sqrt()
+        # A (lazy) approximate 99% two-sided test:
+        # occurs with prob alpha~>=0.01 if unbiased
+        self.assertLess(z.abs().max().item(), 2.58)
+
+    def _test_gumbel_softmax_grad(self, cuda, dtype):
+        # "hard" and "not hard" should propagate same gradient.
+        device = torch.device("cuda") if cuda else torch.device("cpu")
+        logits_soft = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
+        logits_hard = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
+
+        seed = torch.random.get_rng_state()
+        y_soft = F.gumbel_softmax(logits_soft, hard=False)
+        torch.random.set_rng_state(seed)
+        y_hard = F.gumbel_softmax(logits_hard, hard=True)
+
+        y_soft.sum().backward()
+        y_hard.sum().backward()
+
+        # 2eps = 1x addition + 1x subtraction.
+        tol = 2 * torch.finfo(dtype).eps
+        self.assertAlmostEqual(logits_soft.grad, logits_hard.grad, delta=tol)
+
+    @repeat_test_for_types(NO_HALF_TENSORTYPES)
+    def test_gumbel_softmax(self, dtype=torch.float):
+        """
+        NO_HALF_TENSORTYPES because many half-ops doesnt work on cpu.
+        """
+        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5], dim=0, count_expected=1)
+        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5], dim=-1, count_expected=1)
+        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4], dim=1, count_expected=5)
+        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
+        self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
+        self._test_gumbel_softmax_straight_through(cuda=False, dtype=dtype)
+        self._test_gumbel_softmax_grad(cuda=False, dtype=dtype)
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @repeat_test_for_types(ALL_TENSORTYPES)
-    def test_gumbel_softmax_st_cuda(self, dtype=torch.float):
-        self._test_gumbel_softmax_st(True, dtype=dtype)
+    def test_gumbel_softmax_cuda(self, dtype=torch.float):
+        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5], dim=0, count_expected=1)
+        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5], dim=-1, count_expected=1)
+        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4], dim=1, count_expected=5)
+        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
+        self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
+        self._test_gumbel_softmax_straight_through(cuda=True, dtype=dtype)
+        self._test_gumbel_softmax_grad(cuda=True, dtype=dtype)
 
     def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
         # check a known test example
index 36d50fe..485e4bf 100644 (file)
@@ -1262,80 +1262,63 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None):
 
 
 @weak_script
-def _sample_gumbel(shape, eps=1e-10, out=None):
-    # type: (List[int], float, Optional[Tensor]) -> Tensor
-    """
-    Sample from Gumbel(0, 1)
-
-    based on
-    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
-    (MIT license)
-    """
-    if out is None:
-        U = torch.rand(shape)
-    else:
-        U = torch.jit._unwrap_optional(out).resize_(shape).uniform_()
-    return - torch.log(eps - torch.log(U + eps))
-
-
-@weak_script
-def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
-    # type: (Tensor, float, float) -> Tensor
-    """
-    Draw a sample from the Gumbel-Softmax distribution
-
-    based on
-    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
-    (MIT license)
-    """
-    dims = logits.dim()
-    gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=torch.empty_like(logits))
-    y = logits + gumbel_noise
-    return softmax(y / tau, dims - 1)
-
-
-@weak_script
-def gumbel_softmax(logits, tau=1., hard=False, eps=1e-10):
-    # type: (Tensor, float, bool, float) -> Tensor
+def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
+    # type: (Tensor, float, bool, float, int) -> Tensor
     r"""
-    Sample from the Gumbel-Softmax distribution and optionally discretize.
+    Samples from the `Gumbel-Softmax distribution`_ and optionally discretizes.
 
     Args:
-      logits: `[batch_size, num_features]` unnormalized log probabilities
+      logits: `[..., num_features]` unnormalized log probabilities
       tau: non-negative scalar temperature
       hard: if ``True``, the returned samples will be discretized as one-hot vectors,
             but will be differentiated as if it is the soft sample in autograd
+      dim (int): A dimension along which softmax will be computed. Default: -1.
 
     Returns:
-      Sampled tensor of shape ``batch_size x num_features`` from the Gumbel-Softmax distribution.
+      Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
       If ``hard=True``, the returned samples will be one-hot, otherwise they will
-      be probability distributions that sum to 1 across features
+      be probability distributions that sum to 1 across `dim`.
+
+    .. note::
+      This function is here for legacy reasons, may be removed from nn.Functional in the future.
 
-    Constraints:
+    .. note::
+      The main trick for `hard` is to do  `y_hard - y_soft.detach() + y_soft`
 
-    - Currently only work on 2D input :attr:`logits` tensor of shape ``batch_size x num_features``
+      It achieves two things:
+      - makes the output value exactly one-hot
+      (since we add then subtract y_soft value)
+      - makes the gradient equal to y_soft gradient
+      (since we strip all other gradients)
 
-    Based on
-    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
-    (MIT license)
+    Examples::
+        >>> logits = torch.randn(20, 32)
+        >>> # Sample soft categorical using reparametrization trick:
+        >>> F.gumbel_softmax(logits, tau=1, hard=False)
+        >>> # Sample hard categorical using "Straight-through" trick:
+        >>> F.gumbel_softmax(logits, tau=1, hard=True)
+
+    .. _Gumbel-Softmax distribution:
+        https://arxiv.org/abs/1611.00712
+        https://arxiv.org/abs/1611.01144
     """
-    shape = logits.size()
-    assert len(shape) == 2
-    y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps)
+
+    if eps != 1e-10:
+        warnings.warn("`eps` parameter is deprecated and has no effect.")
+
+    gumbels = -torch.empty_like(logits).exponential_().log()  # ~Gumbel(0,1)
+    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
+    y_soft = gumbels.softmax(dim)
+
     if hard:
-        _, k = y_soft.max(-1)
-        # this bit is based on
-        # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
-        y_hard = torch.zeros(shape, dtype=logits.dtype, device=logits.device).scatter_(-1, k.view(-1, 1), 1.0)
-        # this cool bit of code achieves two things:
-        # - makes the output value exactly one-hot (since we add then
-        #   subtract y_soft value)
-        # - makes the gradient equal to y_soft gradient (since we strip
-        #   all other gradients)
-        y = y_hard - y_soft.detach() + y_soft
+        # Straight through.
+        index = y_soft.max(dim, keepdim=True)[1]
+        y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
+        ret = y_hard - y_soft.detach() + y_soft
     else:
-        y = y_soft
-    return y
+        # Reparametrization trick.
+        ret = y_soft
+    return ret
 
 
 @weak_script