From 3a133550aa2cd52e911cb0d7b033e90eef6daa1f Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 20 Nov 2019 09:36:57 -0800 Subject: [PATCH] Compare all outputs in TFLite test_forward_ssd_mobilenet_v1 (#4373) --- tests/python/frontend/tflite/test_forward.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 9d05835..8292dd5 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1134,9 +1134,10 @@ def test_forward_ssd_mobilenet_v1(): tflite_model_buf = f.read() data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor') - tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), - rtol=1e-5, atol=1e-5) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2) + for i in range(2): + tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), + rtol=1e-5, atol=2e-5) ####################################################################### # MediaPipe -- 2.7.4