int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax");
assertThat(index).isEqualTo(0);
}
+
+ @Test
+ public void testTurnOffNNAPI() throws Exception {
+ Path path = MODEL_FILE.toPath();
+ FileChannel fileChannel =
+ (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
+ MappedByteBuffer mappedByteBuffer =
+ fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
+ Interpreter interpreter = new Interpreter(mappedByteBuffer);
+ interpreter.setUseNNAPI(true);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ interpreter.run(fourD, parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.setUseNNAPI(false);
+ interpreter.run(fourD, parsedOutputs);
+ outputOneD = parsedOutputs[0][0][0];
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ fileChannel.close();
+ }
+
+ @Test
+ public void testTurnOnNNAPI() throws Exception {
+ Path path = MODEL_FILE.toPath();
+ FileChannel fileChannel =
+ (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
+ MappedByteBuffer mappedByteBuffer =
+ fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
+ Interpreter interpreter = new Interpreter(mappedByteBuffer);
+ interpreter.setUseNNAPI(true);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ interpreter.run(fourD, parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ fileChannel.close();
+ }
}