Simplify softmax_centered implementation.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 5 Mar 2018 20:46:30 +0000 (12:46 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 20:50:46 +0000 (12:50 -0800)
This also resolves a bug with softmax_centered.inverse not working on inputs with
partially known.

PiperOrigin-RevId: 187907026

tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py

index 62e3869..4a7679d 100644 (file)
@@ -21,7 +21,9 @@ from __future__ import print_function
 import numpy as np
 
 from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
 from tensorflow.python.platform import test
 
@@ -76,6 +78,32 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
           atol=0.,
           rtol=1e-7)
 
+  def testBijectorUnknownShape(self):
+    with self.test_session():
+      softmax = SoftmaxCentered(event_ndims=1)
+      self.assertEqual("softmax_centered", softmax.name)
+      x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
+      real_x = np.log([[2., 3, 4], [4., 8, 12]])
+      y = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
+      real_y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]]
+      self.assertAllClose(real_y, softmax.forward(x).eval(
+          feed_dict={x: real_x}))
+      self.assertAllClose(real_x, softmax.inverse(y).eval(
+          feed_dict={y: real_y}))
+      self.assertAllClose(
+          -np.sum(np.log(real_y), axis=1),
+          softmax.inverse_log_det_jacobian(y).eval(
+              feed_dict={y: real_y}),
+          atol=0.,
+          rtol=1e-7)
+      self.assertAllClose(
+          -softmax.inverse_log_det_jacobian(y).eval(
+              feed_dict={y: real_y}),
+          softmax.forward_log_det_jacobian(x).eval(
+              feed_dict={x: real_x}),
+          atol=0.,
+          rtol=1e-7)
+
   def testShapeGetters(self):
     with self.test_session():
       for x, y, b in ((tensor_shape.TensorShape([]),
index a9dcce6..24add40 100644 (file)
@@ -18,8 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
-
 from tensorflow.contrib.distributions.python.ops import distribution_util
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -161,33 +159,16 @@ class SoftmaxCentered(bijector.Bijector):
     # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization)
     #      = log(exp(x[i])/normalization) - log(y[end])
     #      = log(y[i]) - log(y[end])
-    shape = (np.asarray(y.shape.as_list(), dtype=np.int32)
-             if y.shape.is_fully_defined()
-             else array_ops.shape(y, name="shape"))
-    ndims = distribution_util.prefer_static_rank(y)
 
     # Do this first to make sure CSE catches that it'll happen again in
     # _inverse_log_det_jacobian.
     x = math_ops.log(y)
 
-    # We now extract the last coordinate of the rightmost dimension.
-    # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1].
-    begin = array_ops.one_hot(indices=ndims-1,
-                              depth=ndims,
-                              on_value=shape[-1]-np.array(1, dtype=shape.dtype),
-                              dtype=shape.dtype)
-    size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0)
-    log_normalization = -array_ops.strided_slice(x, begin, begin + size)
-
-    # Here we slice out all but the last coordinate; see above for idea.
-    begin = array_ops.zeros_like(shape)
-    size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0)
-    x = array_ops.strided_slice(x, begin, begin + size)
-
-    x += log_normalization
+    log_normalization = (-x[..., -1])[..., array_ops.newaxis]
+    x = x[..., :-1] + log_normalization
 
     if self._static_event_ndims == 0:
-      x = array_ops.squeeze(x, squeeze_dims=[ndims-1])
+      x = array_ops.squeeze(x, squeeze_dims=-1)
 
     # Set shape hints.
     if y.shape.ndims is not None: