BREAKING_CHANGE: Remove SigmoidCentered bijector.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Mar 2018 19:55:18 +0000 (12:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Mar 2018 20:00:02 +0000 (13:00 -0700)
 - SoftmaxCentered solely works on vector events, and supports broadcasting.
 - Sigmoid exists for event_ndims=0 cases.

PiperOrigin-RevId: 189380445

tensorflow/contrib/distributions/BUILD
tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py [deleted file]
tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py [deleted file]
tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py
tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md

index 6bd3f5f..e9c827a 100644 (file)
@@ -1105,25 +1105,6 @@ cuda_py_test(
     ],
 )
 
-cuda_py_test(
-    name = "sigmoid_centered_test",
-    size = "small",
-    srcs = ["python/kernel_tests/bijectors/sigmoid_centered_test.py"],
-    additional_deps = [
-        ":bijectors_py",
-        ":distributions_py",
-        "//third_party/py/numpy",
-        "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:framework_for_generated_wrappers",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:math_ops",
-        "//tensorflow/python:platform_test",
-    ],
-)
-
 # Tests for SinhArcSinh bijector.  The file name has the extra "_bijector" to
 # avoid BUILD rule name conflicts with the distribution by the same name.
 cuda_py_test(
index 20e7543..a748acd 100644 (file)
@@ -66,12 +66,10 @@ class ChainBijectorTest(test.TestCase):
   def testShapeGetters(self):
     with self.test_session():
       bijector = Chain([
-          SoftmaxCentered(
-              event_ndims=1, validate_args=True),
-          SoftmaxCentered(
-              event_ndims=0, validate_args=True)
+          SoftmaxCentered(validate_args=True),
+          SoftmaxCentered(validate_args=True),
       ])
-      x = tensor_shape.TensorShape([])
+      x = tensor_shape.TensorShape([1])
       y = tensor_shape.TensorShape([2 + 1])
       self.assertAllEqual(y, bijector.forward_event_shape(x))
       self.assertAllEqual(
index 28e3e31..58ba9ce 100644 (file)
@@ -37,8 +37,7 @@ class InvertBijectorTest(test.TestCase):
           bijectors.Exp(event_ndims=1),
           bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]),
           bijectors.Softplus(event_ndims=1),
-          bijectors.SoftmaxCentered(event_ndims=1),
-          bijectors.SigmoidCentered(),
+          bijectors.SoftmaxCentered(),
       ]:
         rev = bijectors.Invert(fwd)
         self.assertEqual("_".join(["invert", fwd.name]), rev.name)
@@ -61,9 +60,9 @@ class InvertBijectorTest(test.TestCase):
 
   def testShapeGetters(self):
     with self.test_session():
-      bijector = bijectors.Invert(bijectors.SigmoidCentered(validate_args=True))
+      bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True))
       x = tensor_shape.TensorShape([2])
-      y = tensor_shape.TensorShape([])
+      y = tensor_shape.TensorShape([1])
       self.assertAllEqual(y, bijector.forward_event_shape(x))
       self.assertAllEqual(
           y.as_list(),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py
deleted file mode 100644 (file)
index 4ff3f33..0000000
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for Bijector."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import SigmoidCentered
-from tensorflow.python.platform import test
-
-
-class SigmoidCenteredBijectorTest(test.TestCase):
-  """Tests correctness of the Y = g(X) = (1 + exp(-X))^-1 transformation."""
-
-  def testBijector(self):
-    with self.test_session():
-      sigmoid = SigmoidCentered()
-      self.assertEqual("sigmoid_centered", sigmoid.name)
-      x = np.log([[2., 3, 4],
-                  [4., 8, 12]])
-      y = [[[2. / 3, 1. / 3],
-            [3. / 4, 1. / 4],
-            [4. / 5, 1. / 5]],
-           [[4. / 5, 1. / 5],
-            [8. / 9, 1. / 9],
-            [12. / 13, 1. / 13]]]
-      self.assertAllClose(y, sigmoid.forward(x).eval())
-      self.assertAllClose(x, sigmoid.inverse(y).eval())
-      self.assertAllClose(
-          -np.sum(np.log(y), axis=2),
-          sigmoid.inverse_log_det_jacobian(y).eval(),
-          atol=0.,
-          rtol=1e-7)
-      self.assertAllClose(
-          -sigmoid.inverse_log_det_jacobian(y).eval(),
-          sigmoid.forward_log_det_jacobian(x).eval(),
-          atol=0.,
-          rtol=1e-7)
-
-
-if __name__ == "__main__":
-  test.main()
index 4a7679d..cad4dd1 100644 (file)
@@ -34,34 +34,9 @@ rng = np.random.RandomState(42)
 class SoftmaxCenteredBijectorTest(test.TestCase):
   """Tests correctness of the Y = g(X) = exp(X) / sum(exp(X)) transformation."""
 
-  def testBijectorScalar(self):
-    with self.test_session():
-      softmax = SoftmaxCentered()  # scalar by default
-      self.assertEqual("softmax_centered", softmax.name)
-      x = np.log([[2., 3, 4],
-                  [4., 8, 12]])
-      y = [[[2. / 3, 1. / 3],
-            [3. / 4, 1. / 4],
-            [4. / 5, 1. / 5]],
-           [[4. / 5, 1. / 5],
-            [8. / 9, 1. / 9],
-            [12. / 13, 1. / 13]]]
-      self.assertAllClose(y, softmax.forward(x).eval())
-      self.assertAllClose(x, softmax.inverse(y).eval())
-      self.assertAllClose(
-          -np.sum(np.log(y), axis=2),
-          softmax.inverse_log_det_jacobian(y).eval(),
-          atol=0.,
-          rtol=1e-7)
-      self.assertAllClose(
-          -softmax.inverse_log_det_jacobian(y).eval(),
-          softmax.forward_log_det_jacobian(x).eval(),
-          atol=0.,
-          rtol=1e-7)
-
   def testBijectorVector(self):
     with self.test_session():
-      softmax = SoftmaxCentered(event_ndims=1)
+      softmax = SoftmaxCentered()
       self.assertEqual("softmax_centered", softmax.name)
       x = np.log([[2., 3, 4], [4., 8, 12]])
       y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]]
@@ -80,7 +55,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
 
   def testBijectorUnknownShape(self):
     with self.test_session():
-      softmax = SoftmaxCentered(event_ndims=1)
+      softmax = SoftmaxCentered()
       self.assertEqual("softmax_centered", softmax.name)
       x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
       real_x = np.log([[2., 3, 4], [4., 8, 12]])
@@ -106,24 +81,21 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
 
   def testShapeGetters(self):
     with self.test_session():
-      for x, y, b in ((tensor_shape.TensorShape([]),
-                       tensor_shape.TensorShape([2]),
-                       SoftmaxCentered(
-                           event_ndims=0, validate_args=True)),
-                      (tensor_shape.TensorShape([4]),
-                       tensor_shape.TensorShape([5]),
-                       SoftmaxCentered(
-                           event_ndims=1, validate_args=True))):
-        self.assertAllEqual(y, b.forward_event_shape(x))
-        self.assertAllEqual(y.as_list(),
-                            b.forward_event_shape_tensor(x.as_list()).eval())
-        self.assertAllEqual(x, b.inverse_event_shape(y))
-        self.assertAllEqual(x.as_list(),
-                            b.inverse_event_shape_tensor(y.as_list()).eval())
+      x = tensor_shape.TensorShape([4])
+      y = tensor_shape.TensorShape([5])
+      bijector = SoftmaxCentered(validate_args=True)
+      self.assertAllEqual(y, bijector.forward_event_shape(x))
+      self.assertAllEqual(y.as_list(),
+                          bijector.forward_event_shape_tensor(
+                              x.as_list()).eval())
+      self.assertAllEqual(x, bijector.inverse_event_shape(y))
+      self.assertAllEqual(x.as_list(),
+                          bijector.inverse_event_shape_tensor(
+                              y.as_list()).eval())
 
   def testBijectiveAndFinite(self):
     with self.test_session():
-      softmax = SoftmaxCentered(event_ndims=1)
+      softmax = SoftmaxCentered()
       x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32)
       # Make y values on the simplex with a wide range.
       y_0 = np.ones(5).astype(np.float32)
index af13553..f0ba1ec 100644 (file)
@@ -186,12 +186,14 @@ class TransformedDistributionTest(test.TestCase):
       standard_normal = ds.Normal(loc=0., scale=1.)
       multi_logit_normal = self._cls()(
           distribution=standard_normal,
-          bijector=softmax)
-      x = [[-np.log(3.), 0.],
-           [np.log(3), np.log(5)]]
+          bijector=softmax,
+          event_shape=[1])
+      x = [[[-np.log(3.)], [0.]],
+           [[np.log(3)], [np.log(5)]]]
       y = softmax.forward(x).eval()
-      expected_log_pdf = (stats.norm(loc=0., scale=1.).logpdf(x) -
-                          np.sum(np.log(y), axis=-1))
+      expected_log_pdf = (
+          np.squeeze(stats.norm(loc=0., scale=1.).logpdf(x)) -
+          np.sum(np.log(y), axis=-1))
       self.assertAllClose(expected_log_pdf,
                           multi_logit_normal.log_prob(y).eval())
       self.assertAllClose(
index 452f1ca..bc6b025 100644 (file)
@@ -35,7 +35,6 @@
 @@RealNVP
 @@Reshape
 @@Sigmoid
-@@SigmoidCentered
 @@SinhArcsinh
 @@SoftmaxCentered
 @@Softplus
@@ -72,7 +71,6 @@ from tensorflow.contrib.distributions.python.ops.bijectors.power_transform impor
 from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import *
 from tensorflow.contrib.distributions.python.ops.bijectors.reshape import *
 from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import *
-from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import *
 from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import *
 from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
 from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py
deleted file mode 100644 (file)
index 223bc9d..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""SigmoidCentered bijector."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered
-
-
-__all__ = [
-    "SigmoidCentered",
-]
-
-
-class SigmoidCentered(softmax_centered.SoftmaxCentered):
-  """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)).
-
-  Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`.
-
-  See `bijector.SoftmaxCentered` for more details.
-  """
-
-  def __init__(self, validate_args=False, name="sigmoid_centered"):
-    super(SigmoidCentered, self).__init__(
-        event_ndims=0, validate_args=validate_args, name=name)
index 24add40..dc94fd0 100644 (file)
@@ -19,10 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.distributions.python.ops import distribution_util
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
@@ -45,17 +42,14 @@ class SoftmaxCentered(bijector.Bijector):
   e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last
   coordinate.
 
-  Because we append a coordinate, this bijector only supports `event_ndim in [0,
-  1]`, i.e., scalars and vectors.
-
   Example Use:
 
   ```python
-  bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4]))
+  bijector.SoftmaxCentered().forward(tf.log([2, 3, 4]))
   # Result: [0.2, 0.3, 0.4, 0.1]
   # Extra result: 0.1
 
-  bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1])
+  bijector.SoftmaxCentered().inverse([0.2, 0.3, 0.4, 0.1])
   # Result: tf.log([2, 3, 4])
   # Extra coordinate removed.
   ```
@@ -67,82 +61,47 @@ class SoftmaxCentered(bijector.Bijector):
   """
 
   def __init__(self,
-               event_ndims=0,
                validate_args=False,
                name="softmax_centered"):
     self._graph_parents = []
     self._name = name
-    with self._name_scope("init", values=[event_ndims]):
-      event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
-      event_ndims = tensor_util.constant_value(event_ndims)
-      if event_ndims is None or event_ndims not in [0, 1]:
-        raise ValueError("`event_ndims` must be a TF constant which is 0 or 1")
-    self._static_event_ndims = event_ndims
     super(SoftmaxCentered, self).__init__(
-        event_ndims=event_ndims,
+        event_ndims=1,
         validate_args=validate_args,
         name=name)
 
   def _forward_event_shape(self, input_shape):
-    if input_shape.ndims is None:
+    if input_shape.ndims is None or input_shape[-1] is None:
       return input_shape
-    if input_shape.ndims != self._static_event_ndims:
-      raise ValueError("input_shape.dims = %d != %d" %
-                       (input_shape.ndims, self._static_event_ndims))
-    if input_shape.ndims == 0:
-      return tensor_shape.TensorShape([2])
-    if input_shape.ndims == 1:
-      return tensor_shape.TensorShape(input_shape[0] + 1)
-    # Unreachable code:
-    raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims)
+    return tensor_shape.TensorShape([input_shape[-1] + 1])
 
   def _forward_event_shape_tensor(self, input_shape):
-    ndims = array_ops.shape(input_shape)
-    if self.validate_args:
-      # It is not possible for a negative shape so we need only check <= 1.
-      is_zero_or_one = check_ops.assert_equal(
-          ndims, 0 if self._static_event_ndims == 0 else 1,
-          message="event_ndims must be 0 or 1")
-      ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims)
-    if self._static_event_ndims == 0:
-      return ops.convert_to_tensor(
-          [2], dtype=dtypes.int32, name="output_shape")
-    return input_shape + 1
+    return (input_shape[-1] + 1)[..., array_ops.newaxis]
 
   def _inverse_event_shape(self, output_shape):
-    if output_shape.ndims is None:
+    if output_shape.ndims is None or output_shape[-1] is None:
       return output_shape
-    if output_shape.ndims != 1:
-      raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims)
-    if self._static_event_ndims == 0:
-      return tensor_shape.TensorShape([])
-    return tensor_shape.TensorShape(output_shape[0] - 1)
+    if output_shape[-1] <= 1:
+      raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1])
+    return tensor_shape.TensorShape([output_shape[-1] - 1])
 
   def _inverse_event_shape_tensor(self, output_shape):
-    ndims = array_ops.shape(output_shape)[0]
     if self.validate_args:
       # It is not possible for a negative shape so we need only check <= 1.
-      is_one = check_ops.assert_equal(
-          ndims, 1, message="event_ndims must be 1")
-      ndims = control_flow_ops.with_dependencies([is_one], ndims)
-    if self._static_event_ndims == 0:
-      return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape")
-    return array_ops.expand_dims(output_shape[0] - 1, dim=0)
+      is_greater_one = check_ops.assert_greater(
+          output_shape[-1], 1, message="Need last dimension greater than 1.")
+      output_shape = control_flow_ops.with_dependencies(
+          [is_greater_one], output_shape)
+    return (output_shape[-1] - 1)[..., array_ops.newaxis]
 
   def _forward(self, x):
     # Pad the last dim with a zeros vector. We need this because it lets us
     # infer the scale in the inverse function.
-    y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x
-    y = distribution_util.pad(y, axis=-1, back=True)
+    y = distribution_util.pad(x, axis=-1, back=True)
 
     # Set shape hints.
     if x.shape.ndims is not None:
-      shape = x.shape.as_list()
-      if self._static_event_ndims == 0:
-        shape += [2]
-      elif shape[-1] is not None:
-        shape[-1] += 1
-      shape = tensor_shape.TensorShape(shape)
+      shape = x.shape[:-1].concatenate(x.shape[-1] + 1)
       y.shape.assert_is_compatible_with(shape)
       y.set_shape(shape)
 
@@ -167,17 +126,9 @@ class SoftmaxCentered(bijector.Bijector):
     log_normalization = (-x[..., -1])[..., array_ops.newaxis]
     x = x[..., :-1] + log_normalization
 
-    if self._static_event_ndims == 0:
-      x = array_ops.squeeze(x, squeeze_dims=-1)
-
     # Set shape hints.
     if y.shape.ndims is not None:
-      shape = y.shape.as_list()
-      if self._static_event_ndims == 0:
-        shape = shape[:-1]
-      elif shape[-1] is not None:
-        shape[-1] -= 1
-      shape = tensor_shape.TensorShape(shape)
+      shape = y.shape[:-1].concatenate(y.shape[-1] - 1)
       x.shape.assert_is_compatible_with(shape)
       x.set_shape(shape)
 
@@ -203,19 +154,16 @@ class SoftmaxCentered(bijector.Bijector):
     return -math_ops.reduce_sum(math_ops.log(y), axis=-1)
 
   def _forward_log_det_jacobian(self, x):
-    if self._static_event_ndims == 0:
-      return x - 2. * nn_ops.softplus(x)
-    else:
-      # This code is similar to nn_ops.log_softmax but different because we have
-      # an implicit zero column to handle. I.e., instead of:
-      #   reduce_sum(logits - reduce_sum(exp(logits), dim))
-      # we must do:
-      #   log_normalization = 1 + reduce_sum(exp(logits))
-      #   -log_normalization + reduce_sum(logits - log_normalization)
-      log_normalization = nn_ops.softplus(
-          math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True))
-      fldj = (-log_normalization +
-              math_ops.reduce_sum(x - log_normalization,
-                                  axis=-1,
-                                  keep_dims=True))
-      return array_ops.squeeze(fldj, squeeze_dims=-1)
+    # This code is similar to nn_ops.log_softmax but different because we have
+    # an implicit zero column to handle. I.e., instead of:
+    #   reduce_sum(logits - reduce_sum(exp(logits), dim))
+    # we must do:
+    #   log_normalization = 1 + reduce_sum(exp(logits))
+    #   -log_normalization + reduce_sum(logits - log_normalization)
+    log_normalization = nn_ops.softplus(
+        math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True))
+    fldj = (-log_normalization +
+            math_ops.reduce_sum(x - log_normalization,
+                                axis=-1,
+                                keep_dims=True))
+    return array_ops.squeeze(fldj, squeeze_dims=-1)
index 0c747f8..3208ecd 100644 (file)
@@ -181,7 +181,7 @@ def quadrature_scheme_softmaxnormal_quantiles(
       edges = array_ops.reshape(edges, shape=array_ops.concat([
           [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0))
       quantiles = dist.quantile(edges)
-      quantiles = SoftmaxCentered(event_ndims=1).forward(quantiles)
+      quantiles = SoftmaxCentered().forward(quantiles)
       # Cyclically permute left by one.
       perm = array_ops.concat([
           math_ops.range(1, 1 + batch_ndims), [0]], axis=0)
index 0ce187b..e169897 100644 (file)
@@ -28,6 +28,5 @@ To apply a `Bijector`, use `distributions.TransformedDistribution`.
 *   @{tf.contrib.distributions.bijectors.Inline}
 *   @{tf.contrib.distributions.bijectors.Invert}
 *   @{tf.contrib.distributions.bijectors.PowerTransform}
-*   @{tf.contrib.distributions.bijectors.SigmoidCentered}
 *   @{tf.contrib.distributions.bijectors.SoftmaxCentered}
 *   @{tf.contrib.distributions.bijectors.Softplus}