Use custom operator kernel in support (#1536)
author오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 4 Jun 2018 08:08:21 +0000 (17:08 +0900)
committer서상민/동작제어Lab(SR)/Staff Engineer/삼성전자 <sangmin7.seo@samsung.com>
Mon, 4 Jun 2018 08:08:21 +0000 (17:08 +0900)
tflite_run and nnapi_test use upport/tflite/kernel/CustomOps.h
to support custom operator in tflite interpreter

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
libs/support/tflite/src/kernels/register.cpp
tools/tflite_run/src/tflite_run.cc

index 0aeed3a..2f1634d 100644 (file)
@@ -17,6 +17,7 @@ limitations under the License.
 // NOTE This code is derived from the following file (in TensorFlow)
 //        'externals/tensorflow/tensorflow/contrib/lite/kernels/register.cc'
 #include "support/tflite/kernels/register.h"
+#include "support/tflite/kernels/CustomOps.h"
 
 // TODO Use namespace nnfw
 namespace tflite
@@ -125,6 +126,8 @@ BuiltinOpResolver::BuiltinOpResolver()
   AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
   AddBuiltin(BuiltinOperator_CAST, Register_CAST());
   AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE());
+
+  AddCustom("TensorFlowMax", tflite::ops::custom::nnfw::Register_TensorFlowMax());
 }
 
 TfLiteRegistration *BuiltinOpResolver::FindOp(tflite::BuiltinOperator op) const
index 75c3bb0..90e792e 100644 (file)
@@ -17,7 +17,6 @@
 #include "support/tflite/kernels/register.h"
 #include "tensorflow/contrib/lite/model.h"
 
-#include "operators.h"
 #include "bin_image.h"
 #include "args.h"
 #include "tensor_dumper.h"
@@ -66,20 +65,6 @@ int main(const int argc, char **argv)
   nnfw::util::benchmark::measure(t_prepare) << [&](void) {
     BuiltinOpResolver resolver;
 
-#define REGISTER(Name)                            \
-  {                                               \
-    resolver.AddCustom(#Name, Register_##Name()); \
-  }
-    REGISTER(CAST);
-    REGISTER(Stack);
-    REGISTER(ArgMax);
-    REGISTER(TensorFlowMax);
-    REGISTER(TensorFlowSqrt);
-    REGISTER(ExpandDims);
-    REGISTER(RSQRT);
-    REGISTER(SquaredDifference);
-#undef REGISTER
-
     InterpreterBuilder builder(*model, resolver);
 
     TFLITE_ENSURE(builder(&interpreter))