Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci-value-test / luci_eval_verifier.py
1 #!/usr/bin/env python3
2 import numpy as np
3 import tensorflow as tf
4 import subprocess
5 import argparse
6 import traceback
7
8 #
9 # This script compares the execution result of luci-interpreter with that of TFLite interpreter
10 #
11 # Basic usage:
12 #   eval_verifier.py --driver build/compiler/luci-value-test/tester/luci_eval_tester
13 #           --model inception_v3
14 parser = argparse.ArgumentParser()
15 parser.add_argument('--driver', type=str, required=True)
16 parser.add_argument('--model', type=str, required=True)
17 args = parser.parse_args()
18
19 driver = args.driver
20 tflite_model = args.model + ".tflite"
21 circle_model = args.model + ".circle"
22
23 # Build TFLite interpreter.
24 interpreter = tf.lite.Interpreter(tflite_model)
25 interpreter.allocate_tensors()
26
27 # Generate random input data.
28 num_inputs = len(interpreter.get_input_details())
29 for i in range(num_inputs):
30     input_details = interpreter.get_input_details()[i]
31     if input_details["dtype"] == np.float32:
32         input_data = np.array(
33             np.random.random_sample(input_details["shape"]), input_details["dtype"])
34     elif input_details["dtype"] == np.uint8:
35         input_data = np.array(
36             np.random.randint(0, 256, size=input_details["shape"]),
37             input_details["dtype"])
38     elif input_details["dtype"] == np.bool_:
39         input_data = np.array(
40             np.random.choice(a=[True, False], size=input_details["shape"]),
41             input_details["dtype"])
42     else:
43         raise SystemExit("Unsupported input dtype")
44
45     interpreter.set_tensor(input_details["index"], input_data)
46     input_data.tofile(circle_model + ".input" + str(i))
47
48 # Do inference
49 interpreter.invoke()
50
51 # Execute luci interpreter.
52 subprocess.run(
53     [
54         driver, circle_model,
55         str(num_inputs), circle_model + ".input", circle_model + ".output"
56     ],
57     check=True)
58
59 # Compare the results.
60 for idx in range(len(interpreter.get_output_details())):
61     output_details = interpreter.get_output_details()[idx]
62     output_data = np.fromfile(circle_model + ".output" + str(idx),
63                               output_details["dtype"])
64     shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
65     output_shape = [int(i) for i in shape_file.read().split(',')]
66     luci_output_data = np.reshape(output_data, output_shape)
67     try:
68         if output_details["dtype"] == np.uint8:
69             if np.allclose(
70                     luci_output_data,
71                     interpreter.get_tensor(
72                         interpreter.get_output_details()[idx]["index"]),
73                     rtol=0,
74                     atol=0) == False:
75                 raise SystemExit("Execution result of " + tflite_model +
76                                  " does not match with " + circle_model)
77         elif output_details["dtype"] == np.float32:
78             if np.allclose(
79                     luci_output_data,
80                     interpreter.get_tensor(
81                         interpreter.get_output_details()[idx]["index"]),
82                     rtol=1.e-5,
83                     atol=1.e-5) == False:
84                 raise SystemExit("Execution result of " + tflite_model +
85                                  " does not match with " + circle_model)
86         elif output_details["dtype"] == np.int64:
87             if np.allclose(
88                     luci_output_data,
89                     interpreter.get_tensor(
90                         interpreter.get_output_details()[idx]["index"]),
91                     rtol=0,
92                     atol=0) == False:
93                 raise SystemExit("Execution result of " + tflite_model +
94                                  " does not match with " + circle_model)
95         elif output_details["dtype"] == np.int32:
96             if np.allclose(
97                     luci_output_data,
98                     interpreter.get_tensor(
99                         interpreter.get_output_details()[idx]["index"]),
100                     rtol=0,
101                     atol=0) == False:
102                 raise SystemExit("Execution result of " + tflite_model +
103                                  " does not match with " + circle_model)
104         else:
105             raise SystemExit("Unsupported data type: ", output_details["dtype"])
106     except:
107         print(traceback.format_exc())
108         quit(255)
109
110 quit(0)