From 3fdf567752b484c5542efea02afc64e4f875eeca Mon Sep 17 00:00:00 2001 From: "hbraun@nvidia.com" Date: Thu, 20 Dec 2018 14:24:27 -0800 Subject: [PATCH] Adding CUDA version for C2 operators generate proposals and nms (#13694) Summary: Related to issue #13684 Pull Request resolved: https://github.com/pytorch/pytorch/pull/13694 Reviewed By: wat3rBro Differential Revision: D13017791 Pulled By: newstzpz fbshipit-source-id: 4bdc58e474d8e1f6cd73a02bf51f91542a2b9d0b --- caffe2/core/common_gpu.h | 41 ++ caffe2/operators/generate_proposals_op.cc | 3 - caffe2/operators/generate_proposals_op.cu | 453 +++++++++++++++++++++ caffe2/operators/generate_proposals_op.h | 29 +- caffe2/operators/generate_proposals_op_gpu_test.cc | 239 +++++++++++ .../generate_proposals_op_util_nms_gpu.cu | 196 +++++++++ .../operators/generate_proposals_op_util_nms_gpu.h | 39 ++ .../generate_proposals_op_util_nms_gpu_test.cc | 338 +++++++++++++++ tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py | 1 + 9 files changed, 1335 insertions(+), 4 deletions(-) create mode 100644 caffe2/operators/generate_proposals_op.cu create mode 100644 caffe2/operators/generate_proposals_op_gpu_test.cc create mode 100644 caffe2/operators/generate_proposals_op_util_nms_gpu.cu create mode 100644 caffe2/operators/generate_proposals_op_util_nms_gpu.h create mode 100644 caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h index db87887..1c2b487 100644 --- a/caffe2/core/common_gpu.h +++ b/caffe2/core/common_gpu.h @@ -286,6 +286,12 @@ CAFFE2_CUDA_API const char* curandGetErrorString(curandStatus_t error); for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) +#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) \ + for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \ + j += blockDim.y * gridDim.y) + // CUDA_KERNEL_ASSERT is a macro that wraps an assert() call inside cuda // kernels. This is not supported by Apple platforms so we special case it. // See http://docs.nvidia.com/cuda/cuda-c-programming-guide/#assertion @@ -309,13 +315,22 @@ CAFFE2_CUDA_API const char* curandGetErrorString(curandStatus_t error); // The number of cuda threads to use. Since work is assigned to SMs at the // granularity of a block, 128 is chosen to allow utilizing more SMs for // smaller input sizes. +// 1D grid constexpr int CAFFE_CUDA_NUM_THREADS = 128; +// 2D grid +constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMX = 16; +constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMY = 16; + // The maximum number of blocks to use in the default kernel call. We set it to // 4096 which would work for compute capability 2.x (where 65536 is the limit). // This number is very carelessly chosen. Ideally, one would like to look at // the hardware at runtime, and pick the number of blocks that makes most // sense for the specific runtime environment. This is a todo item. +// 1D grid constexpr int CAFFE_MAXIMUM_NUM_BLOCKS = 4096; +// 2D grid +constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX = 128; +constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY = 128; constexpr int kCUDAGridDimMaxX = 2147483647; constexpr int kCUDAGridDimMaxY = 65535; @@ -333,6 +348,32 @@ inline int CAFFE_GET_BLOCKS(const int N) { 1); } +/** + * @brief Compute the number of blocks needed to run N threads for a 2D grid + */ +inline dim3 CAFFE_GET_BLOCKS_2D(const int N, const int /* M */) { + dim3 grid; + // Not calling the 1D version for each dim to keep all constants as literals + + grid.x = std::max( + std::min( + (N + CAFFE_CUDA_NUM_THREADS_2D_DIMX - 1) / + CAFFE_CUDA_NUM_THREADS_2D_DIMX, + CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX), + // Use at least 1 block, since CUDA does not allow empty block + 1); + + grid.y = std::max( + std::min( + (N + CAFFE_CUDA_NUM_THREADS_2D_DIMY - 1) / + CAFFE_CUDA_NUM_THREADS_2D_DIMY, + CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY), + // Use at least 1 block, since CUDA does not allow empty block + 1); + + return grid; +} + class DeviceGuard { public: explicit DeviceGuard(int newDevice) : previous_(CaffeCudaGetDevice()) { diff --git a/caffe2/operators/generate_proposals_op.cc b/caffe2/operators/generate_proposals_op.cc index 3f50661..e391f1e 100644 --- a/caffe2/operators/generate_proposals_op.cc +++ b/caffe2/operators/generate_proposals_op.cc @@ -345,8 +345,6 @@ bool GenerateProposalsOp::RunOnDevice() { return true; } -namespace { - REGISTER_CPU_OPERATOR(GenerateProposals, GenerateProposalsOp); // For backward compatibility REGISTER_CPU_OPERATOR(GenerateProposalsCPP, GenerateProposalsOp); @@ -413,5 +411,4 @@ SHOULD_NOT_DO_GRADIENT(GenerateProposals); // For backward compatibility SHOULD_NOT_DO_GRADIENT(GenerateProposalsCPP); -} // namespace } // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op.cu b/caffe2/operators/generate_proposals_op.cu new file mode 100644 index 0000000..fe89a66 --- /dev/null +++ b/caffe2/operators/generate_proposals_op.cu @@ -0,0 +1,453 @@ +#include +#include "caffe2/core/context.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/generate_proposals_op.h" +#include "caffe2/operators/generate_proposals_op_util_boxes.h" // BBOX_XFORM_CLIP_DEFAULT +#include "caffe2/operators/generate_proposals_op_util_nms.h" +#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h" + +namespace caffe2 { +namespace { +__global__ void GeneratePreNMSUprightBoxesKernel( + const int* d_sorted_scores_keys, + const int nboxes_to_generate, + const float* d_bbox_deltas, + const float4* d_anchors, + const int H, + const int W, + const int K, // K = H*W + const int A, + const int KA, // KA = K*A + const float feat_stride, + const float min_size, + const float* d_img_info_vec, + const int num_images, + const float bbox_xform_clip, + const bool correct_transform, + float4* d_out_boxes, + const int prenms_nboxes, // leading dimension of out_boxes + float* d_inout_scores, + char* d_boxes_keep_flags) { + CUDA_2D_KERNEL_LOOP(ibox, nboxes_to_generate, image_index, num_images) { + // box_conv_index : # of the same box, but indexed in + // the scores from the conv layer, of shape (A,H,W) + // the num_images dimension was already removed + // box_conv_index = a*K + h*W + w + const int box_conv_index = d_sorted_scores_keys[image_index * KA + ibox]; + + // We want to decompose box_conv_index in (a,h,w) + // such as box_conv_index = a*K + h*W + w + // (avoiding modulos in the process) + int remaining = box_conv_index; + const int dA = K; // stride of A + const int a = remaining / dA; + remaining -= a * dA; + const int dH = W; // stride of H + const int h = remaining / dH; + remaining -= h * dH; + const int w = remaining; // dW = 1 + + // Loading the anchor a + // float is a struct with float x,y,z,w + const float4 anchor = d_anchors[a]; + // x1,y1,x2,y2 :coordinates of anchor a, shifted for position (h,w) + const float shift_w = feat_stride * w; + float x1 = shift_w + anchor.x; + float x2 = shift_w + anchor.z; + const float shift_h = feat_stride * h; + float y1 = shift_h + anchor.y; + float y2 = shift_h + anchor.w; + + // TODO use fast math when possible + + // Deltas for that box + // Deltas of shape (num_images,4*A,K) + // We're going to compute 4 scattered reads + // better than the alternative, ie transposing the complete deltas + // array first + int deltas_idx = image_index * (KA * 4) + a * 4 * K + h * W + w; + const float dx = d_bbox_deltas[deltas_idx]; + // Stride of K between each dimension + deltas_idx += K; + const float dy = d_bbox_deltas[deltas_idx]; + deltas_idx += K; + float dw = d_bbox_deltas[deltas_idx]; + deltas_idx += K; + float dh = d_bbox_deltas[deltas_idx]; + + // Upper bound on dw,dh + dw = fmin(dw, bbox_xform_clip); + dh = fmin(dh, bbox_xform_clip); + + // Applying the deltas + float width = x2 - x1 + 1.0f; + const float ctr_x = x1 + 0.5f * width; + const float pred_ctr_x = ctr_x + width * dx; // TODO fuse madd + const float pred_w = width * expf(dw); + x1 = pred_ctr_x - 0.5f * pred_w; + x2 = pred_ctr_x + 0.5f * pred_w; + + float height = y2 - y1 + 1.0f; + const float ctr_y = y1 + 0.5f * height; + const float pred_ctr_y = ctr_y + height * dy; + const float pred_h = height * expf(dh); + y1 = pred_ctr_y - 0.5f * pred_h; + y2 = pred_ctr_y + 0.5f * pred_h; + + if (correct_transform) { + x2 -= 1.0f; + y2 -= 1.0f; + } + + // Clipping box to image + const float img_height = d_img_info_vec[3 * image_index + 0]; + const float img_width = d_img_info_vec[3 * image_index + 1]; + const float min_size_scaled = + min_size * d_img_info_vec[3 * image_index + 2]; + x1 = fmax(fmin(x1, img_width - 1.0f), 0.0f); + y1 = fmax(fmin(y1, img_height - 1.0f), 0.0f); + x2 = fmax(fmin(x2, img_width - 1.0f), 0.0f); + y2 = fmax(fmin(y2, img_height - 1.0f), 0.0f); + + // Filter boxes + // Removing boxes with one dim < min_size + // (center of box is in image, because of previous step) + width = x2 - x1 + 1.0f; // may have changed + height = y2 - y1 + 1.0f; + bool keep_box = fmin(width, height) >= min_size_scaled; + + // We are not deleting the box right now even if !keep_box + // we want to keep the relative order of the elements stable + // we'll do it in such a way later + // d_boxes_keep_flags size: (num_images,prenms_nboxes) + // d_out_boxes size: (num_images,prenms_nboxes) + const int out_index = image_index * prenms_nboxes + ibox; + d_boxes_keep_flags[out_index] = keep_box; + d_out_boxes[out_index] = {x1, y1, x2, y2}; + + // d_inout_scores size: (num_images,KA) + if (!keep_box) + d_inout_scores[image_index * KA + ibox] = FLT_MIN; // for NMS + } +} + +__global__ void WriteOutput( + const float4* d_image_boxes, + const float* d_image_scores, + const int* d_image_boxes_keep_list, + const int nboxes, + const int image_index, + float* d_image_out_rois, + float* d_image_out_rois_probs) { + CUDA_1D_KERNEL_LOOP(i, nboxes) { + const int ibox = d_image_boxes_keep_list[i]; + const float4 box = d_image_boxes[ibox]; + const float score = d_image_scores[ibox]; + // Scattered memory accesses + // postnms_nboxes is small anyway + d_image_out_rois_probs[i] = score; + const int base_idx = 5 * i; + d_image_out_rois[base_idx + 0] = image_index; + d_image_out_rois[base_idx + 1] = box.x; + d_image_out_rois[base_idx + 2] = box.y; + d_image_out_rois[base_idx + 3] = box.z; + d_image_out_rois[base_idx + 4] = box.w; + } +} + +__global__ void InitializeDataKernel( + const int num_images, + const int KA, + int* d_image_offsets, + int* d_boxes_keys_iota) { + CUDA_2D_KERNEL_LOOP(box_idx, KA, img_idx, num_images) { + d_boxes_keys_iota[img_idx * KA + box_idx] = box_idx; + + // One 1D line sets the 1D data + if (box_idx == 0) { + d_image_offsets[img_idx] = KA * img_idx; + // One thread sets the last+1 offset + if (img_idx == 0) + d_image_offsets[num_images] = KA * num_images; + } + } +} + +} // namespace + +template <> +bool GenerateProposalsOp::RunOnDevice() { + const auto& scores = Input(0); + const auto& bbox_deltas = Input(1); + const auto& im_info_tensor = Input(2); + const auto& anchors = Input(3); + auto* out_rois = Output(0); + auto* out_rois_probs = Output(1); + + CAFFE_ENFORCE_EQ(scores.ndim(), 4, scores.ndim()); + CAFFE_ENFORCE(scores.template IsType(), scores.meta().name()); + + const auto num_images = scores.dim(0); + const auto A = scores.dim(1); + const auto H = scores.dim(2); + const auto W = scores.dim(3); + const auto box_dim_conv = anchors.dim(1); + + CAFFE_ENFORCE(box_dim_conv == 4); // only upright boxes in GPU version for now + + constexpr int box_dim = 4; + const int K = H * W; + const int conv_layer_nboxes = K * A; + // Getting data members ready + + // We'll sort the scores + // we want to remember their original indexes, + // ie their indexes in the tensor of shape (num_images,A,K) + // from the conv layer + // each row of d_conv_layer_indexes is at first initialized to 1..A*K + dev_conv_layer_indexes_.Resize(num_images, conv_layer_nboxes); + int* d_conv_layer_indexes = + dev_conv_layer_indexes_.template mutable_data(); + + // d_image_offset[i] = i*K*A for i from 1 to num_images+1 + // Used by the segmented sort to only sort scores within one image + dev_image_offset_.Resize(num_images + 1); + int* d_image_offset = dev_image_offset_.template mutable_data(); + + // The following calls to CUB primitives do nothing + // (because the first arg is nullptr) + // except setting cub_*_temp_storage_bytes + size_t cub_sort_temp_storage_bytes = 0; + float* flt_ptr = nullptr; + int* int_ptr = nullptr; + cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, + cub_sort_temp_storage_bytes, + flt_ptr, + flt_ptr, + int_ptr, + int_ptr, + num_images * conv_layer_nboxes, + num_images, + int_ptr, + int_ptr, + 0, + 8 * sizeof(float), // sort all bits + context_.cuda_stream()); + + // Allocate temporary storage for CUB + dev_cub_sort_buffer_.Resize(cub_sort_temp_storage_bytes); + void* d_cub_sort_temp_storage = + dev_cub_sort_buffer_.template mutable_data(); + + size_t cub_select_temp_storage_bytes = 0; + char* char_ptr = nullptr; + cub::DeviceSelect::Flagged( + nullptr, + cub_select_temp_storage_bytes, + flt_ptr, + char_ptr, + flt_ptr, + int_ptr, + K * A, + context_.cuda_stream()); + + // Allocate temporary storage for CUB + dev_cub_select_buffer_.Resize(cub_select_temp_storage_bytes); + void* d_cub_select_temp_storage = + dev_cub_select_buffer_.template mutable_data(); + + // Initialize : + // - each row of dev_conv_layer_indexes to 1..K*A + // - each d_nboxes to 0 + // - d_image_offset[i] = K*A*i for i 1..num_images+1 + // 2D grid + InitializeDataKernel<<< + (CAFFE_GET_BLOCKS(A * K), num_images), + CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1 + 0, + context_.cuda_stream()>>>( + num_images, conv_layer_nboxes, d_image_offset, d_conv_layer_indexes); + + // Sorting input scores + dev_sorted_conv_layer_indexes_.Resize(num_images, conv_layer_nboxes); + dev_sorted_scores_.Resize(num_images, conv_layer_nboxes); + const float* d_in_scores = scores.data(); + int* d_sorted_conv_layer_indexes = + dev_sorted_conv_layer_indexes_.template mutable_data(); + float* d_sorted_scores = dev_sorted_scores_.template mutable_data(); + ; + cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_cub_sort_temp_storage, + cub_sort_temp_storage_bytes, + d_in_scores, + d_sorted_scores, + d_conv_layer_indexes, + d_sorted_conv_layer_indexes, + num_images * conv_layer_nboxes, + num_images, + d_image_offset, + d_image_offset + 1, + 0, + 8 * sizeof(float), // sort all bits + context_.cuda_stream()); + + // Keeping only the topN pre_nms + const int nboxes_to_generate = std::min(conv_layer_nboxes, rpn_pre_nms_topN_); + + // Generating the boxes associated to the topN pre_nms scores + dev_boxes_.Resize(num_images, box_dim * nboxes_to_generate); + dev_boxes_keep_flags_.Resize(num_images, nboxes_to_generate); + const float* d_bbox_deltas = bbox_deltas.data(); + const float* d_anchors = anchors.data(); + const float* d_im_info_vec = im_info_tensor.data(); + float4* d_boxes = + reinterpret_cast(dev_boxes_.template mutable_data()); + ; + char* d_boxes_keep_flags = + dev_boxes_keep_flags_.template mutable_data(); + + GeneratePreNMSUprightBoxesKernel<<< + (CAFFE_GET_BLOCKS(nboxes_to_generate), num_images), + CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1 + 0, + context_.cuda_stream()>>>( + d_sorted_conv_layer_indexes, + nboxes_to_generate, + d_bbox_deltas, + reinterpret_cast(d_anchors), + H, + W, + K, + A, + K * A, + feat_stride_, + rpn_min_size_, + d_im_info_vec, + num_images, + utils::BBOX_XFORM_CLIP_DEFAULT, + correct_transform_coords_, + d_boxes, + nboxes_to_generate, + d_sorted_scores, + d_boxes_keep_flags); + const int nboxes_generated = nboxes_to_generate; + dev_image_prenms_boxes_.Resize(box_dim * nboxes_generated); + float4* d_image_prenms_boxes = reinterpret_cast( + dev_image_prenms_boxes_.template mutable_data()); + dev_image_prenms_scores_.Resize(nboxes_generated); + float* d_image_prenms_scores = + dev_image_prenms_scores_.template mutable_data(); + dev_image_boxes_keep_list_.Resize(nboxes_generated); + int* d_image_boxes_keep_list = + dev_image_boxes_keep_list_.template mutable_data(); + + const int max_postnms_nboxes = std::min(nboxes_generated, rpn_post_nms_topN_); + dev_postnms_rois_.Resize(5 * num_images * max_postnms_nboxes); + dev_postnms_rois_probs_.Resize(num_images * max_postnms_nboxes); + float* d_postnms_rois = dev_postnms_rois_.template mutable_data(); + float* d_postnms_rois_probs = + dev_postnms_rois_probs_.template mutable_data(); + + dev_prenms_nboxes_.Resize(num_images); + host_prenms_nboxes_.Resize(num_images); + int* d_prenms_nboxes = dev_prenms_nboxes_.template mutable_data(); + int* h_prenms_nboxes = host_prenms_nboxes_.template mutable_data(); + + int nrois_in_output = 0; + for (int image_index = 0; image_index < num_images; ++image_index) { + // Sub matrices for current image + const float4* d_image_boxes = &d_boxes[image_index * nboxes_generated]; + const float* d_image_sorted_scores = &d_sorted_scores[image_index * K * A]; + char* d_image_boxes_keep_flags = + &d_boxes_keep_flags[image_index * nboxes_generated]; + + float* d_image_postnms_rois = &d_postnms_rois[5 * nrois_in_output]; + float* d_image_postnms_rois_probs = &d_postnms_rois_probs[nrois_in_output]; + + // Moving valid boxes (ie the ones with d_boxes_keep_flags[ibox] == true) + // to the output tensors + + cub::DeviceSelect::Flagged( + d_cub_select_temp_storage, + cub_select_temp_storage_bytes, + d_image_boxes, + d_image_boxes_keep_flags, + d_image_prenms_boxes, + d_prenms_nboxes, + nboxes_generated, + context_.cuda_stream()); + + cub::DeviceSelect::Flagged( + d_cub_select_temp_storage, + cub_select_temp_storage_bytes, + d_image_sorted_scores, + d_image_boxes_keep_flags, + d_image_prenms_scores, + d_prenms_nboxes, + nboxes_generated, + context_.cuda_stream()); + + host_prenms_nboxes_.CopyFrom(dev_prenms_nboxes_); + + // We know prenms_boxes <= topN_prenms, because nboxes_generated <= + // topN_prenms Calling NMS on the generated boxes + const int prenms_nboxes = *h_prenms_nboxes; + int nkeep; + utils::nms_gpu_upright( + reinterpret_cast(d_image_prenms_boxes), + prenms_nboxes, + rpn_nms_thresh_, + d_image_boxes_keep_list, + &nkeep, + dev_nms_mask_, + host_nms_mask_, + &context_); + + // All operations done after previous sort were keeping the relative order + // of the elements the elements are still sorted keep topN <=> truncate the + // array + const int postnms_nboxes = std::min(nkeep, rpn_post_nms_topN_); + + // Moving the out boxes to the output tensors, + // adding the image_index dimension on the fly + WriteOutput<<< + CAFFE_GET_BLOCKS(postnms_nboxes), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + d_image_prenms_boxes, + d_image_prenms_scores, + d_image_boxes_keep_list, + postnms_nboxes, + image_index, + d_image_postnms_rois, + d_image_postnms_rois_probs); + + nrois_in_output += postnms_nboxes; + } + + // Using a buffer because we cannot call ShrinkTo + out_rois->Resize(nrois_in_output, 5); + out_rois_probs->Resize(nrois_in_output); + float* d_out_rois = out_rois->template mutable_data(); + float* d_out_rois_probs = out_rois_probs->template mutable_data(); + + CUDA_CHECK(cudaMemcpyAsync( + d_out_rois, + d_postnms_rois, + nrois_in_output * 5 * sizeof(float), + cudaMemcpyDeviceToDevice, + context_.cuda_stream())); + CUDA_CHECK(cudaMemcpyAsync( + d_out_rois_probs, + d_postnms_rois_probs, + nrois_in_output * sizeof(float), + cudaMemcpyDeviceToDevice, + context_.cuda_stream())); + + return true; +} + +REGISTER_CUDA_OPERATOR(GenerateProposals, GenerateProposalsOp); +} // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op.h b/caffe2/operators/generate_proposals_op.h index fa933e3..6d667f2 100644 --- a/caffe2/operators/generate_proposals_op.h +++ b/caffe2/operators/generate_proposals_op.h @@ -144,8 +144,35 @@ class GenerateProposalsOp final : public Operator { // tolerance for backward compatibility. Set to negative value for // no clipping. float clip_angle_thresh_{1.0}; + + // Scratch space required by the CUDA version + // CUB buffers + Tensor dev_cub_sort_buffer_{Context::GetDeviceType()}; + Tensor dev_cub_select_buffer_{Context::GetDeviceType()}; + Tensor dev_image_offset_{Context::GetDeviceType()}; + Tensor dev_conv_layer_indexes_{Context::GetDeviceType()}; + Tensor dev_sorted_conv_layer_indexes_{Context::GetDeviceType()}; + Tensor dev_sorted_scores_{Context::GetDeviceType()}; + Tensor dev_boxes_{Context::GetDeviceType()}; + Tensor dev_boxes_keep_flags_{Context::GetDeviceType()}; + + // prenms proposals (raw proposals minus empty boxes) + Tensor dev_image_prenms_boxes_{Context::GetDeviceType()}; + Tensor dev_image_prenms_scores_{Context::GetDeviceType()}; + Tensor dev_prenms_nboxes_{Context::GetDeviceType()}; + Tensor host_prenms_nboxes_{CPU}; + + Tensor dev_image_boxes_keep_list_{Context::GetDeviceType()}; + + // Tensors used by NMS + Tensor dev_nms_mask_{Context::GetDeviceType()}; + Tensor host_nms_mask_{CPU}; + + // Buffer for output + Tensor dev_postnms_rois_{Context::GetDeviceType()}; + Tensor dev_postnms_rois_probs_{Context::GetDeviceType()}; }; } // namespace caffe2 -#endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_ +#endif // CAFFE2_OPERATORS_GENERATE_PROPOSALS_OP_H_ \ No newline at end of file diff --git a/caffe2/operators/generate_proposals_op_gpu_test.cc b/caffe2/operators/generate_proposals_op_gpu_test.cc new file mode 100644 index 0000000..36ad5aa --- /dev/null +++ b/caffe2/operators/generate_proposals_op_gpu_test.cc @@ -0,0 +1,239 @@ +#include "caffe2/operators/generate_proposals_op.h" + +#include +#include "caffe2/core/flags.h" +#include "caffe2/core/macros.h" + +#include "caffe2/core/context.h" +#include "caffe2/core/context_gpu.h" + +#ifdef CAFFE2_USE_OPENCV +#include +#endif // CAFFE2_USE_OPENCV + +namespace caffe2 { + +static void AddLinSpacedInput( + const vector& shape, + const float min_val, + const float max_val, + const string& name, + Workspace* ws) { + DeviceOption option; + CPUContext context(option); + Blob* blob = ws->CreateBlob(name); + auto* tensor = BlobGetMutableTensor(blob, CPU); + tensor->Resize(shape); + EigenVectorMap tensor_vec( + tensor->template mutable_data(), tensor->size()); + tensor_vec.setLinSpaced(min_val, max_val); + + return; +} + +template +void AddConstInput( + const vector& shape, + const float value, + const string& name, + Context* context, + Workspace* ws) { + Blob* blob = ws->CreateBlob(name); + auto* tensor = BlobGetMutableTensor(blob, Context::GetDeviceType()); + tensor->Resize(shape); + math::Set( + tensor->size(), value, tensor->template mutable_data(), context); + return; +} + +template +void AddInput( + const vector& shape, + const vector& values, + const string& name, + Workspace* ws); + +template <> +void AddInput( + const vector& shape, + const vector& values, + const string& name, + Workspace* ws) { + Blob* blob = ws->CreateBlob(name); + auto* tensor = BlobGetMutableTensor(blob, CPU); + tensor->Resize(shape); + EigenVectorMap tensor_vec( + tensor->template mutable_data(), tensor->size()); + tensor_vec.array() = utils::AsEArrXt(values); +} + +template <> +void AddInput( + const vector& shape, + const vector& values, + const string& name, + Workspace* ws) { + Tensor tmp(shape, CPU); + EigenVectorMap tmp_vec(tmp.mutable_data(), tmp.size()); + tmp_vec.array() = utils::AsEArrXt(values); + + Blob* blob = ws->CreateBlob(name); + auto* tensor = BlobGetMutableTensor(blob, CUDA); + tensor->CopyFrom(tmp); +} + +TEST(GenerateProposalsTest, TestRealDownSampledGPU) { + if (!HasCudaGPU()) + return; + Workspace ws; + OperatorDef def; + def.set_name("test"); + def.set_type("GenerateProposals"); + def.add_input("scores"); + def.add_input("bbox_deltas"); + def.add_input("im_info"); + def.add_input("anchors"); + def.add_output("rois"); + def.add_output("rois_probs"); + def.mutable_device_option()->set_device_type(PROTO_CUDA); + const int img_count = 2; + const int A = 2; + const int H = 4; + const int W = 5; + + vector scores{ + 5.44218998e-03f, 1.19207997e-03f, 1.12379994e-03f, 1.17181998e-03f, + 1.20544003e-03f, 6.17993006e-04f, 1.05261997e-05f, 8.91025957e-06f, + 9.29536981e-09f, 6.09605013e-05f, 4.72735002e-04f, 1.13482002e-10f, + 1.50015003e-05f, 4.45032993e-06f, 3.21612994e-08f, 8.02662980e-04f, + 1.40488002e-04f, 3.12508007e-07f, 3.02616991e-06f, 1.97759000e-08f, + 2.66913995e-02f, 5.26766013e-03f, 5.05053019e-03f, 5.62100019e-03f, + 5.37420018e-03f, 5.26280981e-03f, 2.48894998e-04f, 1.06842002e-04f, + 3.92931997e-06f, 1.79388002e-03f, 4.79440019e-03f, 3.41609990e-07f, + 5.20430971e-04f, 3.34090000e-05f, 2.19159006e-07f, 2.28786003e-03f, + 5.16703985e-05f, 4.04523007e-06f, 1.79227004e-06f, 5.32449000e-08f}; + vector bbx{ + -1.65040009e-02f, -1.84051003e-02f, -1.85930002e-02f, -2.08263006e-02f, + -1.83814000e-02f, -2.89172009e-02f, -3.89706008e-02f, -7.52277970e-02f, + -1.54091999e-01f, -2.55433004e-02f, -1.77490003e-02f, -1.10340998e-01f, + -4.20190990e-02f, -2.71421000e-02f, 6.89801015e-03f, 5.71171008e-02f, + -1.75665006e-01f, 2.30021998e-02f, 3.08554992e-02f, -1.39333997e-02f, + 3.40579003e-01f, 3.91070992e-01f, 3.91624004e-01f, 3.92527014e-01f, + 3.91445011e-01f, 3.79328012e-01f, 4.26631987e-01f, 3.64892989e-01f, + 2.76894987e-01f, 5.13985991e-01f, 3.79999995e-01f, 1.80457994e-01f, + 4.37402993e-01f, 4.18545991e-01f, 2.51549989e-01f, 4.48318988e-01f, + 1.68564007e-01f, 4.65440989e-01f, 4.21891987e-01f, 4.45928007e-01f, + 3.27155995e-03f, 3.71480011e-03f, 3.60032008e-03f, 4.27092984e-03f, + 3.74579988e-03f, 5.95752988e-03f, -3.14473989e-03f, 3.52022005e-03f, + -1.88564006e-02f, 1.65188999e-03f, 1.73791999e-03f, -3.56074013e-02f, + -1.66615995e-04f, 3.14146001e-03f, -1.11830998e-02f, -5.35363983e-03f, + 6.49790000e-03f, -9.27671045e-03f, -2.83346009e-02f, -1.61233004e-02f, + -2.15505004e-01f, -2.19910994e-01f, -2.20872998e-01f, -2.12831005e-01f, + -2.19145000e-01f, -2.27687001e-01f, -3.43973994e-01f, -2.75869995e-01f, + -3.19516987e-01f, -2.50418007e-01f, -2.48537004e-01f, -5.08224010e-01f, + -2.28724003e-01f, -2.82402009e-01f, -3.75815988e-01f, -2.86352992e-01f, + -5.28333001e-02f, -4.43836004e-01f, -4.55134988e-01f, -4.34897989e-01f, + -5.65053988e-03f, -9.25739005e-04f, -1.06790999e-03f, -2.37016007e-03f, + -9.71166010e-04f, -8.90910998e-03f, -1.17592998e-02f, -2.08992008e-02f, + -4.94231991e-02f, 6.63906988e-03f, 3.20469006e-03f, -6.44695014e-02f, + -3.11607006e-03f, 2.02738005e-03f, 1.48096997e-02f, 4.39785011e-02f, + -8.28424022e-02f, 3.62076014e-02f, 2.71668993e-02f, 1.38250999e-02f, + 6.76669031e-02f, 1.03252999e-01f, 1.03255004e-01f, 9.89722982e-02f, + 1.03646003e-01f, 4.79663983e-02f, 1.11014001e-01f, 9.31736007e-02f, + 1.15768999e-01f, 1.04014002e-01f, -8.90677981e-03f, 1.13103002e-01f, + 1.33085996e-01f, 1.25405997e-01f, 1.50051996e-01f, -1.13038003e-01f, + 7.01059997e-02f, 1.79651007e-01f, 1.41055003e-01f, 1.62841007e-01f, + -1.00247003e-02f, -8.17587040e-03f, -8.32176022e-03f, -8.90108012e-03f, + -8.13035015e-03f, -1.77263003e-02f, -3.69572006e-02f, -3.51580009e-02f, + -5.92143014e-02f, -1.80795006e-02f, -5.46086021e-03f, -4.10550982e-02f, + -1.83081999e-02f, -2.15411000e-02f, -1.17953997e-02f, 3.33894007e-02f, + -5.29635996e-02f, -6.97528012e-03f, -3.15250992e-03f, -3.27355005e-02f, + 1.29676998e-01f, 1.16080999e-01f, 1.15947001e-01f, 1.21797003e-01f, + 1.16089001e-01f, 1.44875005e-01f, 1.15617000e-01f, 1.31586999e-01f, + 1.74735002e-02f, 1.21973999e-01f, 1.31596997e-01f, 2.48907991e-02f, + 6.18605018e-02f, 1.12855002e-01f, -6.99798986e-02f, 9.58312973e-02f, + 1.53593004e-01f, -8.75087008e-02f, -4.92327996e-02f, -3.32239009e-02f}; + vector im_info{60, 80, 0.166667f}; + vector anchors{-38, -16, 53, 31, -120, -120, 135, 135}; + + // Doubling everything related to images, to simulate + // num_images = 2 + scores.insert(scores.begin(), scores.begin(), scores.end()); + bbx.insert(bbx.begin(), bbx.begin(), bbx.end()); + im_info.insert(im_info.begin(), im_info.begin(), im_info.end()); + + ERMatXf rois_gt(18, 5); + rois_gt << 0, 0, 0, 79, 59, 0, 0, 5.0005703f, 51.6324f, 42.6950f, 0, + 24.13628387f, 7.51243401f, 79, 45.0663f, 0, 0, 7.50924301f, 67.4779f, + 45.0336, 0, 0, 23.09477997f, 50.61448669f, 59, 0, 0, 39.52141571f, + 51.44710541f, 59, 0, 23.57396317f, 29.98791885f, 79, 59, 0, 0, + 41.90219116f, 79, 59, 0, 0, 23.30098343f, 78.2413f, 58.7287f, 1, 0, 0, 79, + 59, 1, 0, 5.0005703f, 51.6324f, 42.6950f, 1, 24.13628387f, 7.51243401f, + 79, 45.0663f, 1, 0, 7.50924301f, 67.4779f, 45.0336, 1, 0, 23.09477997f, + 50.61448669f, 59, 1, 0, 39.52141571f, 51.44710541f, 59, 1, 23.57396317f, + 29.98791885f, 79, 59, 1, 0, 41.90219116f, 79, 59, 1, 0, 23.30098343f, + 78.2413f, 58.7287f; + + vector rois_probs_gt{2.66913995e-02f, + 5.44218998e-03f, + 1.20544003e-03f, + 1.19207997e-03f, + 6.17993006e-04f, + 4.72735002e-04f, + 6.09605013e-05f, + 1.50015003e-05f, + 8.91025957e-06f}; + + // Doubling everything related to images, to simulate + // num_images = 2 + rois_probs_gt.insert( + rois_probs_gt.begin(), rois_probs_gt.begin(), rois_probs_gt.end()); + + AddInput( + vector{img_count, A, H, W}, scores, "scores", &ws); + AddInput( + vector{img_count, 4 * A, H, W}, bbx, "bbox_deltas", &ws); + AddInput(vector{img_count, 3}, im_info, "im_info", &ws); + AddInput(vector{A, 4}, anchors, "anchors", &ws); + + def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f)); + def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000)); + def.add_arg()->CopyFrom(MakeArgument("post_nms_topN", 300)); + def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f)); + def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f)); + def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true)); + + unique_ptr op(CreateOperator(def, &ws)); + EXPECT_NE(nullptr, op.get()); + EXPECT_TRUE(op->Run()); + + // test rois + Blob* rois_blob = ws.GetBlob("rois"); + EXPECT_NE(nullptr, rois_blob); + auto& rois_gpu = rois_blob->Get(); + Tensor rois{CPU}; + rois.CopyFrom(rois_gpu); + + EXPECT_EQ(rois.dims(), (vector{rois_gt.rows(), rois_gt.cols()})); + auto rois_data = + Eigen::Map(rois.data(), rois.dim(0), rois.dim(1)); + EXPECT_NEAR((rois_data.matrix() - rois_gt).cwiseAbs().maxCoeff(), 0, 1e-4); + + // test rois_probs + Blob* rois_probs_blob = ws.GetBlob("rois_probs"); + EXPECT_NE(nullptr, rois_probs_blob); + auto& rois_probs_gpu = rois_probs_blob->Get(); + Tensor rois_probs{CPU}; + rois_probs.CopyFrom(rois_probs_gpu); + EXPECT_EQ( + rois_probs.dims(), (vector{int64_t(rois_probs_gt.size())})); + auto rois_probs_data = + ConstEigenVectorArrayMap(rois_probs.data(), rois.dim(0)); + EXPECT_NEAR( + (rois_probs_data.matrix() - utils::AsEArrXt(rois_probs_gt).matrix()) + .cwiseAbs() + .maxCoeff(), + 0, + 1e-4); +} +} // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.cu b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu new file mode 100644 index 0000000..eb6695c --- /dev/null +++ b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu @@ -0,0 +1,196 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h" + +namespace caffe2 { +namespace utils { +namespace { +// Helper data structure used locally +struct +#ifndef __HIP_PLATFORM_HCC__ + __align__(16) +#endif + Box { + float x1, y1, x2, y2; +}; + +#define BOXES_PER_THREAD (8 * sizeof(int)) +#define CHUNK_SIZE 2000 + +const dim3 CAFFE_CUDA_NUM_THREADS_2D = { + static_cast(CAFFE_CUDA_NUM_THREADS_2D_DIMX), + static_cast(CAFFE_CUDA_NUM_THREADS_2D_DIMY), + 1u}; + +__launch_bounds__( + CAFFE_CUDA_NUM_THREADS_2D_DIMX* CAFFE_CUDA_NUM_THREADS_2D_DIMY, + 4) __global__ + void NMSKernel( + const Box* d_desc_sorted_boxes, + const int nboxes, + const float thresh, + const int mask_ld, + int* d_delete_mask) { + // Storing boxes used by this CUDA block in the shared memory + __shared__ Box shared_i_boxes[CAFFE_CUDA_NUM_THREADS_2D_DIMX]; + // Same thing with areas + __shared__ float shared_i_areas[CAFFE_CUDA_NUM_THREADS_2D_DIMX]; + // The condition of the for loop is common to all threads in the block + // This is necessary to be able to call __syncthreads() inside of the loop + for (int i_block_offset = blockIdx.x * blockDim.x; i_block_offset < nboxes; + i_block_offset += blockDim.x * gridDim.x) { + const int i_to_load = i_block_offset + threadIdx.x; + if (i_to_load < nboxes) { + // One 1D line load the boxes for x-dimension + if (threadIdx.y == 0) { + const Box box = d_desc_sorted_boxes[i_to_load]; + shared_i_areas[threadIdx.x] = + (box.x2 - box.x1 + 1.0f) * (box.y2 - box.y1 + 1.0f); + shared_i_boxes[threadIdx.x] = box; + } + } + __syncthreads(); + const int i = i_block_offset + threadIdx.x; + for (int j_thread_offset = + BOXES_PER_THREAD * (blockIdx.y * blockDim.y + threadIdx.y); + j_thread_offset < nboxes; + j_thread_offset += BOXES_PER_THREAD * blockDim.y * gridDim.y) { + // Note : We can do everything using multiplication, + // and use fp16 - we are comparing against a low precision + // threshold + int above_thresh = 0; + bool valid = false; + for (int ib = 0; ib < BOXES_PER_THREAD; ++ib) { + // This thread will compare Box i and Box j + const int j = j_thread_offset + ib; + if (i < j && i < nboxes && j < nboxes) { + valid = true; + const Box j_box = d_desc_sorted_boxes[j]; + const Box i_box = shared_i_boxes[threadIdx.x]; + const float j_area = + (j_box.x2 - j_box.x1 + 1.0f) * (j_box.y2 - j_box.y1 + 1.0f); + const float i_area = shared_i_areas[threadIdx.x]; + // The following code will not be valid with empty boxes + if (i_area == 0.0f || j_area == 0.0f) + continue; + const float xx1 = fmaxf(i_box.x1, j_box.x1); + const float yy1 = fmaxf(i_box.y1, j_box.y1); + const float xx2 = fminf(i_box.x2, j_box.x2); + const float yy2 = fminf(i_box.y2, j_box.y2); + + // fdimf computes the positive difference between xx2+1 and xx1 + const float w = fdimf(xx2 + 1.0f, xx1); + const float h = fdimf(yy2 + 1.0f, yy1); + const float intersection = w * h; + + // Testing for a/b > t + // eq with a > b*t (b is !=0) + // avoiding divisions + const float a = intersection; + const float b = i_area + j_area - intersection; + const float bt = b * thresh; + // eq. to if ovr > thresh + if (a > bt) { + // we have score[j] <= score[i] + above_thresh |= (1U << ib); + } + } + } + if (valid) + d_delete_mask[i * mask_ld + j_thread_offset / BOXES_PER_THREAD] = + above_thresh; + } + __syncthreads(); // making sure everyone is done reading smem + } +} +} // namespace + +void nms_gpu_upright( + const float* d_desc_sorted_boxes_float_ptr, + const int N, + const float thresh, + int* d_keep_sorted_list, + int* h_nkeep, + TensorCUDA& dev_delete_mask, + TensorCPU& host_delete_mask, + CUDAContext* context) { + // Making sure we respect the __align(16)__ we promised to the compiler + auto iptr = reinterpret_cast(d_desc_sorted_boxes_float_ptr); + CAFFE_ENFORCE_EQ(iptr % 16, 0); + + // The next kernel expects squares + CAFFE_ENFORCE_EQ( + CAFFE_CUDA_NUM_THREADS_2D_DIMX, CAFFE_CUDA_NUM_THREADS_2D_DIMY); + + const int mask_ld = (N + BOXES_PER_THREAD - 1) / BOXES_PER_THREAD; + const Box* d_desc_sorted_boxes = + reinterpret_cast(d_desc_sorted_boxes_float_ptr); + dev_delete_mask.Resize(N * mask_ld); + int* d_delete_mask = dev_delete_mask.template mutable_data(); + NMSKernel<<< + CAFFE_GET_BLOCKS_2D(N, mask_ld), + CAFFE_CUDA_NUM_THREADS_2D, + 0, + context->cuda_stream()>>>( + d_desc_sorted_boxes, N, thresh, mask_ld, d_delete_mask); + + host_delete_mask.Resize(N * mask_ld); + int* h_delete_mask = host_delete_mask.template mutable_data(); + + // Overlapping CPU computes and D2H memcpy + // both take about the same time + cudaEvent_t copy_done; + cudaEventCreate(©_done); + int nto_copy = std::min(CHUNK_SIZE, N); + CUDA_CHECK(cudaMemcpyAsync( + &h_delete_mask[0], + &d_delete_mask[0], + nto_copy * mask_ld * sizeof(int), + cudaMemcpyDeviceToHost, + context->cuda_stream())); + CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream())); + int offset = 0; + std::vector h_keep_sorted_list; + std::vector rmv(mask_ld, 0); + while (offset < N) { + const int ncopied = nto_copy; + int next_offset = offset + ncopied; + nto_copy = std::min(CHUNK_SIZE, N - next_offset); + if (nto_copy > 0) { + CUDA_CHECK(cudaMemcpyAsync( + &h_delete_mask[next_offset * mask_ld], + &d_delete_mask[next_offset * mask_ld], + nto_copy * mask_ld * sizeof(int), + cudaMemcpyDeviceToHost, + context->cuda_stream())); + } + // Waiting for previous copy + CUDA_CHECK(cudaEventSynchronize(copy_done)); + if (nto_copy > 0) + cudaEventRecord(copy_done, context->cuda_stream()); + for (int i = offset; i < next_offset; ++i) { + int iblock = i / BOXES_PER_THREAD; + int inblock = i % BOXES_PER_THREAD; + if (!(rmv[iblock] & (1 << inblock))) { + h_keep_sorted_list.push_back(i); + int* p = &h_delete_mask[i * mask_ld]; + for (int ib = 0; ib < mask_ld; ++ib) { + rmv[ib] |= p[ib]; + } + } + } + offset = next_offset; + } + cudaEventDestroy(copy_done); + + const int nkeep = h_keep_sorted_list.size(); + cudaMemcpyAsync( + d_keep_sorted_list, + &h_keep_sorted_list[0], + nkeep * sizeof(int), + cudaMemcpyHostToDevice, + context->cuda_stream()); + + *h_nkeep = nkeep; +} +} // namespace utils +} // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.h b/caffe2/operators/generate_proposals_op_util_nms_gpu.h new file mode 100644 index 0000000..8b639f4 --- /dev/null +++ b/caffe2/operators/generate_proposals_op_util_nms_gpu.h @@ -0,0 +1,39 @@ +#ifndef CAFFE2_OPERATORS_UTILS_NMS_GPU_H_ +#define CAFFE2_OPERATORS_UTILS_NMS_GPU_H_ + +#include + +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { +namespace utils { + +// Computes Non-Maximum Suppression on the GPU +// Reject a bounding box if its region has an intersection-overunion (IoU) +// overlap with a higher scoring selected bounding box larger than a +// threshold. +// +// d_desc_sorted_boxes : pixel coordinates of proposed bounding boxes +// size: (N,4), format: [x1; y1; x2; y2] +// the boxes are sorted by scores in descending order +// N : number of boxes +// d_keep_sorted_list : row indices of the selected proposals, sorted by score +// h_nkeep : number of selected proposals +// dev_delete_mask, host_delete_mask : Tensors that will be used as temp storage +// by NMS +// Those tensors will be resized to the necessary size +// context : current CUDA context +CAFFE2_API void nms_gpu_upright( + const float* d_desc_sorted_boxes, + const int N, + const float thresh, + int* d_keep_sorted_list, + int* h_nkeep, + TensorCUDA& dev_delete_mask, + TensorCPU& host_delete_mask, + CUDAContext* context); + +} // namespace utils +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_UTILS_NMS_GPU_H_ diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc new file mode 100644 index 0000000..20f8f05 --- /dev/null +++ b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc @@ -0,0 +1,338 @@ +#include "caffe2/core/context.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/core/flags.h" +#include "caffe2/operators/generate_proposals_op_util_nms.h" +#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h" +#include "caffe2/operators/utility_ops.h" +#include "caffe2/utils/eigen_utils.h" + +#include + +#include +#include + +namespace caffe2 { + +TEST(UtilsNMSTest, TestNMSGPU) { + if (!HasCudaGPU()) + return; + const int box_dim = 4; + std::vector boxes = {10, 10, 50, 60, 11, 12, 48, 60, 8, 9, + 40, 50, 100, 100, 150, 140, 99, 110, 155, 139}; + + std::vector scores = {0.5f, 0.7f, 0.6f, 0.9f, 0.8f}; + + std::vector indices(scores.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&scores](int lhs, int rhs) { + return scores[lhs] > scores[rhs]; + }); + std::vector sorted_boxes(boxes.size()); + for (int i = 0; i < scores.size(); ++i) { + for (int d = 0; d < box_dim; ++d) + sorted_boxes[i * box_dim + d] = boxes[indices[i] * box_dim + d]; + } + + Workspace ws; + DeviceOption option; + option.set_device_type(PROTO_CUDA); + CUDAContext cuda_context(option); + + Tensor dev_sorted_boxes{CUDA}; + Tensor dev_scores{CUDA}; + Tensor dev_boxes_valid_flags{CUDA}; + Tensor dev_list{CUDA}; + Tensor dev_delete_mask{CUDA}; + Tensor host_delete_mask{CPU}; + Tensor dev_list_nitems{CUDA}; + Tensor host_list{CPU}; + + int nboxes = boxes.size() / box_dim; + dev_sorted_boxes.Resize(box_dim * nboxes); + dev_list.Resize(nboxes); + host_list.Resize(nboxes); + + float* d_sorted_boxes = dev_sorted_boxes.template mutable_data(); + int* d_list = dev_list.template mutable_data(); + int* h_list = host_list.template mutable_data(); + + CUDA_CHECK(cudaMemcpyAsync( + d_sorted_boxes, + &sorted_boxes[0], + sizeof(*d_sorted_boxes) * box_dim * nboxes, + cudaMemcpyHostToDevice, + cuda_context.cuda_stream())); + + std::vector input_thresh{0.1f, 0.3f, 0.5f, 0.8f, 0.9f}; + std::vector> output_gt{ + {0, 2}, {0, 2}, {0, 2}, {0, 1, 2, 3}, {0, 1, 2, 3, 4}}; + + std::vector keep(nboxes); + std::set keep_as_set; + for (int itest = 0; itest < input_thresh.size(); ++itest) { + const float thresh = input_thresh[itest]; + int list_nitems; + utils::nms_gpu_upright( + d_sorted_boxes, + nboxes, + thresh, + d_list, + &list_nitems, + dev_delete_mask, + host_delete_mask, + &cuda_context); + + cuda_context.FinishDeviceComputation(); + host_list.CopyFrom(dev_list); + + keep_as_set.clear(); + for (int i = 0; i < list_nitems; ++i) { + keep_as_set.insert(h_list[i]); + } + + // Sets are sorted + // sets are equal <=> sets contains the same elements + EXPECT_TRUE(output_gt[itest] == keep_as_set); + } + + cuda_context.FinishDeviceComputation(); +} + +void generateRandomBoxes(float* h_boxes, float* h_scores, const int nboxes) { + const float x_y_max = 100; + const float w_h_max = 10; + const float score_max = 1; + + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::default_random_engine generator(seed); + + std::uniform_real_distribution coordinate_distribution( + 0.0, x_y_max - w_h_max); + std::uniform_real_distribution length_distribution(0.0, w_h_max); + std::uniform_real_distribution score_distribution(0.0, score_max); + + for (int ibox = 0; ibox < nboxes; ++ibox) { + float x1, y1, x2, y2; + x1 = coordinate_distribution(generator); + y1 = coordinate_distribution(generator); + x2 = x1 + length_distribution(generator); + y2 = y1 + length_distribution(generator); + h_boxes[ibox * 4 + 0] = x1; + h_boxes[ibox * 4 + 1] = y1; + h_boxes[ibox * 4 + 2] = x2; + h_boxes[ibox * 4 + 3] = y2; + h_scores[ibox] = score_distribution(generator); + } +} + +TEST(UtilsNMSTest, TestPerfNMS) { + if (!HasCudaGPU()) + return; + const int box_dim = 4; + const int nboxes = 6000; + + Workspace ws; + DeviceOption option; + option.set_device_type(PROTO_CUDA); + CUDAContext cuda_context(option); + + Tensor host_boxes{CPU}; + Tensor host_scores{CPU}; + host_boxes.Resize(box_dim * nboxes); + host_scores.Resize(nboxes); + + float* h_boxes = host_boxes.template mutable_data(); + float* h_scores = host_scores.template mutable_data(); + + // Generating random input + generateRandomBoxes(h_boxes, h_scores, nboxes); + + Eigen::ArrayXXf proposals(nboxes, box_dim); + Eigen::ArrayXXf scores(nboxes, 1); + for (int i = 0; i < nboxes; ++i) { + for (int d = 0; d < box_dim; ++d) + proposals(i, d) = h_boxes[box_dim * i + d]; + scores(i, 0) = h_scores[i]; + } + + const int ntests = 50; + const float thresh = 0.7; + // Not timing the sort for the CPU + // in the real-world use case scores already have been sorted earlier in the + // generate proposals workflow + std::vector indices(proposals.rows()); + std::iota(indices.begin(), indices.end(), 0); + std::sort( + indices.data(), + indices.data() + indices.size(), + [&scores](int lhs, int rhs) { return scores(lhs) > scores(rhs); }); + + // Running ntests runs of CPU NMS + auto cpu_start = std::chrono::steady_clock::now(); + for (int itest = 0; itest < ntests; ++itest) { + utils::nms_cpu(proposals, scores, indices, thresh); + } + auto cpu_stop = std::chrono::steady_clock::now(); + + std::vector sorted_boxes(nboxes * box_dim); + for (int i = 0; i < scores.size(); ++i) { + for (int d = 0; d < box_dim; ++d) + sorted_boxes[i * box_dim + d] = h_boxes[indices[i] * box_dim + d]; + } + + Tensor dev_boxes{CUDA}; + Tensor dev_delete_mask{CUDA}; + Tensor host_delete_mask{CPU}; + Tensor dev_list{CUDA}; + + dev_boxes.Resize(box_dim * nboxes); + float* d_sorted_boxes = dev_boxes.template mutable_data(); + dev_list.Resize(nboxes); + int* d_list = dev_list.template mutable_data(); + int list_nitems; + + // No timing the memcpies because data is already on the GPU in the real-world + // use case (generated by the GPU generate_proposals) + CUDA_CHECK(cudaMemcpyAsync( + d_sorted_boxes, + &sorted_boxes[0], + sizeof(*d_sorted_boxes) * box_dim * nboxes, + cudaMemcpyHostToDevice, + cuda_context.cuda_stream())); + + // Running ntests runs of GPU NMS + auto gpu_start = std::chrono::steady_clock::now(); + for (int itest = 0; itest < ntests; ++itest) { + utils::nms_gpu_upright( + d_sorted_boxes, + nboxes, + thresh, + d_list, + &list_nitems, + dev_delete_mask, + host_delete_mask, + &cuda_context); + } + // Waiting for everything to be done + CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream())); + auto gpu_stop = std::chrono::steady_clock::now(); + + double total_cpu_time = + std::chrono::duration(cpu_stop - cpu_start).count(); + double total_gpu_time = + std::chrono::duration(gpu_stop - gpu_start).count(); + double ratio = total_cpu_time / total_gpu_time; + + double avg_cpu_time = total_cpu_time / ntests; + double avg_gpu_time = total_gpu_time / ntests; + + printf( + "NMS, nproposals=%i, ntests=%i, Avg GPU time = %fms, Avg CPU time = %fms, GPU speed up = %fX \n", + nboxes, + ntests, + avg_gpu_time, + avg_cpu_time, + ratio); +} + +TEST(UtilsNMSTest, GPUEqualsCPUCorrectnessTest) { + if (!HasCudaGPU()) + return; + Workspace ws; + DeviceOption option; + option.set_device_type(PROTO_CUDA); + CUDAContext cuda_context(option); + + const int box_dim = 4; + const std::vector nboxes_vec = {10, 100, 1000, 2000, 6000, 12000}; + for (int nboxes : nboxes_vec) { + Tensor host_boxes{CPU}; + Tensor host_scores{CPU}; + host_boxes.Resize(box_dim * nboxes); + host_scores.Resize(nboxes); + + float* h_boxes = host_boxes.template mutable_data(); + float* h_scores = host_scores.template mutable_data(); + + // Generating random input + generateRandomBoxes(h_boxes, h_scores, nboxes); + + const int ntests = 1; + const float thresh = 0.7; + // Not timing the sort for the CPU + // in the real-world use case scores already have been sorted earlier in the + // generate proposals workflow + std::vector indices(nboxes); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [h_scores](int lhs, int rhs) { + return h_scores[lhs] > h_scores[rhs]; + }); + + std::vector sorted_boxes(nboxes * box_dim); + std::vector sorted_scores(nboxes); + Eigen::ArrayXXf eig_proposals(nboxes, box_dim); + Eigen::ArrayXXf eig_scores(nboxes, 1); + for (int i = 0; i < nboxes; ++i) { + for (int d = 0; d < box_dim; ++d) { + sorted_boxes[i * box_dim + d] = h_boxes[indices[i] * box_dim + d]; + eig_proposals(i, d) = h_boxes[indices[i] * box_dim + d]; + } + sorted_scores[i] = h_scores[indices[i]]; + eig_scores(i) = h_scores[indices[i]]; + } + std::vector sorted_indices(nboxes); + std::iota(sorted_indices.begin(), sorted_indices.end(), 0); + + Tensor dev_boxes{CUDA}; + Tensor dev_delete_mask{CUDA}; + Tensor host_delete_mask{CPU}; + Tensor dev_list{CUDA}; + + dev_boxes.Resize(box_dim * nboxes); + float* d_sorted_boxes = dev_boxes.template mutable_data(); + dev_list.Resize(nboxes); + int* d_list = dev_list.template mutable_data(); + + // No timing the memcpies because data is already on the GPU in the + // real-world use case (generated by the GPU generate_proposals) + CUDA_CHECK(cudaMemcpyAsync( + d_sorted_boxes, + &sorted_boxes[0], + sizeof(*d_sorted_boxes) * box_dim * nboxes, + cudaMemcpyHostToDevice, + cuda_context.cuda_stream())); + + // Running ntests runs of CPU NMS + for (int itest = 0; itest < ntests; ++itest) { + std::vector keep = + utils::nms_cpu(eig_proposals, eig_scores, sorted_indices, thresh); + int list_nitems; + utils::nms_gpu_upright( + d_sorted_boxes, + nboxes, + thresh, + d_list, + &list_nitems, + dev_delete_mask, + host_delete_mask, + &cuda_context); + std::vector gpu_keep(list_nitems); + CUDA_CHECK(cudaMemcpyAsync( + &gpu_keep[0], + d_list, + list_nitems * sizeof(int), + cudaMemcpyDeviceToHost, + cuda_context.cuda_stream())); + CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream())); + + ASSERT_EQ(keep.size(), gpu_keep.size()); + std::sort(keep.begin(), keep.end()); + std::sort(gpu_keep.begin(), gpu_keep.end()); + + for (int i = 0; i < list_nitems; ++i) + EXPECT_EQ(keep[i], gpu_keep[i]); + } + } +} + +} // namespace caffe2 diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index 72db033..cb45333 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -2244,6 +2244,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([ ("/operator_fallback_gpu" , ("/hip/operator_fallback_gpu", API_CAFFE2)), ("/spatial_batch_norm_op_gpu_impl" , ("/hip/spatial_batch_norm_op_gpu_impl", API_CAFFE2)), ("/recurrent_network_executor_gpu" , ("/hip/recurrent_network_executor_gpu", API_CAFFE2)), + ("/generate_proposals_op_util_nms_gpu" , ("/hip/generate_proposals_op_util_nms_gpu", API_CAFFE2)), ("/max_pool_with_index_gpu", ("/hip/max_pool_with_index_gpu", API_CAFFE2)), ("/THCCachingAllocator_gpu", ("/hip/THCCachingAllocator_gpu", API_CAFFE2)), ("/top_k_heap_selection", ("/hip/top_k_heap_selection", API_CAFFE2)), -- 2.7.4