Always cast `tf.distributions.Distribution` `_event_shape`, `_batch_shape`.
authorJoshua V. Dillon <jvdillon@google.com>
Mon, 26 Mar 2018 04:57:09 +0000 (21:57 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 26 Mar 2018 04:59:14 +0000 (21:59 -0700)
PiperOrigin-RevId: 190415923

tensorflow/python/ops/distributions/distribution.py

index c055ca4..0866fa8 100644 (file)
@@ -593,7 +593,7 @@ class Distribution(_BaseDistribution):
     Returns:
       batch_shape: `TensorShape`, possibly unknown.
     """
-    return self._batch_shape()
+    return tensor_shape.as_shape(self._batch_shape())
 
   def _event_shape_tensor(self):
     raise NotImplementedError("event_shape_tensor is not implemented")
@@ -626,7 +626,7 @@ class Distribution(_BaseDistribution):
     Returns:
       event_shape: `TensorShape`, possibly unknown.
     """
-    return self._event_shape()
+    return tensor_shape.as_shape(self._event_shape())
 
   def is_scalar_event(self, name="is_scalar_event"):
     """Indicates that `event_shape == []`.