Set number of threads at Java interpreter constructor so that Conv Kernels can be...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Mar 2018 18:45:42 +0000 (11:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Mar 2018 18:50:09 +0000 (11:50 -0700)
Remove setNumThreads in the Java API as its behavior is ambiguous.

PiperOrigin-RevId: 189370770

tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/model.h

index cc17b49..14f461f 100644 (file)
@@ -78,6 +78,17 @@ public final class Interpreter implements AutoCloseable {
   }
 
   /**
+   * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and
+   * specifies the number of threads used for inference.
+   *
+   * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
+   * Interpreter}.
+   */
+  public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) {
+    wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads);
+  }
+
+  /**
    * Runs model inference if the model takes only one input, and provides only one output.
    *
    * <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please
index 518e8b3..dbf8f8f 100644 (file)
@@ -34,7 +34,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
   NativeInterpreterWrapper(String modelPath) {
     errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
     modelHandle = createModel(modelPath, errorHandle);
-    interpreterHandle = createInterpreter(modelHandle, errorHandle);
+    interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1);
     isMemoryAllocated = true;
   }
 
@@ -47,7 +47,20 @@ final class NativeInterpreterWrapper implements AutoCloseable {
     modelByteBuffer = mappedByteBuffer;
     errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
     modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
-    interpreterHandle = createInterpreter(modelHandle, errorHandle);
+    interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1);
+    isMemoryAllocated = true;
+  }
+
+  /**
+   * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer} and specifies
+   * the number of inference threads. The MappedByteBuffer should not be modified after the
+   * construction of a {@code NativeInterpreterWrapper}.
+   */
+  NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer, int numThreads) {
+    modelByteBuffer = mappedByteBuffer;
+    errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
+    modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
+    interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
     isMemoryAllocated = true;
   }
 
@@ -140,10 +153,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
     useNNAPI(interpreterHandle, useNNAPI);
   }
 
-  void setNumThreads(int numRecommendedThreads) {
-    numThreads(interpreterHandle, numRecommendedThreads);
-  }
-
   /** Gets index of an input given its name. */
   int getInputIndex(String name) {
     if (inputsIndexes == null) {
@@ -312,15 +321,13 @@ final class NativeInterpreterWrapper implements AutoCloseable {
 
   private static native void useNNAPI(long interpreterHandle, boolean state);
 
-  private static native void numThreads(long interpreterHandle, int numRecommendedThreads);
-
   private static native long createErrorReporter(int size);
 
   private static native long createModel(String modelPathOrBuffer, long errorHandle);
 
   private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle);
 
-  private static native long createInterpreter(long modelHandle, long errorHandle);
+  private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads);
 
   private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
 
index cc448b0..8442262 100644 (file)
@@ -14,7 +14,6 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h"
-
 namespace {
 
 const int kByteBufferValue = 999;
@@ -316,16 +315,6 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
   interpreter->UseNNAPI(static_cast<bool>(state));
 }
 
-JNIEXPORT void JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
-                                                             jclass clazz,
-                                                             jlong handle,
-                                                             jint num_threads) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
-  if (interpreter == nullptr) return;
-  interpreter->SetNumThreads(static_cast<int>(num_threads));
-}
-
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
     JNIEnv* env, jclass clazz, jint size) {
@@ -401,7 +390,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
 
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
-    JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle) {
+    JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
+    jint num_threads) {
   tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
   if (model == nullptr) return 0;
   BufferErrorReporter* error_reporter =
@@ -409,8 +399,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
   if (error_reporter == nullptr) return 0;
   auto resolver = ::tflite::CreateOpResolver();
   std::unique_ptr<tflite::Interpreter> interpreter;
-  TfLiteStatus status =
-      tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter);
+  TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))(
+      &interpreter, static_cast<int>(num_threads));
   if (status != kTfLiteOk) {
     throwException(env, kIllegalArgumentException,
                    "Cannot create interpreter: %s",
index fb76125..0e28a77 100644 (file)
@@ -72,17 +72,6 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
 /*
  *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
  *  Method:
- *  Signature: (JI)
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
-                                                             jclass clazz,
-                                                             jlong handle,
-                                                             jint num_threads);
-
-/*
- *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
- *  Method:
  *  Signature: (I)J
  */
 JNIEXPORT jlong JNICALL
@@ -110,11 +99,12 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
 /*
  *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
  *  Method:
- *  Signature: (JJ)J
+ *  Signature: (JJI)J
  */
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
-    JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle);
+    JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
+    jint num_threads);
 
 /*
  *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
index 3722e51..3aef0c3 100644 (file)
@@ -34,20 +34,6 @@ public class TestHelper {
   }
 
   /**
-   * Sets the number of threads for an {@code Interpreter}.
-   *
-   * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
-   *     IllegalArgumentException} will be thrown.
-   * @param numRecommendedThreads an integer value indicating the number of recommended threads.
-   */
-  public static void setNumThreads(Interpreter interpreter, int numRecommendedThreads) {
-    if (interpreter != null && interpreter.wrapper != null) {
-      interpreter.wrapper.setNumThreads(numRecommendedThreads);
-    } else {
-      throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
-    }
-  }
-  /**
    * Gets the last inference duration in nanoseconds. It returns null if there is no previous
    * inference run or the last inference run failed.
    *
index f28d56a..f7daa6f 100644 (file)
@@ -759,6 +759,11 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
 
 TfLiteStatus InterpreterBuilder::operator()(
     std::unique_ptr<Interpreter>* interpreter) {
+  return operator()(interpreter, /*num_threads=*/-1);
+}
+
+TfLiteStatus InterpreterBuilder::operator()(
+    std::unique_ptr<Interpreter>* interpreter, int num_threads) {
   if (!interpreter) {
     error_reporter_->Report(
         "Null output pointer passed to InterpreterBuilder.");
@@ -813,7 +818,8 @@ TfLiteStatus InterpreterBuilder::operator()(
   if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
     return cleanup_and_error();
   }
-
+  // Set num threads
+  (**interpreter).SetNumThreads(num_threads);
   // Parse inputs/outputs
   (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
   (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
index 51a622a..0c77776 100644 (file)
@@ -154,6 +154,8 @@ class InterpreterBuilder {
   InterpreterBuilder(const InterpreterBuilder&) = delete;
   InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
   TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
+  TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
+                          int num_threads);
 
  private:
   TfLiteStatus BuildLocalIndexToRegistrationMapping();