Optimize CPU GenerateProposals op by lazily generating anchors (3-5x faster) (#15103)
authorViswanath Sivakumar <viswanath@fb.com>
Wed, 12 Dec 2018 23:48:03 +0000 (15:48 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 23:53:52 +0000 (15:53 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15103

There are two main optimizations in this diff:
1. We generate all anchors for every single spatial grid first, and then apply
NMS to pick 2000 anchors according to RPN_PRE_NMS_TOP_N. By first sorting the
score and picking the 2000 top ones and then lazily generating only the
corresponding anchors is much faster.
2. Transposing bbox_deltas from (num_anchors * 4, H, W) to
(H, W, num_anchors * 4) was also quite slow - taking about 20ms in the RRPN
case when there are lots of anchors which it's negligible for RPN case (like
0.1 ms). Instead of transponsing, performing all operations in the
(num_anchors, H, W) format speeds things up.

For regular RPN scenario, this gives 5x speedup from 5.84ms to 1.18ms a case
with 35 anchors over a 600x600 image.

For rotated boxes with 245 anchors, the runtime down from 80ms to 27ms per
iter.

Reviewed By: newstzpz

Differential Revision: D13428688

fbshipit-source-id: 6006b332925e01a7c9433ded2ff5dc9e6d96f7d3

caffe2/operators/generate_proposals_op.cc
caffe2/operators/generate_proposals_op.h
caffe2/operators/generate_proposals_op_test.cc

index 8cee7bc..3f50661 100644 (file)
@@ -104,12 +104,51 @@ ERMatXf ComputeAllAnchors(
   return all_anchors_vec;
 }
 
+ERArrXXf ComputeSortedAnchors(
+    const Eigen::Map<const ERArrXXf>& anchors,
+    int height,
+    int width,
+    float feat_stride,
+    const vector<int>& order) {
+  const auto box_dim = anchors.cols();
+  CAFFE_ENFORCE(box_dim == 4 || box_dim == 5);
+
+  // Order is flattened in (A, H, W) format. Unravel the indices.
+  const auto& order_AHW = utils::AsEArrXt(order);
+  const auto& order_AH = order_AHW / width;
+  const auto& order_W = order_AHW - order_AH * width;
+  const auto& order_A = order_AH / height;
+  const auto& order_H = order_AH - order_A * height;
+
+  // Generate shifts for each location in the H * W grid corresponding
+  // to the sorted scores in (A, H, W) order.
+  const auto& shift_x = order_W.cast<float>() * feat_stride;
+  const auto& shift_y = order_H.cast<float>() * feat_stride;
+  Eigen::MatrixXf shifts(order.size(), box_dim);
+  if (box_dim == 4) {
+    // Upright boxes in [x1, y1, x2, y2] format
+    shifts << shift_x, shift_y, shift_x, shift_y;
+  } else {
+    // Rotated boxes in [ctr_x, ctr_y, w, h, angle] format.
+    // Zero shift for width, height and angle.
+    const auto& shift_zero = EArrXf::Constant(order.size(), 0.0);
+    shifts << shift_x, shift_y, shift_zero, shift_zero, shift_zero;
+  }
+
+  // Apply shifts to the relevant anchors.
+  // Equivalent to python code `all_anchors = self._anchors[order_A] + shifts`
+  ERArrXXf anchors_sorted;
+  utils::GetSubArrayRows(anchors, order_A, &anchors_sorted);
+  const auto& all_anchors_sorted = anchors_sorted + shifts.array();
+  return all_anchors_sorted;
+}
+
 } // namespace utils
 
 template <>
 void GenerateProposalsOp<CPUContext>::ProposalsForOneImage(
     const Eigen::Array3f& im_info,
-    const Eigen::Map<const ERMatXf>& all_anchors,
+    const Eigen::Map<const ERArrXXf>& anchors,
     const utils::ConstTensorView<float>& bbox_deltas_tensor,
     const utils::ConstTensorView<float>& scores_tensor,
     ERArrXXf* out_boxes,
@@ -117,39 +156,23 @@ void GenerateProposalsOp<CPUContext>::ProposalsForOneImage(
   const auto& post_nms_topN = rpn_post_nms_topN_;
   const auto& nms_thresh = rpn_nms_thresh_;
   const auto& min_size = rpn_min_size_;
-  const int box_dim = static_cast<int>(all_anchors.cols());
+  const int box_dim = static_cast<int>(anchors.cols());
   CAFFE_ENFORCE(box_dim == 4 || box_dim == 5);
 
-  // Transpose and reshape predicted bbox transformations to get them
-  // into the same order as the anchors:
-  //   - bbox deltas will be (box_dim * A, H, W) format from conv output
-  //   - transpose to (H, W, box_dim * A)
-  //   - reshape to (H * W * A, box_dim) where rows are ordered by (H, W, A)
-  //     in slowest to fastest order to match the enumerated anchors
   CAFFE_ENFORCE_EQ(bbox_deltas_tensor.ndim(), 3);
   CAFFE_ENFORCE_EQ(bbox_deltas_tensor.dim(0) % box_dim, 0);
   auto A = bbox_deltas_tensor.dim(0) / box_dim;
   auto H = bbox_deltas_tensor.dim(1);
   auto W = bbox_deltas_tensor.dim(2);
-  // equivalent to python code
-  //  bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape((-1, box_dim))
-  ERArrXXf bbox_deltas(H * W * A, box_dim);
-  Eigen::Map<ERMatXf>(bbox_deltas.data(), H * W, box_dim * A) =
-      Eigen::Map<const ERMatXf>(bbox_deltas_tensor.data(), A * box_dim, H * W)
-          .transpose();
-  CAFFE_ENFORCE_EQ(bbox_deltas.rows(), all_anchors.rows());
-
-  // - scores are (A, H, W) format from conv output
-  // - transpose to (H, W, A)
-  // - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
-  //   to match the order of anchors and bbox_deltas
+  auto K = H * W;
+  CAFFE_ENFORCE_EQ(A, anchors.rows());
+
+  // scores are (A, H, W) format from conv output.
+  // Maintain the same order without transposing (which is slow)
+  // and compute anchors accordingly.
   CAFFE_ENFORCE_EQ(scores_tensor.ndim(), 3);
   CAFFE_ENFORCE_EQ(scores_tensor.dims(), (vector<int>{A, H, W}));
-  // equivalent to python code
-  // scores = scores.transpose((1, 2, 0)).reshape((-1, 1))
-  EArrXf scores(scores_tensor.size());
-  Eigen::Map<ERMatXf>(scores.data(), H * W, A) =
-      Eigen::Map<const ERMatXf>(scores_tensor.data(), A, H * W).transpose();
+  Eigen::Map<const EArrXf> scores(scores_tensor.data(), scores_tensor.size());
 
   std::vector<int> order(scores.size());
   std::iota(order.begin(), order.end(), 0);
@@ -170,15 +193,34 @@ void GenerateProposalsOp<CPUContext>::ProposalsForOneImage(
     order.resize(rpn_pre_nms_topN_);
   }
 
-  ERArrXXf bbox_deltas_sorted;
-  ERArrXXf all_anchors_sorted;
   EArrXf scores_sorted;
-  utils::GetSubArrayRows(
-      bbox_deltas, utils::AsEArrXt(order), &bbox_deltas_sorted);
-  utils::GetSubArrayRows(
-      all_anchors.array(), utils::AsEArrXt(order), &all_anchors_sorted);
   utils::GetSubArray(scores, utils::AsEArrXt(order), &scores_sorted);
 
+  // bbox_deltas are (A * box_dim, H, W) format from conv output.
+  // Order them based on scores maintaining the same format without
+  // expensive transpose.
+  // Note that order corresponds to (A, H * W) in row-major whereas
+  // bbox_deltas are in (A, box_dim, H * W) in row-major. Hence, we
+  // obtain a sub-view of bbox_deltas for each dim (4 for RPN, 5 for RRPN)
+  // in (A, H * W) with an outer stride of box_dim * H * W. Then we apply
+  // the ordering and filtering for each dim iteratively.
+  ERArrXXf bbox_deltas_sorted(order.size(), box_dim);
+  EArrXf bbox_deltas_per_dim(A * K);
+  EigenOuterStride stride(box_dim * K);
+  for (int j = 0; j < box_dim; ++j) {
+    Eigen::Map<ERMatXf>(bbox_deltas_per_dim.data(), A, K) =
+        Eigen::Map<const ERMatXf, 0, EigenOuterStride>(
+            bbox_deltas_tensor.data() + j * K, A, K, stride);
+    for (int i = 0; i < order.size(); ++i) {
+      bbox_deltas_sorted(i, j) = bbox_deltas_per_dim[order[i]];
+    }
+  }
+
+  // Compute anchors specific to the ordered and pre-filtered indices
+  // in (A, H, W) format.
+  const auto& all_anchors_sorted =
+      utils::ComputeSortedAnchors(anchors, H, W, feat_stride_, order);
+
   // Transform anchors into proposals via bbox transformations
   static const std::vector<float> bbox_weights{1.0, 1.0, 1.0, 1.0};
   auto proposals = utils::bbox_transform(
@@ -220,7 +262,7 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
   const auto& scores = Input(0);
   const auto& bbox_deltas = Input(1);
   const auto& im_info_tensor = Input(2);
-  const auto& anchors = Input(3);
+  const auto& anchors_tensor = Input(3);
 
   CAFFE_ENFORCE_EQ(scores.dim(), 4, scores.dim());
   CAFFE_ENFORCE(scores.template IsType<float>(), scores.dtype().name());
@@ -228,8 +270,7 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
   const auto A = scores.size(1);
   const auto height = scores.size(2);
   const auto width = scores.size(3);
-  const auto K = height * width;
-  const auto box_dim = anchors.size(1);
+  const auto box_dim = anchors_tensor.size(1);
   CAFFE_ENFORCE(box_dim == 4 || box_dim == 5);
 
   // bbox_deltas: (num_images, A * box_dim, H, W)
@@ -243,22 +284,19 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
       im_info_tensor.template IsType<float>(), im_info_tensor.dtype().name());
 
   // anchors: (A, box_dim)
-  CAFFE_ENFORCE_EQ(anchors.sizes(), (vector<int64_t>{A, box_dim}));
-  CAFFE_ENFORCE(anchors.template IsType<float>(), anchors.dtype().name());
-
-  // Broadcast the anchors to all pixels
-  auto all_anchors_vec =
-      utils::ComputeAllAnchors(anchors, height, width, feat_stride_);
-  Eigen::Map<const ERMatXf> all_anchors(all_anchors_vec.data(), K * A, box_dim);
+  CAFFE_ENFORCE_EQ(anchors_tensor.sizes(), (vector<int64_t>{A, box_dim}));
+  CAFFE_ENFORCE(
+      anchors_tensor.template IsType<float>(), anchors_tensor.dtype().name());
 
   Eigen::Map<const ERArrXXf> im_info(
       im_info_tensor.data<float>(),
       im_info_tensor.size(0),
       im_info_tensor.size(1));
 
-  const int roi_col_count = box_dim + 1;
-  auto* out_rois = Output(0, {0, roi_col_count}, at::dtype<float>());
-  auto* out_rois_probs = Output(1, {0}, at::dtype<float>());
+  Eigen::Map<const ERArrXXf> anchors(
+      anchors_tensor.data<float>(),
+      anchors_tensor.size(0),
+      anchors_tensor.size(1));
 
   std::vector<ERArrXXf> im_boxes(num_images);
   std::vector<EArrXf> im_probs(num_images);
@@ -271,7 +309,7 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
     EArrXf& im_i_probs = im_probs[i];
     ProposalsForOneImage(
         cur_im_info,
-        all_anchors,
+        anchors,
         cur_bbox_deltas,
         cur_scores,
         &im_i_boxes,
@@ -282,8 +320,9 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
   for (int i = 0; i < num_images; i++) {
     roi_counts += im_boxes[i].rows();
   }
-  out_rois->Extend(roi_counts, 50);
-  out_rois_probs->Extend(roi_counts, 50);
+  const int roi_col_count = box_dim + 1;
+  auto* out_rois = Output(0, {roi_counts, roi_col_count}, at::dtype<float>());
+  auto* out_rois_probs = Output(1, {roi_counts}, at::dtype<float>());
   float* out_rois_ptr = out_rois->template mutable_data<float>();
   float* out_rois_probs_ptr = out_rois_probs->template mutable_data<float>();
   for (int i = 0; i < num_images; i++) {
index 1d6e28c..fa933e3 100644 (file)
@@ -51,6 +51,17 @@ CAFFE2_API ERMatXf ComputeAllAnchors(
     int width,
     float feat_stride);
 
+// Like ComputeAllAnchors, but instead of computing anchors for every single
+// spatial location, only computes anchors for the already sorted and filtered
+// positions after NMS is applied to avoid unnecessary computation.
+// `order` is a raveled array of sorted indices in (A, H, W) format.
+CAFFE2_API ERArrXXf ComputeSortedAnchors(
+    const Eigen::Map<const ERArrXXf>& anchors,
+    int height,
+    int width,
+    float feat_stride,
+    const vector<int>& order);
+
 } // namespace utils
 
 // C++ implementation of GenerateProposalsOp
@@ -101,7 +112,7 @@ class GenerateProposalsOp final : public Operator<Context> {
   // out_probs: n
   void ProposalsForOneImage(
       const Eigen::Array3f& im_info,
-      const Eigen::Map<const ERMatXf>& all_anchors,
+      const Eigen::Map<const ERArrXXf>& anchors,
       const utils::ConstTensorView<float>& bbox_deltas_tensor,
       const utils::ConstTensorView<float>& scores_tensor,
       ERArrXXf* out_boxes,
index bfe26d0..4d76075 100644 (file)
@@ -92,6 +92,56 @@ TEST(GenerateProposalsTest, TestComputeAllAnchors) {
   EXPECT_EQ((all_anchors_result - all_anchors_gt).norm(), 0);
 }
 
+TEST(GenerateProposalsTest, TestComputeSortedAnchors) {
+  ERMatXf anchors(3, 4);
+  anchors << -38, -16, 53, 31, -84, -40, 99, 55, -176, -88, 191, 103;
+
+  int height = 4;
+  int width = 3;
+  int A = anchors.rows();
+  float feat_stride = 16;
+  int total = height * width * A;
+
+  // Generate all anchors for ground truth
+  Tensor anchors_tensor(vector<int64_t>{anchors.rows(), anchors.cols()}, CPU);
+  Eigen::Map<ERMatXf>(
+      anchors_tensor.mutable_data<float>(), anchors.rows(), anchors.cols()) =
+      anchors;
+  auto all_anchors =
+      utils::ComputeAllAnchors(anchors_tensor, height, width, feat_stride);
+  Eigen::Map<const ERMatXf> all_anchors_result(
+      all_anchors.data(), height * width * A, 4);
+
+  Eigen::Map<const ERArrXXf> anchors_map(
+      anchors.data(), anchors.rows(), anchors.cols());
+
+  // Test with random subsets and ordering of indices
+  vector<int> indices(total);
+  std::iota(indices.begin(), indices.end(), 0);
+  std::random_device rd;
+  std::mt19937 gen(rd());
+  std::shuffle(indices.begin(), indices.end(), gen);
+  for (int count = 0; count <= total; ++count) {
+    vector<int> order(indices.begin(), indices.begin() + count);
+    auto result = utils::ComputeSortedAnchors(
+        anchors_map, height, width, feat_stride, order);
+
+    // Compare the result of ComputeSortedAnchors with first generating all
+    // anchors via ComputeAllAnchors and then applying ordering and filtering.
+    // Need to convert order from (A, H, W) to (H, W, A) format before for this.
+    const auto& order_AHW = utils::AsEArrXt(order);
+    const auto& order_AH = order_AHW / width;
+    const auto& order_W = order_AHW - order_AH * width;
+    const auto& order_A = order_AH / height;
+    const auto& order_H = order_AH - order_A * height;
+    const auto& order_HWA = (order_H * width + order_W) * A + order_A;
+
+    ERArrXXf gt;
+    utils::GetSubArrayRows(all_anchors_result.array(), order_HWA, &gt);
+    EXPECT_EQ((result.matrix() - gt.matrix()).norm(), 0);
+  }
+}
+
 namespace {
 
 template <class Derived>
@@ -156,6 +206,65 @@ TEST(GenerateProposalsTest, TestComputeAllAnchorsRotated) {
   EXPECT_EQ((all_anchors_result - all_anchors_gt).norm(), 0);
 }
 
+TEST(GenerateProposalsTest, TestComputeSortedAnchorsRotated) {
+  // Similar to TestComputeSortedAnchors but for rotated boxes with angle info.
+  ERMatXf anchors_xyxy(3, 4);
+  anchors_xyxy << -38, -16, 53, 31, -84, -40, 99, 55, -176, -88, 191, 103;
+
+  // Convert to RRPN format and add angles
+  ERMatXf anchors(3, 5);
+  anchors.block(0, 0, 3, 4) = boxes_xyxy_to_xywh(anchors_xyxy);
+  std::vector<float> angles{0.0, 45.0, -120.0};
+  for (int i = 0; i < anchors.rows(); ++i) {
+    anchors(i, 4) = angles[i % angles.size()];
+  }
+
+  int height = 4;
+  int width = 3;
+  int A = anchors.rows();
+  float feat_stride = 16;
+  int total = height * width * A;
+
+  // Generate all anchors for ground truth
+  Tensor anchors_tensor(vector<int64_t>{anchors.rows(), anchors.cols()}, CPU);
+  Eigen::Map<ERMatXf>(
+      anchors_tensor.mutable_data<float>(), anchors.rows(), anchors.cols()) =
+      anchors;
+  auto all_anchors =
+      utils::ComputeAllAnchors(anchors_tensor, height, width, feat_stride);
+  Eigen::Map<const ERMatXf> all_anchors_result(
+      all_anchors.data(), height * width * A, 5);
+
+  Eigen::Map<const ERArrXXf> anchors_map(
+      anchors.data(), anchors.rows(), anchors.cols());
+
+  // Test with random subsets and ordering of indices
+  vector<int> indices(total);
+  std::iota(indices.begin(), indices.end(), 0);
+  std::random_device rd;
+  std::mt19937 gen(rd());
+  std::shuffle(indices.begin(), indices.end(), gen);
+  for (int count = 0; count <= total; ++count) {
+    vector<int> order(indices.begin(), indices.begin() + count);
+    auto result = utils::ComputeSortedAnchors(
+        anchors_map, height, width, feat_stride, order);
+
+    // Compare the result of ComputeSortedAnchors with first generating all
+    // anchors via ComputeAllAnchors and then applying ordering and filtering.
+    // Need to convert order from (A, H, W) to (H, W, A) format before for this.
+    const auto& order_AHW = utils::AsEArrXt(order);
+    const auto& order_AH = order_AHW / width;
+    const auto& order_W = order_AHW - order_AH * width;
+    const auto& order_A = order_AH / height;
+    const auto& order_H = order_AH - order_A * height;
+    const auto& order_HWA = (order_H * width + order_W) * A + order_A;
+
+    ERArrXXf gt;
+    utils::GetSubArrayRows(all_anchors_result.array(), order_HWA, &gt);
+    EXPECT_EQ((result.matrix() - gt.matrix()).norm(), 0);
+  }
+}
+
 TEST(GenerateProposalsTest, TestEmpty) {
   Workspace ws;
   OperatorDef def;
@@ -610,11 +719,17 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) {
   EXPECT_NE(nullptr, op.get());
   EXPECT_TRUE(op->Run());
 
-  // Verify that the resulting angles are correct
   Blob* rois_blob = ws.GetBlob("rois");
   EXPECT_NE(nullptr, rois_blob);
   auto& rois = rois_blob->Get<TensorCPU>();
-  EXPECT_GT(rois.size(0), 0);
+  EXPECT_EQ(rois.sizes(), (vector<int64_t>{13, 6}));
+
+  Blob* rois_probs_blob = ws.GetBlob("rois_probs");
+  EXPECT_NE(nullptr, rois_probs_blob);
+  auto& rois_probs = rois_probs_blob->Get<TensorCPU>();
+  EXPECT_EQ(rois_probs.sizes(), (vector<int64_t>{13}));
+
+  // Verify that the resulting angles are correct
   auto rois_data =
       Eigen::Map<const ERMatXf>(rois.data<float>(), rois.size(0), rois.size(1));
   for (int i = 0; i < rois.size(0); ++i) {