/** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
private byte[][] labelProbArray = null;
+ /** multi-stage low pass filter * */
+ private float[][] filterLabelProbArray = null;
+
+ private static final int FILTER_STAGES = 3;
+ private static final float FILTER_FACTOR = 0.4f;
private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
new PriorityQueue<>(
DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
imgData.order(ByteOrder.nativeOrder());
labelProbArray = new byte[1][labelList.size()];
+ filterLabelProbArray = new float[FILTER_STAGES][labelList.size()];
Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
}
tflite.run(imgData, labelProbArray);
long endTime = SystemClock.uptimeMillis();
Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
+
+ // Smooth the results across frames.
+ applyFilter();
+
+ // Print the results.
String textToShow = printTopKLabels();
textToShow = Long.toString(endTime - startTime) + "ms" + textToShow;
return textToShow;
}
+ void applyFilter() {
+ int numLabels = labelList.size();
+
+ // Low pass filter `labelProbArray` into the first stage of the filter.
+ for (int j = 0; j < numLabels; ++j) {
+ filterLabelProbArray[0][j] +=
+ FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]);
+ }
+ // Low pass filter each stage into the next.
+ for (int i = 1; i < FILTER_STAGES; ++i) {
+ for (int j = 0; j < numLabels; ++j) {
+ filterLabelProbArray[i][j] +=
+ FILTER_FACTOR * (filterLabelProbArray[i - 1][j] - filterLabelProbArray[i][j]);
+ }
+ }
+
+ // Copy the last stage filter output back to `labelProbArray`.
+ for (int j = 0; j < numLabels; ++j) {
+ labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j];
+ }
+ }
+
/** Closes tflite to release resources. */
public void close() {
tflite.close();
final int size = sortedLabels.size();
for (int i = 0; i < size; ++i) {
Map.Entry<String, Float> label = sortedLabels.poll();
- textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow;
+ textToShow = String.format("\n%s: %4.2f", label.getKey(), label.getValue()) + textToShow;
}
return textToShow;
}