From a400f825281f3c6f0688e8b16deea4ba12ee6bb5 Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Thu, 14 May 2020 20:16:57 -0700 Subject: [PATCH] [TFLite Runtime] Fix bug and re-enable RPC execution test (#5436) --- src/runtime/contrib/tflite/tflite_runtime.cc | 8 +- src/runtime/contrib/tflite/tflite_runtime.h | 3 + src/runtime/module.cc | 2 + tests/python/contrib/test_tflite_runtime.py | 202 ++++++++++++++++----------- tests/scripts/task_config_build_cpu.sh | 3 + 5 files changed, 135 insertions(+), 83 deletions(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 53d7754..8b34e90 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -93,8 +93,12 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); + // The buffer used to construct the model must be kept alive for + // dependent interpreters to be used. + flatBuffersBuffer_ = std::unique_ptr(new char[buffer_size]); + std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size); std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(), buffer_size); tflite::ops::builtin::BuiltinOpResolver resolver; // Build interpreter TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_); @@ -173,5 +177,7 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = TFLiteRuntimeCreate(args[0], args[1]); }); + +TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index f61f6ee..f3e3bd9 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -26,6 +26,7 @@ #define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ #include +#include #include #include @@ -93,6 +94,8 @@ class TFLiteRuntime : public ModuleNode { */ NDArray GetOutput(int index) const; + // Buffer backing the interpreter's model + std::unique_ptr flatBuffersBuffer_; // TFLite interpreter std::unique_ptr interpreter_; // TVM context diff --git a/src/runtime/module.cc b/src/runtime/module.cc index be75ff2..46ef6fa 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -129,6 +129,8 @@ bool RuntimeEnabled(const std::string& target) { f_name = "device_api.opencl"; } else if (target == "mtl" || target == "metal") { f_name = "device_api.metal"; + } else if (target == "tflite") { + f_name = "target.runtime.tflite"; } else if (target == "vulkan") { f_name = "device_api.vulkan"; } else if (target == "stackvm") { diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 8c883b0..1b911b7 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -14,92 +14,130 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm from tvm import te import numpy as np from tvm import rpc from tvm.contrib import util, tflite_runtime -# import tensorflow as tf -# import tflite_runtime.interpreter as tflite - - -def skipped_test_tflite_runtime(): - - def create_tflite_model(): - root = tf.Module() - root.const = tf.constant([1., 2.], tf.float32) - root.f = tf.function(lambda x: root.const * x) - - input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) - concrete_func = root.f.get_concrete_function(input_signature) - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - tflite_model = converter.convert() - return tflite_model - - - def check_local(): - tflite_fname = "model.tflite" - tflite_model = create_tflite_model() - temp = util.tempdir() - tflite_model_path = temp.relpath(tflite_fname) - open(tflite_model_path, 'wb').write(tflite_model) - - # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path=tflite_model_path) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_shape = input_details[0]['shape'] - tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], tflite_input) - interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - - # inference via tvm tflite runtime - with open(tflite_model_path, 'rb') as model_fin: - runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input)) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) - - - def check_remote(): - tflite_fname = "model.tflite" - tflite_model = create_tflite_model() - temp = util.tempdir() - tflite_model_path = temp.relpath(tflite_fname) - open(tflite_model_path, 'wb').write(tflite_model) - - # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path=tflite_model_path) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_shape = input_details[0]['shape'] - tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], tflite_input) - interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - - # inference via remote tvm tflite runtime - server = rpc.Server("localhost") - remote = rpc.connect(server.host, server.port) - ctx = remote.cpu(0) - a = remote.upload(tflite_model_path) - - with open(tflite_model_path, 'rb') as model_fin: - runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) - - check_local() - check_remote() + + +def _create_tflite_model(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + root = tf.Module() + root.const = tf.constant([1., 2.], tf.float32) + root.f = tf.function(lambda x: root.const * x) + + input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) + concrete_func = root.f.get_concrete_function(input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + tflite_model = converter.convert() + return tflite_model + + +@pytest.mark.skip('skip because accessing output tensor is flakey') +def test_local(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + tflite_fname = "model.tflite" + tflite_model = _create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tf.lite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via tvm tflite runtime + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input)) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + +def test_remote(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + tflite_fname = "model.tflite" + tflite_model = _create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tf.lite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via remote tvm tflite runtime + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a = remote.upload(tflite_model_path) + + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + server.terminate() + if __name__ == "__main__": - # skipped_test_tflite_runtime() - pass + test_local() + test_remote() diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 9c1cf28..ce545bd 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_TFLITE ON\) >> config.cmake +echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake +echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake -- 2.7.4