# should be bitwise equal
self.assertEqual(input.grad, inputf.grad.to(dtype), prec=0)
- def _test_gumbel_softmax_st(self, cuda, dtype=torch.float):
- th = torch.cuda if cuda else torch
- """
- Things we might want to check:
- - if we make various draws, do we get different one-hot values?
- - is the proportion approximately in line with the softmax values?
- - with hard, is it one-hot?
- - with hard, is there still a gradient?
- """
+ def _test_gumbel_softmax_st_shapes(self, cuda, dtype, shape, dim, count_expected):
+ logits = torch.randn(shape, dtype=torch.float)
+ logits = logits.to(dtype)
+ if cuda:
+ logits = logits.cuda()
+
+ y_draw = F.gumbel_softmax(logits, hard=True, dim=dim)
+
+ # All values positive
+ self.assertGreaterEqual(y_draw.min(), 0)
+ # Shape unchanged
+ self.assertTrue(y_draw.shape == logits.shape)
+ # One choice per draw
+ self.assertEqual(y_draw.sum(), count_expected, prec=torch.finfo(y_draw.dtype).eps)
+
+ def _test_gumbel_softmax_straight_through(self, cuda, dtype):
num_draws = 100
- K = 3
- logits = torch.tensor([[0.2, 0.8, 0.1]])
- if dtype != torch.half:
- logits = logits.to(dtype)
- logits_softmax = torch.nn.functional.softmax(logits, 1)
- y_draws = torch.zeros(num_draws, K)
- preds = torch.zeros(num_draws)
+ logits = torch.tensor([[0.2, 0.8, 0.1]])
+ logits = logits.reshape([1, 3])
+ logits = logits.to(dtype).requires_grad_()
if cuda:
logits = logits.cuda()
- y_draws = y_draws.cuda()
- preds = preds.cuda()
+ probs = logits.softmax(dim=-1)
- exceed_limits = 0
+ counts = torch.zeros_like(logits)
for draw in range(num_draws):
- logits_var = logits.detach().requires_grad_()
- y_draw = torch.nn.functional.gumbel_softmax(
- logits_var,
- hard=True)
- assert y_draw.size() == logits.size()
- # check we have a gradient
- assert y_draw.requires_grad
- err = y_draw - logits.new_tensor([[0, 0.5, 0.3]])
- loss = (err * err).sum()
- loss.backward()
- if logits_var.grad.std() < 0.01 or logits_var.grad.std() > 1.0:
- exceed_limits += 1
- y_draws[draw] = y_draw.data
- _, pred = y_draw.max(1)
- preds[draw] = pred.data[0]
- assert exceed_limits / num_draws < 0.05
- # check it's approximately one-hot
- num_ones = (y_draws == 1).int().sum()
- num_zeros = (y_draws == 0).int().sum()
- assert num_ones + num_zeros == num_draws * K
- assert num_ones == num_draws
- # check output classes approx in line with logits
- num_class_one = (preds == 1).int().sum()
- assert num_class_one < num_draws
- assert num_class_one > num_draws / 3
-
- def test_gumbel_softmax_st(self):
- self._test_gumbel_softmax_st(False)
+ y_draw = F.gumbel_softmax(logits, hard=True)
+ counts = counts + y_draw
+
+ # All values positive
+ self.assertGreaterEqual(y_draw.min(), 0)
+ # Each experiment should result in 1 draw.
+ self.assertEqual(counts.sum(), num_draws, prec=torch.finfo(counts.dtype).eps)
+
+ # check results is asymptotically as expected.
+ expected = probs * num_draws
+ # ~z is approximately N(0,1) for unbiased count
+ z = (counts - expected) / (expected * (1 - probs)).sqrt()
+ # A (lazy) approximate 99% two-sided test:
+ # occurs with prob alpha~>=0.01 if unbiased
+ self.assertLess(z.abs().max().item(), 2.58)
+
+ def _test_gumbel_softmax_grad(self, cuda, dtype):
+ # "hard" and "not hard" should propagate same gradient.
+ device = torch.device("cuda") if cuda else torch.device("cpu")
+ logits_soft = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
+ logits_hard = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
+
+ seed = torch.random.get_rng_state()
+ y_soft = F.gumbel_softmax(logits_soft, hard=False)
+ torch.random.set_rng_state(seed)
+ y_hard = F.gumbel_softmax(logits_hard, hard=True)
+
+ y_soft.sum().backward()
+ y_hard.sum().backward()
+
+ # 2eps = 1x addition + 1x subtraction.
+ tol = 2 * torch.finfo(dtype).eps
+ self.assertAlmostEqual(logits_soft.grad, logits_hard.grad, delta=tol)
+
+ @repeat_test_for_types(NO_HALF_TENSORTYPES)
+ def test_gumbel_softmax(self, dtype=torch.float):
+ """
+ NO_HALF_TENSORTYPES because many half-ops doesnt work on cpu.
+ """
+ self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5], dim=0, count_expected=1)
+ self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5], dim=-1, count_expected=1)
+ self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4], dim=1, count_expected=5)
+ self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
+ self._test_gumbel_softmax_st_shapes(cuda=False, dtype=dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
+ self._test_gumbel_softmax_straight_through(cuda=False, dtype=dtype)
+ self._test_gumbel_softmax_grad(cuda=False, dtype=dtype)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
- def test_gumbel_softmax_st_cuda(self, dtype=torch.float):
- self._test_gumbel_softmax_st(True, dtype=dtype)
+ def test_gumbel_softmax_cuda(self, dtype=torch.float):
+ self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5], dim=0, count_expected=1)
+ self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5], dim=-1, count_expected=1)
+ self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4], dim=1, count_expected=5)
+ self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
+ self._test_gumbel_softmax_st_shapes(cuda=True, dtype=dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
+ self._test_gumbel_softmax_straight_through(cuda=True, dtype=dtype)
+ self._test_gumbel_softmax_grad(cuda=True, dtype=dtype)
def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
# check a known test example
@weak_script
-def _sample_gumbel(shape, eps=1e-10, out=None):
- # type: (List[int], float, Optional[Tensor]) -> Tensor
- """
- Sample from Gumbel(0, 1)
-
- based on
- https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
- (MIT license)
- """
- if out is None:
- U = torch.rand(shape)
- else:
- U = torch.jit._unwrap_optional(out).resize_(shape).uniform_()
- return - torch.log(eps - torch.log(U + eps))
-
-
-@weak_script
-def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
- # type: (Tensor, float, float) -> Tensor
- """
- Draw a sample from the Gumbel-Softmax distribution
-
- based on
- https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
- (MIT license)
- """
- dims = logits.dim()
- gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=torch.empty_like(logits))
- y = logits + gumbel_noise
- return softmax(y / tau, dims - 1)
-
-
-@weak_script
-def gumbel_softmax(logits, tau=1., hard=False, eps=1e-10):
- # type: (Tensor, float, bool, float) -> Tensor
+def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
+ # type: (Tensor, float, bool, float, int) -> Tensor
r"""
- Sample from the Gumbel-Softmax distribution and optionally discretize.
+ Samples from the `Gumbel-Softmax distribution`_ and optionally discretizes.
Args:
- logits: `[batch_size, num_features]` unnormalized log probabilities
+ logits: `[..., num_features]` unnormalized log probabilities
tau: non-negative scalar temperature
hard: if ``True``, the returned samples will be discretized as one-hot vectors,
but will be differentiated as if it is the soft sample in autograd
+ dim (int): A dimension along which softmax will be computed. Default: -1.
Returns:
- Sampled tensor of shape ``batch_size x num_features`` from the Gumbel-Softmax distribution.
+ Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
If ``hard=True``, the returned samples will be one-hot, otherwise they will
- be probability distributions that sum to 1 across features
+ be probability distributions that sum to 1 across `dim`.
+
+ .. note::
+ This function is here for legacy reasons, may be removed from nn.Functional in the future.
- Constraints:
+ .. note::
+ The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft`
- - Currently only work on 2D input :attr:`logits` tensor of shape ``batch_size x num_features``
+ It achieves two things:
+ - makes the output value exactly one-hot
+ (since we add then subtract y_soft value)
+ - makes the gradient equal to y_soft gradient
+ (since we strip all other gradients)
- Based on
- https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
- (MIT license)
+ Examples::
+ >>> logits = torch.randn(20, 32)
+ >>> # Sample soft categorical using reparametrization trick:
+ >>> F.gumbel_softmax(logits, tau=1, hard=False)
+ >>> # Sample hard categorical using "Straight-through" trick:
+ >>> F.gumbel_softmax(logits, tau=1, hard=True)
+
+ .. _Gumbel-Softmax distribution:
+ https://arxiv.org/abs/1611.00712
+ https://arxiv.org/abs/1611.01144
"""
- shape = logits.size()
- assert len(shape) == 2
- y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps)
+
+ if eps != 1e-10:
+ warnings.warn("`eps` parameter is deprecated and has no effect.")
+
+ gumbels = -torch.empty_like(logits).exponential_().log() # ~Gumbel(0,1)
+ gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
+ y_soft = gumbels.softmax(dim)
+
if hard:
- _, k = y_soft.max(-1)
- # this bit is based on
- # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
- y_hard = torch.zeros(shape, dtype=logits.dtype, device=logits.device).scatter_(-1, k.view(-1, 1), 1.0)
- # this cool bit of code achieves two things:
- # - makes the output value exactly one-hot (since we add then
- # subtract y_soft value)
- # - makes the gradient equal to y_soft gradient (since we strip
- # all other gradients)
- y = y_hard - y_soft.detach() + y_soft
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
else:
- y = y_soft
- return y
+ # Reparametrization trick.
+ ret = y_soft
+ return ret
@weak_script