From 9518a7390f5381d2f8f3e5e1bf414ae08c1d8659 Mon Sep 17 00:00:00 2001 From: DongHak Park Date: Mon, 24 Apr 2023 19:30:53 +0900 Subject: [PATCH] [Application] Fix Resnet Application -ENABLE_TFLITE_INTERPRETER CASES Now TFLITE Interpreter is not support loss : cross type So in Resnet Application we made some macro to make them mse and there was some wrong part in ResNet Application there was another macro for ENABLE_TEST GTEST's result assume that Application use cross loss For Correct Result Fix some #if statement TODO : even if fix this situation TEST still failed regardless of tflite export releated code Signed-off-by: DongHak Park --- Applications/Resnet/jni/main.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Applications/Resnet/jni/main.cpp b/Applications/Resnet/jni/main.cpp index 9282c67..074f453 100644 --- a/Applications/Resnet/jni/main.cpp +++ b/Applications/Resnet/jni/main.cpp @@ -199,7 +199,7 @@ std::vector createResnet18Graph() { /// @todo update createResnet18 to be more generic ModelHandle createResnet18() { /// @todo support "LOSS : cross" for TF_Lite Exporter -#if defined(ENABLE_TEST) +#if (defined(ENABLE_TFLITE_INTERPRETER) && !defined(ENABLE_TEST)) ModelHandle model = ml::train::createModel(ml::train::ModelType::NEURAL_NET, {withKey("loss", "mse")}); #else @@ -270,9 +270,10 @@ void createAndRun(unsigned int epochs, unsigned int batch_size, model->train(); #if defined(ENABLE_TEST) - model->exports(ml::train::ExportMethods::METHOD_TFLITE, "resnet_test.tflite"); training_loss = model->getTrainingLoss(); validation_loss = model->getValidationLoss(); +#elif defined(ENABLE_TFLITE_INTERPRETER) + model->exports(ml::train::ExportMethods::METHOD_TFLITE, "resnet_test.tflite"); #endif } -- 2.7.4