Implement rotated generate_proposals_op without opencv dependency (CPU version)
authorJing Huang <jinghuang@fb.com>
Thu, 28 Mar 2019 23:58:54 +0000 (16:58 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 00:02:50 +0000 (17:02 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18533

Reviewed By: ezyang

Differential Revision: D14648083

fbshipit-source-id: e53e8f537100862f8015c4efa4efe4d387cef551

caffe2/operators/generate_proposals_op_test.cc
caffe2/operators/generate_proposals_op_util_nms.h
caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc
caffe2/operators/generate_proposals_op_util_nms_test.cc

index eff256d..f79cf68 100644 (file)
@@ -413,7 +413,6 @@ TEST(GenerateProposalsTest, TestRealDownSampled) {
       1e-4);
 }
 
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
 TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
   // Similar to TestRealDownSampled but for rotated boxes with angle info.
   const float angle = 0;
@@ -522,7 +521,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
   ERMatXf rois_gt(rois_gt_xyxy.rows(), 6);
   // Batch ID
   rois_gt.block(0, 0, rois_gt.rows(), 1) =
-      rois_gt_xyxy.block(0, 0, rois_gt.rows(), 0);
+      rois_gt_xyxy.block(0, 0, rois_gt.rows(), 1);
   // rois_gt in [x_ctr, y_ctr, w, h] format
   rois_gt.block(0, 1, rois_gt.rows(), 4) = utils::bbox_xyxy_to_ctrwh(
       rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4).array());
@@ -721,6 +720,5 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) {
     EXPECT_LE(std::abs(rois_data(i, 5) - expected_angle), 1e-4);
   }
 }
-#endif // CV_MAJOR_VERSION >= 3
 
 } // namespace caffe2
index b90fea8..8c5234e 100644 (file)
@@ -169,274 +169,296 @@ std::vector<int> soft_nms_cpu_upright(
   return keep;
 }
 
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
 namespace {
+const int INTERSECT_NONE = 0;
+const int INTERSECT_PARTIAL = 1;
+const int INTERSECT_FULL = 2;
+
+class RotatedRect {
+ public:
+  RotatedRect() {}
+  RotatedRect(
+      const Eigen::Vector2f& p_center,
+      const Eigen::Vector2f& p_size,
+      float p_angle)
+      : center(p_center), size(p_size), angle(p_angle) {}
+  void get_vertices(Eigen::Vector2f* pt) const {
+    // M_PI / 180. == 0.01745329251
+    double _angle = angle * 0.01745329251;
+    float b = (float)cos(_angle) * 0.5f;
+    float a = (float)sin(_angle) * 0.5f;
+
+    pt[0].x() = center.x() - a * size.y() - b * size.x();
+    pt[0].y() = center.y() + b * size.y() - a * size.x();
+    pt[1].x() = center.x() + a * size.y() - b * size.x();
+    pt[1].y() = center.y() - b * size.y() - a * size.x();
+    pt[2] = 2 * center - pt[0];
+    pt[3] = 2 * center - pt[1];
+  }
+  Eigen::Vector2f center;
+  Eigen::Vector2f size;
+  float angle;
+};
 
 template <class Derived>
-cv::RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase<Derived>& box) {
+RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase<Derived>& box) {
   CAFFE_ENFORCE_EQ(box.size(), 5);
   // cv::RotatedRect takes angle to mean clockwise rotation, but RRPN bbox
   // representation means counter-clockwise rotation.
-  return cv::RotatedRect(
-      cv::Point2f(box[0], box[1]), cv::Size2f(box[2], box[3]), -box[4]);
+  return RotatedRect(
+      Eigen::Vector2f(box[0], box[1]),
+      Eigen::Vector2f(box[2], box[3]),
+      -box[4]);
+}
+
+// Eigen doesn't seem to support 2d cross product, so we make one here
+float cross_2d(const Eigen::Vector2f& A, const Eigen::Vector2f& B) {
+  return A.x() * B.y() - B.x() * A.y();
 }
 
-// TODO: cvfix_rotatedRectangleIntersection is a replacement function for
+// rotated_rect_intersection_pts is a replacement function for
 // cv::rotatedRectangleIntersection, which has a bug due to float underflow
-// When OpenCV version is upgraded to be >= 4.0,
-// we can remove this replacement function.
 // For anyone interested, here're the PRs on OpenCV:
 // https://github.com/opencv/opencv/issues/12221
 // https://github.com/opencv/opencv/pull/12222
-int cvfix_rotatedRectangleIntersection(
-    const cv::RotatedRect& rect1,
-    const cv::RotatedRect& rect2,
-    cv::OutputArray intersectingRegion) {
+// Note that we do not check if the number of intersections is <= 8 in this case
+int rotated_rect_intersection_pts(
+    const RotatedRect& rect1,
+    const RotatedRect& rect2,
+    Eigen::Vector2f* intersections,
+    int& num) {
   // Used to test if two points are the same
   const float samePointEps = 0.00001f;
   const float EPS = 1e-14;
+  num = 0; // number of intersections
 
-  cv::Point2f vec1[4], vec2[4];
-  cv::Point2f pts1[4], pts2[4];
-
-  std::vector<cv::Point2f> intersection;
+  Eigen::Vector2f vec1[4], vec2[4], pts1[4], pts2[4];
 
-  rect1.points(pts1);
-  rect2.points(pts2);
-
-  int ret = cv::INTERSECT_FULL;
+  rect1.get_vertices(pts1);
+  rect2.get_vertices(pts2);
 
   // Specical case of rect1 == rect2
-  {
-    bool same = true;
+  bool same = true;
 
-    for (int i = 0; i < 4; i++) {
-      if (fabs(pts1[i].x - pts2[i].x) > samePointEps ||
-          (fabs(pts1[i].y - pts2[i].y) > samePointEps)) {
-        same = false;
-        break;
-      }
+  for (int i = 0; i < 4; i++) {
+    if (fabs(pts1[i].x() - pts2[i].x()) > samePointEps ||
+        (fabs(pts1[i].y() - pts2[i].y()) > samePointEps)) {
+      same = false;
+      break;
     }
+  }
 
-    if (same) {
-      intersection.resize(4);
-
-      for (int i = 0; i < 4; i++) {
-        intersection[i] = pts1[i];
-      }
-
-      cv::Mat(intersection).copyTo(intersectingRegion);
-
-      return cv::INTERSECT_FULL;
+  if (same) {
+    for (int i = 0; i < 4; i++) {
+      intersections[i] = pts1[i];
     }
+    num = 4;
+    return INTERSECT_FULL;
   }
 
   // Line vector
   // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
   for (int i = 0; i < 4; i++) {
-    vec1[i].x = pts1[(i + 1) % 4].x - pts1[i].x;
-    vec1[i].y = pts1[(i + 1) % 4].y - pts1[i].y;
-
-    vec2[i].x = pts2[(i + 1) % 4].x - pts2[i].x;
-    vec2[i].y = pts2[(i + 1) % 4].y - pts2[i].y;
+    vec1[i] = pts1[(i + 1) % 4] - pts1[i];
+    vec2[i] = pts2[(i + 1) % 4] - pts2[i];
   }
 
   // Line test - test all line combos for intersection
   for (int i = 0; i < 4; i++) {
     for (int j = 0; j < 4; j++) {
       // Solve for 2x2 Ax=b
-      float x21 = pts2[j].x - pts1[i].x;
-      float y21 = pts2[j].y - pts1[i].y;
-
-      const auto& l1 = vec1[i];
-      const auto& l2 = vec2[j];
 
       // This takes care of parallel lines
-      float det = l2.x * l1.y - l1.x * l2.y;
+      float det = cross_2d(vec2[j], vec1[i]);
       if (std::fabs(det) <= EPS) {
         continue;
       }
 
-      float t1 = (l2.x * y21 - l2.y * x21) / det;
-      float t2 = (l1.x * y21 - l1.y * x21) / det;
+      auto vec12 = pts2[j] - pts1[i];
+
+      float t1 = cross_2d(vec2[j], vec12) / det;
+      float t2 = cross_2d(vec1[i], vec12) / det;
 
       if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
-        float xi = pts1[i].x + vec1[i].x * t1;
-        float yi = pts1[i].y + vec1[i].y * t1;
-        intersection.push_back(cv::Point2f(xi, yi));
+        intersections[num++] = pts1[i] + t1 * vec1[i];
       }
     }
   }
 
-  if (!intersection.empty()) {
-    ret = cv::INTERSECT_PARTIAL;
-  }
-
   // Check for vertices from rect1 inside rect2
-  for (int i = 0; i < 4; i++) {
-    // We do a sign test to see which side the point lies.
-    // If the point all lie on the same sign for all 4 sides of the rect,
-    // then there's an intersection
-    int posSign = 0;
-    int negSign = 0;
+  {
+    const auto& AB = vec2[0];
+    const auto& DA = vec2[3];
+    auto ABdotAB = AB.squaredNorm();
+    auto ADdotAD = DA.squaredNorm();
+    for (int i = 0; i < 4; i++) {
+      // assume ABCD is the rectangle, and P is the point to be judged
+      // P is inside ABCD iff. P's projection on AB lies within AB
+      // and P's projection on AD lies within AD
 
-    float x = pts1[i].x;
-    float y = pts1[i].y;
+      auto AP = pts1[i] - pts2[0];
 
-    for (int j = 0; j < 4; j++) {
-      // line equation: Ax + By + C = 0
-      // see which side of the line this point is at
-
-      // float causes underflow!
-      // Original version:
-      // float A = -vec2[j].y;
-      // float B = vec2[j].x;
-      // float C = -(A * pts2[j].x + B * pts2[j].y);
-      // float s = A * x + B * y + C;
-
-      double A = -vec2[j].y;
-      double B = vec2[j].x;
-      double C = -(A * pts2[j].x + B * pts2[j].y);
-      double s = A * x + B * y + C;
-
-      if (s >= 0) {
-        posSign++;
-      } else {
-        negSign++;
-      }
-    }
+      auto APdotAB = AP.dot(AB);
+      auto APdotAD = -AP.dot(DA);
 
-    if (posSign == 4 || negSign == 4) {
-      intersection.push_back(pts1[i]);
+      if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
+          (APdotAD <= ADdotAD)) {
+        intersections[num++] = pts1[i];
+      }
     }
   }
 
   // Reverse the check - check for vertices from rect2 inside rect1
-  for (int i = 0; i < 4; i++) {
-    // We do a sign test to see which side the point lies.
-    // If the point all lie on the same sign for all 4 sides of the rect,
-    // then there's an intersection
-    int posSign = 0;
-    int negSign = 0;
+  {
+    const auto& AB = vec1[0];
+    const auto& DA = vec1[3];
+    auto ABdotAB = AB.squaredNorm();
+    auto ADdotAD = DA.squaredNorm();
+    for (int i = 0; i < 4; i++) {
+      auto AP = pts2[i] - pts1[0];
 
-    float x = pts2[i].x;
-    float y = pts2[i].y;
+      auto APdotAB = AP.dot(AB);
+      auto APdotAD = -AP.dot(DA);
 
-    for (int j = 0; j < 4; j++) {
-      // line equation: Ax + By + C = 0
-      // see which side of the line this point is at
-
-      // float causes underflow!
-      // Original version:
-      // float A = -vec1[j].y;
-      // float B = vec1[j].x;
-      // float C = -(A * pts1[j].x + B * pts1[j].y);
-      // float s = A*x + B*y + C;
-
-      double A = -vec1[j].y;
-      double B = vec1[j].x;
-      double C = -(A * pts1[j].x + B * pts1[j].y);
-      double s = A * x + B * y + C;
-
-      if (s >= 0) {
-        posSign++;
-      } else {
-        negSign++;
+      if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
+          (APdotAD <= ADdotAD)) {
+        intersections[num++] = pts2[i];
       }
     }
-
-    if (posSign == 4 || negSign == 4) {
-      intersection.push_back(pts2[i]);
-    }
   }
 
-  // Get rid of dupes
-  for (int i = 0; i < (int)intersection.size() - 1; i++) {
-    for (size_t j = i + 1; j < intersection.size(); j++) {
-      float dx = intersection[i].x - intersection[j].x;
-      float dy = intersection[i].y - intersection[j].y;
-      // can be a really small number, need double here
-      double d2 = dx * dx + dy * dy;
-
-      if (d2 < samePointEps * samePointEps) {
-        // Found a dupe, remove it
-        std::swap(intersection[j], intersection.back());
-        intersection.pop_back();
-        j--; // restart check
-      }
+  return num ? INTERSECT_PARTIAL : INTERSECT_NONE;
+}
+
+// Compute convex hull using Graham scan algorithm
+int convex_hull_graham(
+    const Eigen::Vector2f* p,
+    const int& num_in,
+    Eigen::Vector2f* q,
+    bool shift_to_zero = false) {
+  CAFFE_ENFORCE(num_in >= 2);
+  std::vector<int> order;
+
+  // Step 1:
+  // Find point with minimum y
+  // if more than 1 points have the same minimum y,
+  // pick the one with the mimimum x.
+  int t = 0;
+  for (int i = 1; i < num_in; i++) {
+    if (p[i].y() < p[t].y() || (p[i].y() == p[t].y() && p[i].x() < p[t].x())) {
+      t = i;
     }
   }
+  auto& s = p[t]; // starting point
 
-  if (intersection.empty()) {
-    return cv::INTERSECT_NONE;
+  // Step 2:
+  // Subtract starting point from every points (for sorting in the next step)
+  for (int i = 0; i < num_in; i++) {
+    q[i] = p[i] - s;
   }
 
-  // If this check fails then it means we're getting dupes
-  // CV_Assert(intersection.size() <= 8);
-
-  // At this point, there might still be some edge cases failing the check above
-  // However, it doesn't affect the result of polygon area,
-  // even if the number of intersections is greater than 8.
-  // Therefore, we just print out these cases for now instead of assertion.
-  // TODO: These cases should provide good reference for improving the accuracy
-  // for intersection computation above (for example, we should use
-  // cross-product/dot-product of vectors instead of line equation to
-  // judge the relationships between the points and line segments)
-
-  if (intersection.size() > 8) {
-    LOG(ERROR) << "Intersection size = " << intersection.size();
-    LOG(ERROR) << "Rect 1:";
-    for (int i = 0; i < 4; i++) {
-      LOG(ERROR) << " (" << pts1[i].x << " ," << pts1[i].y << "),";
-    }
-    LOG(ERROR) << "Rect 2:";
-    for (int i = 0; i < 4; i++) {
-      LOG(ERROR) << " (" << pts2[i].x << " ," << pts2[i].y << "),";
-    }
-    LOG(ERROR) << "Intersections:";
-    for (auto& p : intersection) {
-      LOG(ERROR) << " (" << p.x << " ," << p.y << "),";
+  // Swap the starting point to position 0
+  std::swap(q[0], q[t]);
+
+  // Step 3:
+  // Sort point 1 ~ num_in according to their relative cross-product values
+  // (essentially sorting according to angles)
+  std::sort(
+      q + 1,
+      q + num_in,
+      [](const Eigen::Vector2f& A, const Eigen::Vector2f& B) -> bool {
+        float temp = cross_2d(A, B);
+        if (fabs(temp) < 1e-6) {
+          return A.squaredNorm() < B.squaredNorm();
+        } else {
+          return temp > 0;
+        }
+      });
+
+  // Step 4:
+  // Make sure there are at least 2 points (that don't overlap with each other)
+  // in the stack
+  int k; // index of the non-overlapped second point
+  for (k = 1; k < num_in; k++) {
+    if (q[k].squaredNorm() > 1e-8)
+      break;
+  }
+  if (k == num_in) {
+    // We reach the end, which means the convex hull is just one point
+    q[0] = p[t];
+    return 1;
+  }
+  q[1] = q[k];
+  int m = 2; // 2 elements in the stack
+  // Step 5:
+  // Finally we can start the scanning process.
+  // If we find a non-convex relationship between the 3 points,
+  // we pop the previous point from the stack until the stack only has two
+  // points, or the 3-point relationship is convex again
+  for (int i = k + 1; i < num_in; i++) {
+    while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
+      m--;
     }
+    q[m++] = q[i];
+  }
+
+  // Step 6 (Optional):
+  // In general sense we need the original coordinates, so we
+  // need to shift the points back (reverting Step 2)
+  // But if we're only interested in getting the area/perimeter of the shape
+  // We can simply return.
+  if (!shift_to_zero) {
+    for (int i = 0; i < m; i++)
+      q[i] += s;
   }
 
-  cv::Mat(intersection).copyTo(intersectingRegion);
+  return m;
+}
 
-  return ret;
+double polygon_area(const Eigen::Vector2f* q, const int& m) {
+  if (m <= 2)
+    return 0;
+  double area = 0;
+  for (int i = 1; i < m - 1; i++)
+    area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0]));
+  return area / 2.0;
 }
 
 /**
  * Returns the intersection area of two rotated rectangles.
  */
 double rotated_rect_intersection(
-    const cv::RotatedRect& rect1,
-    const cv::RotatedRect& rect2) {
-  std::vector<cv::Point2f> intersectPts, orderedPts;
+    const RotatedRect& rect1,
+    const RotatedRect& rect2) {
+  // There are up to 16 intersections returned from
+  // rotated_rect_intersection_pts
+  Eigen::Vector2f intersectPts[16], orderedPts[16];
+  int num = 0; // number of intersections
 
   // Find points of intersection
 
-  // TODO: cvfix_rotatedRectangleIntersection is a replacement function for
+  // TODO: rotated_rect_intersection_pts is a replacement function for
   // cv::rotatedRectangleIntersection, which has a bug due to float underflow
-  // When OpenCV version is upgraded to be >= 4.0,
-  // we can remove this replacement function and use the following instead:
-  // auto ret = cv::rotatedRectangleIntersection(rect1, rect2, intersectPts);
   // For anyone interested, here're the PRs on OpenCV:
   // https://github.com/opencv/opencv/issues/12221
   // https://github.com/opencv/opencv/pull/12222
-  auto ret = cvfix_rotatedRectangleIntersection(rect1, rect2, intersectPts);
-  if (intersectPts.size() <= 2) {
+  // Note: it doesn't matter if #intersections is greater than 8 here
+  auto ret = rotated_rect_intersection_pts(rect1, rect2, intersectPts, num);
+  CAFFE_ENFORCE(num <= 16);
+  if (num <= 2)
     return 0.0;
-  }
 
   // If one rectangle is fully enclosed within another, return the area
   // of the smaller one early.
-  if (ret == cv::INTERSECT_FULL) {
-    return std::min(rect1.size.area(), rect2.size.area());
+  if (ret == INTERSECT_FULL) {
+    return std::min(
+        rect1.size.x() * rect1.size.y(), rect2.size.x() * rect2.size.y());
   }
 
   // Convex Hull to order the intersection points in clockwise or
   // counter-clockwise order and find the countour area.
-  cv::convexHull(intersectPts, orderedPts);
-  return cv::contourArea(orderedPts);
+  int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true);
+  return polygon_area(orderedPts, num_convex);
 }
 
 } // namespace
@@ -507,7 +529,7 @@ std::vector<int> nms_cpu_rotated(
   auto heights = proposals.col(3);
   EArrX areas = widths * heights;
 
-  std::vector<cv::RotatedRect> rotated_rects(proposals.rows());
+  std::vector<RotatedRect> rotated_rects(proposals.rows());
   for (int i = 0; i < proposals.rows(); ++i) {
     rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i));
   }
@@ -568,7 +590,7 @@ std::vector<int> soft_nms_cpu_rotated(
   auto heights = proposals.col(3);
   EArrX areas = widths * heights;
 
-  std::vector<cv::RotatedRect> rotated_rects(proposals.rows());
+  std::vector<RotatedRect> rotated_rects(proposals.rows());
   for (int i = 0; i < proposals.rows(); ++i) {
     rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i));
   }
@@ -627,7 +649,6 @@ std::vector<int> soft_nms_cpu_rotated(
 
   return keep;
 }
-#endif // CV_MAJOR_VERSION >= 3
 
 template <class Derived1, class Derived2>
 std::vector<int> nms_cpu(
@@ -636,7 +657,6 @@ std::vector<int> nms_cpu(
     const std::vector<int>& sorted_indices,
     float thresh,
     int topN = -1) {
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
   CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5);
   if (proposals.cols() == 4) {
     // Upright boxes
@@ -645,9 +665,6 @@ std::vector<int> nms_cpu(
     // Rotated boxes with angle info
     return nms_cpu_rotated(proposals, scores, sorted_indices, thresh, topN);
   }
-#else
-  return nms_cpu_upright(proposals, scores, sorted_indices, thresh, topN);
-#endif // CV_MAJOR_VERSION >= 3
 }
 
 // Greedy non-maximum suppression for proposed bounding boxes
@@ -686,7 +703,6 @@ std::vector<int> soft_nms_cpu(
     float score_thresh = 0.001,
     unsigned int method = 1,
     int topN = -1) {
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
   CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5);
   if (proposals.cols() == 4) {
     // Upright boxes
@@ -713,18 +729,6 @@ std::vector<int> soft_nms_cpu(
         method,
         topN);
   }
-#else
-  return soft_nms_cpu_upright(
-      out_scores,
-      proposals,
-      scores,
-      indices,
-      sigma,
-      overlap_thresh,
-      score_thresh,
-      method,
-      topN);
-#endif // CV_MAJOR_VERSION >= 3
 }
 
 template <class Derived1, class Derived2, class Derived3>
index 8999fda..372acca 100644 (file)
@@ -379,13 +379,9 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) {
     return;
   const int box_dim = 5;
   // Same boxes in TestNMS with (x_ctr, y_ctr, w, h, angle) format
-  std::vector<float> boxes = {
-    30, 35, 41, 51, 0,
-    29.5, 36, 38, 49, 0,
-    24, 29.5, 33, 42, 0,
-    125, 120, 51, 41, 0,
-    127, 124.5, 57, 30, 0
-  };
+  std::vector<float> boxes = {30, 35, 41,   51,    0,  29.5, 36,  38,  49,
+                              0,  24, 29.5, 33,    42, 0,    125, 120, 51,
+                              41, 0,  127,  124.5, 57, 30,   0};
 
   std::vector<float> scores = {0.5f, 0.7f, 0.6f, 0.9f, 0.8f};
 
@@ -466,7 +462,6 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) {
   cuda_context.FinishDeviceComputation();
 }
 
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
 TEST(UtilsNMSTest, TestPerfRotatedNMS) {
   if (!HasCudaGPU())
     return;
@@ -678,6 +673,5 @@ TEST(UtilsNMSTest, GPUEqualsCPURotatedCorrectnessTest) {
     }
   }
 }
-#endif // CV_MAJOR_VERSION >= 3
 
 } // namespace caffe2
index b7da35b..8e8b5f1 100644 (file)
@@ -212,7 +212,6 @@ TEST(UtilsNMSTest, TestSoftNMS) {
   }
 }
 
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
 TEST(UtilsNMSTest, TestNMSRotatedAngle0) {
   // Same inputs as TestNMS, but in RRPN format with angle 0 for testing
   // nms_cpu_rotated
@@ -389,6 +388,42 @@ TEST(UtilsNMSTest, TestSoftNMSRotatedAngle0) {
 
 TEST(UtilsNMSTest, RotatedBBoxOverlaps) {
   {
+    // One box is fully within another box, the angle is irrelavant
+    int M = 2, N = 3;
+    Eigen::ArrayXXf boxes(M, 5);
+    for (int i = 0; i < M; i++) {
+      boxes.row(i) << 0, 0, 5, 6, (360.0 / M - 180.0);
+    }
+
+    Eigen::ArrayXXf query_boxes(N, 5);
+    for (int i = 0; i < N; i++) {
+      query_boxes.row(i) << 0, 0, 3, 3, (360.0 / M - 180.0);
+    }
+
+    Eigen::ArrayXXf expected(M, N);
+    // 0.3 == (3 * 3) / (5 * 6)
+    expected.fill(0.3);
+
+    auto actual = utils::bbox_overlaps_rotated(boxes, query_boxes);
+    EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
+  }
+
+  {
+    // Angle 0
+    Eigen::ArrayXXf boxes(1, 5);
+    boxes << 39.500000, 50.451096, 80.000000, 18.097809, -0.000000;
+
+    Eigen::ArrayXXf query_boxes(1, 5);
+    query_boxes << 39.120628, 41.014862, 79.241257, 36.427757, -0.000000;
+
+    Eigen::ArrayXXf expected(1, 1);
+    expected << 0.48346716237;
+
+    auto actual = utils::bbox_overlaps_rotated(boxes, query_boxes);
+    EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
+  }
+
+  {
     // Simple case with angle 0 (upright boxes)
     Eigen::ArrayXXf boxes(2, 5);
     boxes << 10.5, 15.5, 21, 31, 0, 14.0, 17, 4, 10, 0;
@@ -436,6 +471,5 @@ TEST(UtilsNMSTest, RotatedBBoxOverlaps) {
     EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
   }
 }
-#endif // CV_MAJOR_VERSION >= 3
 
 } // namespace caffe2