Add low pass filtering to the classification results.
authorMark Daoust <markdaoust@google.com>
Wed, 3 Jan 2018 14:59:29 +0000 (06:59 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 3 Jan 2018 15:03:25 +0000 (07:03 -0800)
This makes the output much more consistent.

Also round the probabilities to 2 decimal places.

PiperOrigin-RevId: 180666095

tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java

index e7bad4637041d003c1e507d81c0c30404c587653..e44c5ae6b48eda187079dd3a0a1bc563276d816e 100644 (file)
@@ -73,6 +73,11 @@ public class ImageClassifier {
 
   /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
   private byte[][] labelProbArray = null;
+  /** multi-stage low pass filter * */
+  private float[][] filterLabelProbArray = null;
+
+  private static final int FILTER_STAGES = 3;
+  private static final float FILTER_FACTOR = 0.4f;
 
   private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
       new PriorityQueue<>(
@@ -93,6 +98,7 @@ public class ImageClassifier {
             DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
     imgData.order(ByteOrder.nativeOrder());
     labelProbArray = new byte[1][labelList.size()];
+    filterLabelProbArray = new float[FILTER_STAGES][labelList.size()];
     Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
   }
 
@@ -108,11 +114,38 @@ public class ImageClassifier {
     tflite.run(imgData, labelProbArray);
     long endTime = SystemClock.uptimeMillis();
     Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
+
+    // Smooth the results across frames.
+    applyFilter();
+
+    // Print the results.
     String textToShow = printTopKLabels();
     textToShow = Long.toString(endTime - startTime) + "ms" + textToShow;
     return textToShow;
   }
 
+  void applyFilter() {
+    int numLabels = labelList.size();
+
+    // Low pass filter `labelProbArray` into the first stage of the filter.
+    for (int j = 0; j < numLabels; ++j) {
+      filterLabelProbArray[0][j] +=
+          FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]);
+    }
+    // Low pass filter each stage into the next.
+    for (int i = 1; i < FILTER_STAGES; ++i) {
+      for (int j = 0; j < numLabels; ++j) {
+        filterLabelProbArray[i][j] +=
+            FILTER_FACTOR * (filterLabelProbArray[i - 1][j] - filterLabelProbArray[i][j]);
+      }
+    }
+
+    // Copy the last stage filter output back to `labelProbArray`.
+    for (int j = 0; j < numLabels; ++j) {
+      labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j];
+    }
+  }
+
   /** Closes tflite to release resources. */
   public void close() {
     tflite.close();
@@ -177,7 +210,7 @@ public class ImageClassifier {
     final int size = sortedLabels.size();
     for (int i = 0; i < size; ++i) {
       Map.Entry<String, Float> label = sortedLabels.poll();
-      textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow;
+      textToShow = String.format("\n%s: %4.2f", label.getKey(), label.getValue()) + textToShow;
     }
     return textToShow;
   }