Adds setUseNNAPI to Interpreter.java, to enable develoeprs turn on & off NNAPI.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 2 Mar 2018 23:24:33 +0000 (15:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Mar 2018 23:28:56 +0000 (15:28 -0800)
PiperOrigin-RevId: 187677765

tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java

index b071cda..9e47e92 100644 (file)
@@ -167,7 +167,6 @@ public final class Interpreter implements AutoCloseable {
     return wrapper.getOutputIndex(opName);
   }
 
-
   /**
    * Returns native inference timing.
    * <p>IllegalArgumentException will be thrown if the model is not initialized by the
@@ -180,6 +179,15 @@ public final class Interpreter implements AutoCloseable {
     return wrapper.getLastNativeInferenceDurationNanoseconds();
   }
 
+  /** Turns on/off Android NNAPI for hardware acceleration when it is available. */
+  public void setUseNNAPI(boolean useNNAPI) {
+    if (wrapper != null) {
+      wrapper.setUseNNAPI(useNNAPI);
+    } else {
+      throw new IllegalStateException("NativeInterpreterWrapper has already been closed.");
+    }
+  }
+
   /** Release resources associated with the {@code Interpreter}. */
   @Override
   public void close() {
index 424b3de..61d6c35 100644 (file)
@@ -218,4 +218,52 @@ public final class InterpreterTest {
     int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax");
     assertThat(index).isEqualTo(0);
   }
+
+  @Test
+  public void testTurnOffNNAPI() throws Exception {
+    Path path = MODEL_FILE.toPath();
+    FileChannel fileChannel =
+        (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
+    MappedByteBuffer mappedByteBuffer =
+        fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
+    Interpreter interpreter = new Interpreter(mappedByteBuffer);
+    interpreter.setUseNNAPI(true);
+    float[] oneD = {1.23f, 6.54f, 7.81f};
+    float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+    float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+    float[][][][] fourD = {threeD, threeD};
+    float[][][][] parsedOutputs = new float[2][8][8][3];
+    interpreter.run(fourD, parsedOutputs);
+    float[] outputOneD = parsedOutputs[0][0][0];
+    float[] expected = {3.69f, 19.62f, 23.43f};
+    assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+    interpreter.setUseNNAPI(false);
+    interpreter.run(fourD, parsedOutputs);
+    outputOneD = parsedOutputs[0][0][0];
+    assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+    interpreter.close();
+    fileChannel.close();
+  }
+
+  @Test
+  public void testTurnOnNNAPI() throws Exception {
+    Path path = MODEL_FILE.toPath();
+    FileChannel fileChannel =
+        (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
+    MappedByteBuffer mappedByteBuffer =
+        fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
+    Interpreter interpreter = new Interpreter(mappedByteBuffer);
+    interpreter.setUseNNAPI(true);
+    float[] oneD = {1.23f, 6.54f, 7.81f};
+    float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+    float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+    float[][][][] fourD = {threeD, threeD};
+    float[][][][] parsedOutputs = new float[2][8][8][3];
+    interpreter.run(fourD, parsedOutputs);
+    float[] outputOneD = parsedOutputs[0][0][0];
+    float[] expected = {3.69f, 19.62f, 23.43f};
+    assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+    interpreter.close();
+    fileChannel.close();
+  }
 }