Adding CUDA version for C2 operators generate proposals and nms (#13694)
authorhbraun@nvidia.com <hbraun@nvidia.com>
Thu, 20 Dec 2018 22:24:27 +0000 (14:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 22:39:09 +0000 (14:39 -0800)
Summary:
Related to issue #13684
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13694

Reviewed By: wat3rBro

Differential Revision: D13017791

Pulled By: newstzpz

fbshipit-source-id: 4bdc58e474d8e1f6cd73a02bf51f91542a2b9d0b

caffe2/core/common_gpu.h
caffe2/operators/generate_proposals_op.cc
caffe2/operators/generate_proposals_op.cu [new file with mode: 0644]
caffe2/operators/generate_proposals_op.h
caffe2/operators/generate_proposals_op_gpu_test.cc [new file with mode: 0644]
caffe2/operators/generate_proposals_op_util_nms_gpu.cu [new file with mode: 0644]
caffe2/operators/generate_proposals_op_util_nms_gpu.h [new file with mode: 0644]
caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc [new file with mode: 0644]
tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py

index db87887..1c2b487 100644 (file)
@@ -286,6 +286,12 @@ CAFFE2_CUDA_API const char* curandGetErrorString(curandStatus_t error);
   for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
        i += blockDim.x * gridDim.x)
 
+#define CUDA_2D_KERNEL_LOOP(i, n, j, m)                             \
+  for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n);   \
+       i += blockDim.x * gridDim.x)                                 \
+    for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \
+         j += blockDim.y * gridDim.y)
+
 // CUDA_KERNEL_ASSERT is a macro that wraps an assert() call inside cuda
 // kernels. This is not supported by Apple platforms so we special case it.
 // See http://docs.nvidia.com/cuda/cuda-c-programming-guide/#assertion
@@ -309,13 +315,22 @@ CAFFE2_CUDA_API const char* curandGetErrorString(curandStatus_t error);
 // The number of cuda threads to use. Since work is assigned to SMs at the
 // granularity of a block, 128 is chosen to allow utilizing more SMs for
 // smaller input sizes.
+// 1D grid
 constexpr int CAFFE_CUDA_NUM_THREADS = 128;
+// 2D grid
+constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMX = 16;
+constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMY = 16;
+
 // The maximum number of blocks to use in the default kernel call. We set it to
 // 4096 which would work for compute capability 2.x (where 65536 is the limit).
 // This number is very carelessly chosen. Ideally, one would like to look at
 // the hardware at runtime, and pick the number of blocks that makes most
 // sense for the specific runtime environment. This is a todo item.
+// 1D grid
 constexpr int CAFFE_MAXIMUM_NUM_BLOCKS = 4096;
+// 2D grid
+constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX = 128;
+constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY = 128;
 
 constexpr int kCUDAGridDimMaxX = 2147483647;
 constexpr int kCUDAGridDimMaxY = 65535;
@@ -333,6 +348,32 @@ inline int CAFFE_GET_BLOCKS(const int N) {
       1);
 }
 
+/**
+ * @brief Compute the number of blocks needed to run N threads for a 2D grid
+ */
+inline dim3 CAFFE_GET_BLOCKS_2D(const int N, const int /* M */) {
+  dim3 grid;
+  // Not calling the 1D version for each dim to keep all constants as literals
+
+  grid.x = std::max(
+      std::min(
+          (N + CAFFE_CUDA_NUM_THREADS_2D_DIMX - 1) /
+              CAFFE_CUDA_NUM_THREADS_2D_DIMX,
+          CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX),
+      // Use at least 1 block, since CUDA does not allow empty block
+      1);
+
+  grid.y = std::max(
+      std::min(
+          (N + CAFFE_CUDA_NUM_THREADS_2D_DIMY - 1) /
+              CAFFE_CUDA_NUM_THREADS_2D_DIMY,
+          CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY),
+      // Use at least 1 block, since CUDA does not allow empty block
+      1);
+
+  return grid;
+}
+
 class DeviceGuard {
  public:
   explicit DeviceGuard(int newDevice) : previous_(CaffeCudaGetDevice()) {
index 3f50661..e391f1e 100644 (file)
@@ -345,8 +345,6 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
   return true;
 }
 
-namespace {
-
 REGISTER_CPU_OPERATOR(GenerateProposals, GenerateProposalsOp<CPUContext>);
 // For backward compatibility
 REGISTER_CPU_OPERATOR(GenerateProposalsCPP, GenerateProposalsOp<CPUContext>);
@@ -413,5 +411,4 @@ SHOULD_NOT_DO_GRADIENT(GenerateProposals);
 // For backward compatibility
 SHOULD_NOT_DO_GRADIENT(GenerateProposalsCPP);
 
-} // namespace
 } // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op.cu b/caffe2/operators/generate_proposals_op.cu
new file mode 100644 (file)
index 0000000..fe89a66
--- /dev/null
@@ -0,0 +1,453 @@
+#include <cub/cub.cuh>
+#include "caffe2/core/context.h"
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/generate_proposals_op.h"
+#include "caffe2/operators/generate_proposals_op_util_boxes.h" // BBOX_XFORM_CLIP_DEFAULT
+#include "caffe2/operators/generate_proposals_op_util_nms.h"
+#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h"
+
+namespace caffe2 {
+namespace {
+__global__ void GeneratePreNMSUprightBoxesKernel(
+    const int* d_sorted_scores_keys,
+    const int nboxes_to_generate,
+    const float* d_bbox_deltas,
+    const float4* d_anchors,
+    const int H,
+    const int W,
+    const int K, // K = H*W
+    const int A,
+    const int KA, // KA = K*A
+    const float feat_stride,
+    const float min_size,
+    const float* d_img_info_vec,
+    const int num_images,
+    const float bbox_xform_clip,
+    const bool correct_transform,
+    float4* d_out_boxes,
+    const int prenms_nboxes, // leading dimension of out_boxes
+    float* d_inout_scores,
+    char* d_boxes_keep_flags) {
+  CUDA_2D_KERNEL_LOOP(ibox, nboxes_to_generate, image_index, num_images) {
+    // box_conv_index : # of the same box, but indexed in
+    // the scores from the conv layer, of shape (A,H,W)
+    // the num_images dimension was already removed
+    // box_conv_index = a*K + h*W + w
+    const int box_conv_index = d_sorted_scores_keys[image_index * KA + ibox];
+
+    // We want to decompose box_conv_index in (a,h,w)
+    // such as box_conv_index = a*K + h*W + w
+    // (avoiding modulos in the process)
+    int remaining = box_conv_index;
+    const int dA = K; // stride of A
+    const int a = remaining / dA;
+    remaining -= a * dA;
+    const int dH = W; // stride of H
+    const int h = remaining / dH;
+    remaining -= h * dH;
+    const int w = remaining; // dW = 1
+
+    // Loading the anchor a
+    // float is a struct with float x,y,z,w
+    const float4 anchor = d_anchors[a];
+    // x1,y1,x2,y2 :coordinates of anchor a, shifted for position (h,w)
+    const float shift_w = feat_stride * w;
+    float x1 = shift_w + anchor.x;
+    float x2 = shift_w + anchor.z;
+    const float shift_h = feat_stride * h;
+    float y1 = shift_h + anchor.y;
+    float y2 = shift_h + anchor.w;
+
+    // TODO use fast math when possible
+
+    // Deltas for that box
+    // Deltas of shape (num_images,4*A,K)
+    // We're going to compute 4 scattered reads
+    // better than the alternative, ie transposing the complete deltas
+    // array first
+    int deltas_idx = image_index * (KA * 4) + a * 4 * K + h * W + w;
+    const float dx = d_bbox_deltas[deltas_idx];
+    // Stride of K between each dimension
+    deltas_idx += K;
+    const float dy = d_bbox_deltas[deltas_idx];
+    deltas_idx += K;
+    float dw = d_bbox_deltas[deltas_idx];
+    deltas_idx += K;
+    float dh = d_bbox_deltas[deltas_idx];
+
+    // Upper bound on dw,dh
+    dw = fmin(dw, bbox_xform_clip);
+    dh = fmin(dh, bbox_xform_clip);
+
+    // Applying the deltas
+    float width = x2 - x1 + 1.0f;
+    const float ctr_x = x1 + 0.5f * width;
+    const float pred_ctr_x = ctr_x + width * dx; // TODO fuse madd
+    const float pred_w = width * expf(dw);
+    x1 = pred_ctr_x - 0.5f * pred_w;
+    x2 = pred_ctr_x + 0.5f * pred_w;
+
+    float height = y2 - y1 + 1.0f;
+    const float ctr_y = y1 + 0.5f * height;
+    const float pred_ctr_y = ctr_y + height * dy;
+    const float pred_h = height * expf(dh);
+    y1 = pred_ctr_y - 0.5f * pred_h;
+    y2 = pred_ctr_y + 0.5f * pred_h;
+
+    if (correct_transform) {
+      x2 -= 1.0f;
+      y2 -= 1.0f;
+    }
+
+    // Clipping box to image
+    const float img_height = d_img_info_vec[3 * image_index + 0];
+    const float img_width = d_img_info_vec[3 * image_index + 1];
+    const float min_size_scaled =
+        min_size * d_img_info_vec[3 * image_index + 2];
+    x1 = fmax(fmin(x1, img_width - 1.0f), 0.0f);
+    y1 = fmax(fmin(y1, img_height - 1.0f), 0.0f);
+    x2 = fmax(fmin(x2, img_width - 1.0f), 0.0f);
+    y2 = fmax(fmin(y2, img_height - 1.0f), 0.0f);
+
+    // Filter boxes
+    // Removing boxes with one dim < min_size
+    // (center of box is in image, because of previous step)
+    width = x2 - x1 + 1.0f; // may have changed
+    height = y2 - y1 + 1.0f;
+    bool keep_box = fmin(width, height) >= min_size_scaled;
+
+    // We are not deleting the box right now even if !keep_box
+    // we want to keep the relative order of the elements stable
+    // we'll do it in such a way later
+    // d_boxes_keep_flags size: (num_images,prenms_nboxes)
+    // d_out_boxes size: (num_images,prenms_nboxes)
+    const int out_index = image_index * prenms_nboxes + ibox;
+    d_boxes_keep_flags[out_index] = keep_box;
+    d_out_boxes[out_index] = {x1, y1, x2, y2};
+
+    // d_inout_scores size: (num_images,KA)
+    if (!keep_box)
+      d_inout_scores[image_index * KA + ibox] = FLT_MIN; // for NMS
+  }
+}
+
+__global__ void WriteOutput(
+    const float4* d_image_boxes,
+    const float* d_image_scores,
+    const int* d_image_boxes_keep_list,
+    const int nboxes,
+    const int image_index,
+    float* d_image_out_rois,
+    float* d_image_out_rois_probs) {
+  CUDA_1D_KERNEL_LOOP(i, nboxes) {
+    const int ibox = d_image_boxes_keep_list[i];
+    const float4 box = d_image_boxes[ibox];
+    const float score = d_image_scores[ibox];
+    // Scattered memory accesses
+    // postnms_nboxes is small anyway
+    d_image_out_rois_probs[i] = score;
+    const int base_idx = 5 * i;
+    d_image_out_rois[base_idx + 0] = image_index;
+    d_image_out_rois[base_idx + 1] = box.x;
+    d_image_out_rois[base_idx + 2] = box.y;
+    d_image_out_rois[base_idx + 3] = box.z;
+    d_image_out_rois[base_idx + 4] = box.w;
+  }
+}
+
+__global__ void InitializeDataKernel(
+    const int num_images,
+    const int KA,
+    int* d_image_offsets,
+    int* d_boxes_keys_iota) {
+  CUDA_2D_KERNEL_LOOP(box_idx, KA, img_idx, num_images) {
+    d_boxes_keys_iota[img_idx * KA + box_idx] = box_idx;
+
+    // One 1D line sets the 1D data
+    if (box_idx == 0) {
+      d_image_offsets[img_idx] = KA * img_idx;
+      // One thread sets the last+1 offset
+      if (img_idx == 0)
+        d_image_offsets[num_images] = KA * num_images;
+    }
+  }
+}
+
+} // namespace
+
+template <>
+bool GenerateProposalsOp<CUDAContext>::RunOnDevice() {
+  const auto& scores = Input(0);
+  const auto& bbox_deltas = Input(1);
+  const auto& im_info_tensor = Input(2);
+  const auto& anchors = Input(3);
+  auto* out_rois = Output(0);
+  auto* out_rois_probs = Output(1);
+
+  CAFFE_ENFORCE_EQ(scores.ndim(), 4, scores.ndim());
+  CAFFE_ENFORCE(scores.template IsType<float>(), scores.meta().name());
+
+  const auto num_images = scores.dim(0);
+  const auto A = scores.dim(1);
+  const auto H = scores.dim(2);
+  const auto W = scores.dim(3);
+  const auto box_dim_conv = anchors.dim(1);
+
+  CAFFE_ENFORCE(box_dim_conv == 4); // only upright boxes in GPU version for now
+
+  constexpr int box_dim = 4;
+  const int K = H * W;
+  const int conv_layer_nboxes = K * A;
+  // Getting data members ready
+
+  // We'll sort the scores
+  // we want to remember their original indexes,
+  // ie their indexes in the tensor of shape (num_images,A,K)
+  // from the conv layer
+  // each row of d_conv_layer_indexes is at first initialized to 1..A*K
+  dev_conv_layer_indexes_.Resize(num_images, conv_layer_nboxes);
+  int* d_conv_layer_indexes =
+      dev_conv_layer_indexes_.template mutable_data<int>();
+
+  // d_image_offset[i] = i*K*A for i from 1 to num_images+1
+  // Used by the segmented sort to only sort scores within one image
+  dev_image_offset_.Resize(num_images + 1);
+  int* d_image_offset = dev_image_offset_.template mutable_data<int>();
+
+  // The following calls to CUB primitives do nothing
+  // (because the first arg is nullptr)
+  // except setting cub_*_temp_storage_bytes
+  size_t cub_sort_temp_storage_bytes = 0;
+  float* flt_ptr = nullptr;
+  int* int_ptr = nullptr;
+  cub::DeviceSegmentedRadixSort::SortPairsDescending(
+      nullptr,
+      cub_sort_temp_storage_bytes,
+      flt_ptr,
+      flt_ptr,
+      int_ptr,
+      int_ptr,
+      num_images * conv_layer_nboxes,
+      num_images,
+      int_ptr,
+      int_ptr,
+      0,
+      8 * sizeof(float), // sort all bits
+      context_.cuda_stream());
+
+  // Allocate temporary storage for CUB
+  dev_cub_sort_buffer_.Resize(cub_sort_temp_storage_bytes);
+  void* d_cub_sort_temp_storage =
+      dev_cub_sort_buffer_.template mutable_data<char>();
+
+  size_t cub_select_temp_storage_bytes = 0;
+  char* char_ptr = nullptr;
+  cub::DeviceSelect::Flagged(
+      nullptr,
+      cub_select_temp_storage_bytes,
+      flt_ptr,
+      char_ptr,
+      flt_ptr,
+      int_ptr,
+      K * A,
+      context_.cuda_stream());
+
+  // Allocate temporary storage for CUB
+  dev_cub_select_buffer_.Resize(cub_select_temp_storage_bytes);
+  void* d_cub_select_temp_storage =
+      dev_cub_select_buffer_.template mutable_data<char>();
+
+  // Initialize :
+  // - each row of dev_conv_layer_indexes to 1..K*A
+  // - each d_nboxes to 0
+  // - d_image_offset[i] = K*A*i for i 1..num_images+1
+  // 2D grid
+  InitializeDataKernel<<<
+      (CAFFE_GET_BLOCKS(A * K), num_images),
+      CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1
+      0,
+      context_.cuda_stream()>>>(
+      num_images, conv_layer_nboxes, d_image_offset, d_conv_layer_indexes);
+
+  // Sorting input scores
+  dev_sorted_conv_layer_indexes_.Resize(num_images, conv_layer_nboxes);
+  dev_sorted_scores_.Resize(num_images, conv_layer_nboxes);
+  const float* d_in_scores = scores.data<float>();
+  int* d_sorted_conv_layer_indexes =
+      dev_sorted_conv_layer_indexes_.template mutable_data<int>();
+  float* d_sorted_scores = dev_sorted_scores_.template mutable_data<float>();
+  ;
+  cub::DeviceSegmentedRadixSort::SortPairsDescending(
+      d_cub_sort_temp_storage,
+      cub_sort_temp_storage_bytes,
+      d_in_scores,
+      d_sorted_scores,
+      d_conv_layer_indexes,
+      d_sorted_conv_layer_indexes,
+      num_images * conv_layer_nboxes,
+      num_images,
+      d_image_offset,
+      d_image_offset + 1,
+      0,
+      8 * sizeof(float), // sort all bits
+      context_.cuda_stream());
+
+  // Keeping only the topN pre_nms
+  const int nboxes_to_generate = std::min(conv_layer_nboxes, rpn_pre_nms_topN_);
+
+  // Generating the boxes associated to the topN pre_nms scores
+  dev_boxes_.Resize(num_images, box_dim * nboxes_to_generate);
+  dev_boxes_keep_flags_.Resize(num_images, nboxes_to_generate);
+  const float* d_bbox_deltas = bbox_deltas.data<float>();
+  const float* d_anchors = anchors.data<float>();
+  const float* d_im_info_vec = im_info_tensor.data<float>();
+  float4* d_boxes =
+      reinterpret_cast<float4*>(dev_boxes_.template mutable_data<float>());
+  ;
+  char* d_boxes_keep_flags =
+      dev_boxes_keep_flags_.template mutable_data<char>();
+
+  GeneratePreNMSUprightBoxesKernel<<<
+      (CAFFE_GET_BLOCKS(nboxes_to_generate), num_images),
+      CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1
+      0,
+      context_.cuda_stream()>>>(
+      d_sorted_conv_layer_indexes,
+      nboxes_to_generate,
+      d_bbox_deltas,
+      reinterpret_cast<const float4*>(d_anchors),
+      H,
+      W,
+      K,
+      A,
+      K * A,
+      feat_stride_,
+      rpn_min_size_,
+      d_im_info_vec,
+      num_images,
+      utils::BBOX_XFORM_CLIP_DEFAULT,
+      correct_transform_coords_,
+      d_boxes,
+      nboxes_to_generate,
+      d_sorted_scores,
+      d_boxes_keep_flags);
+  const int nboxes_generated = nboxes_to_generate;
+  dev_image_prenms_boxes_.Resize(box_dim * nboxes_generated);
+  float4* d_image_prenms_boxes = reinterpret_cast<float4*>(
+      dev_image_prenms_boxes_.template mutable_data<float>());
+  dev_image_prenms_scores_.Resize(nboxes_generated);
+  float* d_image_prenms_scores =
+      dev_image_prenms_scores_.template mutable_data<float>();
+  dev_image_boxes_keep_list_.Resize(nboxes_generated);
+  int* d_image_boxes_keep_list =
+      dev_image_boxes_keep_list_.template mutable_data<int>();
+
+  const int max_postnms_nboxes = std::min(nboxes_generated, rpn_post_nms_topN_);
+  dev_postnms_rois_.Resize(5 * num_images * max_postnms_nboxes);
+  dev_postnms_rois_probs_.Resize(num_images * max_postnms_nboxes);
+  float* d_postnms_rois = dev_postnms_rois_.template mutable_data<float>();
+  float* d_postnms_rois_probs =
+      dev_postnms_rois_probs_.template mutable_data<float>();
+
+  dev_prenms_nboxes_.Resize(num_images);
+  host_prenms_nboxes_.Resize(num_images);
+  int* d_prenms_nboxes = dev_prenms_nboxes_.template mutable_data<int>();
+  int* h_prenms_nboxes = host_prenms_nboxes_.template mutable_data<int>();
+
+  int nrois_in_output = 0;
+  for (int image_index = 0; image_index < num_images; ++image_index) {
+    // Sub matrices for current image
+    const float4* d_image_boxes = &d_boxes[image_index * nboxes_generated];
+    const float* d_image_sorted_scores = &d_sorted_scores[image_index * K * A];
+    char* d_image_boxes_keep_flags =
+        &d_boxes_keep_flags[image_index * nboxes_generated];
+
+    float* d_image_postnms_rois = &d_postnms_rois[5 * nrois_in_output];
+    float* d_image_postnms_rois_probs = &d_postnms_rois_probs[nrois_in_output];
+
+    // Moving valid boxes (ie the ones with d_boxes_keep_flags[ibox] == true)
+    // to the output tensors
+
+    cub::DeviceSelect::Flagged(
+        d_cub_select_temp_storage,
+        cub_select_temp_storage_bytes,
+        d_image_boxes,
+        d_image_boxes_keep_flags,
+        d_image_prenms_boxes,
+        d_prenms_nboxes,
+        nboxes_generated,
+        context_.cuda_stream());
+
+    cub::DeviceSelect::Flagged(
+        d_cub_select_temp_storage,
+        cub_select_temp_storage_bytes,
+        d_image_sorted_scores,
+        d_image_boxes_keep_flags,
+        d_image_prenms_scores,
+        d_prenms_nboxes,
+        nboxes_generated,
+        context_.cuda_stream());
+
+    host_prenms_nboxes_.CopyFrom(dev_prenms_nboxes_);
+
+    // We know prenms_boxes <= topN_prenms, because nboxes_generated <=
+    // topN_prenms Calling NMS on the generated boxes
+    const int prenms_nboxes = *h_prenms_nboxes;
+    int nkeep;
+    utils::nms_gpu_upright(
+        reinterpret_cast<const float*>(d_image_prenms_boxes),
+        prenms_nboxes,
+        rpn_nms_thresh_,
+        d_image_boxes_keep_list,
+        &nkeep,
+        dev_nms_mask_,
+        host_nms_mask_,
+        &context_);
+
+    // All operations done after previous sort were keeping the relative order
+    // of the elements the elements are still sorted keep topN <=> truncate the
+    // array
+    const int postnms_nboxes = std::min(nkeep, rpn_post_nms_topN_);
+
+    // Moving the out boxes to the output tensors,
+    // adding the image_index dimension on the fly
+    WriteOutput<<<
+        CAFFE_GET_BLOCKS(postnms_nboxes),
+        CAFFE_CUDA_NUM_THREADS,
+        0,
+        context_.cuda_stream()>>>(
+        d_image_prenms_boxes,
+        d_image_prenms_scores,
+        d_image_boxes_keep_list,
+        postnms_nboxes,
+        image_index,
+        d_image_postnms_rois,
+        d_image_postnms_rois_probs);
+
+    nrois_in_output += postnms_nboxes;
+  }
+
+  // Using a buffer because we cannot call ShrinkTo
+  out_rois->Resize(nrois_in_output, 5);
+  out_rois_probs->Resize(nrois_in_output);
+  float* d_out_rois = out_rois->template mutable_data<float>();
+  float* d_out_rois_probs = out_rois_probs->template mutable_data<float>();
+
+  CUDA_CHECK(cudaMemcpyAsync(
+      d_out_rois,
+      d_postnms_rois,
+      nrois_in_output * 5 * sizeof(float),
+      cudaMemcpyDeviceToDevice,
+      context_.cuda_stream()));
+  CUDA_CHECK(cudaMemcpyAsync(
+      d_out_rois_probs,
+      d_postnms_rois_probs,
+      nrois_in_output * sizeof(float),
+      cudaMemcpyDeviceToDevice,
+      context_.cuda_stream()));
+
+  return true;
+}
+
+REGISTER_CUDA_OPERATOR(GenerateProposals, GenerateProposalsOp<CUDAContext>);
+} // namespace caffe2
index fa933e3..6d667f2 100644 (file)
@@ -144,8 +144,35 @@ class GenerateProposalsOp final : public Operator<Context> {
   // tolerance for backward compatibility. Set to negative value for
   // no clipping.
   float clip_angle_thresh_{1.0};
+
+  // Scratch space required by the CUDA version
+  // CUB buffers
+  Tensor dev_cub_sort_buffer_{Context::GetDeviceType()};
+  Tensor dev_cub_select_buffer_{Context::GetDeviceType()};
+  Tensor dev_image_offset_{Context::GetDeviceType()};
+  Tensor dev_conv_layer_indexes_{Context::GetDeviceType()};
+  Tensor dev_sorted_conv_layer_indexes_{Context::GetDeviceType()};
+  Tensor dev_sorted_scores_{Context::GetDeviceType()};
+  Tensor dev_boxes_{Context::GetDeviceType()};
+  Tensor dev_boxes_keep_flags_{Context::GetDeviceType()};
+
+  // prenms proposals (raw proposals minus empty boxes)
+  Tensor dev_image_prenms_boxes_{Context::GetDeviceType()};
+  Tensor dev_image_prenms_scores_{Context::GetDeviceType()};
+  Tensor dev_prenms_nboxes_{Context::GetDeviceType()};
+  Tensor host_prenms_nboxes_{CPU};
+
+  Tensor dev_image_boxes_keep_list_{Context::GetDeviceType()};
+
+  // Tensors used by NMS
+  Tensor dev_nms_mask_{Context::GetDeviceType()};
+  Tensor host_nms_mask_{CPU};
+
+  // Buffer for output
+  Tensor dev_postnms_rois_{Context::GetDeviceType()};
+  Tensor dev_postnms_rois_probs_{Context::GetDeviceType()};
 };
 
 } // namespace caffe2
 
-#endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
+#endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_
\ No newline at end of file
diff --git a/caffe2/operators/generate_proposals_op_gpu_test.cc b/caffe2/operators/generate_proposals_op_gpu_test.cc
new file mode 100644 (file)
index 0000000..36ad5aa
--- /dev/null
@@ -0,0 +1,239 @@
+#include "caffe2/operators/generate_proposals_op.h"
+
+#include <gtest/gtest.h>
+#include "caffe2/core/flags.h"
+#include "caffe2/core/macros.h"
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/context_gpu.h"
+
+#ifdef CAFFE2_USE_OPENCV
+#include <opencv2/opencv.hpp>
+#endif // CAFFE2_USE_OPENCV
+
+namespace caffe2 {
+
+static void AddLinSpacedInput(
+    const vector<int64_t>& shape,
+    const float min_val,
+    const float max_val,
+    const string& name,
+    Workspace* ws) {
+  DeviceOption option;
+  CPUContext context(option);
+  Blob* blob = ws->CreateBlob(name);
+  auto* tensor = BlobGetMutableTensor(blob, CPU);
+  tensor->Resize(shape);
+  EigenVectorMap<float> tensor_vec(
+      tensor->template mutable_data<float>(), tensor->size());
+  tensor_vec.setLinSpaced(min_val, max_val);
+
+  return;
+}
+
+template <class Context>
+void AddConstInput(
+    const vector<int64_t>& shape,
+    const float value,
+    const string& name,
+    Context* context,
+    Workspace* ws) {
+  Blob* blob = ws->CreateBlob(name);
+  auto* tensor = BlobGetMutableTensor(blob, Context::GetDeviceType());
+  tensor->Resize(shape);
+  math::Set<float, Context>(
+      tensor->size(), value, tensor->template mutable_data<float>(), context);
+  return;
+}
+
+template <class Context>
+void AddInput(
+    const vector<int64_t>& shape,
+    const vector<float>& values,
+    const string& name,
+    Workspace* ws);
+
+template <>
+void AddInput<CPUContext>(
+    const vector<int64_t>& shape,
+    const vector<float>& values,
+    const string& name,
+    Workspace* ws) {
+  Blob* blob = ws->CreateBlob(name);
+  auto* tensor = BlobGetMutableTensor(blob, CPU);
+  tensor->Resize(shape);
+  EigenVectorMap<float> tensor_vec(
+      tensor->template mutable_data<float>(), tensor->size());
+  tensor_vec.array() = utils::AsEArrXt(values);
+}
+
+template <>
+void AddInput<CUDAContext>(
+    const vector<int64_t>& shape,
+    const vector<float>& values,
+    const string& name,
+    Workspace* ws) {
+  Tensor tmp(shape, CPU);
+  EigenVectorMap<float> tmp_vec(tmp.mutable_data<float>(), tmp.size());
+  tmp_vec.array() = utils::AsEArrXt(values);
+
+  Blob* blob = ws->CreateBlob(name);
+  auto* tensor = BlobGetMutableTensor(blob, CUDA);
+  tensor->CopyFrom(tmp);
+}
+
+TEST(GenerateProposalsTest, TestRealDownSampledGPU) {
+  if (!HasCudaGPU())
+    return;
+  Workspace ws;
+  OperatorDef def;
+  def.set_name("test");
+  def.set_type("GenerateProposals");
+  def.add_input("scores");
+  def.add_input("bbox_deltas");
+  def.add_input("im_info");
+  def.add_input("anchors");
+  def.add_output("rois");
+  def.add_output("rois_probs");
+  def.mutable_device_option()->set_device_type(PROTO_CUDA);
+  const int img_count = 2;
+  const int A = 2;
+  const int H = 4;
+  const int W = 5;
+
+  vector<float> scores{
+      5.44218998e-03f, 1.19207997e-03f, 1.12379994e-03f, 1.17181998e-03f,
+      1.20544003e-03f, 6.17993006e-04f, 1.05261997e-05f, 8.91025957e-06f,
+      9.29536981e-09f, 6.09605013e-05f, 4.72735002e-04f, 1.13482002e-10f,
+      1.50015003e-05f, 4.45032993e-06f, 3.21612994e-08f, 8.02662980e-04f,
+      1.40488002e-04f, 3.12508007e-07f, 3.02616991e-06f, 1.97759000e-08f,
+      2.66913995e-02f, 5.26766013e-03f, 5.05053019e-03f, 5.62100019e-03f,
+      5.37420018e-03f, 5.26280981e-03f, 2.48894998e-04f, 1.06842002e-04f,
+      3.92931997e-06f, 1.79388002e-03f, 4.79440019e-03f, 3.41609990e-07f,
+      5.20430971e-04f, 3.34090000e-05f, 2.19159006e-07f, 2.28786003e-03f,
+      5.16703985e-05f, 4.04523007e-06f, 1.79227004e-06f, 5.32449000e-08f};
+  vector<float> bbx{
+      -1.65040009e-02f, -1.84051003e-02f, -1.85930002e-02f, -2.08263006e-02f,
+      -1.83814000e-02f, -2.89172009e-02f, -3.89706008e-02f, -7.52277970e-02f,
+      -1.54091999e-01f, -2.55433004e-02f, -1.77490003e-02f, -1.10340998e-01f,
+      -4.20190990e-02f, -2.71421000e-02f, 6.89801015e-03f,  5.71171008e-02f,
+      -1.75665006e-01f, 2.30021998e-02f,  3.08554992e-02f,  -1.39333997e-02f,
+      3.40579003e-01f,  3.91070992e-01f,  3.91624004e-01f,  3.92527014e-01f,
+      3.91445011e-01f,  3.79328012e-01f,  4.26631987e-01f,  3.64892989e-01f,
+      2.76894987e-01f,  5.13985991e-01f,  3.79999995e-01f,  1.80457994e-01f,
+      4.37402993e-01f,  4.18545991e-01f,  2.51549989e-01f,  4.48318988e-01f,
+      1.68564007e-01f,  4.65440989e-01f,  4.21891987e-01f,  4.45928007e-01f,
+      3.27155995e-03f,  3.71480011e-03f,  3.60032008e-03f,  4.27092984e-03f,
+      3.74579988e-03f,  5.95752988e-03f,  -3.14473989e-03f, 3.52022005e-03f,
+      -1.88564006e-02f, 1.65188999e-03f,  1.73791999e-03f,  -3.56074013e-02f,
+      -1.66615995e-04f, 3.14146001e-03f,  -1.11830998e-02f, -5.35363983e-03f,
+      6.49790000e-03f,  -9.27671045e-03f, -2.83346009e-02f, -1.61233004e-02f,
+      -2.15505004e-01f, -2.19910994e-01f, -2.20872998e-01f, -2.12831005e-01f,
+      -2.19145000e-01f, -2.27687001e-01f, -3.43973994e-01f, -2.75869995e-01f,
+      -3.19516987e-01f, -2.50418007e-01f, -2.48537004e-01f, -5.08224010e-01f,
+      -2.28724003e-01f, -2.82402009e-01f, -3.75815988e-01f, -2.86352992e-01f,
+      -5.28333001e-02f, -4.43836004e-01f, -4.55134988e-01f, -4.34897989e-01f,
+      -5.65053988e-03f, -9.25739005e-04f, -1.06790999e-03f, -2.37016007e-03f,
+      -9.71166010e-04f, -8.90910998e-03f, -1.17592998e-02f, -2.08992008e-02f,
+      -4.94231991e-02f, 6.63906988e-03f,  3.20469006e-03f,  -6.44695014e-02f,
+      -3.11607006e-03f, 2.02738005e-03f,  1.48096997e-02f,  4.39785011e-02f,
+      -8.28424022e-02f, 3.62076014e-02f,  2.71668993e-02f,  1.38250999e-02f,
+      6.76669031e-02f,  1.03252999e-01f,  1.03255004e-01f,  9.89722982e-02f,
+      1.03646003e-01f,  4.79663983e-02f,  1.11014001e-01f,  9.31736007e-02f,
+      1.15768999e-01f,  1.04014002e-01f,  -8.90677981e-03f, 1.13103002e-01f,
+      1.33085996e-01f,  1.25405997e-01f,  1.50051996e-01f,  -1.13038003e-01f,
+      7.01059997e-02f,  1.79651007e-01f,  1.41055003e-01f,  1.62841007e-01f,
+      -1.00247003e-02f, -8.17587040e-03f, -8.32176022e-03f, -8.90108012e-03f,
+      -8.13035015e-03f, -1.77263003e-02f, -3.69572006e-02f, -3.51580009e-02f,
+      -5.92143014e-02f, -1.80795006e-02f, -5.46086021e-03f, -4.10550982e-02f,
+      -1.83081999e-02f, -2.15411000e-02f, -1.17953997e-02f, 3.33894007e-02f,
+      -5.29635996e-02f, -6.97528012e-03f, -3.15250992e-03f, -3.27355005e-02f,
+      1.29676998e-01f,  1.16080999e-01f,  1.15947001e-01f,  1.21797003e-01f,
+      1.16089001e-01f,  1.44875005e-01f,  1.15617000e-01f,  1.31586999e-01f,
+      1.74735002e-02f,  1.21973999e-01f,  1.31596997e-01f,  2.48907991e-02f,
+      6.18605018e-02f,  1.12855002e-01f,  -6.99798986e-02f, 9.58312973e-02f,
+      1.53593004e-01f,  -8.75087008e-02f, -4.92327996e-02f, -3.32239009e-02f};
+  vector<float> im_info{60, 80, 0.166667f};
+  vector<float> anchors{-38, -16, 53, 31, -120, -120, 135, 135};
+
+  // Doubling everything related to images, to simulate
+  // num_images = 2
+  scores.insert(scores.begin(), scores.begin(), scores.end());
+  bbx.insert(bbx.begin(), bbx.begin(), bbx.end());
+  im_info.insert(im_info.begin(), im_info.begin(), im_info.end());
+
+  ERMatXf rois_gt(18, 5);
+  rois_gt << 0, 0, 0, 79, 59, 0, 0, 5.0005703f, 51.6324f, 42.6950f, 0,
+      24.13628387f, 7.51243401f, 79, 45.0663f, 0, 0, 7.50924301f, 67.4779f,
+      45.0336, 0, 0, 23.09477997f, 50.61448669f, 59, 0, 0, 39.52141571f,
+      51.44710541f, 59, 0, 23.57396317f, 29.98791885f, 79, 59, 0, 0,
+      41.90219116f, 79, 59, 0, 0, 23.30098343f, 78.2413f, 58.7287f, 1, 0, 0, 79,
+      59, 1, 0, 5.0005703f, 51.6324f, 42.6950f, 1, 24.13628387f, 7.51243401f,
+      79, 45.0663f, 1, 0, 7.50924301f, 67.4779f, 45.0336, 1, 0, 23.09477997f,
+      50.61448669f, 59, 1, 0, 39.52141571f, 51.44710541f, 59, 1, 23.57396317f,
+      29.98791885f, 79, 59, 1, 0, 41.90219116f, 79, 59, 1, 0, 23.30098343f,
+      78.2413f, 58.7287f;
+
+  vector<float> rois_probs_gt{2.66913995e-02f,
+                              5.44218998e-03f,
+                              1.20544003e-03f,
+                              1.19207997e-03f,
+                              6.17993006e-04f,
+                              4.72735002e-04f,
+                              6.09605013e-05f,
+                              1.50015003e-05f,
+                              8.91025957e-06f};
+
+  // Doubling everything related to images, to simulate
+  // num_images = 2
+  rois_probs_gt.insert(
+      rois_probs_gt.begin(), rois_probs_gt.begin(), rois_probs_gt.end());
+
+  AddInput<CUDAContext>(
+      vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
+  AddInput<CUDAContext>(
+      vector<int64_t>{img_count, 4 * A, H, W}, bbx, "bbox_deltas", &ws);
+  AddInput<CUDAContext>(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
+  AddInput<CUDAContext>(vector<int64_t>{A, 4}, anchors, "anchors", &ws);
+
+  def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
+  def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
+  def.add_arg()->CopyFrom(MakeArgument("post_nms_topN", 300));
+  def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f));
+  def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f));
+  def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true));
+
+  unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
+  EXPECT_NE(nullptr, op.get());
+  EXPECT_TRUE(op->Run());
+
+  // test rois
+  Blob* rois_blob = ws.GetBlob("rois");
+  EXPECT_NE(nullptr, rois_blob);
+  auto& rois_gpu = rois_blob->Get<TensorCUDA>();
+  Tensor rois{CPU};
+  rois.CopyFrom(rois_gpu);
+
+  EXPECT_EQ(rois.dims(), (vector<int64_t>{rois_gt.rows(), rois_gt.cols()}));
+  auto rois_data =
+      Eigen::Map<const ERMatXf>(rois.data<float>(), rois.dim(0), rois.dim(1));
+  EXPECT_NEAR((rois_data.matrix() - rois_gt).cwiseAbs().maxCoeff(), 0, 1e-4);
+
+  // test rois_probs
+  Blob* rois_probs_blob = ws.GetBlob("rois_probs");
+  EXPECT_NE(nullptr, rois_probs_blob);
+  auto& rois_probs_gpu = rois_probs_blob->Get<TensorCUDA>();
+  Tensor rois_probs{CPU};
+  rois_probs.CopyFrom(rois_probs_gpu);
+  EXPECT_EQ(
+      rois_probs.dims(), (vector<int64_t>{int64_t(rois_probs_gt.size())}));
+  auto rois_probs_data =
+      ConstEigenVectorArrayMap<float>(rois_probs.data<float>(), rois.dim(0));
+  EXPECT_NEAR(
+      (rois_probs_data.matrix() - utils::AsEArrXt(rois_probs_gt).matrix())
+          .cwiseAbs()
+          .maxCoeff(),
+      0,
+      1e-4);
+}
+} // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.cu b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu
new file mode 100644 (file)
index 0000000..eb6695c
--- /dev/null
@@ -0,0 +1,196 @@
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h"
+
+namespace caffe2 {
+namespace utils {
+namespace {
+// Helper data structure used locally
+struct
+#ifndef __HIP_PLATFORM_HCC__
+    __align__(16)
+#endif
+        Box {
+  float x1, y1, x2, y2;
+};
+
+#define BOXES_PER_THREAD (8 * sizeof(int))
+#define CHUNK_SIZE 2000
+
+const dim3 CAFFE_CUDA_NUM_THREADS_2D = {
+    static_cast<unsigned int>(CAFFE_CUDA_NUM_THREADS_2D_DIMX),
+    static_cast<unsigned int>(CAFFE_CUDA_NUM_THREADS_2D_DIMY),
+    1u};
+
+__launch_bounds__(
+    CAFFE_CUDA_NUM_THREADS_2D_DIMX* CAFFE_CUDA_NUM_THREADS_2D_DIMY,
+    4) __global__
+    void NMSKernel(
+        const Box* d_desc_sorted_boxes,
+        const int nboxes,
+        const float thresh,
+        const int mask_ld,
+        int* d_delete_mask) {
+  // Storing boxes used by this CUDA block in the shared memory
+  __shared__ Box shared_i_boxes[CAFFE_CUDA_NUM_THREADS_2D_DIMX];
+  // Same thing with areas
+  __shared__ float shared_i_areas[CAFFE_CUDA_NUM_THREADS_2D_DIMX];
+  // The condition of the for loop is common to all threads in the block
+  // This is necessary to be able to call __syncthreads() inside of the loop
+  for (int i_block_offset = blockIdx.x * blockDim.x; i_block_offset < nboxes;
+       i_block_offset += blockDim.x * gridDim.x) {
+    const int i_to_load = i_block_offset + threadIdx.x;
+    if (i_to_load < nboxes) {
+      // One 1D line load the boxes for x-dimension
+      if (threadIdx.y == 0) {
+        const Box box = d_desc_sorted_boxes[i_to_load];
+        shared_i_areas[threadIdx.x] =
+            (box.x2 - box.x1 + 1.0f) * (box.y2 - box.y1 + 1.0f);
+        shared_i_boxes[threadIdx.x] = box;
+      }
+    }
+    __syncthreads();
+    const int i = i_block_offset + threadIdx.x;
+    for (int j_thread_offset =
+             BOXES_PER_THREAD * (blockIdx.y * blockDim.y + threadIdx.y);
+         j_thread_offset < nboxes;
+         j_thread_offset += BOXES_PER_THREAD * blockDim.y * gridDim.y) {
+      // Note : We can do everything using multiplication,
+      // and use fp16 - we are comparing against a low precision
+      // threshold
+      int above_thresh = 0;
+      bool valid = false;
+      for (int ib = 0; ib < BOXES_PER_THREAD; ++ib) {
+        // This thread will compare Box i and Box j
+        const int j = j_thread_offset + ib;
+        if (i < j && i < nboxes && j < nboxes) {
+          valid = true;
+          const Box j_box = d_desc_sorted_boxes[j];
+          const Box i_box = shared_i_boxes[threadIdx.x];
+          const float j_area =
+              (j_box.x2 - j_box.x1 + 1.0f) * (j_box.y2 - j_box.y1 + 1.0f);
+          const float i_area = shared_i_areas[threadIdx.x];
+          // The following code will not be valid with empty boxes
+          if (i_area == 0.0f || j_area == 0.0f)
+            continue;
+          const float xx1 = fmaxf(i_box.x1, j_box.x1);
+          const float yy1 = fmaxf(i_box.y1, j_box.y1);
+          const float xx2 = fminf(i_box.x2, j_box.x2);
+          const float yy2 = fminf(i_box.y2, j_box.y2);
+
+          // fdimf computes the positive difference between xx2+1 and xx1
+          const float w = fdimf(xx2 + 1.0f, xx1);
+          const float h = fdimf(yy2 + 1.0f, yy1);
+          const float intersection = w * h;
+
+          // Testing for a/b > t
+          // eq with a > b*t (b is !=0)
+          // avoiding divisions
+          const float a = intersection;
+          const float b = i_area + j_area - intersection;
+          const float bt = b * thresh;
+          // eq. to if ovr > thresh
+          if (a > bt) {
+            // we have score[j] <= score[i]
+            above_thresh |= (1U << ib);
+          }
+        }
+      }
+      if (valid)
+        d_delete_mask[i * mask_ld + j_thread_offset / BOXES_PER_THREAD] =
+            above_thresh;
+    }
+    __syncthreads(); // making sure everyone is done reading smem
+  }
+}
+} // namespace
+
+void nms_gpu_upright(
+    const float* d_desc_sorted_boxes_float_ptr,
+    const int N,
+    const float thresh,
+    int* d_keep_sorted_list,
+    int* h_nkeep,
+    TensorCUDA& dev_delete_mask,
+    TensorCPU& host_delete_mask,
+    CUDAContext* context) {
+  // Making sure we respect the __align(16)__ we promised to the compiler
+  auto iptr = reinterpret_cast<std::uintptr_t>(d_desc_sorted_boxes_float_ptr);
+  CAFFE_ENFORCE_EQ(iptr % 16, 0);
+
+  // The next kernel expects squares
+  CAFFE_ENFORCE_EQ(
+      CAFFE_CUDA_NUM_THREADS_2D_DIMX, CAFFE_CUDA_NUM_THREADS_2D_DIMY);
+
+  const int mask_ld = (N + BOXES_PER_THREAD - 1) / BOXES_PER_THREAD;
+  const Box* d_desc_sorted_boxes =
+      reinterpret_cast<const Box*>(d_desc_sorted_boxes_float_ptr);
+  dev_delete_mask.Resize(N * mask_ld);
+  int* d_delete_mask = dev_delete_mask.template mutable_data<int>();
+  NMSKernel<<<
+      CAFFE_GET_BLOCKS_2D(N, mask_ld),
+      CAFFE_CUDA_NUM_THREADS_2D,
+      0,
+      context->cuda_stream()>>>(
+      d_desc_sorted_boxes, N, thresh, mask_ld, d_delete_mask);
+
+  host_delete_mask.Resize(N * mask_ld);
+  int* h_delete_mask = host_delete_mask.template mutable_data<int>();
+
+  // Overlapping CPU computes and D2H memcpy
+  // both take about the same time
+  cudaEvent_t copy_done;
+  cudaEventCreate(&copy_done);
+  int nto_copy = std::min(CHUNK_SIZE, N);
+  CUDA_CHECK(cudaMemcpyAsync(
+      &h_delete_mask[0],
+      &d_delete_mask[0],
+      nto_copy * mask_ld * sizeof(int),
+      cudaMemcpyDeviceToHost,
+      context->cuda_stream()));
+  CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
+  int offset = 0;
+  std::vector<int> h_keep_sorted_list;
+  std::vector<int> rmv(mask_ld, 0);
+  while (offset < N) {
+    const int ncopied = nto_copy;
+    int next_offset = offset + ncopied;
+    nto_copy = std::min(CHUNK_SIZE, N - next_offset);
+    if (nto_copy > 0) {
+      CUDA_CHECK(cudaMemcpyAsync(
+          &h_delete_mask[next_offset * mask_ld],
+          &d_delete_mask[next_offset * mask_ld],
+          nto_copy * mask_ld * sizeof(int),
+          cudaMemcpyDeviceToHost,
+          context->cuda_stream()));
+    }
+    // Waiting for previous copy
+    CUDA_CHECK(cudaEventSynchronize(copy_done));
+    if (nto_copy > 0)
+      cudaEventRecord(copy_done, context->cuda_stream());
+    for (int i = offset; i < next_offset; ++i) {
+      int iblock = i / BOXES_PER_THREAD;
+      int inblock = i % BOXES_PER_THREAD;
+      if (!(rmv[iblock] & (1 << inblock))) {
+        h_keep_sorted_list.push_back(i);
+        int* p = &h_delete_mask[i * mask_ld];
+        for (int ib = 0; ib < mask_ld; ++ib) {
+          rmv[ib] |= p[ib];
+        }
+      }
+    }
+    offset = next_offset;
+  }
+  cudaEventDestroy(copy_done);
+
+  const int nkeep = h_keep_sorted_list.size();
+  cudaMemcpyAsync(
+      d_keep_sorted_list,
+      &h_keep_sorted_list[0],
+      nkeep * sizeof(int),
+      cudaMemcpyHostToDevice,
+      context->cuda_stream());
+
+  *h_nkeep = nkeep;
+}
+} // namespace utils
+} // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.h b/caffe2/operators/generate_proposals_op_util_nms_gpu.h
new file mode 100644 (file)
index 0000000..8b639f4
--- /dev/null
@@ -0,0 +1,39 @@
+#ifndef CAFFE2_OPERATORS_UTILS_NMS_GPU_H_
+#define CAFFE2_OPERATORS_UTILS_NMS_GPU_H_
+
+#include <vector>
+
+#include "caffe2/core/context_gpu.h"
+
+namespace caffe2 {
+namespace utils {
+
+// Computes Non-Maximum Suppression on the GPU
+// Reject a bounding box if its region has an intersection-overunion (IoU)
+//    overlap with a higher scoring selected bounding box larger than a
+//    threshold.
+//
+// d_desc_sorted_boxes : pixel coordinates of proposed bounding boxes
+//    size: (N,4), format: [x1; y1; x2; y2]
+//    the boxes are sorted by scores in descending order
+// N : number of boxes
+// d_keep_sorted_list : row indices of the selected proposals, sorted by score
+// h_nkeep  : number of selected proposals
+// dev_delete_mask, host_delete_mask : Tensors that will be used as temp storage
+// by NMS
+//    Those tensors will be resized to the necessary size
+// context : current CUDA context
+CAFFE2_API void nms_gpu_upright(
+    const float* d_desc_sorted_boxes,
+    const int N,
+    const float thresh,
+    int* d_keep_sorted_list,
+    int* h_nkeep,
+    TensorCUDA& dev_delete_mask,
+    TensorCPU& host_delete_mask,
+    CUDAContext* context);
+
+} // namespace utils
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_UTILS_NMS_GPU_H_
diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc
new file mode 100644 (file)
index 0000000..20f8f05
--- /dev/null
@@ -0,0 +1,338 @@
+#include "caffe2/core/context.h"
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/core/flags.h"
+#include "caffe2/operators/generate_proposals_op_util_nms.h"
+#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h"
+#include "caffe2/operators/utility_ops.h"
+#include "caffe2/utils/eigen_utils.h"
+
+#include <gtest/gtest.h>
+
+#include <chrono>
+#include <random>
+
+namespace caffe2 {
+
+TEST(UtilsNMSTest, TestNMSGPU) {
+  if (!HasCudaGPU())
+    return;
+  const int box_dim = 4;
+  std::vector<float> boxes = {10, 10, 50,  60,  11,  12,  48, 60,  8,   9,
+                              40, 50, 100, 100, 150, 140, 99, 110, 155, 139};
+
+  std::vector<float> scores = {0.5f, 0.7f, 0.6f, 0.9f, 0.8f};
+
+  std::vector<int> indices(scores.size());
+  std::iota(indices.begin(), indices.end(), 0);
+  std::sort(indices.begin(), indices.end(), [&scores](int lhs, int rhs) {
+    return scores[lhs] > scores[rhs];
+  });
+  std::vector<float> sorted_boxes(boxes.size());
+  for (int i = 0; i < scores.size(); ++i) {
+    for (int d = 0; d < box_dim; ++d)
+      sorted_boxes[i * box_dim + d] = boxes[indices[i] * box_dim + d];
+  }
+
+  Workspace ws;
+  DeviceOption option;
+  option.set_device_type(PROTO_CUDA);
+  CUDAContext cuda_context(option);
+
+  Tensor dev_sorted_boxes{CUDA};
+  Tensor dev_scores{CUDA};
+  Tensor dev_boxes_valid_flags{CUDA};
+  Tensor dev_list{CUDA};
+  Tensor dev_delete_mask{CUDA};
+  Tensor host_delete_mask{CPU};
+  Tensor dev_list_nitems{CUDA};
+  Tensor host_list{CPU};
+
+  int nboxes = boxes.size() / box_dim;
+  dev_sorted_boxes.Resize(box_dim * nboxes);
+  dev_list.Resize(nboxes);
+  host_list.Resize(nboxes);
+
+  float* d_sorted_boxes = dev_sorted_boxes.template mutable_data<float>();
+  int* d_list = dev_list.template mutable_data<int>();
+  int* h_list = host_list.template mutable_data<int>();
+
+  CUDA_CHECK(cudaMemcpyAsync(
+      d_sorted_boxes,
+      &sorted_boxes[0],
+      sizeof(*d_sorted_boxes) * box_dim * nboxes,
+      cudaMemcpyHostToDevice,
+      cuda_context.cuda_stream()));
+
+  std::vector<float> input_thresh{0.1f, 0.3f, 0.5f, 0.8f, 0.9f};
+  std::vector<std::set<int>> output_gt{
+      {0, 2}, {0, 2}, {0, 2}, {0, 1, 2, 3}, {0, 1, 2, 3, 4}};
+
+  std::vector<int> keep(nboxes);
+  std::set<int> keep_as_set;
+  for (int itest = 0; itest < input_thresh.size(); ++itest) {
+    const float thresh = input_thresh[itest];
+    int list_nitems;
+    utils::nms_gpu_upright(
+        d_sorted_boxes,
+        nboxes,
+        thresh,
+        d_list,
+        &list_nitems,
+        dev_delete_mask,
+        host_delete_mask,
+        &cuda_context);
+
+    cuda_context.FinishDeviceComputation();
+    host_list.CopyFrom(dev_list);
+
+    keep_as_set.clear();
+    for (int i = 0; i < list_nitems; ++i) {
+      keep_as_set.insert(h_list[i]);
+    }
+
+    // Sets are sorted
+    // sets are equal <=> sets contains the same elements
+    EXPECT_TRUE(output_gt[itest] == keep_as_set);
+  }
+
+  cuda_context.FinishDeviceComputation();
+}
+
+void generateRandomBoxes(float* h_boxes, float* h_scores, const int nboxes) {
+  const float x_y_max = 100;
+  const float w_h_max = 10;
+  const float score_max = 1;
+
+  auto seed = std::chrono::system_clock::now().time_since_epoch().count();
+  std::default_random_engine generator(seed);
+
+  std::uniform_real_distribution<float> coordinate_distribution(
+      0.0, x_y_max - w_h_max);
+  std::uniform_real_distribution<float> length_distribution(0.0, w_h_max);
+  std::uniform_real_distribution<float> score_distribution(0.0, score_max);
+
+  for (int ibox = 0; ibox < nboxes; ++ibox) {
+    float x1, y1, x2, y2;
+    x1 = coordinate_distribution(generator);
+    y1 = coordinate_distribution(generator);
+    x2 = x1 + length_distribution(generator);
+    y2 = y1 + length_distribution(generator);
+    h_boxes[ibox * 4 + 0] = x1;
+    h_boxes[ibox * 4 + 1] = y1;
+    h_boxes[ibox * 4 + 2] = x2;
+    h_boxes[ibox * 4 + 3] = y2;
+    h_scores[ibox] = score_distribution(generator);
+  }
+}
+
+TEST(UtilsNMSTest, TestPerfNMS) {
+  if (!HasCudaGPU())
+    return;
+  const int box_dim = 4;
+  const int nboxes = 6000;
+
+  Workspace ws;
+  DeviceOption option;
+  option.set_device_type(PROTO_CUDA);
+  CUDAContext cuda_context(option);
+
+  Tensor host_boxes{CPU};
+  Tensor host_scores{CPU};
+  host_boxes.Resize(box_dim * nboxes);
+  host_scores.Resize(nboxes);
+
+  float* h_boxes = host_boxes.template mutable_data<float>();
+  float* h_scores = host_scores.template mutable_data<float>();
+
+  // Generating random input
+  generateRandomBoxes(h_boxes, h_scores, nboxes);
+
+  Eigen::ArrayXXf proposals(nboxes, box_dim);
+  Eigen::ArrayXXf scores(nboxes, 1);
+  for (int i = 0; i < nboxes; ++i) {
+    for (int d = 0; d < box_dim; ++d)
+      proposals(i, d) = h_boxes[box_dim * i + d];
+    scores(i, 0) = h_scores[i];
+  }
+
+  const int ntests = 50;
+  const float thresh = 0.7;
+  // Not timing the sort for the CPU
+  // in the real-world use case scores already have been sorted earlier in the
+  // generate proposals workflow
+  std::vector<int> indices(proposals.rows());
+  std::iota(indices.begin(), indices.end(), 0);
+  std::sort(
+      indices.data(),
+      indices.data() + indices.size(),
+      [&scores](int lhs, int rhs) { return scores(lhs) > scores(rhs); });
+
+  // Running ntests runs of CPU NMS
+  auto cpu_start = std::chrono::steady_clock::now();
+  for (int itest = 0; itest < ntests; ++itest) {
+    utils::nms_cpu(proposals, scores, indices, thresh);
+  }
+  auto cpu_stop = std::chrono::steady_clock::now();
+
+  std::vector<float> sorted_boxes(nboxes * box_dim);
+  for (int i = 0; i < scores.size(); ++i) {
+    for (int d = 0; d < box_dim; ++d)
+      sorted_boxes[i * box_dim + d] = h_boxes[indices[i] * box_dim + d];
+  }
+
+  Tensor dev_boxes{CUDA};
+  Tensor dev_delete_mask{CUDA};
+  Tensor host_delete_mask{CPU};
+  Tensor dev_list{CUDA};
+
+  dev_boxes.Resize(box_dim * nboxes);
+  float* d_sorted_boxes = dev_boxes.template mutable_data<float>();
+  dev_list.Resize(nboxes);
+  int* d_list = dev_list.template mutable_data<int>();
+  int list_nitems;
+
+  // No timing the memcpies because data is already on the GPU in the real-world
+  // use case (generated by the GPU generate_proposals)
+  CUDA_CHECK(cudaMemcpyAsync(
+      d_sorted_boxes,
+      &sorted_boxes[0],
+      sizeof(*d_sorted_boxes) * box_dim * nboxes,
+      cudaMemcpyHostToDevice,
+      cuda_context.cuda_stream()));
+
+  // Running ntests runs of GPU NMS
+  auto gpu_start = std::chrono::steady_clock::now();
+  for (int itest = 0; itest < ntests; ++itest) {
+    utils::nms_gpu_upright(
+        d_sorted_boxes,
+        nboxes,
+        thresh,
+        d_list,
+        &list_nitems,
+        dev_delete_mask,
+        host_delete_mask,
+        &cuda_context);
+  }
+  // Waiting for everything to be done
+  CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream()));
+  auto gpu_stop = std::chrono::steady_clock::now();
+
+  double total_cpu_time =
+      std::chrono::duration<double, std::milli>(cpu_stop - cpu_start).count();
+  double total_gpu_time =
+      std::chrono::duration<double, std::milli>(gpu_stop - gpu_start).count();
+  double ratio = total_cpu_time / total_gpu_time;
+
+  double avg_cpu_time = total_cpu_time / ntests;
+  double avg_gpu_time = total_gpu_time / ntests;
+
+  printf(
+      "NMS, nproposals=%i, ntests=%i, Avg GPU time = %fms, Avg CPU time = %fms, GPU speed up = %fX \n",
+      nboxes,
+      ntests,
+      avg_gpu_time,
+      avg_cpu_time,
+      ratio);
+}
+
+TEST(UtilsNMSTest, GPUEqualsCPUCorrectnessTest) {
+  if (!HasCudaGPU())
+    return;
+  Workspace ws;
+  DeviceOption option;
+  option.set_device_type(PROTO_CUDA);
+  CUDAContext cuda_context(option);
+
+  const int box_dim = 4;
+  const std::vector<int> nboxes_vec = {10, 100, 1000, 2000, 6000, 12000};
+  for (int nboxes : nboxes_vec) {
+    Tensor host_boxes{CPU};
+    Tensor host_scores{CPU};
+    host_boxes.Resize(box_dim * nboxes);
+    host_scores.Resize(nboxes);
+
+    float* h_boxes = host_boxes.template mutable_data<float>();
+    float* h_scores = host_scores.template mutable_data<float>();
+
+    // Generating random input
+    generateRandomBoxes(h_boxes, h_scores, nboxes);
+
+    const int ntests = 1;
+    const float thresh = 0.7;
+    // Not timing the sort for the CPU
+    // in the real-world use case scores already have been sorted earlier in the
+    // generate proposals workflow
+    std::vector<int> indices(nboxes);
+    std::iota(indices.begin(), indices.end(), 0);
+    std::sort(indices.begin(), indices.end(), [h_scores](int lhs, int rhs) {
+      return h_scores[lhs] > h_scores[rhs];
+    });
+
+    std::vector<float> sorted_boxes(nboxes * box_dim);
+    std::vector<float> sorted_scores(nboxes);
+    Eigen::ArrayXXf eig_proposals(nboxes, box_dim);
+    Eigen::ArrayXXf eig_scores(nboxes, 1);
+    for (int i = 0; i < nboxes; ++i) {
+      for (int d = 0; d < box_dim; ++d) {
+        sorted_boxes[i * box_dim + d] = h_boxes[indices[i] * box_dim + d];
+        eig_proposals(i, d) = h_boxes[indices[i] * box_dim + d];
+      }
+      sorted_scores[i] = h_scores[indices[i]];
+      eig_scores(i) = h_scores[indices[i]];
+    }
+    std::vector<int> sorted_indices(nboxes);
+    std::iota(sorted_indices.begin(), sorted_indices.end(), 0);
+
+    Tensor dev_boxes{CUDA};
+    Tensor dev_delete_mask{CUDA};
+    Tensor host_delete_mask{CPU};
+    Tensor dev_list{CUDA};
+
+    dev_boxes.Resize(box_dim * nboxes);
+    float* d_sorted_boxes = dev_boxes.template mutable_data<float>();
+    dev_list.Resize(nboxes);
+    int* d_list = dev_list.template mutable_data<int>();
+
+    // No timing the memcpies because data is already on the GPU in the
+    // real-world use case (generated by the GPU generate_proposals)
+    CUDA_CHECK(cudaMemcpyAsync(
+        d_sorted_boxes,
+        &sorted_boxes[0],
+        sizeof(*d_sorted_boxes) * box_dim * nboxes,
+        cudaMemcpyHostToDevice,
+        cuda_context.cuda_stream()));
+
+    // Running ntests runs of CPU NMS
+    for (int itest = 0; itest < ntests; ++itest) {
+      std::vector<int> keep =
+          utils::nms_cpu(eig_proposals, eig_scores, sorted_indices, thresh);
+      int list_nitems;
+      utils::nms_gpu_upright(
+          d_sorted_boxes,
+          nboxes,
+          thresh,
+          d_list,
+          &list_nitems,
+          dev_delete_mask,
+          host_delete_mask,
+          &cuda_context);
+      std::vector<int> gpu_keep(list_nitems);
+      CUDA_CHECK(cudaMemcpyAsync(
+          &gpu_keep[0],
+          d_list,
+          list_nitems * sizeof(int),
+          cudaMemcpyDeviceToHost,
+          cuda_context.cuda_stream()));
+      CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream()));
+
+      ASSERT_EQ(keep.size(), gpu_keep.size());
+      std::sort(keep.begin(), keep.end());
+      std::sort(gpu_keep.begin(), gpu_keep.end());
+
+      for (int i = 0; i < list_nitems; ++i)
+        EXPECT_EQ(keep[i], gpu_keep[i]);
+    }
+  }
+}
+
+} // namespace caffe2
index 72db033..cb45333 100644 (file)
@@ -2244,6 +2244,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([
     ("/operator_fallback_gpu" , ("/hip/operator_fallback_gpu", API_CAFFE2)),
     ("/spatial_batch_norm_op_gpu_impl" , ("/hip/spatial_batch_norm_op_gpu_impl", API_CAFFE2)),
     ("/recurrent_network_executor_gpu" , ("/hip/recurrent_network_executor_gpu", API_CAFFE2)),
+    ("/generate_proposals_op_util_nms_gpu" , ("/hip/generate_proposals_op_util_nms_gpu", API_CAFFE2)),
     ("/max_pool_with_index_gpu", ("/hip/max_pool_with_index_gpu", API_CAFFE2)),
     ("/THCCachingAllocator_gpu", ("/hip/THCCachingAllocator_gpu", API_CAFFE2)),
     ("/top_k_heap_selection", ("/hip/top_k_heap_selection", API_CAFFE2)),