From 47c79c228d91a6b065fc5275b0b696490f6684cd Mon Sep 17 00:00:00 2001 From: "freedom\" Koan-Sin Tan" Date: Mon, 16 Apr 2018 10:24:14 +0800 Subject: [PATCH] export tflite::Intepreter's UseNNAPI() and setNumThreads() to java (#16065) * export UseNNAPI() and setNumThreads() to java Export tflite::Intepreter's UseNNAPI() and SetNumThreads() to Java and modify the Android TfLiteCameraDemo app to use them. * change CheckedChangeListener accordingly * add error checking to setNumThreads() --- .../tflitecamerademo/Camera2BasicFragment.java | 23 ++++++++++++ .../android/tflitecamerademo/ImageClassifier.java | 10 ++++++ .../src/main/res/layout/fragment_camera2_basic.xml | 41 ++++++++++++++++++++-- .../java/demo/app/src/main/res/values/strings.xml | 2 ++ .../main/java/org/tensorflow/lite/Interpreter.java | 7 ++++ .../tensorflow/lite/NativeInterpreterWrapper.java | 6 ++++ .../main/native/nativeinterpreterwrapper_jni.cc | 10 ++++++ .../src/main/native/nativeinterpreterwrapper_jni.h | 12 ++++++- 8 files changed, 107 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java index 300786c..18f6465 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -54,6 +54,9 @@ import android.view.Surface; import android.view.TextureView; import android.view.View; import android.view.ViewGroup; +import android.widget.CompoundButton; +import android.widget.NumberPicker; +import android.widget.ToggleButton; import android.widget.TextView; import android.widget.Toast; import java.io.IOException; @@ -82,6 +85,8 @@ public class Camera2BasicFragment extends Fragment private boolean runClassifier = false; private boolean checkedPermissions = false; private TextView textView; + private ToggleButton toggle; + private NumberPicker np; private ImageClassifier classifier; /** Max preview width that is guaranteed by Camera2 API */ @@ -289,6 +294,24 @@ public class Camera2BasicFragment extends Fragment public void onViewCreated(final View view, Bundle savedInstanceState) { textureView = (AutoFitTextureView) view.findViewById(R.id.texture); textView = (TextView) view.findViewById(R.id.text); + toggle = (ToggleButton) view.findViewById(R.id.button); + + toggle.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() { + public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) { + classifier.setUseNNAPI(isChecked); + } + }); + + np = (NumberPicker) view.findViewById(R.id.np); + np.setMinValue(1); + np.setMaxValue(10); + np.setWrapSelectorWheel(true); + np.setOnValueChangedListener(new NumberPicker.OnValueChangeListener() { + @Override + public void onValueChange(NumberPicker picker, int oldVal, int newVal){ + classifier.setNumThreads(newVal); + } + }); } /** Load the model and labels. */ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index c57bb34..d32c077 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -142,6 +142,16 @@ public abstract class ImageClassifier { } } + public void setUseNNAPI(Boolean nnapi) { + if (tflite != null) + tflite.setUseNNAPI(nnapi); + } + + public void setNumThreads(int num_threads) { + if (tflite != null) + tflite.setNumThreads(num_threads); + } + /** Closes tflite to release resources. */ public void close() { tflite.close(); diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml index 15305c4..db557ad 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -22,24 +22,59 @@ android:layout_width="wrap_content" android:layout_height="wrap_content" android:layout_alignParentStart="true" + android:layout_alignParentLeft="true" android:layout_alignParentTop="true" /> - + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml index a08ec3e..29a033b 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml @@ -21,4 +21,6 @@ NN:On NN:Off Use NNAPI + tflite + NNAPI diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index a33959d..451a1cd 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -212,6 +212,13 @@ public final class Interpreter implements AutoCloseable { } } + public void setNumThreads(int num_threads) { + if (wrapper == null) { + throw new IllegalStateException("The interpreter has already been closed."); + } + wrapper.setNumThreads(num_threads); + } + /** Release resources associated with the {@code Interpreter}. */ @Override public void close() { diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index fc8187a..61a552d 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -153,6 +153,10 @@ final class NativeInterpreterWrapper implements AutoCloseable { useNNAPI(interpreterHandle, useNNAPI); } + void setNumThreads(int num_threads) { + numThreads(interpreterHandle, num_threads); + } + /** Gets index of an input given its name. */ int getInputIndex(String name) { if (inputsIndexes == null) { @@ -321,6 +325,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native void useNNAPI(long interpreterHandle, boolean state); + private static native void numThreads(long interpreterHandle, int num_threads); + private static native long createErrorReporter(int size); private static native long createModel(String modelPathOrBuffer, long errorHandle); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 8442262..4c33a2d 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -315,6 +315,16 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, interpreter->UseNNAPI(static_cast(state)); } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, + jclass clazz, + jlong handle, + jint num_threads) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + interpreter->SetNumThreads(static_cast(num_threads)); +} + JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( JNIEnv* env, jclass clazz, jint size) { diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 0e28a77..eaa765c 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -61,7 +61,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (JZ) + * Signature: (JZ)V */ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, @@ -72,6 +72,16 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: + * Signature: (JI)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, + jclass clazz, + jlong handle, + jint num_threads); +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: * Signature: (I)J */ JNIEXPORT jlong JNICALL -- 2.7.4