assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
- assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert input_tensors, "input tensors should not be empty"
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
except ImportError:
from tensorflow.contrib import lite as interpreter_wrapper
+from tvm.contrib.download import download_testdata
import tvm.relay.testing.tf as tf_testing
from packaging import version as package_version
rtol=1e-5, atol=1e-5)
#######################################################################
+# MediaPipe
+# -------------
+
+def test_forward_mediapipe_hand_landmark():
+ """Test MediaPipe 2D hand landmark TF Lite model."""
+ # MediaPipe 2D hand landmark TF
+ tflite_model_file = download_testdata(
+ "https://github.com/google/mediapipe/raw/master/mediapipe/models/hand_landmark.tflite",
+ "hand_landmark.tflite")
+ with open(tflite_model_file, "rb") as f:
+ tflite_model_buf = f.read()
+ data = np.random.uniform(size=(1, 256, 256, 3)).astype('float32')
+ tflite_output = run_tflite_graph(tflite_model_buf, data)
+ tvm_output = run_tvm_graph(tflite_model_buf, data, 'input_1', num_output=2)
+ for i in range(2):
+ tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
+ rtol=1e-5, atol=1e-5)
+
+#######################################################################
# Main
# ----
if __name__ == '__main__':
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1()
+ test_forward_mediapipe_hand_landmark()
# End to End quantized
test_forward_qnn_inception_v1_net()