Modify tf.contrib.distributions.BatchReshape to behave a bit more like
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 16 May 2018 15:45:28 +0000 (08:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 16 May 2018 15:47:59 +0000 (08:47 -0700)
tf.reshape: accept a single unknown dimension and infer partial shape
information statically.

PiperOrigin-RevId: 196833267

tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
tensorflow/contrib/distributions/python/ops/batch_reshape.py

index 59d549b..f2bb2d3 100644 (file)
@@ -448,8 +448,7 @@ class _BatchReshapeTest(object):
 
     else:
       with self.test_session():
-        with self.assertRaisesOpError(r"`batch_shape` size must match "
-                                      r"`distributions.batch_shape` size"):
+        with self.assertRaisesOpError(r"Shape sizes do not match."):
           batch_reshape_lib.BatchReshape(
               distribution=mvn,
               batch_shape=new_batch_shape_ph,
@@ -457,8 +456,13 @@ class _BatchReshapeTest(object):
 
   def test_non_positive_shape(self):
     dims = 2
-    new_batch_shape = [-1, -2]   # -1*-2=2 so will pass size check.
-    old_batch_shape = [2]
+    old_batch_shape = [4]
+    if self.is_static_shape:
+      # Unknown first dimension does not trigger size check. Note that
+      # any dimension < 0 is treated statically as unknown.
+      new_batch_shape = [-1, 0]
+    else:
+      new_batch_shape = [-2, -2]  # -2 * -2 = 4, same size as the old shape.
 
     new_batch_shape_ph = (
         constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
@@ -471,7 +475,7 @@ class _BatchReshapeTest(object):
     mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
 
     if self.is_static_shape:
-      with self.assertRaisesRegexp(ValueError, r".*must be positive.*"):
+      with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"):
         batch_reshape_lib.BatchReshape(
             distribution=mvn,
             batch_shape=new_batch_shape_ph,
@@ -479,7 +483,7 @@ class _BatchReshapeTest(object):
 
     else:
       with self.test_session():
-        with self.assertRaisesOpError(r".*must be positive.*"):
+        with self.assertRaisesOpError(r".*must be >=-1.*"):
           batch_reshape_lib.BatchReshape(
               distribution=mvn,
               batch_shape=new_batch_shape_ph,
index 8a4041c..c709318 100644 (file)
@@ -42,9 +42,6 @@ class BatchReshape(distribution_lib.Distribution):
   This "meta-distribution" reshapes the batch dimensions of another
   distribution.
 
-  Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support
-  `-1` for flattening.
-
   #### Examples
 
   ```python
@@ -52,7 +49,7 @@ class BatchReshape(distribution_lib.Distribution):
 
   dtype = np.float32
   dims = 2
-  new_batch_shape = [1, 2, 3]
+  new_batch_shape = [1, 2, -1]
   old_batch_shape = [6]
 
   scale = np.ones(old_batch_shape + [dims], dtype)
@@ -86,8 +83,9 @@ class BatchReshape(distribution_lib.Distribution):
     Args:
       distribution: The base distribution instance to reshape. Typically an
         instance of `Distribution`.
-      batch_shape: Positive `int`-like vector-shaped `Tensor` representing the
-        new shape of the batch dimensions.
+      batch_shape: Positive `int`-like vector-shaped `Tensor` representing
+        the new shape of the batch dimensions. Up to one dimension may contain
+        `-1`, meaning the remainder of the batch size.
       validate_args: Python `bool`, default `False`. When `True` distribution
         parameters are checked for validity despite possibly degrading runtime
         performance. When `False` invalid inputs may silently render incorrect
@@ -107,29 +105,26 @@ class BatchReshape(distribution_lib.Distribution):
     """
     parameters = distribution_util.parent_frame_arguments()
     name = name or "BatchReshape" + distribution.name
-    self._distribution = distribution
     with ops.name_scope(name, values=[batch_shape]) as name:
-      self._batch_shape_ = ops.convert_to_tensor(
-          batch_shape,
-          dtype=dtypes.int32,
-          name="batch_shape")
-      self._batch_shape_static = tensor_util.constant_value(self._batch_shape_)
-      if self._batch_shape_static is not None:
-        self._batch_shape_static = np.int32(self._batch_shape_static)
-      self._runtime_assertions = validate_init_args(
-          self._distribution,
-          self._batch_shape_,
-          validate_args,
-          self._batch_shape_static)
+      # The unexpanded batch shape may contain up to one dimension of -1.
+      self._batch_shape_unexpanded = ops.convert_to_tensor(
+          batch_shape, dtype=dtypes.int32, name="batch_shape")
+      validate_init_args_statically(distribution, self._batch_shape_unexpanded)
+      batch_shape, batch_shape_static, runtime_assertions = calculate_reshape(
+          distribution.batch_shape_tensor(), self._batch_shape_unexpanded,
+          validate_args)
+      self._distribution = distribution
+      self._batch_shape_ = batch_shape
+      self._batch_shape_static = batch_shape_static
+      self._runtime_assertions = runtime_assertions
       super(BatchReshape, self).__init__(
-          dtype=self._distribution.dtype,
-          reparameterization_type=self._distribution.reparameterization_type,
+          dtype=distribution.dtype,
+          reparameterization_type=distribution.reparameterization_type,
           validate_args=validate_args,
           allow_nan_stats=allow_nan_stats,
           parameters=parameters,
           graph_parents=(
-              [self._batch_shape_] +
-              self._distribution._graph_parents),  # pylint: disable=protected-access
+              [self._batch_shape_unexpanded] + distribution._graph_parents),  # pylint: disable=protected-access
           name=name)
 
   @property
@@ -141,7 +136,7 @@ class BatchReshape(distribution_lib.Distribution):
       return array_ops.identity(self._batch_shape_)
 
   def _batch_shape(self):
-    return tensor_shape.TensorShape(self._batch_shape_static)
+    return self._batch_shape_static
 
   def _event_shape_tensor(self):
     with ops.control_dependencies(self._runtime_assertions):
@@ -153,11 +148,13 @@ class BatchReshape(distribution_lib.Distribution):
   def _sample_n(self, n, seed=None):
     with ops.control_dependencies(self._runtime_assertions):
       x = self.distribution.sample(sample_shape=n, seed=seed)
-      new_shape = array_ops.concat([
-          [n],
-          self.batch_shape_tensor(),
-          self.event_shape_tensor(),
-      ], axis=0)
+      new_shape = array_ops.concat(
+          [
+              [n],
+              self._batch_shape_unexpanded,
+              self.event_shape_tensor(),
+          ],
+          axis=0)
       return array_ops.reshape(x, new_shape)
 
   def _log_prob(self, x):
@@ -214,9 +211,9 @@ class BatchReshape(distribution_lib.Distribution):
     event_ndims = (array_ops.size(self.event_shape_tensor())
                    if self.event_shape.ndims is None
                    else self.event_shape.ndims)
-    batch_ndims = (array_ops.size(self.batch_shape_tensor())
-                   if self.batch_shape.ndims is None
-                   else self.batch_shape.ndims)
+    batch_ndims = (
+        array_ops.size(self._batch_shape_unexpanded)
+        if self.batch_shape.ndims is None else self.batch_shape.ndims)
     sample_ndims = x_ndims - batch_ndims - event_ndims
     if isinstance(sample_ndims, int):
       static_sample_shape = x.shape[:sample_ndims]
@@ -239,10 +236,11 @@ class BatchReshape(distribution_lib.Distribution):
           self.event_shape_tensor(),
       ], axis=0)
       result = fn(array_ops.reshape(x, old_shape))
-      new_shape = array_ops.concat([
-          sample_shape,
-          self.batch_shape_tensor(),
-      ], axis=0)
+      new_shape = array_ops.concat(
+          [
+              sample_shape,
+              self._batch_shape_unexpanded,
+          ], axis=0)
       result = array_ops.reshape(result, new_shape)
       if (static_sample_shape.ndims is not None and
           self.batch_shape.ndims is not None):
@@ -262,8 +260,7 @@ class BatchReshape(distribution_lib.Distribution):
       if static_event_shape_list is None:
         static_event_shape_list = [self.event_shape]
       new_shape = array_ops.concat(
-          [self.batch_shape_tensor()] + event_shape_list,
-          axis=0)
+          [self._batch_shape_unexpanded] + event_shape_list, axis=0)
       result = array_ops.reshape(fn(), new_shape)
       if (self.batch_shape.ndims is not None and
           self.event_shape.ndims is not None):
@@ -282,9 +279,9 @@ class BatchReshape(distribution_lib.Distribution):
       event_ndims = (array_ops.size(self.event_shape_tensor())
                      if self.event_shape.ndims is None
                      else self.event_shape.ndims)
-      batch_ndims = (array_ops.size(self.batch_shape_tensor())
-                     if self.batch_shape.ndims is None
-                     else self.batch_shape.ndims)
+      batch_ndims = (
+          array_ops.size(self._batch_shape_unexpanded)
+          if self.batch_shape.ndims is None else self.batch_shape.ndims)
       expected_batch_event_ndims = batch_ndims + event_ndims
 
       if (isinstance(x_ndims, int) and
@@ -356,62 +353,56 @@ class BatchReshape(distribution_lib.Distribution):
       return runtime_assertions
 
 
-def validate_init_args(
-    distribution,
-    batch_shape,
-    validate_args,
-    batch_shape_static):
+def calculate_reshape(original_shape, new_shape, validate=False, name=None):
+  """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
+  batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
+  if batch_shape_static.is_fully_defined():
+    return np.int32(batch_shape_static.as_list()), batch_shape_static, []
+  with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
+    original_size = math_ops.reduce_prod(original_shape)
+    implicit_dim = math_ops.equal(new_shape, -1)
+    size_implicit_dim = (
+        original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
+    new_ndims = array_ops.shape(new_shape)
+    expanded_new_shape = array_ops.where(  # Assumes exactly one `-1`.
+        implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape)
+    validations = [] if not validate else [
+        check_ops.assert_rank(
+            original_shape, 1, message="Original shape must be a vector."),
+        check_ops.assert_rank(
+            new_shape, 1, message="New shape must be a vector."),
+        check_ops.assert_less_equal(
+            math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32),
+            1,
+            message="At most one dimension can be unknown."),
+        check_ops.assert_positive(
+            expanded_new_shape, message="Shape elements must be >=-1."),
+        check_ops.assert_equal(
+            math_ops.reduce_prod(expanded_new_shape),
+            original_size,
+            message="Shape sizes do not match."),
+    ]
+    return expanded_new_shape, batch_shape_static, validations
+
+
+def validate_init_args_statically(distribution, batch_shape):
   """Helper to __init__ which makes or raises assertions."""
-  with ops.name_scope(name="validate_init_args",
-                      values=[batch_shape] + distribution._graph_parents):  # pylint: disable=protected-access
-    runtime_assertions = []
-
-    if batch_shape.shape.ndims is not None:
-      if batch_shape.shape.ndims != 1:
-        raise ValueError("`batch_shape` must be a vector "
-                         "(saw rank: {}).".format(
-                             batch_shape.shape.ndims))
-    elif validate_args:
-      runtime_assertions += [
-          check_ops.assert_rank(
-              batch_shape,
-              1,
-              message="`batch_shape` must be a vector.",
-              name="assert_batch_shape_is_vector"),
-      ]
-
-    batch_size_static = np.prod(batch_shape_static)
-    dist_batch_size_static = (
-        None if not distribution.batch_shape.is_fully_defined()
-        else np.prod(distribution.batch_shape).value)
-
-    if batch_size_static is not None and dist_batch_size_static is not None:
-      if batch_size_static != dist_batch_size_static:
-        raise ValueError("`batch_shape` size ({}) must match "
-                         "`distribution.batch_shape` size ({}).".format(
-                             batch_size_static,
-                             dist_batch_size_static))
-    elif validate_args:
-      runtime_assertions += [
-          check_ops.assert_equal(
-              math_ops.reduce_prod(batch_shape),
-              math_ops.reduce_prod(distribution.batch_shape_tensor()),
-              message=("`batch_shape` size must match "
-                       "`distributions.batch_shape` size."),
-              name="assert_batch_size"),
-      ]
-
-    if batch_shape_static is not None:
-      if np.any(batch_shape_static < 1):
-        raise ValueError("`batch_shape` elements must be positive "
-                         "(i.e., larger than zero).")
-    elif validate_args:
-      runtime_assertions += [
-          check_ops.assert_positive(
-              batch_shape,
-              message=("`batch_shape` elements must be positive "
-                       "(i.e., larger than zero)."),
-              name="assert_batch_shape_positive")
-      ]
-
-    return runtime_assertions
+  if batch_shape.shape.ndims is not None:
+    if batch_shape.shape.ndims != 1:
+      raise ValueError("`batch_shape` must be a vector "
+                       "(saw rank: {}).".format(batch_shape.shape.ndims))
+
+  batch_shape_static = tensor_util.constant_value_as_shape(batch_shape)
+  batch_size_static = batch_shape_static.num_elements()
+  dist_batch_size_static = distribution.batch_shape.num_elements()
+
+  if batch_size_static is not None and dist_batch_size_static is not None:
+    if batch_size_static != dist_batch_size_static:
+      raise ValueError("`batch_shape` size ({}) must match "
+                       "`distribution.batch_shape` size ({}).".format(
+                           batch_size_static, dist_batch_size_static))
+
+  if batch_shape_static.dims is not None:
+    if any(
+        dim.value is not None and dim.value < 1 for dim in batch_shape_static):
+      raise ValueError("`batch_shape` elements must be >=-1.")