Add score filtering to tf.image.non_max_suppression.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 14 May 2018 21:32:03 +0000 (14:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 14 May 2018 21:34:48 +0000 (14:34 -0700)
PiperOrigin-RevId: 196567928

tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV3.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV3.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/non_max_suppression_op.cc
tensorflow/core/kernels/non_max_suppression_op.h
tensorflow/core/kernels/non_max_suppression_op_test.cc
tensorflow/core/ops/image_ops.cc
tensorflow/python/ops/image_ops_impl.py
tensorflow/tools/api/golden/tensorflow.image.pbtxt

diff --git a/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV3.pbtxt
new file mode 100644 (file)
index 0000000..25ec87e
--- /dev/null
@@ -0,0 +1,64 @@
+op {
+  graph_op_name: "NonMaxSuppressionV3"
+  in_arg {
+    name: "boxes"
+    description: <<END
+A 2-D float tensor of shape `[num_boxes, 4]`.
+END
+  }
+  in_arg {
+    name: "scores"
+    description: <<END
+A 1-D float tensor of shape `[num_boxes]` representing a single
+score corresponding to each box (each row of boxes).
+END
+  }
+  in_arg {
+    name: "max_output_size"
+    description: <<END
+A scalar integer tensor representing the maximum number of
+boxes to be selected by non max suppression.
+END
+  }
+  in_arg {
+    name: "iou_threshold"
+    description: <<END
+A 0-D float tensor representing the threshold for deciding whether
+boxes overlap too much with respect to IOU.
+END
+  }
+  in_arg {
+    name: "score_threshold"
+    description: <<END
+A 0-D float tensor representing the threshold for deciding when to remove
+boxes based on score.
+END
+  }
+  out_arg {
+    name: "selected_indices"
+    description: <<END
+A 1-D integer tensor of shape `[M]` representing the selected
+indices from the boxes tensor, where `M <= max_output_size`.
+END
+  }
+  summary: "Greedily selects a subset of bounding boxes in descending order of score,"
+  description: <<END
+pruning away boxes that have high intersection-over-union (IOU) overlap
+with previously selected boxes.  Bounding boxes with score less than
+`score_threshold` are removed.  Bounding boxes are supplied as
+[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
+diagonal pair of box corners and the coordinates can be provided as normalized
+(i.e., lying in the interval [0, 1]) or absolute.  Note that this algorithm
+is agnostic to where the origin is in the coordinate system and more
+generally is invariant to orthogonal transformations and translations
+of the coordinate system; thus translating or reflections of the coordinate
+system result in the same boxes being selected by the algorithm.
+The output of this operation is a set of integers indexing into the input
+collection of bounding boxes representing the selected boxes.  The bounding
+box coordinates corresponding to the selected indices can then be obtained
+using the `tf.gather operation`.  For example:
+  selected_indices = tf.image.non_max_suppression_v2(
+      boxes, scores, max_output_size, iou_threshold, score_threshold)
+  selected_boxes = tf.gather(boxes, selected_indices)
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV3.pbtxt b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV3.pbtxt
new file mode 100644 (file)
index 0000000..263cba1
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "NonMaxSuppressionV3"
+  visibility: HIDDEN
+}
index 903b898..2b010f8 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 
 #include "tensorflow/core/kernels/non_max_suppression_op.h"
 
+#include <queue>
 #include <vector>
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -56,20 +57,9 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
               errors::InvalidArgument("scores has incompatible shape"));
 }
 
-static inline void DecreasingArgSort(const std::vector<float>& values,
-                                     std::vector<int>* indices) {
-  indices->resize(values.size());
-  for (int i = 0; i < values.size(); ++i) (*indices)[i] = i;
-  std::sort(
-      indices->begin(), indices->end(),
-      [&values](const int i, const int j) { return values[i] > values[j]; });
-}
-
-// Return true if intersection-over-union overlap between boxes i and j
-// is greater than iou_threshold.
-static inline bool IOUGreaterThanThreshold(
-    typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
-    float iou_threshold) {
+// Return intersection-over-union overlap between boxes i and j
+static inline float IOU(typename TTypes<float, 2>::ConstTensor boxes, int i,
+                        int j) {
   const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
   const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
   const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
@@ -88,13 +78,13 @@ static inline bool IOUGreaterThanThreshold(
   const float intersection_area =
       std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
       std::max<float>(intersection_xmax - intersection_xmin, 0.0);
-  const float iou = intersection_area / (area_i + area_j - intersection_area);
-  return iou > iou_threshold;
+  return intersection_area / (area_i + area_j - intersection_area);
 }
 
 void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
                            const Tensor& scores, const Tensor& max_output_size,
-                           const float iou_threshold) {
+                           const float iou_threshold,
+                           const float score_threshold) {
   OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
               errors::InvalidArgument("iou_threshold must be in [0, 1]"));
 
@@ -109,37 +99,61 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
 
   std::vector<float> scores_data(num_boxes);
   std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
-  std::vector<int> sorted_indices;
-  DecreasingArgSort(scores_data, &sorted_indices);
+
+  // Data structure for selection candidate in NMS.
+  struct Candidate {
+    int box_index;
+    float score;
+  };
+
+  auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
+    return bs_i.score < bs_j.score;
+  };
+  std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
+      candidate_priority_queue(cmp);
+  for (int i = 0; i < scores_data.size(); ++i) {
+    if (scores_data[i] > score_threshold) {
+      candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
+    }
+  }
+
+  auto suppress_func = [iou_threshold](const float x) {
+    return x <= iou_threshold ? 1 : 0;
+  };
 
   std::vector<int> selected;
-  std::vector<int> selected_indices(output_size, 0);
-  int num_selected = 0;
-  for (int i = 0; i < num_boxes; ++i) {
-    if (selected.size() >= output_size) break;
-    bool should_select = true;
+  std::vector<float> selected_scores;
+  Candidate next_candidate;
+  float iou, original_score;
+
+  while (selected.size() < output_size && !candidate_priority_queue.empty()) {
+    next_candidate = candidate_priority_queue.top();
+    original_score = next_candidate.score;
+    candidate_priority_queue.pop();
+
     // Overlapping boxes are likely to have similar scores,
-    // therefore we iterate through the selected boxes backwards.
-    for (int j = num_selected - 1; j >= 0; --j) {
-      if (IOUGreaterThanThreshold(boxes_data, sorted_indices[i],
-                                  sorted_indices[selected_indices[j]],
-                                  iou_threshold)) {
-        should_select = false;
-        break;
-      }
+    // therefore we iterate through the previously selected boxes backwards
+    // in order to see if `next_candidate` should be suppressed.
+    for (int j = selected.size() - 1; j >= 0; --j) {
+      iou = IOU(boxes_data, next_candidate.box_index, selected[j]);
+      if (iou == 0.0) continue;
+      next_candidate.score *= suppress_func(iou);
+      if (next_candidate.score <= score_threshold) break;
     }
-    if (should_select) {
-      selected.push_back(sorted_indices[i]);
-      selected_indices[num_selected++] = i;
+
+    if (original_score == next_candidate.score) {
+      selected.push_back(next_candidate.box_index);
+      selected_scores.push_back(next_candidate.score);
     }
   }
 
-  // Allocate output tensor
-  Tensor* output = nullptr;
+  // Allocate output tensors
+  Tensor* output_indices = nullptr;
   TensorShape output_shape({static_cast<int>(selected.size())});
-  OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
-  TTypes<int, 1>::Tensor selected_indices_data = output->tensor<int, 1>();
-  std::copy_n(selected.begin(), selected.size(), selected_indices_data.data());
+  OP_REQUIRES_OK(context,
+                 context->allocate_output(0, output_shape, &output_indices));
+  TTypes<int, 1>::Tensor output_indices_data = output_indices->tensor<int, 1>();
+  std::copy_n(selected.begin(), selected.size(), output_indices_data.data());
 }
 
 }  // namespace
@@ -164,8 +178,9 @@ class NonMaxSuppressionOp : public OpKernel {
         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
                                 max_output_size.shape().DebugString()));
 
+    const float score_threshold_val = 0.0;
     DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
-                          iou_threshold_);
+                          iou_threshold_, score_threshold_val);
   }
 
  private:
@@ -194,11 +209,48 @@ class NonMaxSuppressionV2Op : public OpKernel {
     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
                                         iou_threshold.shape().DebugString()));
+    const float iou_threshold_val = iou_threshold.scalar<float>()();
 
+    const float score_threshold_val = 0.0;
+    DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
+                          iou_threshold_val, score_threshold_val);
+  }
+};
+
+template <typename Device>
+class NonMaxSuppressionV3Op : public OpKernel {
+ public:
+  explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    // boxes: [num_boxes, 4]
+    const Tensor& boxes = context->input(0);
+    // scores: [num_boxes]
+    const Tensor& scores = context->input(1);
+    // max_output_size: scalar
+    const Tensor& max_output_size = context->input(2);
+    OP_REQUIRES(
+        context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+        errors::InvalidArgument("max_output_size must be 0-D, got shape ",
+                                max_output_size.shape().DebugString()));
+    // iou_threshold: scalar
+    const Tensor& iou_threshold = context->input(3);
+    OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
+                errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
+                                        iou_threshold.shape().DebugString()));
     const float iou_threshold_val = iou_threshold.scalar<float>()();
 
+    // score_threshold: scalar
+    const Tensor& score_threshold = context->input(4);
+    OP_REQUIRES(
+        context, TensorShapeUtils::IsScalar(score_threshold.shape()),
+        errors::InvalidArgument("score_threshold must be 0-D, got shape ",
+                                score_threshold.shape().DebugString()));
+    const float score_threshold_val = score_threshold.scalar<float>()();
+
     DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
-                          iou_threshold_val);
+                          iou_threshold_val, score_threshold_val);
   }
 };
 
@@ -208,4 +260,7 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
                         NonMaxSuppressionV2Op<CPUDevice>);
 
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
+                        NonMaxSuppressionV3Op<CPUDevice>);
+
 }  // namespace tensorflow
index d4349ed..933b1af 100644 (file)
@@ -27,7 +27,8 @@ template <typename Device, typename T>
 struct NonMaxSuppression {
   void operator()(const Device& d, typename TTypes<float, 2>::ConstTensor boxes,
                   typename TTypes<float, 1>::ConstTensor scores,
-                  float iou_threshold, int max_output_size,
+                  float iou_threshold, float score_threshold,
+                  int max_output_size,
                   typename TTypes<int, 1>::Tensor selected_indices);
 };
 
index 9387fb1..c71aa23 100644 (file)
@@ -340,4 +340,195 @@ TEST_F(NonMaxSuppressionV2OpTest, TestEmptyInput) {
   test::ExpectTensorEqual<int>(expected, *GetOutput(0));
 }
 
+//
+// NonMaxSuppressionV3Op Tests
+//
+
+class NonMaxSuppressionV3OpTest : public OpsTestBase {
+ protected:
+  void MakeOp() {
+    TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV3")
+                     .Input(FakeInput(DT_FLOAT))
+                     .Input(FakeInput(DT_FLOAT))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_FLOAT))
+                     .Input(FakeInput(DT_FLOAT))
+                     .Finalize(node_def()));
+    TF_EXPECT_OK(InitOp());
+  }
+};
+
+TEST_F(NonMaxSuppressionV3OpTest, TestSelectFromThreeClusters) {
+  MakeOp();
+  AddInputFromArray<float>(
+      TensorShape({6, 4}),
+      {0, 0,  1, 1,  0, 0.1f,  1, 1.1f,  0, -0.1f, 1, 0.9f,
+       0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100,   1, 101});
+  AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+  AddInputFromArray<int>(TensorShape({}), {3});
+  AddInputFromArray<float>(TensorShape({}), {.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+  test::FillValues<int>(&expected, {3, 0, 5});
+  test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest,
+       TestSelectFromThreeClustersWithScoreThreshold) {
+  MakeOp();
+  AddInputFromArray<float>(
+      TensorShape({6, 4}),
+      {0, 0,  1, 1,  0, 0.1f,  1, 1.1f,  0, -0.1f, 1, 0.9f,
+       0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100,   1, 101});
+  AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+  AddInputFromArray<int>(TensorShape({}), {3});
+  AddInputFromArray<float>(TensorShape({}), {0.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.4f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({2}));
+  test::FillValues<int>(&expected, {3, 0});
+  test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest,
+       TestSelectFromThreeClustersFlippedCoordinates) {
+  MakeOp();
+  AddInputFromArray<float>(TensorShape({6, 4}),
+                           {1, 1,  0, 0,  0, 0.1f,  1, 1.1f,  0, .9f, 1, -0.1f,
+                            0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100});
+  AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+  AddInputFromArray<int>(TensorShape({}), {3});
+  AddInputFromArray<float>(TensorShape({}), {.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+  test::FillValues<int>(&expected, {3, 0, 5});
+  test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestSelectAtMostTwoBoxesFromThreeClusters) {
+  MakeOp();
+  AddInputFromArray<float>(
+      TensorShape({6, 4}),
+      {0, 0,  1, 1,  0, 0.1f,  1, 1.1f,  0, -0.1f, 1, 0.9f,
+       0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100,   1, 101});
+  AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+  AddInputFromArray<int>(TensorShape({}), {2});
+  AddInputFromArray<float>(TensorShape({}), {.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({2}));
+  test::FillValues<int>(&expected, {3, 0});
+  test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest,
+       TestSelectAtMostThirtyBoxesFromThreeClusters) {
+  MakeOp();
+  AddInputFromArray<float>(
+      TensorShape({6, 4}),
+      {0, 0,  1, 1,  0, 0.1f,  1, 1.1f,  0, -0.1f, 1, 0.9f,
+       0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100,   1, 101});
+  AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+  AddInputFromArray<int>(TensorShape({}), {30});
+  AddInputFromArray<float>(TensorShape({}), {.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+  test::FillValues<int>(&expected, {3, 0, 5});
+  test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestSelectSingleBox) {
+  MakeOp();
+  AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
+  AddInputFromArray<float>(TensorShape({1}), {.9f});
+  AddInputFromArray<int>(TensorShape({}), {3});
+  AddInputFromArray<float>(TensorShape({}), {.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+  test::FillValues<int>(&expected, {0});
+  test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestSelectFromTenIdenticalBoxes) {
+  MakeOp();
+
+  int num_boxes = 10;
+  std::vector<float> corners(num_boxes * 4);
+  std::vector<float> scores(num_boxes);
+  for (int i = 0; i < num_boxes; ++i) {
+    corners[i * 4 + 0] = 0;
+    corners[i * 4 + 1] = 0;
+    corners[i * 4 + 2] = 1;
+    corners[i * 4 + 3] = 1;
+    scores[i] = .9;
+  }
+  AddInputFromArray<float>(TensorShape({num_boxes, 4}), corners);
+  AddInputFromArray<float>(TensorShape({num_boxes}), scores);
+  AddInputFromArray<int>(TensorShape({}), {3});
+  AddInputFromArray<float>(TensorShape({}), {.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+  test::FillValues<int>(&expected, {0});
+  test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestInconsistentBoxAndScoreShapes) {
+  MakeOp();
+  AddInputFromArray<float>(
+      TensorShape({6, 4}),
+      {0, 0,  1, 1,  0, 0.1f,  1, 1.1f,  0, -0.1f, 1, 0.9f,
+       0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100,   1, 101});
+  AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f});
+  AddInputFromArray<int>(TensorShape({}), {30});
+  AddInputFromArray<float>(TensorShape({}), {.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  Status s = RunOpKernel();
+
+  ASSERT_FALSE(s.ok());
+  EXPECT_TRUE(
+      str_util::StrContains(s.ToString(), "scores has incompatible shape"))
+      << s;
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestInvalidIOUThreshold) {
+  MakeOp();
+  AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
+  AddInputFromArray<float>(TensorShape({1}), {.9f});
+  AddInputFromArray<int>(TensorShape({}), {3});
+  AddInputFromArray<float>(TensorShape({}), {1.2f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  Status s = RunOpKernel();
+
+  ASSERT_FALSE(s.ok());
+  EXPECT_TRUE(
+      str_util::StrContains(s.ToString(), "iou_threshold must be in [0, 1]"))
+      << s;
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) {
+  MakeOp();
+  AddInputFromArray<float>(TensorShape({0, 4}), {});
+  AddInputFromArray<float>(TensorShape({0}), {});
+  AddInputFromArray<int>(TensorShape({}), {30});
+  AddInputFromArray<float>(TensorShape({}), {.5f});
+  AddInputFromArray<float>(TensorShape({}), {0.0f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({0}));
+  test::FillValues<int>(&expected, {});
+  test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
 }  // namespace tensorflow
index 0d0677b..82330ec 100644 (file)
@@ -657,4 +657,35 @@ REGISTER_OP("NonMaxSuppressionV2")
       return Status::OK();
     });
 
+REGISTER_OP("NonMaxSuppressionV3")
+    .Input("boxes: float")
+    .Input("scores: float")
+    .Input("max_output_size: int32")
+    .Input("iou_threshold: float")
+    .Input("score_threshold: float")
+    .Output("selected_indices: int32")
+    .SetShapeFn([](InferenceContext* c) {
+      // Get inputs and validate ranks.
+      ShapeHandle boxes;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+      ShapeHandle scores;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+      ShapeHandle max_output_size;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+      ShapeHandle iou_threshold;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
+      ShapeHandle score_threshold;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+      // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+      DimensionHandle unused;
+      // The boxes[0] and scores[0] are both num_boxes.
+      TF_RETURN_IF_ERROR(
+          c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+      // The boxes[1] is 4.
+      TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
+      c->set_output(0, c->Vector(c->UnknownDim()));
+      return Status::OK();
+    });
+
 }  // namespace tensorflow
index bd5b2ae..54e27b8 100644 (file)
@@ -1772,6 +1772,7 @@ def non_max_suppression(boxes,
                         scores,
                         max_output_size,
                         iou_threshold=0.5,
+                        score_threshold=0.0,
                         name=None):
   """Greedily selects a subset of bounding boxes in descending order of score.
 
@@ -1800,6 +1801,8 @@ def non_max_suppression(boxes,
       of boxes to be selected by non max suppression.
     iou_threshold: A float representing the threshold for deciding whether boxes
       overlap too much with respect to IOU.
+    score_threshold: A float representing the threshold for deciding when to
+      remove boxes based on score.
     name: A name for the operation (optional).
 
   Returns:
@@ -1808,8 +1811,10 @@ def non_max_suppression(boxes,
   """
   with ops.name_scope(name, 'non_max_suppression'):
     iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold')
-    return gen_image_ops.non_max_suppression_v2(boxes, scores, max_output_size,
-                                                iou_threshold)
+    score_threshold = ops.convert_to_tensor(
+        score_threshold, name='score_threshold')
+    return gen_image_ops.non_max_suppression_v3(boxes, scores, max_output_size,
+                                                iou_threshold, score_threshold)
 
 
 _rgb_to_yiq_kernel = [[0.299, 0.59590059,
index 3fc64da..acc3fc4 100644 (file)
@@ -110,7 +110,7 @@ tf_module {
   }
   member_method {
     name: "non_max_suppression"
-    argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\'], "
+    argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'None\'], "
   }
   member_method {
     name: "pad_to_bounding_box"