From 7134e84a3dcf2e18e98e4ccc1498e4b4f41de014 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 23 Feb 2018 14:38:37 -0800 Subject: [PATCH] Make tf.size() with optimize=True encode 0 if any dimension is 0. PiperOrigin-RevId: 186824964 --- tensorflow/python/ops/array_ops.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 08db8a1..b3020ef 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -401,8 +401,11 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32): else: input_tensor = ops.convert_to_tensor(input) input_shape = input_tensor.get_shape() - if optimize and input_shape.is_fully_defined(): - return constant(input_shape.num_elements(), out_type, name=name) + if optimize: + if input_shape.is_fully_defined(): + return constant(input_shape.num_elements(), out_type, name=name) + if input_shape.dims and any(dim == 0 for dim in input_shape.dims): + return constant(0, out_type, name=name) return gen_array_ops.size(input, name=name, out_type=out_type) -- 2.7.4