From 8e9681486efc504b940683a4d0306c273e6179db Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 17 May 2018 09:22:24 -0700 Subject: [PATCH] Update SessionTest.testFeedShapeCompatibility to work with C API enabled. This test got lost in the transition. Prior to enabling the C API, some constant node whose values were used for shape inference would be marked as unfeedable in tensor_util.constant_value (https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/python/framework/tensor_util.py#L810). This shape inference path is no longer used with the C API enabled, so the constant node is successfully fed, triggering a runtime shape error. This is arguably a regression, but given that the Python code wouldn't mark all nodes evaluated during shape inference as unfeedable, it seems ok to relax the check a little more. PiperOrigin-RevId: 197002741 --- tensorflow/python/client/session_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e9a7d9a..4824970 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1565,10 +1565,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) def testFeedShapeCompatibility(self): - # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable. - if ops._USE_C_API: - return - with session.Session() as sess: some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) new_shape = constant_op.constant([2, 2]) @@ -1577,7 +1573,10 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'): sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) - with self.assertRaisesRegexp(ValueError, 'may not be fed'): + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + 'Input to reshape is a tensor with 4 values, ' + 'but the requested shape has 21'): sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) def testInferShapesFalse(self): -- 2.7.4