Expose setNumThreads in the Java API.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 15 Mar 2018 16:07:41 +0000 (09:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 16:11:35 +0000 (09:11 -0700)
PiperOrigin-RevId: 189193847

tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java

index 014636f..518e8b3 100644 (file)
@@ -140,6 +140,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
     useNNAPI(interpreterHandle, useNNAPI);
   }
 
+  void setNumThreads(int numRecommendedThreads) {
+    numThreads(interpreterHandle, numRecommendedThreads);
+  }
+
   /** Gets index of an input given its name. */
   int getInputIndex(String name) {
     if (inputsIndexes == null) {
@@ -308,6 +312,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
 
   private static native void useNNAPI(long interpreterHandle, boolean state);
 
+  private static native void numThreads(long interpreterHandle, int numRecommendedThreads);
+
   private static native long createErrorReporter(int size);
 
   private static native long createModel(String modelPathOrBuffer, long errorHandle);
index 2870ffe..21bcff4 100644 (file)
@@ -316,6 +316,16 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
   interpreter->UseNNAPI(static_cast<bool>(state));
 }
 
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
+                                                             jclass clazz,
+                                                             jlong handle,
+                                                             jint num_threads) {
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  if (interpreter == nullptr) return;
+  interpreter->SetNumThreads(static_cast<int>(num_threads));
+}
+
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
     JNIEnv* env, jclass clazz, jint size) {
index d611ec7..fb76125 100644 (file)
@@ -72,6 +72,17 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
 /*
  *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
  *  Method:
+ *  Signature: (JI)
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
+                                                             jclass clazz,
+                                                             jlong handle,
+                                                             jint num_threads);
+
+/*
+ *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
+ *  Method:
  *  Signature: (I)J
  */
 JNIEXPORT jlong JNICALL
index 3aef0c3..3722e51 100644 (file)
@@ -34,6 +34,20 @@ public class TestHelper {
   }
 
   /**
+   * Sets the number of threads for an {@code Interpreter}.
+   *
+   * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
+   *     IllegalArgumentException} will be thrown.
+   * @param numRecommendedThreads an integer value indicating the number of recommended threads.
+   */
+  public static void setNumThreads(Interpreter interpreter, int numRecommendedThreads) {
+    if (interpreter != null && interpreter.wrapper != null) {
+      interpreter.wrapper.setNumThreads(numRecommendedThreads);
+    } else {
+      throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
+    }
+  }
+  /**
    * Gets the last inference duration in nanoseconds. It returns null if there is no previous
    * inference run or the last inference run failed.
    *