Fixing package path in tflite test (#3427)
authorSammy <samshang.wang@mail.utoronto.ca>
Tue, 25 Jun 2019 03:55:55 +0000 (23:55 -0400)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 25 Jun 2019 03:55:55 +0000 (20:55 -0700)
tests/python/frontend/tflite/test_forward.py

index 41147fe..c9fd0dc 100644 (file)
@@ -32,7 +32,10 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import variables
-from tensorflow import lite as interpreter_wrapper
+try:
+    from tensorflow import lite as interpreter_wrapper
+except ImportError:
+    from tensorflow.contrib import lite as interpreter_wrapper
 
 import tvm.relay.testing.tf as tf_testing
 
@@ -131,7 +134,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         if init_global_variables:
             sess.run(variables.global_variables_initializer())
         # convert to tflite model
-        converter = tf.contrib.lite.TFLiteConverter.from_session(
+        converter = interpreter_wrapper.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
         tflite_model_buffer = converter.convert()
         tflite_output = run_tflite_graph(tflite_model_buffer, in_data)