WIP implemented Ordered bijector
authorJunpeng Lao <junpeng.lao@unifr.ch>
Tue, 17 Apr 2018 16:51:01 +0000 (18:51 +0200)
committerJunpeng Lao <junpeng.lao@unifr.ch>
Tue, 17 Apr 2018 16:51:01 +0000 (18:51 +0200)
tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py [new file with mode: 0644]
tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
tensorflow/contrib/distributions/python/ops/bijectors/ordered.py [new file with mode: 0644]

diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
new file mode 100644 (file)
index 0000000..1bcbfed
--- /dev/null
@@ -0,0 +1,111 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops.bijectors.ordered import Ordered
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.platform import test
+
+
+rng = np.random.RandomState(42)
+
+
+class OrderedBijectorTest(test.TestCase):
+  """Tests correctness of the ordered transformation."""
+
+  def testBijectorVector(self):
+    with self.test_session():
+      ordered = Ordered()
+      self.assertEqual("ordered", ordered.name)
+      x = np.log([[2., 3, 4], [4., 8, 12]])
+      y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]]
+      self.assertAllClose(y, ordered.forward(x).eval())
+      self.assertAllClose(x, ordered.inverse(y).eval())
+      self.assertAllClose(
+          -np.sum(np.log(y), axis=1),
+          ordered.inverse_log_det_jacobian(y, event_ndims=1).eval(),
+          atol=0.,
+          rtol=1e-7)
+      self.assertAllClose(
+          -ordered.inverse_log_det_jacobian(y, event_ndims=1).eval(),
+          ordered.forward_log_det_jacobian(x, event_ndims=1).eval(),
+          atol=0.,
+          rtol=1e-7)
+
+  def testBijectorUnknownShape(self):
+    with self.test_session():
+      ordered = Ordered()
+      self.assertEqual("ordered", ordered.name)
+      x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
+      real_x = np.log([[2., 3, 4], [4., 8, 12]])
+      y = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
+      real_y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]]
+      self.assertAllClose(real_y, ordered.forward(x).eval(
+          feed_dict={x: real_x}))
+      self.assertAllClose(real_x, ordered.inverse(y).eval(
+          feed_dict={y: real_y}))
+      self.assertAllClose(
+          -np.sum(np.log(real_y), axis=1),
+          ordered.inverse_log_det_jacobian(y, event_ndims=1).eval(
+              feed_dict={y: real_y}),
+          atol=0.,
+          rtol=1e-7)
+      self.assertAllClose(
+          -ordered.inverse_log_det_jacobian(y, event_ndims=1).eval(
+              feed_dict={y: real_y}),
+          ordered.forward_log_det_jacobian(x, event_ndims=1).eval(
+              feed_dict={x: real_x}),
+          atol=0.,
+          rtol=1e-7)
+
+  def testShapeGetters(self):
+    with self.test_session():
+      x = tensor_shape.TensorShape([4])
+      y = tensor_shape.TensorShape([5])
+      bijector = Ordered(validate_args=True)
+      self.assertAllEqual(y, bijector.forward_event_shape(x))
+      self.assertAllEqual(y.as_list(),
+                          bijector.forward_event_shape_tensor(
+                              x.as_list()).eval())
+      self.assertAllEqual(x, bijector.inverse_event_shape(y))
+      self.assertAllEqual(x.as_list(),
+                          bijector.inverse_event_shape_tensor(
+                              y.as_list()).eval())
+
+  def testBijectiveAndFinite(self):
+    with self.test_session():
+      ordered = Ordered()
+      x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32)
+      # Make y values on the simplex with a wide range.
+      y_0 = np.ones(5).astype(np.float32)
+      y_1 = (1e-5 * rng.rand(5)).astype(np.float32)
+      y_2 = (1e1 * rng.rand(5)).astype(np.float32)
+      y = np.array([y_0, y_1, y_2])
+      y /= y.sum(axis=0)
+      y = y.T  # y.shape = [5, 3]
+      assert_bijective_and_finite(ordered, x, y, event_ndims=1)
+
+
+if __name__ == "__main__":
+  test.main()
index babce80..51478db 100644 (file)
@@ -30,6 +30,7 @@
 @@Invert
 @@Kumaraswamy
 @@MaskedAutoregressiveFlow
+@@Ordered
 @@Permute
 @@PowerTransform
 @@RealNVP
@@ -67,6 +68,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
 from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
 from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import *
 from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import *
+from tensorflow.contrib.distributions.python.ops.bijectors.ordered import *
 from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
 from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
 from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py
new file mode 100644 (file)
index 0000000..ec8f660
--- /dev/null
@@ -0,0 +1,114 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ordered bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+    "Ordered",
+]
+
+
+class Ordered(bijector.Bijector):
+  """Bijector which maps a tensor x_k that has increasing elements in the last
+  dimension to an unconstrained tensor y_k.
+
+  On the last dimension of the tensor, Ordered bijector performs:
+  `y[0] = x[0]`
+  `y[1:] = math_ops.log(x[1:] - x[:-1])`
+
+  Example Use:
+
+  ```python
+  bijector.Ordered().forward(tf.log([2, 3, 4]))
+  # Result: [0.6931472, 3.6931472, 7.693147]
+
+  bijector.Ordered().inverse([0.2, 0.3, 0.4])
+  # Result: tf.log([2, 3, 4])
+  ```
+  """
+
+  def __init__(self,
+               validate_args=False,
+               name="ordered"):
+    self._graph_parents = []
+    self._name = name
+    super(Ordered, self).__init__(
+        forward_min_event_ndims=1,
+        validate_args=validate_args,
+        name=name)
+
+  def _forward_event_shape(self, input_shape):
+    if input_shape.ndims is None or input_shape[-1] is None:
+      return input_shape
+    return tensor_shape.TensorShape([input_shape[-1]])
+
+  def _forward_event_shape_tensor(self, input_shape):
+    return (input_shape[-1])[..., array_ops.newaxis]
+
+  def _inverse_event_shape(self, output_shape):
+    if output_shape.ndims is None or output_shape[-1] is None:
+      return output_shape
+    if output_shape[-1] <= 1:
+      raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1])
+    return tensor_shape.TensorShape([output_shape[-1]])
+
+  def _inverse_event_shape_tensor(self, output_shape):
+    if self.validate_args:
+      # It is not possible for a negative shape so we need only check <= 1.
+      is_greater_one = check_ops.assert_greater(
+          output_shape[-1], 1, message="Need last dimension greater than 1.")
+      output_shape = control_flow_ops.with_dependencies(
+          [is_greater_one], output_shape)
+    return (output_shape[-1])[..., array_ops.newaxis]
+
+  def _forward(self, x):
+    x = self._maybe_assert_valid_x(x)
+    y0 = array_ops.expand_dims(x[..., 0], -1)
+    yk = math_ops.log(x[..., 1:] - x[..., :-1])
+    y = array_ops.concat([y0, yk], axis=-1)
+    return y
+
+  def _inverse(self, y):
+    x0 = array_ops.expand_dims(y[..., 0], -1)
+    xk = math_ops.exp(y[..., 1:])
+    x = array_ops.concat([x0, xk], axis=-1)
+    return math_ops.cumsum(x, axis=-1)
+
+  def _inverse_log_det_jacobian(self, y):
+    return math_ops.reduce_sum(y[..., 1:], axis=-1)
+
+  def _forward_log_det_jacobian(self, x):
+    pass
+
+  def _maybe_assert_valid_x(self, x):
+    if not self.validate_args:
+      return x
+    is_valid = check_ops.is_strictly_increasing(
+        x,
+        message="Forward transformation input must be strictly increasing.")
+    return control_flow_ops.with_dependencies([is_valid], x)
\ No newline at end of file