Correct handling of dtype for Categorical sampling.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 23 Jan 2018 17:18:11 +0000 (09:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 23 Jan 2018 17:22:21 +0000 (09:22 -0800)
PiperOrigin-RevId: 182943806

tensorflow/python/kernel_tests/distributions/categorical_test.py
tensorflow/python/ops/distributions/categorical.py

index 019c1bc353a9891da6967a7ce9114b58226a980a..ca2358fe99934e110ba743c6085d1f25ff0f5e5e 100644 (file)
@@ -100,6 +100,10 @@ class CategoricalTest(test.TestCase):
     self.assertEqual(
         dist.logits.dtype, dist.log_prob(np.array(
             0, dtype=np.int64)).dtype)
+    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+      dist = make_categorical([], 5, dtype=dtype)
+      self.assertEqual(dist.dtype, dtype)
+      self.assertEqual(dist.dtype, dist.sample(5).dtype)
 
   def testUnknownShape(self):
     with self.test_session():
index 60a515583e2a6f750975ed9356d6c885eb83632f..9161e3fa9f5f7f844e7f4926992c954acae246d6 100644 (file)
@@ -265,7 +265,9 @@ class Categorical(distribution.Distribution):
       logits_2d = self.logits
     else:
       logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
-    draws = random_ops.multinomial(logits_2d, n, seed=seed)
+    sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32
+    draws = random_ops.multinomial(
+        logits_2d, n, seed=seed, output_dtype=sample_dtype)
     draws = array_ops.reshape(
         array_ops.transpose(draws),
         array_ops.concat([[n], self.batch_shape_tensor()], 0))