export tflite::Intepreter's UseNNAPI() and setNumThreads() to java (#16065)
authorfreedom" Koan-Sin Tan <koansin.tan@gmail.com>
Mon, 16 Apr 2018 02:24:14 +0000 (10:24 +0800)
committerJonathan Hseu <vomjom@vomjom.net>
Mon, 16 Apr 2018 02:24:14 +0000 (19:24 -0700)
* export UseNNAPI() and setNumThreads() to java

Export tflite::Intepreter's UseNNAPI() and SetNumThreads() to Java
and modify the Android TfLiteCameraDemo app to use them.

* change CheckedChangeListener accordingly

* add error checking to setNumThreads()

tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml
tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h

index 300786c..18f6465 100644 (file)
@@ -54,6 +54,9 @@ import android.view.Surface;
 import android.view.TextureView;
 import android.view.View;
 import android.view.ViewGroup;
+import android.widget.CompoundButton;
+import android.widget.NumberPicker;
+import android.widget.ToggleButton;
 import android.widget.TextView;
 import android.widget.Toast;
 import java.io.IOException;
@@ -82,6 +85,8 @@ public class Camera2BasicFragment extends Fragment
   private boolean runClassifier = false;
   private boolean checkedPermissions = false;
   private TextView textView;
+  private ToggleButton toggle;
+  private NumberPicker np;
   private ImageClassifier classifier;
 
   /** Max preview width that is guaranteed by Camera2 API */
@@ -289,6 +294,24 @@ public class Camera2BasicFragment extends Fragment
   public void onViewCreated(final View view, Bundle savedInstanceState) {
     textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
     textView = (TextView) view.findViewById(R.id.text);
+    toggle = (ToggleButton) view.findViewById(R.id.button);
+
+    toggle.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
+      public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
+        classifier.setUseNNAPI(isChecked);
+      }
+    });
+
+    np = (NumberPicker) view.findViewById(R.id.np);
+    np.setMinValue(1);
+    np.setMaxValue(10);
+    np.setWrapSelectorWheel(true);
+    np.setOnValueChangedListener(new NumberPicker.OnValueChangeListener() {
+      @Override
+      public void onValueChange(NumberPicker picker, int oldVal, int newVal){
+        classifier.setNumThreads(newVal);
+      }
+    });
   }
 
   /** Load the model and labels. */
index c57bb34..d32c077 100644 (file)
@@ -142,6 +142,16 @@ public abstract class ImageClassifier {
     }
   }
 
+  public void setUseNNAPI(Boolean nnapi) {
+    if (tflite != null)
+        tflite.setUseNNAPI(nnapi);
+  }
+
+  public void setNumThreads(int num_threads) {
+    if (tflite != null)
+        tflite.setNumThreads(num_threads);
+  }
+
   /** Closes tflite to release resources. */
   public void close() {
     tflite.close();
index 15305c4..db557ad 100644 (file)
         android:layout_width="wrap_content"
         android:layout_height="wrap_content"
         android:layout_alignParentStart="true"
+        android:layout_alignParentLeft="true"
         android:layout_alignParentTop="true" />
 
     <FrameLayout
         android:id="@+id/control"
         android:layout_width="match_parent"
-        android:layout_height="112dp"
+        android:layout_height="135dp"
         android:layout_alignParentBottom="true"
         android:layout_alignParentStart="true"
+        android:layout_alignParentLeft="true"
+        android:layout_alignParentEnd="true"
+        android:layout_alignParentRight="true"
+        android:layout_marginEnd="150dp"
+        android:layout_marginRight="150dp"
         android:background="@color/control_background">
 
-        <TextView android:id="@+id/text"
+        <TextView
+            android:id="@+id/text"
             android:layout_width="wrap_content"
             android:layout_height="wrap_content"
-            android:paddingLeft="80dp"
+            android:paddingLeft="20dp"
             android:textColor="#FFF"
             android:textSize="20sp"
             android:textStyle="bold" />
 
     </FrameLayout>
 
+    <RelativeLayout
+        android:id="@+id/control2"
+        android:layout_width="match_parent"
+        android:layout_height="135dp"
+        android:layout_alignParentLeft="true"
+        android:layout_alignParentStart="true"
+        android:layout_alignTop="@+id/control"
+        android:layout_marginLeft="300dp"
+        android:layout_marginStart="300dp"
+        android:background="@color/control_background">
+
+        <ToggleButton
+            android:id="@+id/button"
+            android:textOff="@string/tflite"
+            android:textOn="@string/nnapi"
+            android:layout_width="wrap_content"
+            android:layout_height="wrap_content"
+            android:layout_alignParentLeft="true"
+            android:layout_alignParentStart="true" />
+
+        <NumberPicker
+            android:id="@+id/np"
+            android:layout_width="wrap_content"
+            android:layout_height="wrap_content"
+            android:layout_below="@+id/button"
+            android:visibility="visible" />
+    </RelativeLayout>
+
 </RelativeLayout>
index a08ec3e..29a033b 100644 (file)
@@ -21,4 +21,6 @@
     <string name="toggle_turn_on">NN:On</string>
     <string name="toggle_turn_off">NN:Off</string>
     <string name="toggle">Use NNAPI</string>
+    <string name="tflite">tflite</string>
+    <string name="nnapi">NNAPI</string>
 </resources>
index a33959d..451a1cd 100644 (file)
@@ -212,6 +212,13 @@ public final class Interpreter implements AutoCloseable {
     }
   }
 
+  public void setNumThreads(int num_threads) {
+    if (wrapper == null) {
+      throw new IllegalStateException("The interpreter has already been closed.");
+    }
+    wrapper.setNumThreads(num_threads);
+  }
+
   /** Release resources associated with the {@code Interpreter}. */
   @Override
   public void close() {
index fc8187a..61a552d 100644 (file)
@@ -153,6 +153,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
     useNNAPI(interpreterHandle, useNNAPI);
   }
 
+  void setNumThreads(int num_threads) {
+    numThreads(interpreterHandle, num_threads);
+  }
+
   /** Gets index of an input given its name. */
   int getInputIndex(String name) {
     if (inputsIndexes == null) {
@@ -321,6 +325,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
 
   private static native void useNNAPI(long interpreterHandle, boolean state);
 
+  private static native void numThreads(long interpreterHandle, int num_threads);
+
   private static native long createErrorReporter(int size);
 
   private static native long createModel(String modelPathOrBuffer, long errorHandle);
index 8442262..4c33a2d 100644 (file)
@@ -315,6 +315,16 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
   interpreter->UseNNAPI(static_cast<bool>(state));
 }
 
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
+                                                           jclass clazz,
+                                                           jlong handle,
+                                                           jint num_threads) {
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  if (interpreter == nullptr) return;
+  interpreter->SetNumThreads(static_cast<int>(num_threads));
+}
+
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
     JNIEnv* env, jclass clazz, jint size) {
index 0e28a77..eaa765c 100644 (file)
@@ -61,7 +61,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
 /*
  *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
  *  Method:
- *  Signature: (JZ)
+ *  Signature: (JZ)V
  */
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
@@ -72,6 +72,16 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
 /*
  *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
  *  Method:
+ *  Signature: (JI)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
+                                                           jclass clazz,
+                                                           jlong handle,
+                                                           jint num_threads);
+/*
+ *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
+ *  Method:
  *  Signature: (I)J
  */
 JNIEXPORT jlong JNICALL