register BoxWithNMSLimit with C10
authorYanghan Wang <yanghan@instagram.com>
Fri, 29 Mar 2019 20:31:45 +0000 (13:31 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 20:41:40 +0000 (13:41 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17956

Reviewed By: houseroad

Differential Revision: D14417300

fbshipit-source-id: eb5e2ba84513b3b7bfa509dc442424b13fe9148f

caffe2/operators/box_with_nms_limit_op.cc
caffe2/operators/box_with_nms_limit_op.h
caffe2/python/operator_test/torch_integration_test.py

index 18646b4..915bf43 100644 (file)
@@ -295,3 +295,28 @@ SHOULD_NOT_DO_GRADIENT(BoxWithNMSLimit);
 
 } // namespace
 } // namespace caffe2
+
+C10_REGISTER_CAFFE2_OPERATOR_CPU(
+    BoxWithNMSLimit,
+    (std::vector<c10::Argument>{
+        c10::Argument("scores"),
+        c10::Argument("boxes"),
+        c10::Argument("batch_splits"),
+        c10::Argument("score_thresh", FloatType::get()),
+        c10::Argument("nms", FloatType::get()),
+        c10::Argument("detections_per_im", IntType::get()),
+        c10::Argument("soft_nms_enabled", BoolType::get()),
+        c10::Argument("soft_nms_method", StringType::get()),
+        c10::Argument("soft_nms_sigma", FloatType::get()),
+        c10::Argument("soft_nms_min_score_thres", FloatType::get()),
+        c10::Argument("rotated", BoolType::get()),
+    }),
+    (std::vector<c10::Argument>{
+        c10::Argument("scores"),
+        c10::Argument("boxes"),
+        c10::Argument("classes"),
+        c10::Argument("batch_splits"),
+        // c10::Argument("keeps"),
+        // c10::Argument("keeps_size"),
+    }),
+    caffe2::BoxWithNMSLimitOp<caffe2::CPUContext>);
index 722fe2c..93cbbed 100644 (file)
@@ -6,6 +6,8 @@
 #include "caffe2/core/context.h"
 #include "caffe2/core/operator.h"
 
+C10_DECLARE_CAFFE2_OPERATOR(BoxWithNMSLimit)
+
 namespace caffe2 {
 
 // C++ implementation of function insert_box_results_with_nms_and_limit()
index c9aa64d..d8ce5b6 100644 (file)
@@ -1,7 +1,4 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import absolute_import, division, print_function, unicode_literals
 
 from caffe2.python import core, workspace
 import torch
@@ -44,16 +41,33 @@ def generate_rois_rotated(roi_counts, im_dims):
     # [batch_id, ctr_x, ctr_y, w, h, angle]
     rotated_rois = np.empty((rois.shape[0], 6)).astype(np.float32)
     rotated_rois[:, 0] = rois[:, 0]  # batch_id
-    rotated_rois[:, 1] = (rois[:, 1] + rois[:, 3]) / 2.  # ctr_x = (x1 + x2) / 2
-    rotated_rois[:, 2] = (rois[:, 2] + rois[:, 4]) / 2.  # ctr_y = (y1 + y2) / 2
+    rotated_rois[:, 1] = (rois[:, 1] + rois[:, 3]) / 2.0  # ctr_x = (x1 + x2) / 2
+    rotated_rois[:, 2] = (rois[:, 2] + rois[:, 4]) / 2.0  # ctr_y = (y1 + y2) / 2
     rotated_rois[:, 3] = rois[:, 3] - rois[:, 1] + 1.0  # w = x2 - x1 + 1
     rotated_rois[:, 4] = rois[:, 4] - rois[:, 2] + 1.0  # h = y2 - y1 + 1
     rotated_rois[:, 5] = np.random.uniform(-90.0, 90.0)  # angle in degrees
     return rotated_rois
 
 
-class TorchIntegration(hu.HypothesisTestCase):
+def create_bbox_transform_inputs(roi_counts, num_classes, rotated):
+    batch_size = len(roi_counts)
+    total_rois = sum(roi_counts)
+    im_dims = np.random.randint(100, 600, batch_size)
+    rois = (
+        generate_rois_rotated(roi_counts, im_dims)
+        if rotated
+        else generate_rois(roi_counts, im_dims)
+    )
+    box_dim = 5 if rotated else 4
+    deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32)
+    im_info = np.zeros((batch_size, 3)).astype(np.float32)
+    im_info[:, 0] = im_dims
+    im_info[:, 1] = im_dims
+    im_info[:, 2] = 1.0
+    return rois, deltas, im_info
+
 
+class TorchIntegration(hu.HypothesisTestCase):
     @given(
         roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10),
         num_classes=st.integers(1, 10),
@@ -75,20 +89,9 @@ class TorchIntegration(hu.HypothesisTestCase):
         """
         Test with rois for multiple images in a batch
         """
-        batch_size = len(roi_counts)
-        total_rois = sum(roi_counts)
-        im_dims = np.random.randint(100, 600, batch_size)
-        rois = (
-            generate_rois_rotated(roi_counts, im_dims)
-            if rotated
-            else generate_rois(roi_counts, im_dims)
+        rois, deltas, im_info = create_bbox_transform_inputs(
+            roi_counts, num_classes, rotated
         )
-        box_dim = 5 if rotated else 4
-        deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32)
-        im_info = np.zeros((batch_size, 3)).astype(np.float32)
-        im_info[:, 0] = im_dims
-        im_info[:, 1] = im_dims
-        im_info[:, 2] = 1.0
 
         def bbox_transform_ref():
             ref_op = core.CreateOperator(
@@ -108,13 +111,101 @@ class TorchIntegration(hu.HypothesisTestCase):
 
         box_out = torch.tensor(bbox_transform_ref())
         a, b = torch.ops._caffe2.BBoxTransform(
-                torch.tensor(rois), torch.tensor(deltas),
+            torch.tensor(rois),
+            torch.tensor(deltas),
+            torch.tensor(im_info),
+            [1.0, 1.0, 1.0, 1.0],
+            False,
+            rotated,
+            angle_bound_on,
+            -90,
+            90,
+            clip_angle_thresh,
+        )
+
+        torch.testing.assert_allclose(box_out, a)
+
+    @given(
+        roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10),
+        num_classes=st.integers(1, 10),
+        rotated=st.booleans(),
+        angle_bound_on=st.booleans(),
+        clip_angle_thresh=st.sampled_from([-1.0, 1.0]),
+        **hu.gcs_cpu_only
+    )
+    def test_box_with_nms_limits(
+        self,
+        roi_counts,
+        num_classes,
+        rotated,
+        angle_bound_on,
+        clip_angle_thresh,
+        gc,
+        dc,
+    ):
+        rotated = False  # FIXME remove this after rotation is supported
+        rois, deltas, im_info = create_bbox_transform_inputs(
+            roi_counts, num_classes, rotated
+        )
+        pred_bbox, batch_splits = [
+            t.detach().numpy()
+            for t in torch.ops._caffe2.BBoxTransform(
+                torch.tensor(rois),
+                torch.tensor(deltas),
                 torch.tensor(im_info),
                 [1.0, 1.0, 1.0, 1.0],
-                False, rotated, angle_bound_on,
-                -90, 90, clip_angle_thresh)
+                False,
+                rotated,
+                angle_bound_on,
+                -90,
+                90,
+                clip_angle_thresh,
+            )
+        ]
+        class_prob = np.random.randn(sum(roi_counts), num_classes).astype(np.float32)
+        score_thresh = 0.5
+        nms_thresh = 0.5
+        topk_per_image = sum(roi_counts) / 2
 
-        torch.testing.assert_allclose(box_out, a)
+        def box_with_nms_limit_ref():
+            input_blobs = ["class_prob", "pred_bbox", "batch_splits"]
+            output_blobs = ["score_nms", "bbox_nms", "class_nms", "batch_splits_nms"]
+            ref_op = core.CreateOperator(
+                "BoxWithNMSLimit",
+                input_blobs,
+                output_blobs,
+                score_thresh=float(score_thresh),
+                nms=float(nms_thresh),
+                detections_per_im=int(topk_per_image),
+                soft_nms_enabled=False,
+                soft_nms_method="linear",
+                soft_nms_sigma=0.5,
+                soft_nms_min_score_thres=0.001,
+                rotated=rotated,
+            )
+            workspace.FeedBlob("class_prob", class_prob)
+            workspace.FeedBlob("pred_bbox", pred_bbox)
+            workspace.FeedBlob("batch_splits", batch_splits)
+            workspace.RunOperatorOnce(ref_op)
+            return (workspace.FetchBlob(b) for b in output_blobs)
+
+        output_refs = box_with_nms_limit_ref()
+        outputs = torch.ops._caffe2.BoxWithNMSLimit(
+            torch.tensor(class_prob),
+            torch.tensor(pred_bbox),
+            torch.tensor(batch_splits),
+            score_thresh=float(score_thresh),
+            nms=float(nms_thresh),
+            detections_per_im=int(topk_per_image),
+            soft_nms_enabled=False,
+            soft_nms_method="linear",
+            soft_nms_sigma=0.5,
+            soft_nms_min_score_thres=0.001,
+            rotated=rotated,
+        )
+
+        for o, o_ref in zip(outputs, output_refs):
+            torch.testing.assert_allclose(o, o_ref)
 
     @given(
         A=st.integers(min_value=4, max_value=4),
@@ -124,8 +215,11 @@ class TorchIntegration(hu.HypothesisTestCase):
     )
     def test_generate_proposals(self, A, H, W, img_count):
         scores = np.ones((img_count, A, H, W)).astype(np.float32)
-        bbox_deltas = np.linspace(0, 10, num=img_count*4*A*H*W).reshape(
-                (img_count, 4*A, H, W)).astype(np.float32)
+        bbox_deltas = (
+            np.linspace(0, 10, num=img_count * 4 * A * H * W)
+            .reshape((img_count, 4 * A, H, W))
+            .astype(np.float32)
+        )
         im_info = np.ones((img_count, 3)).astype(np.float32) / 10
         anchors = np.ones((A, 4)).astype(np.float32)
 
@@ -147,9 +241,20 @@ class TorchIntegration(hu.HypothesisTestCase):
         rois = torch.tensor(rois)
         rois_probs = torch.tensor(rois_probs)
         a, b = torch.ops._caffe2.GenerateProposals(
-                torch.tensor(scores), torch.tensor(bbox_deltas),
-                torch.tensor(im_info), torch.tensor(anchors),
-                2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0)
+            torch.tensor(scores),
+            torch.tensor(bbox_deltas),
+            torch.tensor(im_info),
+            torch.tensor(anchors),
+            2.0,
+            6000,
+            300,
+            0.7,
+            16,
+            True,
+            -90,
+            90,
+            1.0,
+        )
         torch.testing.assert_allclose(rois, a)
         torch.testing.assert_allclose(rois_probs, b)
 
@@ -241,11 +346,14 @@ class TorchIntegration(hu.HypothesisTestCase):
         H=st.integers(min_value=10, max_value=10),
         W=st.integers(min_value=8, max_value=8),
         img_count=st.integers(min_value=3, max_value=3),
-        )
+    )
     def test_generate_proposals_cuda(self, A, H, W, img_count):
         scores = np.ones((img_count, A, H, W)).astype(np.float32)
-        bbox_deltas = np.linspace(0, 10, num=img_count*4*A*H*W).reshape(
-                (img_count, 4*A, H, W)).astype(np.float32)
+        bbox_deltas = (
+            np.linspace(0, 10, num=img_count * 4 * A * H * W)
+            .reshape((img_count, 4 * A, H, W))
+            .astype(np.float32)
+        )
         im_info = np.ones((img_count, 3)).astype(np.float32) / 10
         anchors = np.ones((A, 4)).astype(np.float32)
 
@@ -267,9 +375,20 @@ class TorchIntegration(hu.HypothesisTestCase):
         rois = torch.tensor(rois)
         rois_probs = torch.tensor(rois_probs)
         a, b = torch.ops._caffe2.GenerateProposals(
-                torch.tensor(scores).cuda(), torch.tensor(bbox_deltas).cuda(),
-                torch.tensor(im_info).cuda(), torch.tensor(anchors).cuda(),
-                2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0)
+            torch.tensor(scores).cuda(),
+            torch.tensor(bbox_deltas).cuda(),
+            torch.tensor(im_info).cuda(),
+            torch.tensor(anchors).cuda(),
+            2.0,
+            6000,
+            300,
+            0.7,
+            16,
+            True,
+            -90,
+            90,
+            1.0,
+        )
         torch.testing.assert_allclose(rois, a.cpu())
         torch.testing.assert_allclose(rois_probs, b.cpu())
 
@@ -281,13 +400,15 @@ class TorchIntegration(hu.HypothesisTestCase):
     )
     def _test_roi_align(self, N, C, H, W, device):
         def rand_roi():
-            return np.array([
-                float(int(N * np.random.rand())),
-                0.5 * np.random.rand() * W,
-                0.5 * np.random.rand() * H,
-                (0.5 + 0.5 * np.random.rand()) * W,
-                (0.5 + 0.5 * np.random.rand()) * H,
-            ]).astype(np.float32)
+            return np.array(
+                [
+                    float(int(N * np.random.rand())),
+                    0.5 * np.random.rand() * W,
+                    0.5 * np.random.rand() * H,
+                    (0.5 + 0.5 * np.random.rand()) * W,
+                    (0.5 + 0.5 * np.random.rand()) * H,
+                ]
+            ).astype(np.float32)
 
         feature = np.random.randn(N, C, H, W).astype(np.float32)
         rois = np.array([rand_roi() for _ in range(10)])
@@ -300,7 +421,7 @@ class TorchIntegration(hu.HypothesisTestCase):
                 spatial_scale=1.0,
                 pooled_h=3,
                 pooled_w=3,
-                sampling_ratio=0
+                sampling_ratio=0,
             )
             workspace.FeedBlob("feature", _feature)
             workspace.FeedBlob("rois", _rois)
@@ -315,7 +436,7 @@ class TorchIntegration(hu.HypothesisTestCase):
             spatial_scale=1.0,
             pooled_h=3,
             pooled_w=3,
-            sampling_ratio=0
+            sampling_ratio=0,
         )
         torch.testing.assert_allclose(roi_feature_ref, roi_feature.cpu())