From: Jing Huang Date: Thu, 28 Mar 2019 23:58:54 +0000 (-0700) Subject: Implement rotated generate_proposals_op without opencv dependency (CPU version) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~575 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=11ac0cf276d804ec0b1eb07bcb10eaa988282bcf;p=platform%2Fupstream%2Fpytorch.git Implement rotated generate_proposals_op without opencv dependency (CPU version) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18533 Reviewed By: ezyang Differential Revision: D14648083 fbshipit-source-id: e53e8f537100862f8015c4efa4efe4d387cef551 --- diff --git a/caffe2/operators/generate_proposals_op_test.cc b/caffe2/operators/generate_proposals_op_test.cc index eff256d..f79cf68 100644 --- a/caffe2/operators/generate_proposals_op_test.cc +++ b/caffe2/operators/generate_proposals_op_test.cc @@ -413,7 +413,6 @@ TEST(GenerateProposalsTest, TestRealDownSampled) { 1e-4); } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) { // Similar to TestRealDownSampled but for rotated boxes with angle info. const float angle = 0; @@ -522,7 +521,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) { ERMatXf rois_gt(rois_gt_xyxy.rows(), 6); // Batch ID rois_gt.block(0, 0, rois_gt.rows(), 1) = - rois_gt_xyxy.block(0, 0, rois_gt.rows(), 0); + rois_gt_xyxy.block(0, 0, rois_gt.rows(), 1); // rois_gt in [x_ctr, y_ctr, w, h] format rois_gt.block(0, 1, rois_gt.rows(), 4) = utils::bbox_xyxy_to_ctrwh( rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4).array()); @@ -721,6 +720,5 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) { EXPECT_LE(std::abs(rois_data(i, 5) - expected_angle), 1e-4); } } -#endif // CV_MAJOR_VERSION >= 3 } // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op_util_nms.h b/caffe2/operators/generate_proposals_op_util_nms.h index b90fea8..8c5234e 100644 --- a/caffe2/operators/generate_proposals_op_util_nms.h +++ b/caffe2/operators/generate_proposals_op_util_nms.h @@ -169,274 +169,296 @@ std::vector soft_nms_cpu_upright( return keep; } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) namespace { +const int INTERSECT_NONE = 0; +const int INTERSECT_PARTIAL = 1; +const int INTERSECT_FULL = 2; + +class RotatedRect { + public: + RotatedRect() {} + RotatedRect( + const Eigen::Vector2f& p_center, + const Eigen::Vector2f& p_size, + float p_angle) + : center(p_center), size(p_size), angle(p_angle) {} + void get_vertices(Eigen::Vector2f* pt) const { + // M_PI / 180. == 0.01745329251 + double _angle = angle * 0.01745329251; + float b = (float)cos(_angle) * 0.5f; + float a = (float)sin(_angle) * 0.5f; + + pt[0].x() = center.x() - a * size.y() - b * size.x(); + pt[0].y() = center.y() + b * size.y() - a * size.x(); + pt[1].x() = center.x() + a * size.y() - b * size.x(); + pt[1].y() = center.y() - b * size.y() - a * size.x(); + pt[2] = 2 * center - pt[0]; + pt[3] = 2 * center - pt[1]; + } + Eigen::Vector2f center; + Eigen::Vector2f size; + float angle; +}; template -cv::RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase& box) { +RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase& box) { CAFFE_ENFORCE_EQ(box.size(), 5); // cv::RotatedRect takes angle to mean clockwise rotation, but RRPN bbox // representation means counter-clockwise rotation. - return cv::RotatedRect( - cv::Point2f(box[0], box[1]), cv::Size2f(box[2], box[3]), -box[4]); + return RotatedRect( + Eigen::Vector2f(box[0], box[1]), + Eigen::Vector2f(box[2], box[3]), + -box[4]); +} + +// Eigen doesn't seem to support 2d cross product, so we make one here +float cross_2d(const Eigen::Vector2f& A, const Eigen::Vector2f& B) { + return A.x() * B.y() - B.x() * A.y(); } -// TODO: cvfix_rotatedRectangleIntersection is a replacement function for +// rotated_rect_intersection_pts is a replacement function for // cv::rotatedRectangleIntersection, which has a bug due to float underflow -// When OpenCV version is upgraded to be >= 4.0, -// we can remove this replacement function. // For anyone interested, here're the PRs on OpenCV: // https://github.com/opencv/opencv/issues/12221 // https://github.com/opencv/opencv/pull/12222 -int cvfix_rotatedRectangleIntersection( - const cv::RotatedRect& rect1, - const cv::RotatedRect& rect2, - cv::OutputArray intersectingRegion) { +// Note that we do not check if the number of intersections is <= 8 in this case +int rotated_rect_intersection_pts( + const RotatedRect& rect1, + const RotatedRect& rect2, + Eigen::Vector2f* intersections, + int& num) { // Used to test if two points are the same const float samePointEps = 0.00001f; const float EPS = 1e-14; + num = 0; // number of intersections - cv::Point2f vec1[4], vec2[4]; - cv::Point2f pts1[4], pts2[4]; - - std::vector intersection; + Eigen::Vector2f vec1[4], vec2[4], pts1[4], pts2[4]; - rect1.points(pts1); - rect2.points(pts2); - - int ret = cv::INTERSECT_FULL; + rect1.get_vertices(pts1); + rect2.get_vertices(pts2); // Specical case of rect1 == rect2 - { - bool same = true; + bool same = true; - for (int i = 0; i < 4; i++) { - if (fabs(pts1[i].x - pts2[i].x) > samePointEps || - (fabs(pts1[i].y - pts2[i].y) > samePointEps)) { - same = false; - break; - } + for (int i = 0; i < 4; i++) { + if (fabs(pts1[i].x() - pts2[i].x()) > samePointEps || + (fabs(pts1[i].y() - pts2[i].y()) > samePointEps)) { + same = false; + break; } + } - if (same) { - intersection.resize(4); - - for (int i = 0; i < 4; i++) { - intersection[i] = pts1[i]; - } - - cv::Mat(intersection).copyTo(intersectingRegion); - - return cv::INTERSECT_FULL; + if (same) { + for (int i = 0; i < 4; i++) { + intersections[i] = pts1[i]; } + num = 4; + return INTERSECT_FULL; } // Line vector // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] for (int i = 0; i < 4; i++) { - vec1[i].x = pts1[(i + 1) % 4].x - pts1[i].x; - vec1[i].y = pts1[(i + 1) % 4].y - pts1[i].y; - - vec2[i].x = pts2[(i + 1) % 4].x - pts2[i].x; - vec2[i].y = pts2[(i + 1) % 4].y - pts2[i].y; + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; } // Line test - test all line combos for intersection for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { // Solve for 2x2 Ax=b - float x21 = pts2[j].x - pts1[i].x; - float y21 = pts2[j].y - pts1[i].y; - - const auto& l1 = vec1[i]; - const auto& l2 = vec2[j]; // This takes care of parallel lines - float det = l2.x * l1.y - l1.x * l2.y; + float det = cross_2d(vec2[j], vec1[i]); if (std::fabs(det) <= EPS) { continue; } - float t1 = (l2.x * y21 - l2.y * x21) / det; - float t2 = (l1.x * y21 - l1.y * x21) / det; + auto vec12 = pts2[j] - pts1[i]; + + float t1 = cross_2d(vec2[j], vec12) / det; + float t2 = cross_2d(vec1[i], vec12) / det; if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) { - float xi = pts1[i].x + vec1[i].x * t1; - float yi = pts1[i].y + vec1[i].y * t1; - intersection.push_back(cv::Point2f(xi, yi)); + intersections[num++] = pts1[i] + t1 * vec1[i]; } } } - if (!intersection.empty()) { - ret = cv::INTERSECT_PARTIAL; - } - // Check for vertices from rect1 inside rect2 - for (int i = 0; i < 4; i++) { - // We do a sign test to see which side the point lies. - // If the point all lie on the same sign for all 4 sides of the rect, - // then there's an intersection - int posSign = 0; - int negSign = 0; + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = AB.squaredNorm(); + auto ADdotAD = DA.squaredNorm(); + for (int i = 0; i < 4; i++) { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD - float x = pts1[i].x; - float y = pts1[i].y; + auto AP = pts1[i] - pts2[0]; - for (int j = 0; j < 4; j++) { - // line equation: Ax + By + C = 0 - // see which side of the line this point is at - - // float causes underflow! - // Original version: - // float A = -vec2[j].y; - // float B = vec2[j].x; - // float C = -(A * pts2[j].x + B * pts2[j].y); - // float s = A * x + B * y + C; - - double A = -vec2[j].y; - double B = vec2[j].x; - double C = -(A * pts2[j].x + B * pts2[j].y); - double s = A * x + B * y + C; - - if (s >= 0) { - posSign++; - } else { - negSign++; - } - } + auto APdotAB = AP.dot(AB); + auto APdotAD = -AP.dot(DA); - if (posSign == 4 || negSign == 4) { - intersection.push_back(pts1[i]); + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && + (APdotAD <= ADdotAD)) { + intersections[num++] = pts1[i]; + } } } // Reverse the check - check for vertices from rect2 inside rect1 - for (int i = 0; i < 4; i++) { - // We do a sign test to see which side the point lies. - // If the point all lie on the same sign for all 4 sides of the rect, - // then there's an intersection - int posSign = 0; - int negSign = 0; + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = AB.squaredNorm(); + auto ADdotAD = DA.squaredNorm(); + for (int i = 0; i < 4; i++) { + auto AP = pts2[i] - pts1[0]; - float x = pts2[i].x; - float y = pts2[i].y; + auto APdotAB = AP.dot(AB); + auto APdotAD = -AP.dot(DA); - for (int j = 0; j < 4; j++) { - // line equation: Ax + By + C = 0 - // see which side of the line this point is at - - // float causes underflow! - // Original version: - // float A = -vec1[j].y; - // float B = vec1[j].x; - // float C = -(A * pts1[j].x + B * pts1[j].y); - // float s = A*x + B*y + C; - - double A = -vec1[j].y; - double B = vec1[j].x; - double C = -(A * pts1[j].x + B * pts1[j].y); - double s = A * x + B * y + C; - - if (s >= 0) { - posSign++; - } else { - negSign++; + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && + (APdotAD <= ADdotAD)) { + intersections[num++] = pts2[i]; } } - - if (posSign == 4 || negSign == 4) { - intersection.push_back(pts2[i]); - } } - // Get rid of dupes - for (int i = 0; i < (int)intersection.size() - 1; i++) { - for (size_t j = i + 1; j < intersection.size(); j++) { - float dx = intersection[i].x - intersection[j].x; - float dy = intersection[i].y - intersection[j].y; - // can be a really small number, need double here - double d2 = dx * dx + dy * dy; - - if (d2 < samePointEps * samePointEps) { - // Found a dupe, remove it - std::swap(intersection[j], intersection.back()); - intersection.pop_back(); - j--; // restart check - } + return num ? INTERSECT_PARTIAL : INTERSECT_NONE; +} + +// Compute convex hull using Graham scan algorithm +int convex_hull_graham( + const Eigen::Vector2f* p, + const int& num_in, + Eigen::Vector2f* q, + bool shift_to_zero = false) { + CAFFE_ENFORCE(num_in >= 2); + std::vector order; + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the mimimum x. + int t = 0; + for (int i = 1; i < num_in; i++) { + if (p[i].y() < p[t].y() || (p[i].y() == p[t].y() && p[i].x() < p[t].x())) { + t = i; } } + auto& s = p[t]; // starting point - if (intersection.empty()) { - return cv::INTERSECT_NONE; + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) { + q[i] = p[i] - s; } - // If this check fails then it means we're getting dupes - // CV_Assert(intersection.size() <= 8); - - // At this point, there might still be some edge cases failing the check above - // However, it doesn't affect the result of polygon area, - // even if the number of intersections is greater than 8. - // Therefore, we just print out these cases for now instead of assertion. - // TODO: These cases should provide good reference for improving the accuracy - // for intersection computation above (for example, we should use - // cross-product/dot-product of vectors instead of line equation to - // judge the relationships between the points and line segments) - - if (intersection.size() > 8) { - LOG(ERROR) << "Intersection size = " << intersection.size(); - LOG(ERROR) << "Rect 1:"; - for (int i = 0; i < 4; i++) { - LOG(ERROR) << " (" << pts1[i].x << " ," << pts1[i].y << "),"; - } - LOG(ERROR) << "Rect 2:"; - for (int i = 0; i < 4; i++) { - LOG(ERROR) << " (" << pts2[i].x << " ," << pts2[i].y << "),"; - } - LOG(ERROR) << "Intersections:"; - for (auto& p : intersection) { - LOG(ERROR) << " (" << p.x << " ," << p.y << "),"; + // Swap the starting point to position 0 + std::swap(q[0], q[t]); + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + std::sort( + q + 1, + q + num_in, + [](const Eigen::Vector2f& A, const Eigen::Vector2f& B) -> bool { + float temp = cross_2d(A, B); + if (fabs(temp) < 1e-6) { + return A.squaredNorm() < B.squaredNorm(); + } else { + return temp > 0; + } + }); + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) { + if (q[k].squaredNorm() > 1e-8) + break; + } + if (k == num_in) { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; + } + q[1] = q[k]; + int m = 2; // 2 elements in the stack + // Step 5: + // Finally we can start the scanning process. + // If we find a non-convex relationship between the 3 points, + // we pop the previous point from the stack until the stack only has two + // points, or the 3-point relationship is convex again + for (int i = k + 1; i < num_in; i++) { + while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) { + m--; } + q[m++] = q[i]; + } + + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) { + for (int i = 0; i < m; i++) + q[i] += s; } - cv::Mat(intersection).copyTo(intersectingRegion); + return m; +} - return ret; +double polygon_area(const Eigen::Vector2f* q, const int& m) { + if (m <= 2) + return 0; + double area = 0; + for (int i = 1; i < m - 1; i++) + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); + return area / 2.0; } /** * Returns the intersection area of two rotated rectangles. */ double rotated_rect_intersection( - const cv::RotatedRect& rect1, - const cv::RotatedRect& rect2) { - std::vector intersectPts, orderedPts; + const RotatedRect& rect1, + const RotatedRect& rect2) { + // There are up to 16 intersections returned from + // rotated_rect_intersection_pts + Eigen::Vector2f intersectPts[16], orderedPts[16]; + int num = 0; // number of intersections // Find points of intersection - // TODO: cvfix_rotatedRectangleIntersection is a replacement function for + // TODO: rotated_rect_intersection_pts is a replacement function for // cv::rotatedRectangleIntersection, which has a bug due to float underflow - // When OpenCV version is upgraded to be >= 4.0, - // we can remove this replacement function and use the following instead: - // auto ret = cv::rotatedRectangleIntersection(rect1, rect2, intersectPts); // For anyone interested, here're the PRs on OpenCV: // https://github.com/opencv/opencv/issues/12221 // https://github.com/opencv/opencv/pull/12222 - auto ret = cvfix_rotatedRectangleIntersection(rect1, rect2, intersectPts); - if (intersectPts.size() <= 2) { + // Note: it doesn't matter if #intersections is greater than 8 here + auto ret = rotated_rect_intersection_pts(rect1, rect2, intersectPts, num); + CAFFE_ENFORCE(num <= 16); + if (num <= 2) return 0.0; - } // If one rectangle is fully enclosed within another, return the area // of the smaller one early. - if (ret == cv::INTERSECT_FULL) { - return std::min(rect1.size.area(), rect2.size.area()); + if (ret == INTERSECT_FULL) { + return std::min( + rect1.size.x() * rect1.size.y(), rect2.size.x() * rect2.size.y()); } // Convex Hull to order the intersection points in clockwise or // counter-clockwise order and find the countour area. - cv::convexHull(intersectPts, orderedPts); - return cv::contourArea(orderedPts); + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); } } // namespace @@ -507,7 +529,7 @@ std::vector nms_cpu_rotated( auto heights = proposals.col(3); EArrX areas = widths * heights; - std::vector rotated_rects(proposals.rows()); + std::vector rotated_rects(proposals.rows()); for (int i = 0; i < proposals.rows(); ++i) { rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i)); } @@ -568,7 +590,7 @@ std::vector soft_nms_cpu_rotated( auto heights = proposals.col(3); EArrX areas = widths * heights; - std::vector rotated_rects(proposals.rows()); + std::vector rotated_rects(proposals.rows()); for (int i = 0; i < proposals.rows(); ++i) { rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i)); } @@ -627,7 +649,6 @@ std::vector soft_nms_cpu_rotated( return keep; } -#endif // CV_MAJOR_VERSION >= 3 template std::vector nms_cpu( @@ -636,7 +657,6 @@ std::vector nms_cpu( const std::vector& sorted_indices, float thresh, int topN = -1) { -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5); if (proposals.cols() == 4) { // Upright boxes @@ -645,9 +665,6 @@ std::vector nms_cpu( // Rotated boxes with angle info return nms_cpu_rotated(proposals, scores, sorted_indices, thresh, topN); } -#else - return nms_cpu_upright(proposals, scores, sorted_indices, thresh, topN); -#endif // CV_MAJOR_VERSION >= 3 } // Greedy non-maximum suppression for proposed bounding boxes @@ -686,7 +703,6 @@ std::vector soft_nms_cpu( float score_thresh = 0.001, unsigned int method = 1, int topN = -1) { -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5); if (proposals.cols() == 4) { // Upright boxes @@ -713,18 +729,6 @@ std::vector soft_nms_cpu( method, topN); } -#else - return soft_nms_cpu_upright( - out_scores, - proposals, - scores, - indices, - sigma, - overlap_thresh, - score_thresh, - method, - topN); -#endif // CV_MAJOR_VERSION >= 3 } template diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc index 8999fda..372acca 100644 --- a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc +++ b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc @@ -379,13 +379,9 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) { return; const int box_dim = 5; // Same boxes in TestNMS with (x_ctr, y_ctr, w, h, angle) format - std::vector boxes = { - 30, 35, 41, 51, 0, - 29.5, 36, 38, 49, 0, - 24, 29.5, 33, 42, 0, - 125, 120, 51, 41, 0, - 127, 124.5, 57, 30, 0 - }; + std::vector boxes = {30, 35, 41, 51, 0, 29.5, 36, 38, 49, + 0, 24, 29.5, 33, 42, 0, 125, 120, 51, + 41, 0, 127, 124.5, 57, 30, 0}; std::vector scores = {0.5f, 0.7f, 0.6f, 0.9f, 0.8f}; @@ -466,7 +462,6 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) { cuda_context.FinishDeviceComputation(); } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) TEST(UtilsNMSTest, TestPerfRotatedNMS) { if (!HasCudaGPU()) return; @@ -678,6 +673,5 @@ TEST(UtilsNMSTest, GPUEqualsCPURotatedCorrectnessTest) { } } } -#endif // CV_MAJOR_VERSION >= 3 } // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op_util_nms_test.cc b/caffe2/operators/generate_proposals_op_util_nms_test.cc index b7da35b..8e8b5f1 100644 --- a/caffe2/operators/generate_proposals_op_util_nms_test.cc +++ b/caffe2/operators/generate_proposals_op_util_nms_test.cc @@ -212,7 +212,6 @@ TEST(UtilsNMSTest, TestSoftNMS) { } } -#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3) TEST(UtilsNMSTest, TestNMSRotatedAngle0) { // Same inputs as TestNMS, but in RRPN format with angle 0 for testing // nms_cpu_rotated @@ -389,6 +388,42 @@ TEST(UtilsNMSTest, TestSoftNMSRotatedAngle0) { TEST(UtilsNMSTest, RotatedBBoxOverlaps) { { + // One box is fully within another box, the angle is irrelavant + int M = 2, N = 3; + Eigen::ArrayXXf boxes(M, 5); + for (int i = 0; i < M; i++) { + boxes.row(i) << 0, 0, 5, 6, (360.0 / M - 180.0); + } + + Eigen::ArrayXXf query_boxes(N, 5); + for (int i = 0; i < N; i++) { + query_boxes.row(i) << 0, 0, 3, 3, (360.0 / M - 180.0); + } + + Eigen::ArrayXXf expected(M, N); + // 0.3 == (3 * 3) / (5 * 6) + expected.fill(0.3); + + auto actual = utils::bbox_overlaps_rotated(boxes, query_boxes); + EXPECT_TRUE(((expected - actual).abs() < 1e-6).all()); + } + + { + // Angle 0 + Eigen::ArrayXXf boxes(1, 5); + boxes << 39.500000, 50.451096, 80.000000, 18.097809, -0.000000; + + Eigen::ArrayXXf query_boxes(1, 5); + query_boxes << 39.120628, 41.014862, 79.241257, 36.427757, -0.000000; + + Eigen::ArrayXXf expected(1, 1); + expected << 0.48346716237; + + auto actual = utils::bbox_overlaps_rotated(boxes, query_boxes); + EXPECT_TRUE(((expected - actual).abs() < 1e-6).all()); + } + + { // Simple case with angle 0 (upright boxes) Eigen::ArrayXXf boxes(2, 5); boxes << 10.5, 15.5, 21, 31, 0, 14.0, 17, 4, 10, 0; @@ -436,6 +471,5 @@ TEST(UtilsNMSTest, RotatedBBoxOverlaps) { EXPECT_TRUE(((expected - actual).abs() < 1e-6).all()); } } -#endif // CV_MAJOR_VERSION >= 3 } // namespace caffe2