Fix Python API.
authorNupur Garg <nupurgarg@google.com>
Mon, 4 Jun 2018 19:08:15 +0000 (12:08 -0700)
committerGunhan Gulsoy <gunan@google.com>
Tue, 5 Jun 2018 03:39:49 +0000 (20:39 -0700)
PiperOrigin-RevId: 199171845

tensorflow/contrib/lite/python/convert_saved_model.py
tensorflow/contrib/lite/python/convert_saved_model_test.py

index b952a72..5dad49f 100644 (file)
@@ -216,9 +216,9 @@ def set_tensor_shapes(tensors, shapes):
   """
   if shapes:
     for tensor in tensors:
-      shape = shapes.get(tensor.name)
+      shape = shapes.get(tensor_name(tensor))
       if shape is not None:
-        tensor.set_shape(shapes[tensor.name])
+        tensor.set_shape(shape)
 
 
 def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
index 80e5dc6..1e570d2 100644 (file)
@@ -73,10 +73,15 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
     tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
     self.assertEqual([None, 3, 5], tensor.shape.as_list())
 
-    convert_saved_model.set_tensor_shapes([tensor],
-                                          {"Placeholder:0": [5, 3, 5]})
+    convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]})
     self.assertEqual([5, 3, 5], tensor.shape.as_list())
 
+  def testSetTensorShapeNoneValid(self):
+    tensor = array_ops.placeholder(dtype=dtypes.float32)
+
+    convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
+    self.assertEqual([1, 3, 5], tensor.shape.as_list())
+
   def testSetTensorShapeInvalid(self):
     tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
     self.assertEqual([None, 3, 5], tensor.shape.as_list())