Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-import-tf
1 #!/bin/bash
2
3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 set -e
18
19 DRIVER_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
20
21 usage()
22 {
23   echo "Convert TensorFlow model to circle."
24   echo "Usage: one-import-tf"
25   echo "    --version Show version information and exit"
26   echo "    --input_path <path/to/tfmodel>"
27   echo "    --output_path <path/to/circle>"
28   echo "    --input_arrays <names of the input arrays, comma-separated>"
29   echo "    --input_shapes <input shapes, colon-separated>"
30   echo "    --output_arrays <names of the output arrays, comma-separated>"
31   echo "    --v2 Use TensorFlow 2.x interface (default is 1.x interface)"
32   exit 255
33 }
34
35 version()
36 {
37   $DRIVER_PATH/one-version one-import-tf
38   exit 255
39 }
40
41 TF_INTERFACE="--v1"
42
43 # Parse command-line arguments
44 #
45 while [ "$#" -ne 0 ]; do
46   CUR="$1"
47
48   case $CUR in
49     '--help')
50       usage
51       ;;
52     '--version')
53       version
54       ;;
55     '--input_path')
56       export INPUT_PATH="$2"
57       shift 2
58       ;;
59     '--output_path')
60       export OUTPUT_PATH="$2"
61       shift 2
62       ;;
63     '--input_arrays')
64       export INPUT_ARRAYS="$2"
65       shift 2
66       ;;
67     '--input_shapes')
68       export INPUT_SHAPES="$2"
69       shift 2
70       ;;
71     '--output_arrays')
72       export OUTPUT_ARRAYS="$2"
73       shift 2
74       ;;
75     '--v2')
76       TF_INTERFACE="--v2"
77       shift
78       ;;
79     *)
80       echo "Unknown parameter: ${CUR}"
81       shift
82       ;;
83   esac
84 done
85
86 if [ -n ${INPUT_SHAPES} ] && [ ${TF_INTERFACE} = "--v2" ]; then
87   echo "Warning: if --v2 option is used, shape will be ignored"
88 fi
89
90 if [ -z ${INPUT_PATH} ] || [ ! -e ${INPUT_PATH} ]; then
91   echo "Error: input model not found"
92   echo ""
93   usage
94   exit 2
95 fi
96
97 FILE_BASE=$(basename ${OUTPUT_PATH})
98 MODEL_NAME="${FILE_BASE%.*}"
99
100 TMPDIR=$(mktemp -d)
101 trap "{ rm -rf $TMPDIR; }" EXIT
102
103 # activate python virtual environment
104 VIRTUALENV_LINUX="${DRIVER_PATH}/venv/bin/activate"
105 VIRTUALENV_WINDOWS="${DRIVER_PATH}/venv/Scripts/activate"
106
107 if [ -e ${VIRTUALENV_LINUX} ]; then
108   source ${VIRTUALENV_LINUX}
109 elif [ -e ${VIRTUALENV_WINDOWS} ]; then
110   source ${VIRTUALENV_WINDOWS}
111 fi
112
113 # remove previous log
114 rm -rf "${OUTPUT_PATH}.log"
115
116 show_err_onexit()
117 {
118   cat "${OUTPUT_PATH}.log"
119 }
120
121 trap show_err_onexit ERR
122
123 # generate temporary tflite file
124 CONVERT_SCRIPT="python ${DRIVER_PATH}/tf2tfliteV2.py ${TF_INTERFACE} "
125 CONVERT_SCRIPT+="--input_path ${INPUT_PATH} "
126 CONVERT_SCRIPT+="--input_arrays ${INPUT_ARRAYS} "
127 CONVERT_SCRIPT+="--output_path ${TMPDIR}/${MODEL_NAME}.tflite "
128 CONVERT_SCRIPT+="--output_arrays ${OUTPUT_ARRAYS} "
129 if [ ! -z ${INPUT_SHAPES} ]; then
130   CONVERT_SCRIPT+="--input_shapes ${INPUT_SHAPES} "
131 fi
132
133 echo ${CONVERT_SCRIPT} > "${OUTPUT_PATH}.log"
134 echo "" >> "${OUTPUT_PATH}.log"
135 $CONVERT_SCRIPT >> "${OUTPUT_PATH}.log" 2>&1
136
137 # convert .tflite to .circle
138 echo " " >> "${OUTPUT_PATH}.log"
139 echo "${DRIVER_PATH}/tflite2circle" "${TMPDIR}/${MODEL_NAME}.tflite" \
140 "${OUTPUT_PATH}" >> "${OUTPUT_PATH}.log"
141 echo " " >> "${OUTPUT_PATH}.log"
142
143 "${DRIVER_PATH}/tflite2circle" "${TMPDIR}/${MODEL_NAME}.tflite" \
144 "${OUTPUT_PATH}" >> "${OUTPUT_PATH}.log" 2>&1