register RoIAlign with C10
authorYanghan Wang <yanghan@instagram.com>
Thu, 14 Mar 2019 18:49:31 +0000 (11:49 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 14 Mar 2019 18:55:29 +0000 (11:55 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17889

Reviewed By: smessmer

Differential Revision: D14411630

fbshipit-source-id: c3b7941d725ae2c78e8d79f52a7983db92b75807

caffe2/operators/roi_align_op.cc
caffe2/operators/roi_align_op.cu
caffe2/operators/roi_align_op.h
caffe2/python/operator_test/torch_integration_test.py

index 06e1cbd..31bfbb1 100644 (file)
@@ -373,3 +373,21 @@ Region of Interest (RoI) align operation as used in Mask R-CNN.
         "is a pooled feature map cooresponding to the r-th RoI.");
 
 } // namespace caffe2
+
+using RoIAlignOpFloatCPU = caffe2::RoIAlignOp<float, caffe2::CPUContext>;
+
+C10_REGISTER_CAFFE2_OPERATOR_CPU(
+    RoIAlign,
+    (std::vector<c10::Argument>{
+        c10::Argument("features"),
+        c10::Argument("rois"),
+        c10::Argument("order", StringType::get()),
+        c10::Argument("spatial_scale", FloatType::get()),
+        c10::Argument("pooled_h", IntType::get()),
+        c10::Argument("pooled_w", IntType::get()),
+        c10::Argument("sampling_ratio", IntType::get()),
+    }),
+    (std::vector<c10::Argument>{
+        c10::Argument("pooled_features"),
+    }),
+    RoIAlignOpFloatCPU);
index e0aa92d..1a2661a 100644 (file)
@@ -184,3 +184,7 @@ bool RoIAlignOp<float, CUDAContext>::RunOnDevice() {
 
 REGISTER_CUDA_OPERATOR(RoIAlign, RoIAlignOp<float, CUDAContext>);
 } // namespace caffe2
+
+using RoIAlignOpFloatCUDA = caffe2::RoIAlignOp<float, caffe2::CUDAContext>;
+
+C10_REGISTER_CAFFE2_OPERATOR_CUDA(RoIAlign, RoIAlignOpFloatCUDA);
index f7e793b..f5b1730 100644 (file)
@@ -7,6 +7,8 @@
 #include "caffe2/core/logging.h"
 #include "caffe2/core/operator.h"
 
+C10_DECLARE_CAFFE2_OPERATOR(RoIAlign)
+
 namespace caffe2 {
 
 template <typename T, class Context>
index f2236ec..2d9aeaf 100644 (file)
@@ -192,3 +192,56 @@ class TorchIntegration(hu.HypothesisTestCase):
                 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0)
         torch.testing.assert_allclose(rois, a.cpu())
         torch.testing.assert_allclose(rois_probs, b.cpu())
+
+    @given(
+        N=st.integers(min_value=1, max_value=2),
+        C=st.integers(min_value=4, max_value=4),
+        H=st.integers(min_value=10, max_value=10),
+        W=st.integers(min_value=8, max_value=8),
+    )
+    def _test_roi_align(self, N, C, H, W, device):
+        def rand_roi():
+            return np.array([
+                float(int(N * np.random.rand())),
+                0.5 * np.random.rand() * W,
+                0.5 * np.random.rand() * H,
+                (0.5 + 0.5 * np.random.rand()) * W,
+                (0.5 + 0.5 * np.random.rand()) * H,
+            ]).astype(np.float32)
+
+        feature = np.random.randn(N, C, H, W).astype(np.float32)
+        rois = np.array([rand_roi() for _ in range(10)])
+
+        def roi_align_ref(_feature, _rois):
+            ref_op = core.CreateOperator(
+                "RoIAlign",
+                ["feature", "rois"],
+                ["roi_feature"],
+                spatial_scale=1.0,
+                pooled_h=3,
+                pooled_w=3,
+                sampling_ratio=0
+            )
+            workspace.FeedBlob("feature", _feature)
+            workspace.FeedBlob("rois", _rois)
+            workspace.RunOperatorOnce(ref_op)
+            return workspace.FetchBlob("roi_feature")
+
+        roi_feature_ref = roi_align_ref(feature, rois)
+        roi_feature = torch.ops._caffe2.RoIAlign(
+            torch.Tensor(feature).to(device),
+            torch.Tensor(rois).to(device),
+            order="NCHW",
+            spatial_scale=1.0,
+            pooled_h=3,
+            pooled_w=3,
+            sampling_ratio=0
+        )
+        torch.testing.assert_allclose(roi_feature_ref, roi_feature.cpu())
+
+    def test_roi_align_cpu(self):
+        self._test_roi_align(device="cpu")
+
+    @unittest.skipIf(not workspace.has_cuda_support, "No cuda support")
+    def test_roi_align_cuda(self):
+        self._test_roi_align(device="cuda")