From f4e35d30ed32d9d948d002eb6bc9e235fb433fb7 Mon Sep 17 00:00:00 2001 From: Yanghan Wang Date: Fri, 29 Mar 2019 13:31:45 -0700 Subject: [PATCH] register BoxWithNMSLimit with C10 Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17956 Reviewed By: houseroad Differential Revision: D14417300 fbshipit-source-id: eb5e2ba84513b3b7bfa509dc442424b13fe9148f --- caffe2/operators/box_with_nms_limit_op.cc | 25 +++ caffe2/operators/box_with_nms_limit_op.h | 2 + .../python/operator_test/torch_integration_test.py | 209 ++++++++++++++++----- 3 files changed, 192 insertions(+), 44 deletions(-) diff --git a/caffe2/operators/box_with_nms_limit_op.cc b/caffe2/operators/box_with_nms_limit_op.cc index 18646b4..915bf43 100644 --- a/caffe2/operators/box_with_nms_limit_op.cc +++ b/caffe2/operators/box_with_nms_limit_op.cc @@ -295,3 +295,28 @@ SHOULD_NOT_DO_GRADIENT(BoxWithNMSLimit); } // namespace } // namespace caffe2 + +C10_REGISTER_CAFFE2_OPERATOR_CPU( + BoxWithNMSLimit, + (std::vector{ + c10::Argument("scores"), + c10::Argument("boxes"), + c10::Argument("batch_splits"), + c10::Argument("score_thresh", FloatType::get()), + c10::Argument("nms", FloatType::get()), + c10::Argument("detections_per_im", IntType::get()), + c10::Argument("soft_nms_enabled", BoolType::get()), + c10::Argument("soft_nms_method", StringType::get()), + c10::Argument("soft_nms_sigma", FloatType::get()), + c10::Argument("soft_nms_min_score_thres", FloatType::get()), + c10::Argument("rotated", BoolType::get()), + }), + (std::vector{ + c10::Argument("scores"), + c10::Argument("boxes"), + c10::Argument("classes"), + c10::Argument("batch_splits"), + // c10::Argument("keeps"), + // c10::Argument("keeps_size"), + }), + caffe2::BoxWithNMSLimitOp); diff --git a/caffe2/operators/box_with_nms_limit_op.h b/caffe2/operators/box_with_nms_limit_op.h index 722fe2c..93cbbed 100644 --- a/caffe2/operators/box_with_nms_limit_op.h +++ b/caffe2/operators/box_with_nms_limit_op.h @@ -6,6 +6,8 @@ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" +C10_DECLARE_CAFFE2_OPERATOR(BoxWithNMSLimit) + namespace caffe2 { // C++ implementation of function insert_box_results_with_nms_and_limit() diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index c9aa64d..d8ce5b6 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -1,7 +1,4 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals +from __future__ import absolute_import, division, print_function, unicode_literals from caffe2.python import core, workspace import torch @@ -44,16 +41,33 @@ def generate_rois_rotated(roi_counts, im_dims): # [batch_id, ctr_x, ctr_y, w, h, angle] rotated_rois = np.empty((rois.shape[0], 6)).astype(np.float32) rotated_rois[:, 0] = rois[:, 0] # batch_id - rotated_rois[:, 1] = (rois[:, 1] + rois[:, 3]) / 2. # ctr_x = (x1 + x2) / 2 - rotated_rois[:, 2] = (rois[:, 2] + rois[:, 4]) / 2. # ctr_y = (y1 + y2) / 2 + rotated_rois[:, 1] = (rois[:, 1] + rois[:, 3]) / 2.0 # ctr_x = (x1 + x2) / 2 + rotated_rois[:, 2] = (rois[:, 2] + rois[:, 4]) / 2.0 # ctr_y = (y1 + y2) / 2 rotated_rois[:, 3] = rois[:, 3] - rois[:, 1] + 1.0 # w = x2 - x1 + 1 rotated_rois[:, 4] = rois[:, 4] - rois[:, 2] + 1.0 # h = y2 - y1 + 1 rotated_rois[:, 5] = np.random.uniform(-90.0, 90.0) # angle in degrees return rotated_rois -class TorchIntegration(hu.HypothesisTestCase): +def create_bbox_transform_inputs(roi_counts, num_classes, rotated): + batch_size = len(roi_counts) + total_rois = sum(roi_counts) + im_dims = np.random.randint(100, 600, batch_size) + rois = ( + generate_rois_rotated(roi_counts, im_dims) + if rotated + else generate_rois(roi_counts, im_dims) + ) + box_dim = 5 if rotated else 4 + deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32) + im_info = np.zeros((batch_size, 3)).astype(np.float32) + im_info[:, 0] = im_dims + im_info[:, 1] = im_dims + im_info[:, 2] = 1.0 + return rois, deltas, im_info + +class TorchIntegration(hu.HypothesisTestCase): @given( roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10), num_classes=st.integers(1, 10), @@ -75,20 +89,9 @@ class TorchIntegration(hu.HypothesisTestCase): """ Test with rois for multiple images in a batch """ - batch_size = len(roi_counts) - total_rois = sum(roi_counts) - im_dims = np.random.randint(100, 600, batch_size) - rois = ( - generate_rois_rotated(roi_counts, im_dims) - if rotated - else generate_rois(roi_counts, im_dims) + rois, deltas, im_info = create_bbox_transform_inputs( + roi_counts, num_classes, rotated ) - box_dim = 5 if rotated else 4 - deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32) - im_info = np.zeros((batch_size, 3)).astype(np.float32) - im_info[:, 0] = im_dims - im_info[:, 1] = im_dims - im_info[:, 2] = 1.0 def bbox_transform_ref(): ref_op = core.CreateOperator( @@ -108,13 +111,101 @@ class TorchIntegration(hu.HypothesisTestCase): box_out = torch.tensor(bbox_transform_ref()) a, b = torch.ops._caffe2.BBoxTransform( - torch.tensor(rois), torch.tensor(deltas), + torch.tensor(rois), + torch.tensor(deltas), + torch.tensor(im_info), + [1.0, 1.0, 1.0, 1.0], + False, + rotated, + angle_bound_on, + -90, + 90, + clip_angle_thresh, + ) + + torch.testing.assert_allclose(box_out, a) + + @given( + roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10), + num_classes=st.integers(1, 10), + rotated=st.booleans(), + angle_bound_on=st.booleans(), + clip_angle_thresh=st.sampled_from([-1.0, 1.0]), + **hu.gcs_cpu_only + ) + def test_box_with_nms_limits( + self, + roi_counts, + num_classes, + rotated, + angle_bound_on, + clip_angle_thresh, + gc, + dc, + ): + rotated = False # FIXME remove this after rotation is supported + rois, deltas, im_info = create_bbox_transform_inputs( + roi_counts, num_classes, rotated + ) + pred_bbox, batch_splits = [ + t.detach().numpy() + for t in torch.ops._caffe2.BBoxTransform( + torch.tensor(rois), + torch.tensor(deltas), torch.tensor(im_info), [1.0, 1.0, 1.0, 1.0], - False, rotated, angle_bound_on, - -90, 90, clip_angle_thresh) + False, + rotated, + angle_bound_on, + -90, + 90, + clip_angle_thresh, + ) + ] + class_prob = np.random.randn(sum(roi_counts), num_classes).astype(np.float32) + score_thresh = 0.5 + nms_thresh = 0.5 + topk_per_image = sum(roi_counts) / 2 - torch.testing.assert_allclose(box_out, a) + def box_with_nms_limit_ref(): + input_blobs = ["class_prob", "pred_bbox", "batch_splits"] + output_blobs = ["score_nms", "bbox_nms", "class_nms", "batch_splits_nms"] + ref_op = core.CreateOperator( + "BoxWithNMSLimit", + input_blobs, + output_blobs, + score_thresh=float(score_thresh), + nms=float(nms_thresh), + detections_per_im=int(topk_per_image), + soft_nms_enabled=False, + soft_nms_method="linear", + soft_nms_sigma=0.5, + soft_nms_min_score_thres=0.001, + rotated=rotated, + ) + workspace.FeedBlob("class_prob", class_prob) + workspace.FeedBlob("pred_bbox", pred_bbox) + workspace.FeedBlob("batch_splits", batch_splits) + workspace.RunOperatorOnce(ref_op) + return (workspace.FetchBlob(b) for b in output_blobs) + + output_refs = box_with_nms_limit_ref() + outputs = torch.ops._caffe2.BoxWithNMSLimit( + torch.tensor(class_prob), + torch.tensor(pred_bbox), + torch.tensor(batch_splits), + score_thresh=float(score_thresh), + nms=float(nms_thresh), + detections_per_im=int(topk_per_image), + soft_nms_enabled=False, + soft_nms_method="linear", + soft_nms_sigma=0.5, + soft_nms_min_score_thres=0.001, + rotated=rotated, + ) + + for o, o_ref in zip(outputs, output_refs): + torch.testing.assert_allclose(o, o_ref) @given( A=st.integers(min_value=4, max_value=4), @@ -124,8 +215,11 @@ class TorchIntegration(hu.HypothesisTestCase): ) def test_generate_proposals(self, A, H, W, img_count): scores = np.ones((img_count, A, H, W)).astype(np.float32) - bbox_deltas = np.linspace(0, 10, num=img_count*4*A*H*W).reshape( - (img_count, 4*A, H, W)).astype(np.float32) + bbox_deltas = ( + np.linspace(0, 10, num=img_count * 4 * A * H * W) + .reshape((img_count, 4 * A, H, W)) + .astype(np.float32) + ) im_info = np.ones((img_count, 3)).astype(np.float32) / 10 anchors = np.ones((A, 4)).astype(np.float32) @@ -147,9 +241,20 @@ class TorchIntegration(hu.HypothesisTestCase): rois = torch.tensor(rois) rois_probs = torch.tensor(rois_probs) a, b = torch.ops._caffe2.GenerateProposals( - torch.tensor(scores), torch.tensor(bbox_deltas), - torch.tensor(im_info), torch.tensor(anchors), - 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0) + torch.tensor(scores), + torch.tensor(bbox_deltas), + torch.tensor(im_info), + torch.tensor(anchors), + 2.0, + 6000, + 300, + 0.7, + 16, + True, + -90, + 90, + 1.0, + ) torch.testing.assert_allclose(rois, a) torch.testing.assert_allclose(rois_probs, b) @@ -241,11 +346,14 @@ class TorchIntegration(hu.HypothesisTestCase): H=st.integers(min_value=10, max_value=10), W=st.integers(min_value=8, max_value=8), img_count=st.integers(min_value=3, max_value=3), - ) + ) def test_generate_proposals_cuda(self, A, H, W, img_count): scores = np.ones((img_count, A, H, W)).astype(np.float32) - bbox_deltas = np.linspace(0, 10, num=img_count*4*A*H*W).reshape( - (img_count, 4*A, H, W)).astype(np.float32) + bbox_deltas = ( + np.linspace(0, 10, num=img_count * 4 * A * H * W) + .reshape((img_count, 4 * A, H, W)) + .astype(np.float32) + ) im_info = np.ones((img_count, 3)).astype(np.float32) / 10 anchors = np.ones((A, 4)).astype(np.float32) @@ -267,9 +375,20 @@ class TorchIntegration(hu.HypothesisTestCase): rois = torch.tensor(rois) rois_probs = torch.tensor(rois_probs) a, b = torch.ops._caffe2.GenerateProposals( - torch.tensor(scores).cuda(), torch.tensor(bbox_deltas).cuda(), - torch.tensor(im_info).cuda(), torch.tensor(anchors).cuda(), - 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0) + torch.tensor(scores).cuda(), + torch.tensor(bbox_deltas).cuda(), + torch.tensor(im_info).cuda(), + torch.tensor(anchors).cuda(), + 2.0, + 6000, + 300, + 0.7, + 16, + True, + -90, + 90, + 1.0, + ) torch.testing.assert_allclose(rois, a.cpu()) torch.testing.assert_allclose(rois_probs, b.cpu()) @@ -281,13 +400,15 @@ class TorchIntegration(hu.HypothesisTestCase): ) def _test_roi_align(self, N, C, H, W, device): def rand_roi(): - return np.array([ - float(int(N * np.random.rand())), - 0.5 * np.random.rand() * W, - 0.5 * np.random.rand() * H, - (0.5 + 0.5 * np.random.rand()) * W, - (0.5 + 0.5 * np.random.rand()) * H, - ]).astype(np.float32) + return np.array( + [ + float(int(N * np.random.rand())), + 0.5 * np.random.rand() * W, + 0.5 * np.random.rand() * H, + (0.5 + 0.5 * np.random.rand()) * W, + (0.5 + 0.5 * np.random.rand()) * H, + ] + ).astype(np.float32) feature = np.random.randn(N, C, H, W).astype(np.float32) rois = np.array([rand_roi() for _ in range(10)]) @@ -300,7 +421,7 @@ class TorchIntegration(hu.HypothesisTestCase): spatial_scale=1.0, pooled_h=3, pooled_w=3, - sampling_ratio=0 + sampling_ratio=0, ) workspace.FeedBlob("feature", _feature) workspace.FeedBlob("rois", _rois) @@ -315,7 +436,7 @@ class TorchIntegration(hu.HypothesisTestCase): spatial_scale=1.0, pooled_h=3, pooled_w=3, - sampling_ratio=0 + sampling_ratio=0, ) torch.testing.assert_allclose(roi_feature_ref, roi_feature.cpu()) -- 2.7.4