BREAKING_CHANGE: Split out event_ndims=0 bijectors from Affine and CholeskyOuterProduct.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 23:54:10 +0000 (16:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 23:57:27 +0000 (16:57 -0700)
    - Deprecate event_ndims argument
    - Create a Square bijector for the scalar case of CholeskyOuterProduct (which now only operates on matrices).
    - Create a AffineScalar bijector for the scalar case of Affine (which now only operates on vectors)

PiperOrigin-RevId: 188801116

14 files changed:
tensorflow/contrib/distributions/BUILD
tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py [new file with mode: 0644]
tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py [new file with mode: 0644]
tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
tensorflow/contrib/distributions/python/ops/bijectors/affine.py
tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py [new file with mode: 0644]
tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
tensorflow/contrib/distributions/python/ops/bijectors/square.py [new file with mode: 0644]
tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py

index 203fbf9..6bd3f5f 100644 (file)
@@ -817,6 +817,25 @@ cuda_py_test(
 )
 
 cuda_py_test(
+    name = "affine_scalar_test",
+    size = "small",
+    srcs = ["python/kernel_tests/bijectors/affine_scalar_test.py"],
+    additional_deps = [
+        ":bijectors_py",
+        ":distributions_py",
+        "//third_party/py/numpy",
+        "@six_archive//:six",
+        "//tensorflow/contrib/linalg:linalg_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+cuda_py_test(
     name = "affine_linear_operator_test",
     size = "small",
     srcs = ["python/kernel_tests/bijectors/affine_linear_operator_test.py"],
@@ -1165,6 +1184,25 @@ cuda_py_test(
 )
 
 cuda_py_test(
+    name = "square_test",
+    size = "small",
+    srcs = ["python/kernel_tests/bijectors/square_test.py"],
+    additional_deps = [
+        ":bijectors_py",
+        ":distributions_py",
+        "//third_party/py/numpy",
+        "@six_archive//:six",
+        "//tensorflow/contrib/linalg:linalg_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+cuda_py_test(
     name = "weibull_test",
     size = "small",
     srcs = ["python/kernel_tests/bijectors/weibull_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
new file mode 100644 (file)
index 0000000..16173a1
--- /dev/null
@@ -0,0 +1,153 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Affine Scalar Tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import AffineScalar
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
+from tensorflow.python.platform import test
+
+
+class AffineScalarBijectorTest(test.TestCase):
+  """Tests correctness of the Y = scale @ x + shift transformation."""
+
+  def testProperties(self):
+    with self.test_session():
+      mu = -1.
+      # scale corresponds to 1.
+      bijector = AffineScalar(shift=mu)
+      self.assertEqual("affine_scalar", bijector.name)
+
+  def testNoBatchScalar(self):
+    with self.test_session() as sess:
+
+      def static_run(fun, x):
+        return fun(x).eval()
+
+      def dynamic_run(fun, x_value):
+        x_value = np.array(x_value)
+        x = array_ops.placeholder(dtypes.float32, name="x")
+        return sess.run(fun(x), feed_dict={x: x_value})
+
+      for run in (static_run, dynamic_run):
+        mu = -1.
+        # Corresponds to scale = 2
+        bijector = AffineScalar(shift=mu, scale=2.)
+        x = [1., 2, 3]  # Three scalar samples (no batches).
+        self.assertAllClose([1., 3, 5], run(bijector.forward, x))
+        self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
+        self.assertAllClose([-np.log(2.)] * 3,
+                            run(bijector.inverse_log_det_jacobian, x))
+
+  def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self):
+    with self.test_session() as sess:
+
+      def static_run(fun, x):
+        return fun(x).eval()
+
+      def dynamic_run(fun, x_value):
+        x_value = np.array(x_value).astype(np.float64)
+        x = array_ops.placeholder(dtypes.float64, name="x")
+        return sess.run(fun(x), feed_dict={x: x_value})
+
+      for run in (static_run, dynamic_run):
+        mu = np.float64([1.])
+        # One batch, scalar.
+        # Corresponds to scale = 1.
+        bijector = AffineScalar(shift=mu)
+        x = np.float64([1.])  # One sample from one batches.
+        self.assertAllClose([2.], run(bijector.forward, x))
+        self.assertAllClose([0.], run(bijector.inverse, x))
+        self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+
+  def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self):
+    with self.test_session() as sess:
+
+      def static_run(fun, x):
+        return fun(x).eval()
+
+      def dynamic_run(fun, x_value):
+        x_value = np.array(x_value).astype(np.float64)
+        x = array_ops.placeholder(dtypes.float64, name="x")
+        return sess.run(fun(x), feed_dict={x: x_value})
+
+      for run in (static_run, dynamic_run):
+        multiplier = np.float64([2.])
+        # One batch, scalar.
+        # Corresponds to scale = 2, shift = 0.
+        bijector = AffineScalar(scale=multiplier)
+        x = np.float64([1.])  # One sample from one batches.
+        self.assertAllClose([2.], run(bijector.forward, x))
+        self.assertAllClose([0.5], run(bijector.inverse, x))
+        self.assertAllClose([np.log(0.5)],
+                            run(bijector.inverse_log_det_jacobian, x))
+
+  def testTwoBatchScalarIdentityViaIdentity(self):
+    with self.test_session() as sess:
+
+      def static_run(fun, x):
+        return fun(x).eval()
+
+      def dynamic_run(fun, x_value):
+        x_value = np.array(x_value)
+        x = array_ops.placeholder(dtypes.float32, name="x")
+        return sess.run(fun(x), feed_dict={x: x_value})
+
+      for run in (static_run, dynamic_run):
+        mu = [1., -1]
+        # Univariate, two batches.
+        # Corresponds to scale = 1.
+        bijector = AffineScalar(shift=mu)
+        x = [1., 1]  # One sample from each of two batches.
+        self.assertAllClose([2., 0], run(bijector.forward, x))
+        self.assertAllClose([0., 2], run(bijector.inverse, x))
+        self.assertAllClose([0., 0.], run(bijector.inverse_log_det_jacobian, x))
+
+  def testTwoBatchScalarIdentityViaScale(self):
+    with self.test_session() as sess:
+
+      def static_run(fun, x):
+        return fun(x).eval()
+
+      def dynamic_run(fun, x_value):
+        x_value = np.array(x_value)
+        x = array_ops.placeholder(dtypes.float32, name="x")
+        return sess.run(fun(x), feed_dict={x: x_value})
+
+      for run in (static_run, dynamic_run):
+        mu = [1., -1]
+        # Univariate, two batches.
+        # Corresponds to scale = 1.
+        bijector = AffineScalar(shift=mu, scale=[2., 1])
+        x = [1., 1]  # One sample from each of two batches.
+        self.assertAllClose([3., 0], run(bijector.forward, x))
+        self.assertAllClose([0., 2], run(bijector.inverse, x))
+        self.assertAllClose(
+            [-np.log(2), 0.], run(bijector.inverse_log_det_jacobian, x))
+
+  def testScalarCongruency(self):
+    with self.test_session():
+      bijector = AffineScalar(shift=3.6, scale=0.42)
+      assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)
+
+if __name__ == "__main__":
+  test.main()
index c915811..077e617 100644 (file)
@@ -25,7 +25,6 @@ import numpy as np
 from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
 from tensorflow.python.platform import test
 
 
@@ -36,192 +35,9 @@ class AffineBijectorTest(test.TestCase):
     with self.test_session():
       mu = -1.
       # scale corresponds to 1.
-      bijector = Affine(shift=mu, event_ndims=0)
+      bijector = Affine(shift=mu)
       self.assertEqual("affine", bijector.name)
 
-  def testNoBatchScalarViaIdentity(self):
-    with self.test_session() as sess:
-
-      def static_run(fun, x):
-        return fun(x).eval()
-
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
-
-      for run in (static_run, dynamic_run):
-        mu = -1.
-        # Corresponds to scale = 2
-        bijector = Affine(
-            shift=mu, scale_identity_multiplier=2., event_ndims=0)
-        self.assertEqual(0, bijector.event_ndims.eval())  # "is scalar"
-        x = [1., 2, 3]  # Three scalar samples (no batches).
-        self.assertAllClose([1., 3, 5], run(bijector.forward, x))
-        self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
-        self.assertAllClose(-np.log(2.),
-                            run(bijector.inverse_log_det_jacobian, x))
-
-  def testNoBatchScalarViaDiag(self):
-    with self.test_session() as sess:
-
-      def static_run(fun, x):
-        return fun(x).eval()
-
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
-
-      for run in (static_run, dynamic_run):
-        mu = -1.
-        # Corresponds to scale = 2
-        bijector = Affine(shift=mu, scale_identity_multiplier=2., event_ndims=0)
-        self.assertEqual(0, bijector.event_ndims.eval())  # "is scalar"
-        x = [1., 2, 3]  # Three scalar samples (no batches).
-        self.assertAllClose([1., 3, 5], run(bijector.forward, x))
-        self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
-        self.assertAllClose(-np.log(2.),
-                            run(bijector.inverse_log_det_jacobian, x))
-
-  def testWeirdSampleNoBatchScalarViaDiagMultiplier(self):
-    with self.test_session() as sess:
-
-      def static_run(fun, x):
-        return fun(x).eval()
-
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
-
-      for run in (static_run, dynamic_run):
-        mu = -1.
-        # Corresponds to scale = 2.
-        bijector = Affine(
-            shift=mu, scale_identity_multiplier=2., event_ndims=0)
-        self.assertEqual(0, bijector.event_ndims.eval())  # "is scalar"
-        x = [[1., 2, 3], [4, 5, 6]]  # Weird sample shape.
-        self.assertAllClose([[1., 3, 5],
-                             [7, 9, 11]],
-                            run(bijector.forward, x))
-        self.assertAllClose([[1., 1.5, 2.],
-                             [2.5, 3, 3.5]],
-                            run(bijector.inverse, x))
-        self.assertAllClose(-np.log(2.),
-                            run(bijector.inverse_log_det_jacobian, x))
-
-  def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self):
-    with self.test_session() as sess:
-
-      def static_run(fun, x):
-        return fun(x).eval()
-
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value).astype(np.float64)
-        x = array_ops.placeholder(dtypes.float64, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
-
-      for run in (static_run, dynamic_run):
-        mu = np.float64([1.])
-        # One batch, scalar.
-        # Corresponds to scale = 1.
-        bijector = Affine(shift=mu, event_ndims=0)
-        self.assertEqual(0, bijector.event_ndims.eval())  # "is scalar"
-        x = np.float64([1.])  # One sample from one batches.
-        self.assertAllClose([2.], run(bijector.forward, x))
-        self.assertAllClose([0.], run(bijector.inverse, x))
-        self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
-
-  def testOneBatchScalarViaIdentityIn64BitUserProvidesMultiplierOnly(self):
-    with self.test_session() as sess:
-
-      def static_run(fun, x):
-        return fun(x).eval()
-
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value).astype(np.float64)
-        x = array_ops.placeholder(dtypes.float64, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
-
-      for run in (static_run, dynamic_run):
-        multiplier = np.float64([2.])
-        # One batch, scalar.
-        # Corresponds to scale = 2, shift = 0.
-        bijector = Affine(scale_identity_multiplier=multiplier, event_ndims=0)
-        self.assertEqual(0, bijector.event_ndims.eval())  # "is scalar"
-        x = np.float64([1.])  # One sample from one batches.
-        self.assertAllClose([2.], run(bijector.forward, x))
-        self.assertAllClose([0.5], run(bijector.inverse, x))
-        self.assertAllClose([np.log(0.5)],
-                            run(bijector.inverse_log_det_jacobian, x))
-
-  def testOneBatchScalarViaDiagMultiplier(self):
-    with self.test_session() as sess:
-
-      def static_run(fun, x):
-        return fun(x).eval()
-
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
-
-      for run in (static_run, dynamic_run):
-        mu = [1.]
-        # One batch, scalar.
-        # Corresponds to scale = 1.
-        bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0)
-        self.assertEqual(0, bijector.event_ndims.eval())  # "is scalar"
-        x = [1.]  # One sample from one batches.
-        self.assertAllClose([2.], run(bijector.forward, x))
-        self.assertAllClose([0.], run(bijector.inverse, x))
-        self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
-
-  def testTwoBatchScalarIdentityViaIdentity(self):
-    with self.test_session() as sess:
-
-      def static_run(fun, x):
-        return fun(x).eval()
-
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
-
-      for run in (static_run, dynamic_run):
-        mu = [1., -1]
-        # Univariate, two batches.
-        # Corresponds to scale = 1.
-        bijector = Affine(shift=mu, event_ndims=0)
-        self.assertEqual(0, bijector.event_ndims.eval())  # "is scalar"
-        x = [1., 1]  # One sample from each of two batches.
-        self.assertAllClose([2., 0], run(bijector.forward, x))
-        self.assertAllClose([0., 2], run(bijector.inverse, x))
-        self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
-
-  def testTwoBatchScalarIdentityViaDiagMultiplier(self):
-    with self.test_session() as sess:
-
-      def static_run(fun, x):
-        return fun(x).eval()
-
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
-
-      for run in (static_run, dynamic_run):
-        mu = [1., -1]
-        # Univariate, two batches.
-        # Corresponds to scale = 1.
-        bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0)
-        self.assertEqual(0, bijector.event_ndims.eval())  # "is scalar"
-        x = [1., 1]  # One sample from each of two batches.
-        self.assertAllClose([2., 0], run(bijector.forward, x))
-        self.assertAllClose([0., 2], run(bijector.inverse, x))
-        self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
-
   def testNoBatchMultivariateIdentity(self):
     with self.test_session() as sess:
 
@@ -238,7 +54,6 @@ class AffineBijectorTest(test.TestCase):
         # Multivariate
         # Corresponds to scale = [[1., 0], [0, 1.]]
         bijector = Affine(shift=mu)
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [1., 1]
         # matmul(sigma, x) + shift
         # = [-1, -1] + [1, -1]
@@ -269,7 +84,6 @@ class AffineBijectorTest(test.TestCase):
         # Multivariate
         # Corresponds to scale = [[2., 0], [0, 1.]]
         bijector = Affine(shift=mu, scale_diag=[2., 1])
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [1., 1]
         # matmul(sigma, x) + shift
         # = [-1, -1] + [1, -1]
@@ -297,22 +111,17 @@ class AffineBijectorTest(test.TestCase):
       x = array_ops.placeholder(dtypes.float32, name="x")
       mu = array_ops.placeholder(dtypes.float32, name="mu")
       scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag")
-      event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims")
 
       x_value = np.array([[1., 1]], dtype=np.float32)
       mu_value = np.array([1., -1], dtype=np.float32)
       scale_diag_value = np.array([2., 2], dtype=np.float32)
-      event_ndims_value = np.array(1, dtype=np.int32)
       feed_dict = {
           x: x_value,
           mu: mu_value,
           scale_diag: scale_diag_value,
-          event_ndims: event_ndims_value
       }
 
-      bijector = Affine(
-          shift=mu, scale_diag=scale_diag, event_ndims=event_ndims)
-      self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict))
+      bijector = Affine(shift=mu, scale_diag=scale_diag)
       self.assertAllClose([[3., 1]], sess.run(bijector.forward(x), feed_dict))
       self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict))
       self.assertAllClose(
@@ -335,7 +144,6 @@ class AffineBijectorTest(test.TestCase):
         # Corresponds to 1 2x2 matrix, with twos on the diagonal.
         scale = 2.
         bijector = Affine(shift=mu, scale_identity_multiplier=scale)
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [[[1., 1]]]
         self.assertAllClose([[[3., 1]]], run(bijector.forward, x))
         self.assertAllClose([[[0., 1]]], run(bijector.inverse, x))
@@ -358,7 +166,6 @@ class AffineBijectorTest(test.TestCase):
         # Corresponds to 1 2x2 matrix, with twos on the diagonal.
         scale_diag = [[2., 2]]
         bijector = Affine(shift=mu, scale_diag=scale_diag)
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [[[1., 1]]]
         self.assertAllClose([[[3., 1]]], run(bijector.forward, x))
         self.assertAllClose([[[0., 1]]], run(bijector.inverse, x))
@@ -370,23 +177,18 @@ class AffineBijectorTest(test.TestCase):
       x = array_ops.placeholder(dtypes.float32, name="x")
       mu = array_ops.placeholder(dtypes.float32, name="mu")
       scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag")
-      event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims")
 
       x_value = np.array([[[1., 1]]], dtype=np.float32)
       mu_value = np.array([[1., -1]], dtype=np.float32)
       scale_diag_value = np.array([[2., 2]], dtype=np.float32)
-      event_ndims_value = 1
 
       feed_dict = {
           x: x_value,
           mu: mu_value,
           scale_diag: scale_diag_value,
-          event_ndims: event_ndims_value
       }
 
-      bijector = Affine(
-          shift=mu, scale_diag=scale_diag, event_ndims=event_ndims)
-      self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict))
+      bijector = Affine(shift=mu, scale_diag=scale_diag)
       self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict))
       self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict))
       self.assertAllClose([-np.log(4)],
@@ -410,9 +212,7 @@ class AffineBijectorTest(test.TestCase):
         bijector = Affine(
             shift=mu,
             scale_identity_multiplier=1.,
-            scale_diag=[1., 1., 1.],
-            event_ndims=1)
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
+            scale_diag=[1., 1., 1.])
         x = [1., 2, 3]  # Three scalar samples (no batches).
         self.assertAllClose([1., 3, 5], run(bijector.forward, x))
         self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
@@ -437,7 +237,6 @@ class AffineBijectorTest(test.TestCase):
             shift=mu,
             scale_identity_multiplier=1.,
             scale_tril=[[1., 0], [2., 1]])
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [[1., 2]]  # One multivariate sample.
         self.assertAllClose([[1., 5]], run(bijector.forward, x))
         self.assertAllClose([[1., 0.5]], run(bijector.inverse, x))
@@ -460,7 +259,6 @@ class AffineBijectorTest(test.TestCase):
         # scale = [[2., 0], [2, 3]]
         bijector = Affine(
             shift=mu, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]])
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [[1., 2]]  # One multivariate sample.
         self.assertAllClose([[1., 7]], run(bijector.forward, x))
         self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x))
@@ -486,7 +284,6 @@ class AffineBijectorTest(test.TestCase):
             scale_identity_multiplier=1.0,
             scale_diag=[1., 2.],
             scale_tril=[[1., 0], [2., 1]])
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [[1., 2]]  # One multivariate sample.
         self.assertAllClose([[2., 9]], run(bijector.forward, x))
         self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x))
@@ -514,7 +311,6 @@ class AffineBijectorTest(test.TestCase):
             scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
         bijector_ref = Affine(shift=mu, scale_diag=[10., 2, 3])
 
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [1., 2, 3]  # Vector.
         self.assertAllClose([9., 3, 8], run(bijector.forward, x))
         self.assertAllClose(
@@ -550,7 +346,6 @@ class AffineBijectorTest(test.TestCase):
             scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
         bijector_ref = Affine(shift=mu, scale_diag=[10., 3, 5])
 
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [1., 2, 3]  # Vector.
         self.assertAllClose([9., 5, 14], run(bijector.forward, x))
         self.assertAllClose(
@@ -586,7 +381,6 @@ class AffineBijectorTest(test.TestCase):
         bijector_ref = Affine(
             shift=mu, scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]])
 
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [1., 2, 3]  # Vector.
         self.assertAllClose([9., 6, 22], run(bijector.forward, x))
         self.assertAllClose(
@@ -622,7 +416,6 @@ class AffineBijectorTest(test.TestCase):
         bijector_ref = Affine(
             shift=mu, scale_tril=[[6., 0, 0], [1, 3, 0], [2, 3, 5]])
 
-        self.assertEqual(1, bijector.event_ndims.eval())  # "is vector"
         x = [1., 2, 3]  # Vector.
         self.assertAllClose([5., 6, 22], run(bijector.forward, x))
         self.assertAllClose(
@@ -647,38 +440,6 @@ class AffineBijectorTest(test.TestCase):
       with self.assertRaisesOpError("diagonal part must be non-zero"):
         bijector.forward([1., 1.]).eval()
 
-  def testEventNdimsLargerThanOneRaises(self):
-    with self.test_session():
-      mu = [1., -1]
-      with self.assertRaisesRegexp(
-          ValueError, (r"event_ndims\(2\) was not 0 or 1")):
-        # Scale corresponds to 2x2 identity matrix.
-        bijector = Affine(shift=mu, event_ndims=2, validate_args=True)
-        bijector.forward([1., 1.]).eval()
-
-  def testScaleZeroScalarRaises(self):
-    with self.test_session():
-      mu = -1.
-      # Check Identity matrix with zero scaling.
-      bijector = Affine(
-          shift=mu,
-          scale_identity_multiplier=0.,
-          event_ndims=0,
-          validate_args=True)
-      with self.assertRaisesOpError("identity_multiplier should be non-zero"):
-        bijector.forward(1.).eval()
-
-  def testScaleDiagAndEventNdimsZeroRaises(self):
-    # Check Diag matrix with zero scaling.
-    with self.assertRaisesRegexp(ValueError, "only scale argument"):
-      Affine(shift=None, scale_diag=[0.0], event_ndims=0, validate_args=True)
-
-  def testScalarCongruency(self):
-    with self.test_session():
-      bijector = Affine(
-          shift=3.6, scale_identity_multiplier=0.42, event_ndims=0)
-      assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)
-
   def _makeScale(self,
                  x,
                  scale_identity_multiplier=None,
@@ -747,14 +508,12 @@ class AffineBijectorTest(test.TestCase):
         scale_args = dict({"x": x}, **args)
         scale = self._makeScale(**scale_args)
 
-        bijector_args = dict({"event_ndims": 1}, **args)
-
         # We haven't specified enough information for the scale.
         if scale is None:
           with self.assertRaisesRegexp(ValueError, ("must be specified.")):
-            bijector = Affine(shift=shift, **bijector_args)
+            bijector = Affine(shift=shift, **args)
         else:
-          bijector = Affine(shift=shift, **bijector_args)
+          bijector = Affine(shift=shift, **args)
           np_x = x
           # For the case a vector is passed in, we need to make the shape
           # match the matrix for matmul to work.
@@ -829,15 +588,5 @@ class AffineBijectorTest(test.TestCase):
         x=np.array(
             [1., 2], dtype=np.float32))
 
-  def testScalarEventIdentityScale(self):
-    with self.test_session() as sess:
-      doubler = Affine(
-          scale_identity_multiplier=2.,
-          event_ndims=0)
-      doubler2 = doubler.inverse_log_det_jacobian(2.)
-      doubler2_ildj_ = sess.run([doubler2])
-      self.assertAllClose([-np.log(2.)], doubler2_ildj_)
-
-
 if __name__ == "__main__":
   test.main()
index ab2338f..f392e83 100644 (file)
@@ -23,7 +23,6 @@ import numpy as np
 from tensorflow.contrib.distributions.python.ops import bijectors
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
 from tensorflow.python.platform import test
 
 
@@ -32,8 +31,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
 
   def testBijectorMatrix(self):
     with self.test_session():
-      bijector = bijectors.CholeskyOuterProduct(
-          event_ndims=2, validate_args=True)
+      bijector = bijectors.CholeskyOuterProduct(validate_args=True)
       self.assertEqual("cholesky_outer_product", bijector.name)
       x = [[[1., 0], [2, 1]], [[np.sqrt(2.), 0], [np.sqrt(8.), 1]]]
       y = np.matmul(x, np.transpose(x, axes=(0, 2, 1)))
@@ -60,39 +58,12 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
           atol=0.,
           rtol=1e-7)
 
-  def testBijectorScalar(self):
-    with self.test_session():
-      bijector = bijectors.CholeskyOuterProduct(
-          event_ndims=0, validate_args=True)
-      self.assertEqual("cholesky_outer_product", bijector.name)
-      x = [[[1., 5],
-            [2, 1]],
-           [[np.sqrt(2.), 3],
-            [np.sqrt(8.), 1]]]
-      y = np.square(x)
-      ildj = -np.log(2.) - np.log(x)
-      self.assertAllClose(y, bijector.forward(x).eval())
-      self.assertAllClose(x, bijector.inverse(y).eval())
-      self.assertAllClose(
-          ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7)
-      self.assertAllClose(
-          -bijector.inverse_log_det_jacobian(y).eval(),
-          bijector.forward_log_det_jacobian(x).eval(),
-          atol=0.,
-          rtol=1e-7)
-
-  def testScalarCongruency(self):
-    with self.test_session():
-      bijector = bijectors.CholeskyOuterProduct(
-          event_ndims=0, validate_args=True)
-      assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
-
   def testNoBatchStatic(self):
     x = np.array([[1., 0], [2, 1]])  # np.linalg.cholesky(y)
     y = np.array([[1., 2], [2, 5]])  # np.matmul(x, x.T)
     with self.test_session() as sess:
-      y_actual = bijectors.CholeskyOuterProduct(event_ndims=2).forward(x=x)
-      x_actual = bijectors.CholeskyOuterProduct(event_ndims=2).inverse(y=y)
+      y_actual = bijectors.CholeskyOuterProduct().forward(x=x)
+      x_actual = bijectors.CholeskyOuterProduct().inverse(y=y)
     [y_actual_, x_actual_] = sess.run([y_actual, x_actual])
     self.assertAllEqual([2, 2], y_actual.get_shape())
     self.assertAllEqual([2, 2], x_actual.get_shape())
@@ -105,8 +76,8 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
     with self.test_session() as sess:
       x_pl = array_ops.placeholder(dtypes.float32)
       y_pl = array_ops.placeholder(dtypes.float32)
-      y_actual = bijectors.CholeskyOuterProduct(event_ndims=2).forward(x=x_pl)
-      x_actual = bijectors.CholeskyOuterProduct(event_ndims=2).inverse(y=y_pl)
+      y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl)
+      x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl)
     [y_actual_, x_actual_] = sess.run([y_actual, x_actual],
                                       feed_dict={x_pl: x, y_pl: y})
     self.assertEqual(None, y_actual.get_shape())
@@ -124,8 +95,8 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
                   [[9., 3],
                    [3, 5]]])  # np.matmul(x, x.T)
     with self.test_session() as sess:
-      y_actual = bijectors.CholeskyOuterProduct(event_ndims=2).forward(x=x)
-      x_actual = bijectors.CholeskyOuterProduct(event_ndims=2).inverse(y=y)
+      y_actual = bijectors.CholeskyOuterProduct().forward(x=x)
+      x_actual = bijectors.CholeskyOuterProduct().inverse(y=y)
     [y_actual_, x_actual_] = sess.run([y_actual, x_actual])
     self.assertEqual([2, 2, 2], y_actual.get_shape())
     self.assertEqual([2, 2, 2], x_actual.get_shape())
@@ -144,8 +115,8 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
     with self.test_session() as sess:
       x_pl = array_ops.placeholder(dtypes.float32)
       y_pl = array_ops.placeholder(dtypes.float32)
-      y_actual = bijectors.CholeskyOuterProduct(event_ndims=2).forward(x=x_pl)
-      x_actual = bijectors.CholeskyOuterProduct(event_ndims=2).inverse(y=y_pl)
+      y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl)
+      x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl)
     [y_actual_, x_actual_] = sess.run([y_actual, x_actual],
                                       feed_dict={x_pl: x, y_pl: y})
     self.assertEqual(None, y_actual.get_shape())
index 0ff3530..28e3e31 100644 (file)
@@ -35,8 +35,7 @@ class InvertBijectorTest(test.TestCase):
       for fwd in [
           bijectors.Identity(),
           bijectors.Exp(event_ndims=1),
-          bijectors.Affine(
-              shift=[0., 1.], scale_diag=[2., 3.], event_ndims=1),
+          bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]),
           bijectors.Softplus(event_ndims=1),
           bijectors.SoftmaxCentered(event_ndims=1),
           bijectors.SigmoidCentered(),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
new file mode 100644 (file)
index 0000000..f03d6f1
--- /dev/null
@@ -0,0 +1,58 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
+from tensorflow.python.platform import test
+
+
+class SquareBijectorTest(test.TestCase):
+  """Tests the correctness of the Y = X ** 2 transformation."""
+
+  def testBijectorScalar(self):
+    with self.test_session():
+      bijector = bijectors.Square(validate_args=True)
+      self.assertEqual("square", bijector.name)
+      x = [[[1., 5],
+            [2, 1]],
+           [[np.sqrt(2.), 3],
+            [np.sqrt(8.), 1]]]
+      y = np.square(x)
+      ildj = -np.log(2.) - np.log(x)
+      self.assertAllClose(y, bijector.forward(x).eval())
+      self.assertAllClose(x, bijector.inverse(y).eval())
+      self.assertAllClose(
+          ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7)
+      self.assertAllClose(
+          -bijector.inverse_log_det_jacobian(y).eval(),
+          bijector.forward_log_det_jacobian(x).eval(),
+          atol=0.,
+          rtol=1e-7)
+
+  def testScalarCongruency(self):
+    with self.test_session():
+      bijector = bijectors.Square(validate_args=True)
+      assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
+
+
+if __name__ == "__main__":
+  test.main()
index cbaf74d..af13553 100644 (file)
@@ -245,9 +245,8 @@ class TransformedDistributionTest(test.TestCase):
     with self.test_session() as sess:
       exp2 = self._cls()(
           ds.Exponential(rate=0.25),
-          bijector=ds.bijectors.Affine(
-              scale_identity_multiplier=2.,
-              event_ndims=0))
+          bijector=ds.bijectors.AffineScalar(scale=2.)
+      )
       log_prob = exp2.log_prob(1.)
       log_prob_ = sess.run(log_prob)
       base_log_prob = -0.5 * 0.25 + np.log(0.25)
index 46ec497..452f1ca 100644 (file)
@@ -17,6 +17,7 @@
 @@AbsoluteValue
 @@Affine
 @@AffineLinearOperator
+@@AffineScalar
 @@Bijector
 @@BatchNormalization
 @@Chain
@@ -38,6 +39,7 @@
 @@SinhArcsinh
 @@SoftmaxCentered
 @@Softplus
+@@Square
 @@Weibull
 
 @@masked_autoregressive_default_template
@@ -54,6 +56,7 @@ from __future__ import print_function
 from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import *
 from tensorflow.contrib.distributions.python.ops.bijectors.affine import *
 from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import *
+from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import *
 from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import *
 from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
 from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
@@ -73,6 +76,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered impo
 from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import *
 from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
 from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
+from tensorflow.contrib.distributions.python.ops.bijectors.square import *
 from tensorflow.python.ops.distributions.bijector import *
 from tensorflow.python.ops.distributions.identity_bijector import Identity
 
index 05bb9c2..7fe73ad 100644 (file)
@@ -104,7 +104,6 @@ class Affine(bijector.Bijector):
                scale_tril=None,
                scale_perturb_factor=None,
                scale_perturb_diag=None,
-               event_ndims=1,
                validate_args=False,
                name="affine"):
     """Instantiates the `Affine` bijector.
@@ -157,8 +156,6 @@ class Affine(bijector.Bijector):
         matrix. `scale_perturb_diag` has shape [N1, N2, ...  r], which
         represents an `r x r` diagonal matrix. When `None` low rank updates will
         take the form `scale_perturb_factor * scale_perturb_factor.T`.
-      event_ndims: Scalar `int` `Tensor` indicating the number of dimensions
-        associated with a particular draw from the distribution. Must be 0 or 1.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
@@ -187,23 +184,6 @@ class Affine(bijector.Bijector):
     with self._name_scope("init", values=[
         shift, scale_identity_multiplier, scale_diag, scale_tril,
         scale_perturb_diag, scale_perturb_factor]):
-      event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
-      event_ndims_const = tensor_util.constant_value(event_ndims)
-      if event_ndims_const is not None and event_ndims_const not in (0, 1):
-        raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const)
-      else:
-        if validate_args:
-          # Shape tool will catch if event_ndims is negative.
-          event_ndims = control_flow_ops.with_dependencies(
-              [check_ops.assert_less(
-                  event_ndims, 2, message="event_ndims must be 0 or 1")],
-              event_ndims)
-
-      if event_ndims_const == 0 and not self._is_only_identity_multiplier:
-        raise ValueError(
-            "If event_ndims == 0, the only scale argument you can pass is "
-            "scale_identity_multiplier.  All others operate on vectors.")
-
       # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`.
       dtype = dtypes.float32
 
@@ -251,12 +231,11 @@ class Affine(bijector.Bijector):
       self._scale = scale
       self._shaper = _DistributionShape(
           batch_ndims=batch_ndims,
-          event_ndims=event_ndims,
+          event_ndims=1,
           validate_args=validate_args)
       super(Affine, self).__init__(
-          event_ndims=event_ndims,
+          event_ndims=1,
           graph_parents=(
-              [event_ndims] +
               [self._scale] if tensor_util.is_tensor(self._scale)
               else self._scale.graph_parents +
               [self._shift] if self._shift is not None else []),
@@ -388,9 +367,7 @@ class Affine(bijector.Bijector):
     if self._is_only_identity_multiplier:
       # We don't pad in this case and instead let the fldj be applied
       # via broadcast.
-      event_size = distribution_util.pick_vector(
-          math_ops.equal(self._shaper.event_ndims, 0),
-          [1], array_ops.shape(x))[-1]
+      event_size = array_ops.shape(x)[-1]
       event_size = math_ops.cast(event_size, dtype=self._scale.dtype)
       return math_ops.log(math_ops.abs(self._scale)) * event_size
     return self.scale.log_abs_determinant()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py
new file mode 100644 (file)
index 0000000..8adaa54
--- /dev/null
@@ -0,0 +1,138 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Affine bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+    "AffineScalar",
+]
+
+
+class AffineScalar(bijector.Bijector):
+  """Compute `Y = g(X; shift, scale) = scale * X + shift`.
+
+  Examples:
+
+  ```python
+  # Y = X
+  b = AffineScalar()
+
+  # Y = X + shift
+  b = AffineScalar(shift=[1., 2, 3])
+
+  # Y = 2 * X + shift
+  b = AffineScalar(
+    shift=[1., 2, 3],
+    scale=2.)
+  ```
+
+  """
+
+  def __init__(self,
+               shift=None,
+               scale=None,
+               validate_args=False,
+               name="affine_scalar"):
+    """Instantiates the `AffineScalar` bijector.
+
+    This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments,
+    giving the forward operation:
+
+    ```none
+    Y = g(X) = scale * X + shift
+    ```
+
+    if `scale` is not specified, then the bijector has the semantics of
+    `scale = 1.`. Similarly, if `shift` is not specified, then the bijector
+    has the semantics of `shift = 0.`.
+
+    Args:
+      shift: Floating-point `Tensor`. If this is set to `None`, no shift is
+        applied.
+      scale: Floating-point `Tensor`. If this is set to `None`, no scale is
+        applied.
+      validate_args: Python `bool` indicating whether arguments should be
+        checked for correctness.
+      name: Python `str` name given to ops managed by this object.
+    """
+    self._graph_parents = []
+    self._name = name
+    self._validate_args = validate_args
+
+    with self._name_scope("init", values=[scale, shift]):
+      self._shift = shift
+      self._scale = scale
+
+      if self._shift is not None:
+        self._shift = ops.convert_to_tensor(shift, name="shift")
+
+      if self._scale is not None:
+        self._scale = ops.convert_to_tensor(self._scale, name="scale")
+        if validate_args:
+          self._scale = control_flow_ops.with_dependencies(
+              [check_ops.assert_none_equal(
+                  self._scale,
+                  array_ops.zeros([], dtype=self._scale.dtype))],
+              self._scale)
+
+      super(AffineScalar, self).__init__(
+          event_ndims=0,
+          is_constant_jacobian=True,
+          validate_args=validate_args,
+          name=name)
+
+  @property
+  def shift(self):
+    """The `shift` `Tensor` in `Y = scale @ X + shift`."""
+    return self._shift
+
+  @property
+  def scale(self):
+    """The `scale` `LinearOperator` in `Y = scale @ X + shift`."""
+    return self._scale
+
+  def _forward(self, x):
+    y = array_ops.identity(x)
+    if self.scale is not None:
+      y *= self.scale
+    if self.shift is not None:
+      y += self.shift
+    return y
+
+  def _inverse(self, y):
+    x = array_ops.identity(y)
+    if self.shift is not None:
+      x -= self.shift
+    if self.scale is not None:
+      x /= self.scale
+    return x
+
+  def _forward_log_det_jacobian(self, x):
+    log_det_jacobian = array_ops.zeros_like(x)
+    if self.scale is None:
+      return log_det_jacobian
+    log_det_jacobian += math_ops.log(math_ops.abs(self.scale))
+    return log_det_jacobian
index cbd60f9..43208ff 100644 (file)
@@ -20,8 +20,6 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
@@ -39,8 +37,6 @@ __all__ = [
 class CholeskyOuterProduct(bijector.Bijector):
   """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix.
 
-  `event_ndims` must be 0 or 2, i.e., scalar or matrix.
-
   Note: the upper-triangular part of X is ignored (whether or not its zero).
 
   The surjectivity of g as a map from  the set of n x n positive-diagonal
@@ -64,46 +60,31 @@ class CholeskyOuterProduct(bijector.Bijector):
   Examples:
 
   ```python
-  bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]])
+  bijector.CholeskyOuterProduct().forward(x=[[1., 0], [2, 1]])
   # Result: [[1., 2], [2, 5]], i.e., x @ x.T
 
-  bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]])
+  bijector.CholeskyOuterProduct().inverse(y=[[1., 2], [2, 5]])
   # Result: [[1., 0], [2, 1]], i.e., cholesky(y).
   ```
 
   """
 
-  def __init__(self, event_ndims=2, validate_args=False,
-               name="cholesky_outer_product"):
+  def __init__(self, validate_args=False, name="cholesky_outer_product"):
     """Instantiates the `CholeskyOuterProduct` bijector.
 
     Args:
-      event_ndims: `constant` `int32` scalar `Tensor` indicating the number of
-        dimensions associated with a particular draw from the distribution. Must
-        be 0 or 2.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
-
-    Raises:
-      ValueError: if event_ndims is neither 0 or 2.
     """
     self._graph_parents = []
     self._name = name
-    with self._name_scope("init", values=[event_ndims]):
-      event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
-      event_ndims = tensor_util.constant_value(event_ndims)
-    if event_ndims is None or event_ndims not in [0, 2]:
-      raise ValueError("`event_ndims` must be a TF constant which is 0 or 2")
-    self._static_event_ndims = event_ndims
     super(CholeskyOuterProduct, self).__init__(
-        event_ndims=event_ndims,
+        event_ndims=2,
         validate_args=validate_args,
         name=name)
 
   def _forward(self, x):
-    if self._static_event_ndims == 0:
-      return math_ops.square(x)
     if self.validate_args:
       is_matrix = check_ops.assert_rank_at_least(x, 2)
       shape = array_ops.shape(x)
@@ -114,11 +95,7 @@ class CholeskyOuterProduct(bijector.Bijector):
     return math_ops.matmul(x, x, adjoint_b=True)
 
   def _inverse(self, y):
-    return (math_ops.sqrt(y) if self._static_event_ndims == 0
-            else linalg_ops.cholesky(y))
-
-  def _inverse_log_det_jacobian(self, y):
-    return -self._forward_log_det_jacobian(x=self._inverse(y))
+    return linalg_ops.cholesky(y)
 
   def _forward_log_det_jacobian(self, x):
     # Let Y be a symmetric, positive definite matrix and write:
@@ -161,13 +138,6 @@ class CholeskyOuterProduct(bijector.Bijector):
     # Since there is a 2 X[j,j] term for every lower-triangular element of X we
     # conclude:
     #   |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
-    if self._static_event_ndims == 0:
-      if self.validate_args:
-        is_positive = check_ops.assert_positive(
-            x, message="All elements must be positive.")
-        x = control_flow_ops.with_dependencies([is_positive], x)
-      return np.log(2.) + math_ops.log(x)
-
     diag = array_ops.matrix_diag_part(x)
 
     # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py
new file mode 100644 (file)
index 0000000..2831a92
--- /dev/null
@@ -0,0 +1,84 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Square bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+    "Square",
+]
+
+
+class Square(bijector.Bijector):
+  """Compute `g(X) = X^2`; X is a positive real number.
+
+  g is a bijection between the non-negative real numbers (R_+) and the
+  non-negative real numbers.
+
+  Examples:
+
+  ```python
+  bijector.Square().forward(x=[[1., 0], [2, 1]])
+  # Result: [[1., 0], [4, 1]], i.e., x^2
+
+  bijector.Square().inverse(y=[[1., 4], [9, 1]])
+  # Result: [[1., 2], [3, 1]], i.e., sqrt(y).
+  ```
+
+  """
+
+  def __init__(self, validate_args=False, name="square"):
+    """Instantiates the `Square` bijector.
+
+    Args:
+      validate_args: Python `bool` indicating whether arguments should be
+        checked for correctness.
+      name: Python `str` name given to ops managed by this object.
+    """
+    self._name = name
+    super(Square, self).__init__(
+        event_ndims=0,
+        validate_args=validate_args,
+        name=name)
+
+  def _forward(self, x):
+    x = self._maybe_assert_valid(x)
+    return math_ops.square(x)
+
+  def _inverse(self, y):
+    y = self._maybe_assert_valid(y)
+    return math_ops.sqrt(y)
+
+  def _forward_log_det_jacobian(self, x):
+    x = self._maybe_assert_valid(x)
+    return np.log(2.) + math_ops.log(x)
+
+  def _maybe_assert_valid(self, t):
+    if not self.validate_args:
+      return t
+    is_valid = check_ops.assert_non_negative(
+        t, message="All elements must be non-negative.")
+    return control_flow_ops.with_dependencies([is_valid], t)
+
index c4b8f05..0d8a192 100644 (file)
@@ -174,13 +174,12 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
             skewness=skewness.dtype.as_numpy_dtype(0.),
             tailweight=tailweight, event_ndims=0)
 
-      # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2))
+      # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
       c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype))
-      affine = bijectors.Affine(
+      affine = bijectors.AffineScalar(
           shift=loc,
-          scale_identity_multiplier=c,
-          validate_args=validate_args,
-          event_ndims=0)
+          scale=c,
+          validate_args=validate_args)
 
       bijector = bijectors.Chain([affine, f])
 
index e1ccf11..003c66b 100644 (file)
@@ -227,7 +227,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
       c = 2 * scale_diag_part / f_noskew.forward(
           ops.convert_to_tensor(2, dtype=dtype))
       affine = bijectors.Affine(
-          shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1)
+          shift=loc, scale_diag=c, validate_args=validate_args)
 
       bijector = bijectors.Chain([affine, f])