Internal change.
authorShashi Shekhar <shashishekhar@google.com>
Fri, 30 Mar 2018 18:21:31 +0000 (11:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 30 Mar 2018 18:24:25 +0000 (11:24 -0700)
PiperOrigin-RevId: 191090993

tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java

index 14f461f..a33959d 100644 (file)
@@ -68,6 +68,19 @@ public final class Interpreter implements AutoCloseable {
   }
 
   /**
+   * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
+   *
+   * @param modelFile: a file of a pre-trained TF Lite model
+   * @param numThreads: number of threads to use for inference
+   */
+  public Interpreter(@NonNull File modelFile, int numThreads) {
+    if (modelFile == null) {
+      return;
+    }
+    wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads);
+  }
+
+  /**
    * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
    *
    * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
index dbf8f8f..fc8187a 100644 (file)
@@ -32,9 +32,13 @@ import java.util.Map;
 final class NativeInterpreterWrapper implements AutoCloseable {
 
   NativeInterpreterWrapper(String modelPath) {
+    this(modelPath, /* numThreads= */ -1);
+  }
+
+  NativeInterpreterWrapper(String modelPath, int numThreads) {
     errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
     modelHandle = createModel(modelPath, errorHandle);
-    interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1);
+    interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
     isMemoryAllocated = true;
   }
 
@@ -44,11 +48,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
    * NativeInterpreterWrapper}.
    */
   NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) {
-    modelByteBuffer = mappedByteBuffer;
-    errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
-    modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
-    interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1);
-    isMemoryAllocated = true;
+    this(mappedByteBuffer, /* numThreads= */ -1);
   }
 
   /**