Fix #18180
authorAsim Shankar <ashankar@google.com>
Mon, 2 Apr 2018 19:01:32 +0000 (12:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 2 Apr 2018 19:03:55 +0000 (12:03 -0700)
tf.size() was not respecting the `out_type` argument when eager execution was
enabled.

PiperOrigin-RevId: 191326039

tensorflow/python/kernel_tests/array_ops_test.py
tensorflow/python/ops/array_ops.py

index 78bdb7e..5a20eeb 100644 (file)
@@ -1007,30 +1007,38 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
 
 class ShapeSizeRankTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testDenseShape(self):
-    with self.test_session():
-      t_value = [[0, 42], [24, 0]]
-      self.assertAllEqual((2, 2), array_ops.shape(t_value).eval())
-      self.assertEqual(4, array_ops.size(t_value).eval())
-      self.assertEqual(2, array_ops.rank(t_value).eval())
+    t_value = [[0, 42], [24, 0]]
+    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value)))
+    self.assertEqual(4, self.evaluate(array_ops.size(t_value)))
+    self.assertEqual(2, self.evaluate(array_ops.rank(t_value)))
 
-      t = constant_op.constant(t_value)
-      self.assertAllEqual((2, 2), array_ops.shape(t).eval())
-      self.assertEqual(4, array_ops.size(t).eval())
-      self.assertEqual(2, array_ops.rank(t).eval())
+    t = constant_op.constant(t_value)
+    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t)))
+    self.assertEqual(4, self.evaluate(array_ops.size(t)))
+    self.assertEqual(2, self.evaluate(array_ops.rank(t)))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testSparseShape(self):
-    with self.test_session():
-      sp_value = sparse_tensor.SparseTensorValue(
-          indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2))
-      self.assertAllEqual((2, 2), array_ops.shape(sp_value).eval())
-      self.assertEqual(4, array_ops.size(sp_value).eval())
-      self.assertEqual(2, array_ops.rank(sp_value).eval())
-
-      sp = sparse_tensor.SparseTensor.from_value(sp_value)
-      self.assertAllEqual((2, 2), array_ops.shape(sp).eval())
-      self.assertEqual(4, array_ops.size(sp).eval())
-      self.assertEqual(2, array_ops.rank(sp).eval())
+    sp_value = sparse_tensor.SparseTensorValue(
+        indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2))
+    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp_value)))
+    self.assertEqual(4, self.evaluate(array_ops.size(sp_value)))
+    self.assertEqual(2, self.evaluate(array_ops.rank(sp_value)))
+
+    sp = sparse_tensor.SparseTensor.from_value(sp_value)
+    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp)))
+    self.assertEqual(4, self.evaluate(array_ops.size(sp)))
+    self.assertEqual(2, self.evaluate(array_ops.rank(sp)))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testSizeDtype(self):
+    tensor = [1]
+    self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype)
+    self.assertEqual(
+        dtypes.int64,
+        self.evaluate(array_ops.size(tensor, out_type=dtypes.int64)).dtype)
 
 
 @test_util.with_c_api
index 2078666..68d4466 100644 (file)
@@ -387,7 +387,10 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
   """
   if context.executing_eagerly() and not isinstance(
       input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
-    return np.prod(ops.convert_to_tensor(input)._shape_tuple())  # pylint: disable=protected-access
+    input = ops.convert_to_tensor(input)
+    np_out_type = out_type.as_numpy_dtype
+    num_elements = np.prod(input._shape_tuple(), dtype=np_out_type)  # pylint: disable=protected-acces:
+    return ops.convert_to_tensor(num_elements, dtype=out_type)
   with ops.name_scope(name, "Size", [input]) as name:
     if isinstance(input, (sparse_tensor.SparseTensor,
                           sparse_tensor.SparseTensorValue)):