From a51bf6db47bf4de36cc74464f818b89c5ef3ed96 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EA=B9=80=EC=A0=95=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 30 Mar 2018 15:52:41 +0900 Subject: [PATCH] Introduce weights extractor from .tflite (#288) This commit introduces weight/bias extractor from .tflite file. You can execute this script like below: `$ extract_from_tflite.sh inception.tflite` Then, you can find *.npy files in the current working directory. Those *.npy files contain weight or bias values. The name of each file is set to be the name of each layer. Signed-off-by: Junghyun Kim --- tools/extract_weights_from_tflite/extract.py | 73 ++++++++++++++++++++++ .../extract_from_tflite.sh | 16 +++++ 2 files changed, 89 insertions(+) create mode 100755 tools/extract_weights_from_tflite/extract.py create mode 100755 tools/extract_weights_from_tflite/extract_from_tflite.sh diff --git a/tools/extract_weights_from_tflite/extract.py b/tools/extract_weights_from_tflite/extract.py new file mode 100755 index 0000000..0bc88ac --- /dev/null +++ b/tools/extract_weights_from_tflite/extract.py @@ -0,0 +1,73 @@ +#!/usr/bin/python +import numpy as np +import sys +import json +import struct + +def printUsage(progname): + print("%s <.json>"%(progname)) + print(" This program extracts weight and bias values into .npy files in ACL format [N,H,W,C]") + print(" .npy filenames is set according to the layer's name") + +if len(sys.argv) < 2: + printUsage(sys.argv[0]) + exit() + +filename = sys.argv[1] +f = open(filename) +j = json.loads(f.read()) + +tensors = j['subgraphs'][0]['tensors'] +buffer_name_map={} + +for t in tensors: + if 'buffer' in t: + if t['buffer'] in buffer_name_map: + print 'find conflict!!' + print t + print buffer_name_map + comps = t['name'].split('/') + names = [] + if len(comps) > 1 and comps[0] == comps[1]: + names = comps[2:] + else: + names = comps[1:] + + layername = '_'.join(names) + + shape = t['shape'] + buffer_name_map[t['buffer']] = {'name': layername, "shape":shape} + +for i in range(len(j['buffers'])): + b = j['buffers'][i] + if 'data' in b: + if i not in buffer_name_map: + print "buffer %d is not found in buffer_name_map. skip printing the buffer..." + continue + + filename = "%s.npy" % (buffer_name_map[i]['name']) + shape = buffer_name_map[i]['shape'] + buf = struct.pack('%sB' % len(b['data']), *b['data']) + + elem_size = 1 + for s in shape: + elem_size *= s + + l = struct.unpack('%sf' % elem_size, buf) + n = np.array(l, dtype='f') + n = n.reshape(shape) + if len(shape) == 4: + # [N,C,H,W] -> [N,H,W,C] + n = np.rollaxis(n,3,1) + elif len(shape) == 3: + # [C,H,W] -> [H,W,C] + n = np.rollaxis(n,2,0) + elif len(shape) == 1: + pass + else: + print "Undefined length: conversion skipped. shape=", shape + #print shape, filename, n.shape + np.save(filename, n) + +print "Done." + diff --git a/tools/extract_weights_from_tflite/extract_from_tflite.sh b/tools/extract_weights_from_tflite/extract_from_tflite.sh new file mode 100755 index 0000000..25b67ff --- /dev/null +++ b/tools/extract_weights_from_tflite/extract_from_tflite.sh @@ -0,0 +1,16 @@ +#!/bin/bash +SCRIPT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +ROOT_PATH=$SCRIPT_PATH/../.. +FLATC=$ROOT_PATH/Product/out/bin/flatc + +if [ ! -e "$1" ]; then + echo "file not exists: $1" + exit 1 +fi + +TFLITE_FILE=$1 +TFLITE_FILENAME=${TFLITE_FILE##*\/} +TFLITE_JSON=${TFLITE_FILENAME%\.tflite}.json + +$FLATC --json --strict-json $ROOT_PATH/externals/tensorflow/tensorflow/contrib/lite/schema/schema.fbs -- $TFLITE_FILE +$SCRIPT_PATH/extract.py $TFLITE_JSON -- 2.7.4