From 6052d04100cff9bddcdd5a1444405f5daf7b3b74 Mon Sep 17 00:00:00 2001 From: Jing Huang Date: Fri, 22 Mar 2019 18:12:27 -0700 Subject: [PATCH] Implement rotated generate_proposals_op without opencv dependency (1.8x faster) (#18010) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18010 [C2] Implement rotated generate_proposals_op without opencv dependency. Reviewed By: newstzpz Differential Revision: D14446895 fbshipit-source-id: 847f2443e645f8cae1327dfbaa111c48875ca9be --- caffe2/operators/generate_proposals_op_gpu_test.cc | 2 - caffe2/operators/generate_proposals_op_test.cc | 2 - caffe2/operators/generate_proposals_op_util_nms.h | 358 +++++++++------------ .../generate_proposals_op_util_nms_gpu_test.cc | 2 - .../generate_proposals_op_util_nms_test.cc | 2 - 5 files changed, 156 insertions(+), 210 deletions(-) diff --git a/caffe2/operators/generate_proposals_op_gpu_test.cc b/caffe2/operators/generate_proposals_op_gpu_test.cc index da3f56a..21d0bf0 100644 --- a/caffe2/operators/generate_proposals_op_gpu_test.cc +++ b/caffe2/operators/generate_proposals_op_gpu_test.cc @@ -234,7 +234,6 @@ TEST(GenerateProposalsTest, TestRealDownSampledGPU) { 1e-4); } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0GPU) { // Similar to TestRealDownSampledGPU but for rotated boxes with angle info. if (!HasCudaGPU()) @@ -637,6 +636,5 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedGPU) { 0, 1e-4); } -#endif // CV_MAJOR_VERSION >= 3 } // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op_test.cc b/caffe2/operators/generate_proposals_op_test.cc index eff256d..557c826 100644 --- a/caffe2/operators/generate_proposals_op_test.cc +++ b/caffe2/operators/generate_proposals_op_test.cc @@ -413,7 +413,6 @@ TEST(GenerateProposalsTest, TestRealDownSampled) { 1e-4); } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) { // Similar to TestRealDownSampled but for rotated boxes with angle info. const float angle = 0; @@ -721,6 +720,5 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) { EXPECT_LE(std::abs(rois_data(i, 5) - expected_angle), 1e-4); } } -#endif // CV_MAJOR_VERSION >= 3 } // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op_util_nms.h b/caffe2/operators/generate_proposals_op_util_nms.h index b90fea8..d2d83ae 100644 --- a/caffe2/operators/generate_proposals_op_util_nms.h +++ b/caffe2/operators/generate_proposals_op_util_nms.h @@ -169,274 +169,246 @@ std::vector soft_nms_cpu_upright( return keep; } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) namespace { +const int INTERSECT_NONE = 0; +const int INTERSECT_PARTIAL = 1; +const int INTERSECT_FULL = 2; + +class RotatedRect { + public: + RotatedRect() {} + RotatedRect( + const Eigen::Vector2f& p_center, + const Eigen::Vector2f& p_size, + float p_angle) + : center(p_center), size(p_size), angle(p_angle) {} + void get_vertices(Eigen::Vector2f* pt) const { + // M_PI / 180. == 0.01745329251 + double _angle = angle * 0.01745329251; + float b = (float)cos(_angle) * 0.5f; + float a = (float)sin(_angle) * 0.5f; + + pt[0].x() = center.x() - a * size.y() - b * size.x(); + pt[0].y() = center.y() + b * size.y() - a * size.x(); + pt[1].x() = center.x() + a * size.y() - b * size.x(); + pt[1].y() = center.y() - b * size.y() - a * size.x(); + pt[2] = 2 * center - pt[0]; + pt[3] = 2 * center - pt[1]; + } + Eigen::Vector2f center; + Eigen::Vector2f size; + float angle; +}; template -cv::RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase& box) { +RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase& box) { CAFFE_ENFORCE_EQ(box.size(), 5); // cv::RotatedRect takes angle to mean clockwise rotation, but RRPN bbox // representation means counter-clockwise rotation. - return cv::RotatedRect( - cv::Point2f(box[0], box[1]), cv::Size2f(box[2], box[3]), -box[4]); + return RotatedRect( + Eigen::Vector2f(box[0], box[1]), + Eigen::Vector2f(box[2], box[3]), + -box[4]); +} + +// Eigen doesn't seem to support 2d cross product, so we make one here +float cross_2d(const Eigen::Vector2f& A, const Eigen::Vector2f& B) { + return A.x() * B.y() - B.x() * A.y(); } -// TODO: cvfix_rotatedRectangleIntersection is a replacement function for +// rotated_rect_intersection_pts is a replacement function for // cv::rotatedRectangleIntersection, which has a bug due to float underflow -// When OpenCV version is upgraded to be >= 4.0, -// we can remove this replacement function. // For anyone interested, here're the PRs on OpenCV: // https://github.com/opencv/opencv/issues/12221 // https://github.com/opencv/opencv/pull/12222 -int cvfix_rotatedRectangleIntersection( - const cv::RotatedRect& rect1, - const cv::RotatedRect& rect2, - cv::OutputArray intersectingRegion) { +// Note that we do not check if the number of intersections is <= 8 in this case +int rotated_rect_intersection_pts( + const RotatedRect& rect1, + const RotatedRect& rect2, + Eigen::Vector2f* intersections, + int& num) { // Used to test if two points are the same const float samePointEps = 0.00001f; const float EPS = 1e-14; + num = 0; // number of intersections - cv::Point2f vec1[4], vec2[4]; - cv::Point2f pts1[4], pts2[4]; - - std::vector intersection; + Eigen::Vector2f vec1[4], vec2[4], pts1[4], pts2[4]; - rect1.points(pts1); - rect2.points(pts2); - - int ret = cv::INTERSECT_FULL; + rect1.get_vertices(pts1); + rect2.get_vertices(pts2); // Specical case of rect1 == rect2 - { - bool same = true; + bool same = true; - for (int i = 0; i < 4; i++) { - if (fabs(pts1[i].x - pts2[i].x) > samePointEps || - (fabs(pts1[i].y - pts2[i].y) > samePointEps)) { - same = false; - break; - } + for (int i = 0; i < 4; i++) { + if (fabs(pts1[i].x() - pts2[i].x()) > samePointEps || + (fabs(pts1[i].y() - pts2[i].y()) > samePointEps)) { + same = false; + break; } + } - if (same) { - intersection.resize(4); - - for (int i = 0; i < 4; i++) { - intersection[i] = pts1[i]; - } - - cv::Mat(intersection).copyTo(intersectingRegion); - - return cv::INTERSECT_FULL; + if (same) { + for (int i = 0; i < 4; i++) { + intersections[i] = pts1[i]; } + num = 4; + return INTERSECT_FULL; } // Line vector // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] for (int i = 0; i < 4; i++) { - vec1[i].x = pts1[(i + 1) % 4].x - pts1[i].x; - vec1[i].y = pts1[(i + 1) % 4].y - pts1[i].y; - - vec2[i].x = pts2[(i + 1) % 4].x - pts2[i].x; - vec2[i].y = pts2[(i + 1) % 4].y - pts2[i].y; + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; } // Line test - test all line combos for intersection for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { // Solve for 2x2 Ax=b - float x21 = pts2[j].x - pts1[i].x; - float y21 = pts2[j].y - pts1[i].y; - - const auto& l1 = vec1[i]; - const auto& l2 = vec2[j]; // This takes care of parallel lines - float det = l2.x * l1.y - l1.x * l2.y; + float det = cross_2d(vec2[j], vec1[i]); if (std::fabs(det) <= EPS) { continue; } - float t1 = (l2.x * y21 - l2.y * x21) / det; - float t2 = (l1.x * y21 - l1.y * x21) / det; + auto vec12 = pts2[j] - pts1[i]; + + float t1 = cross_2d(vec2[j], vec12) / det; + float t2 = cross_2d(vec1[i], vec12) / det; if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) { - float xi = pts1[i].x + vec1[i].x * t1; - float yi = pts1[i].y + vec1[i].y * t1; - intersection.push_back(cv::Point2f(xi, yi)); + intersections[num++] = pts1[i] + t1 * vec1[i]; } } } - if (!intersection.empty()) { - ret = cv::INTERSECT_PARTIAL; - } - // Check for vertices from rect1 inside rect2 - for (int i = 0; i < 4; i++) { - // We do a sign test to see which side the point lies. - // If the point all lie on the same sign for all 4 sides of the rect, - // then there's an intersection - int posSign = 0; - int negSign = 0; + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = AB.squaredNorm(); + auto ADdotAD = DA.squaredNorm(); + for (int i = 0; i < 4; i++) { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD - float x = pts1[i].x; - float y = pts1[i].y; + auto AP = pts1[i] - pts2[0]; - for (int j = 0; j < 4; j++) { - // line equation: Ax + By + C = 0 - // see which side of the line this point is at - - // float causes underflow! - // Original version: - // float A = -vec2[j].y; - // float B = vec2[j].x; - // float C = -(A * pts2[j].x + B * pts2[j].y); - // float s = A * x + B * y + C; - - double A = -vec2[j].y; - double B = vec2[j].x; - double C = -(A * pts2[j].x + B * pts2[j].y); - double s = A * x + B * y + C; - - if (s >= 0) { - posSign++; - } else { - negSign++; - } - } + auto APdotAB = AP.dot(AB); + auto APdotAD = -AP.dot(DA); - if (posSign == 4 || negSign == 4) { - intersection.push_back(pts1[i]); + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && + (APdotAD <= ADdotAD)) { + intersections[num++] = pts1[i]; + } } } // Reverse the check - check for vertices from rect2 inside rect1 - for (int i = 0; i < 4; i++) { - // We do a sign test to see which side the point lies. - // If the point all lie on the same sign for all 4 sides of the rect, - // then there's an intersection - int posSign = 0; - int negSign = 0; - - float x = pts2[i].x; - float y = pts2[i].y; - - for (int j = 0; j < 4; j++) { - // line equation: Ax + By + C = 0 - // see which side of the line this point is at - - // float causes underflow! - // Original version: - // float A = -vec1[j].y; - // float B = vec1[j].x; - // float C = -(A * pts1[j].x + B * pts1[j].y); - // float s = A*x + B*y + C; - - double A = -vec1[j].y; - double B = vec1[j].x; - double C = -(A * pts1[j].x + B * pts1[j].y); - double s = A * x + B * y + C; - - if (s >= 0) { - posSign++; - } else { - negSign++; - } - } + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = AB.squaredNorm(); + auto ADdotAD = DA.squaredNorm(); + for (int i = 0; i < 4; i++) { + auto AP = pts2[i] - pts1[0]; - if (posSign == 4 || negSign == 4) { - intersection.push_back(pts2[i]); - } - } + auto APdotAB = AP.dot(AB); + auto APdotAD = -AP.dot(DA); - // Get rid of dupes - for (int i = 0; i < (int)intersection.size() - 1; i++) { - for (size_t j = i + 1; j < intersection.size(); j++) { - float dx = intersection[i].x - intersection[j].x; - float dy = intersection[i].y - intersection[j].y; - // can be a really small number, need double here - double d2 = dx * dx + dy * dy; - - if (d2 < samePointEps * samePointEps) { - // Found a dupe, remove it - std::swap(intersection[j], intersection.back()); - intersection.pop_back(); - j--; // restart check + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && + (APdotAD <= ADdotAD)) { + intersections[num++] = pts2[i]; } } } - if (intersection.empty()) { - return cv::INTERSECT_NONE; - } + return num ? INTERSECT_PARTIAL : INTERSECT_NONE; +} - // If this check fails then it means we're getting dupes - // CV_Assert(intersection.size() <= 8); - - // At this point, there might still be some edge cases failing the check above - // However, it doesn't affect the result of polygon area, - // even if the number of intersections is greater than 8. - // Therefore, we just print out these cases for now instead of assertion. - // TODO: These cases should provide good reference for improving the accuracy - // for intersection computation above (for example, we should use - // cross-product/dot-product of vectors instead of line equation to - // judge the relationships between the points and line segments) - - if (intersection.size() > 8) { - LOG(ERROR) << "Intersection size = " << intersection.size(); - LOG(ERROR) << "Rect 1:"; - for (int i = 0; i < 4; i++) { - LOG(ERROR) << " (" << pts1[i].x << " ," << pts1[i].y << "),"; +// Compute convex hull using Graham scan algorithm +int convex_hull_graham( + const Eigen::Vector2f* p, + const int& num_in, + Eigen::Vector2f* q) { + int m = 0; + std::vector order; + + // start from point with minimum y + int t = 0; + for (int i = 0; i < num_in; i++) { + if (p[i].y() < p[t].y()) { + t = i; } - LOG(ERROR) << "Rect 2:"; - for (int i = 0; i < 4; i++) { - LOG(ERROR) << " (" << pts2[i].x << " ," << pts2[i].y << "),"; + } + order.push_back(t); + do { + t = order[m] + 1; + if (t >= num_in) { + t = 0; } - LOG(ERROR) << "Intersections:"; - for (auto& p : intersection) { - LOG(ERROR) << " (" << p.x << " ," << p.y << "),"; + for (int i = 0; i < num_in; i++) { + if (cross_2d(p[i] - p[order[m]], p[t] - p[order[m]]) < 0) { + t = i; + } } - } - - cv::Mat(intersection).copyTo(intersectingRegion); + m++; + order.push_back(t); + } while (order[m] != order[0]); + // ignore the ending point (which is the same as the starting point) + for (int i = 0; i < m; i++) + q[i] = p[order[i]]; + return m; +} - return ret; +double polygon_area(const Eigen::Vector2f* q, const int& m) { + if (m <= 2) + return 0; + double area = 0; + for (int i = 1; i < m - 1; i++) + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); + return area / 2.0; } /** * Returns the intersection area of two rotated rectangles. */ double rotated_rect_intersection( - const cv::RotatedRect& rect1, - const cv::RotatedRect& rect2) { - std::vector intersectPts, orderedPts; + const RotatedRect& rect1, + const RotatedRect& rect2) { + // There are up to 16 intersections returned from + // rotated_rect_intersection_pts + Eigen::Vector2f intersectPts[16], orderedPts[16]; + int num = 0; // number of intersections // Find points of intersection - // TODO: cvfix_rotatedRectangleIntersection is a replacement function for + // TODO: rotated_rect_intersection_pts is a replacement function for // cv::rotatedRectangleIntersection, which has a bug due to float underflow - // When OpenCV version is upgraded to be >= 4.0, - // we can remove this replacement function and use the following instead: - // auto ret = cv::rotatedRectangleIntersection(rect1, rect2, intersectPts); // For anyone interested, here're the PRs on OpenCV: // https://github.com/opencv/opencv/issues/12221 // https://github.com/opencv/opencv/pull/12222 - auto ret = cvfix_rotatedRectangleIntersection(rect1, rect2, intersectPts); - if (intersectPts.size() <= 2) { + // Note: it doesn't matter if #intersections is greater than 8 here + auto ret = rotated_rect_intersection_pts(rect1, rect2, intersectPts, num); + CAFFE_ENFORCE(num <= 16); + if (num <= 2) return 0.0; - } // If one rectangle is fully enclosed within another, return the area // of the smaller one early. - if (ret == cv::INTERSECT_FULL) { - return std::min(rect1.size.area(), rect2.size.area()); + if (ret == INTERSECT_FULL) { + return std::min( + rect1.size.x() * rect1.size.y(), rect2.size.x() * rect2.size.y()); } // Convex Hull to order the intersection points in clockwise or // counter-clockwise order and find the countour area. - cv::convexHull(intersectPts, orderedPts); - return cv::contourArea(orderedPts); + int num_convex = convex_hull_graham(intersectPts, num, orderedPts); + return polygon_area(orderedPts, num_convex); } } // namespace @@ -507,7 +479,7 @@ std::vector nms_cpu_rotated( auto heights = proposals.col(3); EArrX areas = widths * heights; - std::vector rotated_rects(proposals.rows()); + std::vector rotated_rects(proposals.rows()); for (int i = 0; i < proposals.rows(); ++i) { rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i)); } @@ -568,7 +540,7 @@ std::vector soft_nms_cpu_rotated( auto heights = proposals.col(3); EArrX areas = widths * heights; - std::vector rotated_rects(proposals.rows()); + std::vector rotated_rects(proposals.rows()); for (int i = 0; i < proposals.rows(); ++i) { rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i)); } @@ -627,7 +599,6 @@ std::vector soft_nms_cpu_rotated( return keep; } -#endif // CV_MAJOR_VERSION >= 3 template std::vector nms_cpu( @@ -636,7 +607,6 @@ std::vector nms_cpu( const std::vector& sorted_indices, float thresh, int topN = -1) { -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5); if (proposals.cols() == 4) { // Upright boxes @@ -645,9 +615,6 @@ std::vector nms_cpu( // Rotated boxes with angle info return nms_cpu_rotated(proposals, scores, sorted_indices, thresh, topN); } -#else - return nms_cpu_upright(proposals, scores, sorted_indices, thresh, topN); -#endif // CV_MAJOR_VERSION >= 3 } // Greedy non-maximum suppression for proposed bounding boxes @@ -686,7 +653,6 @@ std::vector soft_nms_cpu( float score_thresh = 0.001, unsigned int method = 1, int topN = -1) { -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5); if (proposals.cols() == 4) { // Upright boxes @@ -713,18 +679,6 @@ std::vector soft_nms_cpu( method, topN); } -#else - return soft_nms_cpu_upright( - out_scores, - proposals, - scores, - indices, - sigma, - overlap_thresh, - score_thresh, - method, - topN); -#endif // CV_MAJOR_VERSION >= 3 } template diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc index 8999fda..762c111 100644 --- a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc +++ b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc @@ -466,7 +466,6 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) { cuda_context.FinishDeviceComputation(); } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) TEST(UtilsNMSTest, TestPerfRotatedNMS) { if (!HasCudaGPU()) return; @@ -678,6 +677,5 @@ TEST(UtilsNMSTest, GPUEqualsCPURotatedCorrectnessTest) { } } } -#endif // CV_MAJOR_VERSION >= 3 } // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op_util_nms_test.cc b/caffe2/operators/generate_proposals_op_util_nms_test.cc index b7da35b..a2b9dbe 100644 --- a/caffe2/operators/generate_proposals_op_util_nms_test.cc +++ b/caffe2/operators/generate_proposals_op_util_nms_test.cc @@ -212,7 +212,6 @@ TEST(UtilsNMSTest, TestSoftNMS) { } } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) TEST(UtilsNMSTest, TestNMSRotatedAngle0) { // Same inputs as TestNMS, but in RRPN format with angle 0 for testing // nms_cpu_rotated @@ -436,6 +435,5 @@ TEST(UtilsNMSTest, RotatedBBoxOverlaps) { EXPECT_TRUE(((expected - actual).abs() < 1e-6).all()); } } -#endif // CV_MAJOR_VERSION >= 3 } // namespace caffe2 -- 2.7.4