Update `tflite_convert.sh` script (#6869)
author이성재/On-Device Lab(SR)/Principal Engineer/삼성전자 <sj925.lee@samsung.com>
Thu, 5 Sep 2019 11:24:29 +0000 (04:24 -0700)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 5 Sep 2019 11:24:29 +0000 (20:24 +0900)
- Change to use the 'tensorflow' module installed on the system by
  default.

Signed-off-by: Sung-Jae Lee <sj925.lee@samsung.com>
tools/tflkit/tflite_convert.sh

index 3fe1432..f5b94ed 100755 (executable)
@@ -3,15 +3,12 @@
 usage()
 {
   echo "usage : $0"
-  echo "       --info=Information file"
-  echo "       --tensorflow_path=TensorFlow path (Use externals/tensorflow by default)"
-  echo "       --tensorflow_version=TensorFlow version (Must be entered)"
+  echo "       --info=<infroamtion file>"
+  echo "       [ --tensorflow_path=<path> --tensorflow_version=<version> ] (If omitted, the module installed in system will be used by default.)"
 }
 
 SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
 
-TF_DIR="${SCRIPT_PATH}/../../externals/tensorflow"
-
 for i in "$@"
 do
   case $i in
@@ -41,15 +38,16 @@ if [ -z "$INFO" ]; then
   usage
   exit 1
 fi
-if [ -z "$TF_DIR" ]; then
-  echo "tensorflow_path is unset or set to the empty string"
-  usage
-  exit 1
-fi
+
 if [ -z "$TF_VERSION" ]; then
-  echo "tensorflow_version is unset or set to the empty string"
-  usage
-  exit 1
+  if [ -z "$TF_DIR" ]; then
+    TF_VERSION=$(python -c 'import tensorflow as tf; print(tf.__version__)')
+    echo "TensorFlow version detected : $TF_VERSION"
+  else
+    echo "tensorflow_version is unset or set to the empty string"
+    usage
+    exit 1
+  fi
 fi
 
 if [ ! -x "$(command -v bazel)" ]; then
@@ -87,8 +85,13 @@ fi
 
 CUR_DIR=$(pwd)
 {
-  echo "Enter $TF_DIR"
-  pushd $TF_DIR > /dev/null
+  if [ -e "$TF_DIR" ]; then
+    echo "Enter $TF_DIR"
+    pushd $TF_DIR > /dev/null
+    TFLITE_CONVERT="bazel run tensorflow/lite/python:tflite_convert -- "
+  else
+    TFLITE_CONVERT="python -m tensorflow.lite.python.tflite_convert "
+  fi
 
   NAME_LIST=()
   INPUT_SHAPE_LIST=()
@@ -111,7 +114,7 @@ CUR_DIR=$(pwd)
 
   for (( i=0; i < ${#NAME_LIST[@]}; ++i )); do
     if [ "${TF_VERSION%%.*}" = "2" ]; then
-      bazel run tensorflow/lite/python:tflite_convert -- \
+      $TFLITE_CONVERT \
       --output_file="${NAME_LIST[$i]}" \
       --graph_def_file="$GRAPHDEF_PATH" \
       --input_arrays="$INPUT" \
@@ -119,7 +122,7 @@ CUR_DIR=$(pwd)
       --output_arrays="$OUTPUT" \
       --allow_custom_ops=true
     else
-      bazel run tensorflow/lite/python:tflite_convert -- \
+      $TFLITE_CONVERT \
       --output_file="${NAME_LIST[$i]}" \
       --graph_def_file="$GRAPHDEF_PATH" \
       --input_arrays="$INPUT" \
@@ -128,7 +131,10 @@ CUR_DIR=$(pwd)
       --allow_custom_ops
     fi
   done
-  popd
+
+  if [ -e "$TF_DIR" ]; then
+    popd
+  fi
 
   for (( i=0; i < ${#NAME_LIST[@]}; ++i )); do
     echo "OUTPUT FILE : ${NAME_LIST[$i]}"