}
/**
+ * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and
+ * specifies the number of threads used for inference.
+ *
+ * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
+ * Interpreter}.
+ */
+ public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) {
+ wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads);
+ }
+
+ /**
* Runs model inference if the model takes only one input, and provides only one output.
*
* <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please
NativeInterpreterWrapper(String modelPath) {
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModel(modelPath, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1);
isMemoryAllocated = true;
}
modelByteBuffer = mappedByteBuffer;
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1);
+ isMemoryAllocated = true;
+ }
+
+ /**
+ * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer} and specifies
+ * the number of inference threads. The MappedByteBuffer should not be modified after the
+ * construction of a {@code NativeInterpreterWrapper}.
+ */
+ NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer, int numThreads) {
+ modelByteBuffer = mappedByteBuffer;
+ errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
+ modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
}
useNNAPI(interpreterHandle, useNNAPI);
}
- void setNumThreads(int numRecommendedThreads) {
- numThreads(interpreterHandle, numRecommendedThreads);
- }
-
/** Gets index of an input given its name. */
int getInputIndex(String name) {
if (inputsIndexes == null) {
private static native void useNNAPI(long interpreterHandle, boolean state);
- private static native void numThreads(long interpreterHandle, int numRecommendedThreads);
-
private static native long createErrorReporter(int size);
private static native long createModel(String modelPathOrBuffer, long errorHandle);
private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle);
- private static native long createInterpreter(long modelHandle, long errorHandle);
+ private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads);
private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
==============================================================================*/
#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h"
-
namespace {
const int kByteBufferValue = 999;
interpreter->UseNNAPI(static_cast<bool>(state));
}
-JNIEXPORT void JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
- if (interpreter == nullptr) return;
- interpreter->SetNumThreads(static_cast<int>(num_threads));
-}
-
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
JNIEnv* env, jclass clazz, jint size) {
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
- JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle) {
+ JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
+ jint num_threads) {
tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
if (model == nullptr) return 0;
BufferErrorReporter* error_reporter =
if (error_reporter == nullptr) return 0;
auto resolver = ::tflite::CreateOpResolver();
std::unique_ptr<tflite::Interpreter> interpreter;
- TfLiteStatus status =
- tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter);
+ TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))(
+ &interpreter, static_cast<int>(num_threads));
if (status != kTfLiteOk) {
throwException(env, kIllegalArgumentException,
"Cannot create interpreter: %s",
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
- * Signature: (JI)
- */
-JNIEXPORT void JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads);
-
-/*
- * Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method:
* Signature: (I)J
*/
JNIEXPORT jlong JNICALL
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
- * Signature: (JJ)J
+ * Signature: (JJI)J
*/
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
- JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle);
+ JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
+ jint num_threads);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
}
/**
- * Sets the number of threads for an {@code Interpreter}.
- *
- * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
- * IllegalArgumentException} will be thrown.
- * @param numRecommendedThreads an integer value indicating the number of recommended threads.
- */
- public static void setNumThreads(Interpreter interpreter, int numRecommendedThreads) {
- if (interpreter != null && interpreter.wrapper != null) {
- interpreter.wrapper.setNumThreads(numRecommendedThreads);
- } else {
- throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
- }
- }
- /**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
*
TfLiteStatus InterpreterBuilder::operator()(
std::unique_ptr<Interpreter>* interpreter) {
+ return operator()(interpreter, /*num_threads=*/-1);
+}
+
+TfLiteStatus InterpreterBuilder::operator()(
+ std::unique_ptr<Interpreter>* interpreter, int num_threads) {
if (!interpreter) {
error_reporter_->Report(
"Null output pointer passed to InterpreterBuilder.");
if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
return cleanup_and_error();
}
-
+ // Set num threads
+ (**interpreter).SetNumThreads(num_threads);
// Parse inputs/outputs
(**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
(**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
InterpreterBuilder(const InterpreterBuilder&) = delete;
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
+ TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
+ int num_threads);
private:
TfLiteStatus BuildLocalIndexToRegistrationMapping();