From: Nupur Garg Date: Mon, 4 Jun 2018 19:08:15 +0000 (-0700) Subject: Fix Python API. X-Git-Tag: upstream/v1.9.0_rc1~22 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0bb7c844dd4375d7f53c88a7eacf78b0d6552498;p=platform%2Fupstream%2Ftensorflow.git Fix Python API. PiperOrigin-RevId: 199171845 --- diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index b952a72..5dad49f 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -216,9 +216,9 @@ def set_tensor_shapes(tensors, shapes): """ if shapes: for tensor in tensors: - shape = shapes.get(tensor.name) + shape = shapes.get(tensor_name(tensor)) if shape is not None: - tensor.set_shape(shapes[tensor.name]) + tensor.set_shape(shape) def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index 80e5dc6..1e570d2 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -73,10 +73,15 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) - convert_saved_model.set_tensor_shapes([tensor], - {"Placeholder:0": [5, 3, 5]}) + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]}) self.assertEqual([5, 3, 5], tensor.shape.as_list()) + def testSetTensorShapeNoneValid(self): + tensor = array_ops.placeholder(dtype=dtypes.float32) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) + self.assertEqual([1, 3, 5], tensor.shape.as_list()) + def testSetTensorShapeInvalid(self): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list())