From fff901507e932188932fe21ae56c55e4aba5ae52 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 16 Mar 2018 11:45:42 -0700 Subject: [PATCH] Set number of threads at Java interpreter constructor so that Conv Kernels can be selected properly. Remove setNumThreads in the Java API as its behavior is ambiguous. PiperOrigin-RevId: 189370770 --- .../main/java/org/tensorflow/lite/Interpreter.java | 11 ++++++++++ .../tensorflow/lite/NativeInterpreterWrapper.java | 25 ++++++++++++++-------- .../main/native/nativeinterpreterwrapper_jni.cc | 18 ++++------------ .../src/main/native/nativeinterpreterwrapper_jni.h | 16 +++----------- .../java/org/tensorflow/lite/TestHelper.java | 14 ------------ tensorflow/contrib/lite/model.cc | 8 ++++++- tensorflow/contrib/lite/model.h | 2 ++ 7 files changed, 43 insertions(+), 51 deletions(-) diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index cc17b49..14f461f 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -78,6 +78,17 @@ public final class Interpreter implements AutoCloseable { } /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and + * specifies the number of threads used for inference. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads); + } + + /** * Runs model inference if the model takes only one input, and provides only one output. * *

Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 518e8b3..dbf8f8f 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -34,7 +34,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { NativeInterpreterWrapper(String modelPath) { errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModel(modelPath, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1); isMemoryAllocated = true; } @@ -47,7 +47,20 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelByteBuffer = mappedByteBuffer; errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1); + isMemoryAllocated = true; + } + + /** + * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer} and specifies + * the number of inference threads. The MappedByteBuffer should not be modified after the + * construction of a {@code NativeInterpreterWrapper}. + */ + NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer, int numThreads) { + modelByteBuffer = mappedByteBuffer; + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); isMemoryAllocated = true; } @@ -140,10 +153,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { useNNAPI(interpreterHandle, useNNAPI); } - void setNumThreads(int numRecommendedThreads) { - numThreads(interpreterHandle, numRecommendedThreads); - } - /** Gets index of an input given its name. */ int getInputIndex(String name) { if (inputsIndexes == null) { @@ -312,15 +321,13 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native void useNNAPI(long interpreterHandle, boolean state); - private static native void numThreads(long interpreterHandle, int numRecommendedThreads); - private static native long createErrorReporter(int size); private static native long createModel(String modelPathOrBuffer, long errorHandle); private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); - private static native long createInterpreter(long modelHandle, long errorHandle); + private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads); private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index cc448b0..8442262 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" - namespace { const int kByteBufferValue = 999; @@ -316,16 +315,6 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, interpreter->UseNNAPI(static_cast(state)); } -JNIEXPORT void JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, - jclass clazz, - jlong handle, - jint num_threads) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); - if (interpreter == nullptr) return; - interpreter->SetNumThreads(static_cast(num_threads)); -} - JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( JNIEnv* env, jclass clazz, jint size) { @@ -401,7 +390,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( - JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle) { + JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle, + jint num_threads) { tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); if (model == nullptr) return 0; BufferErrorReporter* error_reporter = @@ -409,8 +399,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( if (error_reporter == nullptr) return 0; auto resolver = ::tflite::CreateOpResolver(); std::unique_ptr interpreter; - TfLiteStatus status = - tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))( + &interpreter, static_cast(num_threads)); if (status != kTfLiteOk) { throwException(env, kIllegalArgumentException, "Cannot create interpreter: %s", diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index fb76125..0e28a77 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -72,17 +72,6 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (JI) - */ -JNIEXPORT void JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, - jclass clazz, - jlong handle, - jint num_threads); - -/* - * Class: org_tensorflow_lite_NativeInterpreterWrapper - * Method: * Signature: (I)J */ JNIEXPORT jlong JNICALL @@ -110,11 +99,12 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (JJ)J + * Signature: (JJI)J */ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( - JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle); + JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle, + jint num_threads); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java index 3722e51..3aef0c3 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -34,20 +34,6 @@ public class TestHelper { } /** - * Sets the number of threads for an {@code Interpreter}. - * - * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code - * IllegalArgumentException} will be thrown. - * @param numRecommendedThreads an integer value indicating the number of recommended threads. - */ - public static void setNumThreads(Interpreter interpreter, int numRecommendedThreads) { - if (interpreter != null && interpreter.wrapper != null) { - interpreter.wrapper.setNumThreads(numRecommendedThreads); - } else { - throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI."); - } - } - /** * Gets the last inference duration in nanoseconds. It returns null if there is no previous * inference run or the last inference run failed. * diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index f28d56a..f7daa6f 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -759,6 +759,11 @@ TfLiteStatus InterpreterBuilder::ParseTensors( TfLiteStatus InterpreterBuilder::operator()( std::unique_ptr* interpreter) { + return operator()(interpreter, /*num_threads=*/-1); +} + +TfLiteStatus InterpreterBuilder::operator()( + std::unique_ptr* interpreter, int num_threads) { if (!interpreter) { error_reporter_->Report( "Null output pointer passed to InterpreterBuilder."); @@ -813,7 +818,8 @@ TfLiteStatus InterpreterBuilder::operator()( if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { return cleanup_and_error(); } - + // Set num threads + (**interpreter).SetNumThreads(num_threads); // Parse inputs/outputs (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 51a622a..0c77776 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -154,6 +154,8 @@ class InterpreterBuilder { InterpreterBuilder(const InterpreterBuilder&) = delete; InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; TfLiteStatus operator()(std::unique_ptr* interpreter); + TfLiteStatus operator()(std::unique_ptr* interpreter, + int num_threads); private: TfLiteStatus BuildLocalIndexToRegistrationMapping(); -- 2.7.4