Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / infra / packaging / res / tf2nnpkg.20200630
1 #!/bin/bash
2
3 set -e
4
5 ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
6
7 command_exists() {
8   if [ "$#" -le 0 ]; then
9     return 1
10   fi
11   command -v "$@" > /dev/null 2>&1
12 }
13
14 usage()
15 {
16   echo "Convert TensorFlow model to nnpackage."
17   echo "Usage: tf2nnpkg"
18   echo "    --info <path/to/info>"
19   echo "    --graphdef <path/to/pb>"
20   echo "    -o <path/to/nnpkg/directory>"
21   echo "    --v2 (optional) Use TF 2.x interface"
22   exit 255
23 }
24
25 TF_INTERFACE="--v1"
26
27 # Parse command-line arguments
28 #
29 while [ "$#" -ne 0 ]; do
30   CUR="$1"
31
32   case $CUR in
33     '--help')
34       usage
35       ;;
36     '--info')
37       export INFO_FILE="$2"
38       shift 2
39       ;;
40     '--graphdef')
41       export GRAPHDEF_FILE="$2"
42       shift 2
43       ;;
44     '-o')
45       export OUTPUT_DIR="$2"
46       shift 2
47       ;;
48     '--v2')
49       TF_INTERFACE="--v2"
50       shift
51       ;;
52     *)
53       echo "${CUR}"
54       shift
55       ;;
56   esac
57 done
58
59 if [ -z ${GRAPHDEF_FILE} ] || [ ! -e ${GRAPHDEF_FILE} ]; then
60   echo "pb is not found. Please check --graphdef is correct."
61   exit 2
62 fi
63
64 if [ -z ${INFO_FILE} ] || [ ! -e ${INFO_FILE} ]; then
65   echo "info is not found. Please check --info is correct."
66   exit 2
67 fi
68
69 if [ -z ${OUTPUT_DIR} ]; then
70   echo "output directory is not specifed. Please check -o is correct.."
71   exit 2
72 fi
73
74 FILE_BASE=$(basename ${GRAPHDEF_FILE})
75 MODEL_NAME="${FILE_BASE%.*}"
76 TMPDIR=$(mktemp -d)
77 trap "{ rm -rf $TMPDIR; }" EXIT
78
79 # activate python virtual environment
80 VIRTUALENV_LINUX="${ROOT}/bin/venv/bin/activate"
81 VIRTUALENV_WINDOWS="${ROOT}/bin/venv/Scripts/activate"
82
83 if [ -e ${VIRTUALENV_LINUX} ]; then
84   source ${VIRTUALENV_LINUX}
85 elif [ -e ${VIRTUALENV_WINDOWS} ]; then
86   source ${VIRTUALENV_WINDOWS}
87 fi
88
89 # parse inputs, outputs from info file
90 INPUT=$(awk -F, '/^input/ { print $2 }' ${INFO_FILE} | cut -d: -f1 | tr -d ' ' | paste -d, -s)
91 OUTPUT=$(awk -F, '/^output/ { print $2 }' ${INFO_FILE} | cut -d: -f1 | tr -d ' ' | paste -d, -s)
92
93 INPUT_SHAPES=$(grep ^input ${INFO_FILE} | cut -d "[" -f2 | cut -d "]" -f1 | tr -d ' ' | xargs | tr ' ' ':')
94
95 # generate tflite file
96 python "${ROOT}/bin/tf2tfliteV2.py" ${TF_INTERFACE} --input_path ${GRAPHDEF_FILE} \
97 --output_path "${TMPDIR}/${MODEL_NAME}.tflite" \
98 --input_arrays ${INPUT} --input_shapes ${INPUT_SHAPES} \
99 --output_arrays ${OUTPUT}
100
101 # convert .tflite to .circle
102 "${ROOT}/bin/tflite2circle" "${TMPDIR}/${MODEL_NAME}.tflite" "${TMPDIR}/${MODEL_NAME}.tmp.circle"
103
104 # optimize
105 "${ROOT}/bin/circle2circle" --all "${TMPDIR}/${MODEL_NAME}.tmp.circle" "${TMPDIR}/${MODEL_NAME}.circle"
106
107 "${ROOT}/bin/model2nnpkg.sh" -o "${OUTPUT_DIR}" "${TMPDIR}/${MODEL_NAME}.circle"