Expose GenerateProposals to PyTorch
authorSebastian Messmer <messmer@fb.com>
Mon, 11 Feb 2019 22:03:45 +0000 (14:03 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 11 Feb 2019 22:15:47 +0000 (14:15 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16880

Reviewed By: bwasti

Differential Revision: D13998092

fbshipit-source-id: 23ab886ba137377312557fa718f262f4c8149cc7

caffe2/operators/generate_proposals_op.cc
caffe2/operators/generate_proposals_op.h
caffe2/python/operator_test/torch_integration_test.py

index e391f1e..75253b9 100644 (file)
@@ -412,3 +412,28 @@ SHOULD_NOT_DO_GRADIENT(GenerateProposals);
 SHOULD_NOT_DO_GRADIENT(GenerateProposalsCPP);
 
 } // namespace caffe2
+
+C10_REGISTER_CAFFE2_OPERATOR_CPU(
+    GenerateProposals,
+    (std::vector<c10::Argument>{
+        c10::Argument("scores"),
+        c10::Argument("bbox_deltas"),
+        c10::Argument("im_info"),
+        c10::Argument("anchors"),
+        c10::Argument("spatial_scale", FloatType::get()),
+        c10::Argument("pre_nms_topN", IntType::get()),
+        c10::Argument("post_nms_topN", IntType::get()),
+        c10::Argument("nms_thresh", FloatType::get()),
+        c10::Argument("min_size", FloatType::get()),
+        c10::Argument("correct_transform_coords", BoolType::get()),
+        c10::Argument("angle_bound_on", BoolType::get()),
+        c10::Argument("angle_bound_lo", IntType::get()),
+        c10::Argument("angle_bound_hi", IntType::get()),
+        c10::Argument("clip_angle_thresh", FloatType::get()),
+    }),
+    (std::vector<c10::Argument>{
+        c10::Argument("output_0"),
+        c10::Argument("output_1"),
+    }),
+    caffe2::GenerateProposalsOp<caffe2::CPUContext>
+);
index 0de3c50..235c9ab 100644 (file)
@@ -6,6 +6,8 @@
 #include "caffe2/utils/eigen_utils.h"
 #include "caffe2/utils/math.h"
 
+C10_DECLARE_CAFFE2_OPERATOR(GenerateProposalsOp);
+
 namespace caffe2 {
 
 namespace utils {
@@ -76,8 +78,9 @@ template <class Context>
 class GenerateProposalsOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
-  GenerateProposalsOp(const OperatorDef& operator_def, Workspace* ws)
-      : Operator<Context>(operator_def, ws),
+  template<class... Args>
+  explicit GenerateProposalsOp(Args&&... args)
+      : Operator<Context>(std::forward<Args>(args)...),
         spatial_scale_(
             this->template GetSingleArgument<float>("spatial_scale", 1.0 / 16)),
         feat_stride_(1.0 / spatial_scale_),
@@ -176,4 +179,4 @@ class GenerateProposalsOp final : public Operator<Context> {
 
 } // namespace caffe2
 
-#endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
\ No newline at end of file
+#endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
index 21f4bcd..06b2519 100644 (file)
@@ -111,3 +111,40 @@ class TorchIntegration(hu.HypothesisTestCase):
                 -90, 90, clip_angle_thresh)
 
         torch.testing.assert_allclose(box_out, a)
+
+    @given(
+        A=st.integers(min_value=4, max_value=4),
+        H=st.integers(min_value=10, max_value=10),
+        W=st.integers(min_value=8, max_value=8),
+        img_count=st.integers(min_value=3, max_value=3),
+        )
+    def test_generate_proposals(self, A, H, W, img_count):
+        scores = np.ones((img_count, A, H, W)).astype(np.float32)
+        bbox_deltas = np.linspace(0, 10, num=img_count*4*A*H*W).reshape(
+                (img_count, 4*A, H, W)).astype(np.float32)
+        im_info = np.ones((img_count, 3)).astype(np.float32) / 10
+        anchors = np.ones((A, 4)).astype(np.float32)
+
+        def generate_proposals_ref():
+            ref_op = core.CreateOperator(
+                "GenerateProposals",
+                ["scores", "bbox_deltas", "im_info", "anchors"],
+                ["rois", "rois_probs"],
+                spatial_scale=2.0,
+            )
+            workspace.FeedBlob("scores", scores)
+            workspace.FeedBlob("bbox_deltas", bbox_deltas)
+            workspace.FeedBlob("im_info", im_info)
+            workspace.FeedBlob("anchors", anchors)
+            workspace.RunOperatorOnce(ref_op)
+            return workspace.FetchBlob("rois"), workspace.FetchBlob("rois_probs")
+
+        rois, rois_probs = generate_proposals_ref()
+        rois = torch.tensor(rois)
+        rois_probs = torch.tensor(rois_probs)
+        a, b = torch.ops._caffe2.GenerateProposals(
+                torch.tensor(scores), torch.tensor(bbox_deltas),
+                torch.tensor(im_info), torch.tensor(anchors),
+                2.0, 6000, 300, 0.7, 16, False, True, -90, 90, 1.0)
+        torch.testing.assert_allclose(rois, a)
+        torch.testing.assert_allclose(rois_probs, b)