return all_anchors_vec;
}
+ERArrXXf ComputeSortedAnchors(
+ const Eigen::Map<const ERArrXXf>& anchors,
+ int height,
+ int width,
+ float feat_stride,
+ const vector<int>& order) {
+ const auto box_dim = anchors.cols();
+ CAFFE_ENFORCE(box_dim == 4 || box_dim == 5);
+
+ // Order is flattened in (A, H, W) format. Unravel the indices.
+ const auto& order_AHW = utils::AsEArrXt(order);
+ const auto& order_AH = order_AHW / width;
+ const auto& order_W = order_AHW - order_AH * width;
+ const auto& order_A = order_AH / height;
+ const auto& order_H = order_AH - order_A * height;
+
+ // Generate shifts for each location in the H * W grid corresponding
+ // to the sorted scores in (A, H, W) order.
+ const auto& shift_x = order_W.cast<float>() * feat_stride;
+ const auto& shift_y = order_H.cast<float>() * feat_stride;
+ Eigen::MatrixXf shifts(order.size(), box_dim);
+ if (box_dim == 4) {
+ // Upright boxes in [x1, y1, x2, y2] format
+ shifts << shift_x, shift_y, shift_x, shift_y;
+ } else {
+ // Rotated boxes in [ctr_x, ctr_y, w, h, angle] format.
+ // Zero shift for width, height and angle.
+ const auto& shift_zero = EArrXf::Constant(order.size(), 0.0);
+ shifts << shift_x, shift_y, shift_zero, shift_zero, shift_zero;
+ }
+
+ // Apply shifts to the relevant anchors.
+ // Equivalent to python code `all_anchors = self._anchors[order_A] + shifts`
+ ERArrXXf anchors_sorted;
+ utils::GetSubArrayRows(anchors, order_A, &anchors_sorted);
+ const auto& all_anchors_sorted = anchors_sorted + shifts.array();
+ return all_anchors_sorted;
+}
+
} // namespace utils
template <>
void GenerateProposalsOp<CPUContext>::ProposalsForOneImage(
const Eigen::Array3f& im_info,
- const Eigen::Map<const ERMatXf>& all_anchors,
+ const Eigen::Map<const ERArrXXf>& anchors,
const utils::ConstTensorView<float>& bbox_deltas_tensor,
const utils::ConstTensorView<float>& scores_tensor,
ERArrXXf* out_boxes,
const auto& post_nms_topN = rpn_post_nms_topN_;
const auto& nms_thresh = rpn_nms_thresh_;
const auto& min_size = rpn_min_size_;
- const int box_dim = static_cast<int>(all_anchors.cols());
+ const int box_dim = static_cast<int>(anchors.cols());
CAFFE_ENFORCE(box_dim == 4 || box_dim == 5);
- // Transpose and reshape predicted bbox transformations to get them
- // into the same order as the anchors:
- // - bbox deltas will be (box_dim * A, H, W) format from conv output
- // - transpose to (H, W, box_dim * A)
- // - reshape to (H * W * A, box_dim) where rows are ordered by (H, W, A)
- // in slowest to fastest order to match the enumerated anchors
CAFFE_ENFORCE_EQ(bbox_deltas_tensor.ndim(), 3);
CAFFE_ENFORCE_EQ(bbox_deltas_tensor.dim(0) % box_dim, 0);
auto A = bbox_deltas_tensor.dim(0) / box_dim;
auto H = bbox_deltas_tensor.dim(1);
auto W = bbox_deltas_tensor.dim(2);
- // equivalent to python code
- // bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape((-1, box_dim))
- ERArrXXf bbox_deltas(H * W * A, box_dim);
- Eigen::Map<ERMatXf>(bbox_deltas.data(), H * W, box_dim * A) =
- Eigen::Map<const ERMatXf>(bbox_deltas_tensor.data(), A * box_dim, H * W)
- .transpose();
- CAFFE_ENFORCE_EQ(bbox_deltas.rows(), all_anchors.rows());
-
- // - scores are (A, H, W) format from conv output
- // - transpose to (H, W, A)
- // - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
- // to match the order of anchors and bbox_deltas
+ auto K = H * W;
+ CAFFE_ENFORCE_EQ(A, anchors.rows());
+
+ // scores are (A, H, W) format from conv output.
+ // Maintain the same order without transposing (which is slow)
+ // and compute anchors accordingly.
CAFFE_ENFORCE_EQ(scores_tensor.ndim(), 3);
CAFFE_ENFORCE_EQ(scores_tensor.dims(), (vector<int>{A, H, W}));
- // equivalent to python code
- // scores = scores.transpose((1, 2, 0)).reshape((-1, 1))
- EArrXf scores(scores_tensor.size());
- Eigen::Map<ERMatXf>(scores.data(), H * W, A) =
- Eigen::Map<const ERMatXf>(scores_tensor.data(), A, H * W).transpose();
+ Eigen::Map<const EArrXf> scores(scores_tensor.data(), scores_tensor.size());
std::vector<int> order(scores.size());
std::iota(order.begin(), order.end(), 0);
order.resize(rpn_pre_nms_topN_);
}
- ERArrXXf bbox_deltas_sorted;
- ERArrXXf all_anchors_sorted;
EArrXf scores_sorted;
- utils::GetSubArrayRows(
- bbox_deltas, utils::AsEArrXt(order), &bbox_deltas_sorted);
- utils::GetSubArrayRows(
- all_anchors.array(), utils::AsEArrXt(order), &all_anchors_sorted);
utils::GetSubArray(scores, utils::AsEArrXt(order), &scores_sorted);
+ // bbox_deltas are (A * box_dim, H, W) format from conv output.
+ // Order them based on scores maintaining the same format without
+ // expensive transpose.
+ // Note that order corresponds to (A, H * W) in row-major whereas
+ // bbox_deltas are in (A, box_dim, H * W) in row-major. Hence, we
+ // obtain a sub-view of bbox_deltas for each dim (4 for RPN, 5 for RRPN)
+ // in (A, H * W) with an outer stride of box_dim * H * W. Then we apply
+ // the ordering and filtering for each dim iteratively.
+ ERArrXXf bbox_deltas_sorted(order.size(), box_dim);
+ EArrXf bbox_deltas_per_dim(A * K);
+ EigenOuterStride stride(box_dim * K);
+ for (int j = 0; j < box_dim; ++j) {
+ Eigen::Map<ERMatXf>(bbox_deltas_per_dim.data(), A, K) =
+ Eigen::Map<const ERMatXf, 0, EigenOuterStride>(
+ bbox_deltas_tensor.data() + j * K, A, K, stride);
+ for (int i = 0; i < order.size(); ++i) {
+ bbox_deltas_sorted(i, j) = bbox_deltas_per_dim[order[i]];
+ }
+ }
+
+ // Compute anchors specific to the ordered and pre-filtered indices
+ // in (A, H, W) format.
+ const auto& all_anchors_sorted =
+ utils::ComputeSortedAnchors(anchors, H, W, feat_stride_, order);
+
// Transform anchors into proposals via bbox transformations
static const std::vector<float> bbox_weights{1.0, 1.0, 1.0, 1.0};
auto proposals = utils::bbox_transform(
const auto& scores = Input(0);
const auto& bbox_deltas = Input(1);
const auto& im_info_tensor = Input(2);
- const auto& anchors = Input(3);
+ const auto& anchors_tensor = Input(3);
CAFFE_ENFORCE_EQ(scores.dim(), 4, scores.dim());
CAFFE_ENFORCE(scores.template IsType<float>(), scores.dtype().name());
const auto A = scores.size(1);
const auto height = scores.size(2);
const auto width = scores.size(3);
- const auto K = height * width;
- const auto box_dim = anchors.size(1);
+ const auto box_dim = anchors_tensor.size(1);
CAFFE_ENFORCE(box_dim == 4 || box_dim == 5);
// bbox_deltas: (num_images, A * box_dim, H, W)
im_info_tensor.template IsType<float>(), im_info_tensor.dtype().name());
// anchors: (A, box_dim)
- CAFFE_ENFORCE_EQ(anchors.sizes(), (vector<int64_t>{A, box_dim}));
- CAFFE_ENFORCE(anchors.template IsType<float>(), anchors.dtype().name());
-
- // Broadcast the anchors to all pixels
- auto all_anchors_vec =
- utils::ComputeAllAnchors(anchors, height, width, feat_stride_);
- Eigen::Map<const ERMatXf> all_anchors(all_anchors_vec.data(), K * A, box_dim);
+ CAFFE_ENFORCE_EQ(anchors_tensor.sizes(), (vector<int64_t>{A, box_dim}));
+ CAFFE_ENFORCE(
+ anchors_tensor.template IsType<float>(), anchors_tensor.dtype().name());
Eigen::Map<const ERArrXXf> im_info(
im_info_tensor.data<float>(),
im_info_tensor.size(0),
im_info_tensor.size(1));
- const int roi_col_count = box_dim + 1;
- auto* out_rois = Output(0, {0, roi_col_count}, at::dtype<float>());
- auto* out_rois_probs = Output(1, {0}, at::dtype<float>());
+ Eigen::Map<const ERArrXXf> anchors(
+ anchors_tensor.data<float>(),
+ anchors_tensor.size(0),
+ anchors_tensor.size(1));
std::vector<ERArrXXf> im_boxes(num_images);
std::vector<EArrXf> im_probs(num_images);
EArrXf& im_i_probs = im_probs[i];
ProposalsForOneImage(
cur_im_info,
- all_anchors,
+ anchors,
cur_bbox_deltas,
cur_scores,
&im_i_boxes,
for (int i = 0; i < num_images; i++) {
roi_counts += im_boxes[i].rows();
}
- out_rois->Extend(roi_counts, 50);
- out_rois_probs->Extend(roi_counts, 50);
+ const int roi_col_count = box_dim + 1;
+ auto* out_rois = Output(0, {roi_counts, roi_col_count}, at::dtype<float>());
+ auto* out_rois_probs = Output(1, {roi_counts}, at::dtype<float>());
float* out_rois_ptr = out_rois->template mutable_data<float>();
float* out_rois_probs_ptr = out_rois_probs->template mutable_data<float>();
for (int i = 0; i < num_images; i++) {
EXPECT_EQ((all_anchors_result - all_anchors_gt).norm(), 0);
}
+TEST(GenerateProposalsTest, TestComputeSortedAnchors) {
+ ERMatXf anchors(3, 4);
+ anchors << -38, -16, 53, 31, -84, -40, 99, 55, -176, -88, 191, 103;
+
+ int height = 4;
+ int width = 3;
+ int A = anchors.rows();
+ float feat_stride = 16;
+ int total = height * width * A;
+
+ // Generate all anchors for ground truth
+ Tensor anchors_tensor(vector<int64_t>{anchors.rows(), anchors.cols()}, CPU);
+ Eigen::Map<ERMatXf>(
+ anchors_tensor.mutable_data<float>(), anchors.rows(), anchors.cols()) =
+ anchors;
+ auto all_anchors =
+ utils::ComputeAllAnchors(anchors_tensor, height, width, feat_stride);
+ Eigen::Map<const ERMatXf> all_anchors_result(
+ all_anchors.data(), height * width * A, 4);
+
+ Eigen::Map<const ERArrXXf> anchors_map(
+ anchors.data(), anchors.rows(), anchors.cols());
+
+ // Test with random subsets and ordering of indices
+ vector<int> indices(total);
+ std::iota(indices.begin(), indices.end(), 0);
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ std::shuffle(indices.begin(), indices.end(), gen);
+ for (int count = 0; count <= total; ++count) {
+ vector<int> order(indices.begin(), indices.begin() + count);
+ auto result = utils::ComputeSortedAnchors(
+ anchors_map, height, width, feat_stride, order);
+
+ // Compare the result of ComputeSortedAnchors with first generating all
+ // anchors via ComputeAllAnchors and then applying ordering and filtering.
+ // Need to convert order from (A, H, W) to (H, W, A) format before for this.
+ const auto& order_AHW = utils::AsEArrXt(order);
+ const auto& order_AH = order_AHW / width;
+ const auto& order_W = order_AHW - order_AH * width;
+ const auto& order_A = order_AH / height;
+ const auto& order_H = order_AH - order_A * height;
+ const auto& order_HWA = (order_H * width + order_W) * A + order_A;
+
+ ERArrXXf gt;
+ utils::GetSubArrayRows(all_anchors_result.array(), order_HWA, >);
+ EXPECT_EQ((result.matrix() - gt.matrix()).norm(), 0);
+ }
+}
+
namespace {
template <class Derived>
EXPECT_EQ((all_anchors_result - all_anchors_gt).norm(), 0);
}
+TEST(GenerateProposalsTest, TestComputeSortedAnchorsRotated) {
+ // Similar to TestComputeSortedAnchors but for rotated boxes with angle info.
+ ERMatXf anchors_xyxy(3, 4);
+ anchors_xyxy << -38, -16, 53, 31, -84, -40, 99, 55, -176, -88, 191, 103;
+
+ // Convert to RRPN format and add angles
+ ERMatXf anchors(3, 5);
+ anchors.block(0, 0, 3, 4) = boxes_xyxy_to_xywh(anchors_xyxy);
+ std::vector<float> angles{0.0, 45.0, -120.0};
+ for (int i = 0; i < anchors.rows(); ++i) {
+ anchors(i, 4) = angles[i % angles.size()];
+ }
+
+ int height = 4;
+ int width = 3;
+ int A = anchors.rows();
+ float feat_stride = 16;
+ int total = height * width * A;
+
+ // Generate all anchors for ground truth
+ Tensor anchors_tensor(vector<int64_t>{anchors.rows(), anchors.cols()}, CPU);
+ Eigen::Map<ERMatXf>(
+ anchors_tensor.mutable_data<float>(), anchors.rows(), anchors.cols()) =
+ anchors;
+ auto all_anchors =
+ utils::ComputeAllAnchors(anchors_tensor, height, width, feat_stride);
+ Eigen::Map<const ERMatXf> all_anchors_result(
+ all_anchors.data(), height * width * A, 5);
+
+ Eigen::Map<const ERArrXXf> anchors_map(
+ anchors.data(), anchors.rows(), anchors.cols());
+
+ // Test with random subsets and ordering of indices
+ vector<int> indices(total);
+ std::iota(indices.begin(), indices.end(), 0);
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ std::shuffle(indices.begin(), indices.end(), gen);
+ for (int count = 0; count <= total; ++count) {
+ vector<int> order(indices.begin(), indices.begin() + count);
+ auto result = utils::ComputeSortedAnchors(
+ anchors_map, height, width, feat_stride, order);
+
+ // Compare the result of ComputeSortedAnchors with first generating all
+ // anchors via ComputeAllAnchors and then applying ordering and filtering.
+ // Need to convert order from (A, H, W) to (H, W, A) format before for this.
+ const auto& order_AHW = utils::AsEArrXt(order);
+ const auto& order_AH = order_AHW / width;
+ const auto& order_W = order_AHW - order_AH * width;
+ const auto& order_A = order_AH / height;
+ const auto& order_H = order_AH - order_A * height;
+ const auto& order_HWA = (order_H * width + order_W) * A + order_A;
+
+ ERArrXXf gt;
+ utils::GetSubArrayRows(all_anchors_result.array(), order_HWA, >);
+ EXPECT_EQ((result.matrix() - gt.matrix()).norm(), 0);
+ }
+}
+
TEST(GenerateProposalsTest, TestEmpty) {
Workspace ws;
OperatorDef def;
EXPECT_NE(nullptr, op.get());
EXPECT_TRUE(op->Run());
- // Verify that the resulting angles are correct
Blob* rois_blob = ws.GetBlob("rois");
EXPECT_NE(nullptr, rois_blob);
auto& rois = rois_blob->Get<TensorCPU>();
- EXPECT_GT(rois.size(0), 0);
+ EXPECT_EQ(rois.sizes(), (vector<int64_t>{13, 6}));
+
+ Blob* rois_probs_blob = ws.GetBlob("rois_probs");
+ EXPECT_NE(nullptr, rois_probs_blob);
+ auto& rois_probs = rois_probs_blob->Get<TensorCPU>();
+ EXPECT_EQ(rois_probs.sizes(), (vector<int64_t>{13}));
+
+ // Verify that the resulting angles are correct
auto rois_data =
Eigen::Map<const ERMatXf>(rois.data<float>(), rois.size(0), rois.size(1));
for (int i = 0; i < rois.size(0); ++i) {