3 import tensorflow as tf
9 # This script compares the execution result of luci-interpreter with that of TFLite interpreter
12 # eval_verifier.py --driver build/compiler/luci-value-test/tester/luci_eval_tester
13 # --model inception_v3
14 parser = argparse.ArgumentParser()
15 parser.add_argument('--driver', type=str, required=True)
16 parser.add_argument('--model', type=str, required=True)
17 args = parser.parse_args()
20 tflite_model = args.model + ".tflite"
21 circle_model = args.model + ".circle"
23 # Build TFLite interpreter.
24 interpreter = tf.lite.Interpreter(tflite_model)
25 interpreter.allocate_tensors()
27 # Generate random input data.
28 num_inputs = len(interpreter.get_input_details())
29 for i in range(num_inputs):
30 input_details = interpreter.get_input_details()[i]
31 if input_details["dtype"] == np.float32:
32 input_data = np.array(
33 np.random.random_sample(input_details["shape"]), input_details["dtype"])
34 elif input_details["dtype"] == np.uint8:
35 input_data = np.array(
36 np.random.randint(0, 256, size=input_details["shape"]),
37 input_details["dtype"])
38 elif input_details["dtype"] == np.bool_:
39 input_data = np.array(
40 np.random.choice(a=[True, False], size=input_details["shape"]),
41 input_details["dtype"])
43 raise SystemExit("Unsupported input dtype")
45 interpreter.set_tensor(input_details["index"], input_data)
46 input_data.tofile(circle_model + ".input" + str(i))
51 # Execute luci interpreter.
55 str(num_inputs), circle_model + ".input", circle_model + ".output"
59 # Compare the results.
60 for idx in range(len(interpreter.get_output_details())):
61 output_details = interpreter.get_output_details()[idx]
62 output_data = np.fromfile(circle_model + ".output" + str(idx),
63 output_details["dtype"])
64 shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
65 output_shape = [int(i) for i in shape_file.read().split(',')]
66 luci_output_data = np.reshape(output_data, output_shape)
68 if output_details["dtype"] == np.uint8:
71 interpreter.get_tensor(
72 interpreter.get_output_details()[idx]["index"]),
75 raise SystemExit("Execution result of " + tflite_model +
76 " does not match with " + circle_model)
77 elif output_details["dtype"] == np.float32:
80 interpreter.get_tensor(
81 interpreter.get_output_details()[idx]["index"]),
84 raise SystemExit("Execution result of " + tflite_model +
85 " does not match with " + circle_model)
86 elif output_details["dtype"] == np.int64:
89 interpreter.get_tensor(
90 interpreter.get_output_details()[idx]["index"]),
93 raise SystemExit("Execution result of " + tflite_model +
94 " does not match with " + circle_model)
95 elif output_details["dtype"] == np.int32:
98 interpreter.get_tensor(
99 interpreter.get_output_details()[idx]["index"]),
102 raise SystemExit("Execution result of " + tflite_model +
103 " does not match with " + circle_model)
105 raise SystemExit("Unsupported data type: ", output_details["dtype"])
107 print(traceback.format_exc())