circle_nms_kernel.cu 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. // Copyright 2022 TIER IV, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Modified from
  15. // https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu
  16. /*
  17. 3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
  18. Written by Shaoshuai Shi
  19. All Rights Reserved 2019-2020.
  20. */
  21. #include "lidar_centerpoint/postprocess/circle_nms_kernel.hpp"
  22. #include <lidar_centerpoint/cuda_utils.hpp>
  23. #include <lidar_centerpoint/utils.hpp>
  24. #include <thrust/host_vector.h>
  25. namespace
  26. {
  27. const std::size_t THREADS_PER_BLOCK_NMS = 16;
  28. } // namespace
  29. namespace centerpoint
  30. {
  31. __device__ inline float dist2dPow(const Box3D * a, const Box3D * b)
  32. {
  33. return powf(a->x - b->x, 2) + powf(a->y - b->y, 2);
  34. }
  35. __global__ void circleNMS_Kernel(
  36. const Box3D * boxes, const std::size_t num_boxes3d, const std::size_t col_blocks,
  37. const float dist2d_pow_threshold, std::uint64_t * mask)
  38. {
  39. // params: boxes (N,)
  40. // params: mask (N, divup(N/THREADS_PER_BLOCK_NMS))
  41. const auto row_start = blockIdx.y;
  42. const auto col_start = blockIdx.x;
  43. if (row_start > col_start) return;
  44. const std::size_t row_size =
  45. fminf(num_boxes3d - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);
  46. const std::size_t col_size =
  47. fminf(num_boxes3d - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);
  48. __shared__ Box3D block_boxes[THREADS_PER_BLOCK_NMS];
  49. if (threadIdx.x < col_size) {
  50. block_boxes[threadIdx.x] = boxes[THREADS_PER_BLOCK_NMS * col_start + threadIdx.x];
  51. }
  52. __syncthreads();
  53. if (threadIdx.x < row_size) {
  54. const std::size_t cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
  55. const Box3D * cur_box = boxes + cur_box_idx;
  56. std::uint64_t t = 0;
  57. std::size_t start = 0;
  58. if (row_start == col_start) {
  59. start = threadIdx.x + 1;
  60. }
  61. for (std::size_t i = start; i < col_size; i++) {
  62. if (dist2dPow(cur_box, block_boxes + i) < dist2d_pow_threshold) {
  63. t |= 1ULL << i;
  64. }
  65. }
  66. mask[cur_box_idx * col_blocks + col_start] = t;
  67. }
  68. }
  69. cudaError_t circleNMS_launch(
  70. const thrust::device_vector<Box3D> & boxes3d, const std::size_t num_boxes3d,
  71. std::size_t col_blocks, const float distance_threshold,
  72. thrust::device_vector<std::uint64_t> & mask, cudaStream_t stream)
  73. {
  74. const float dist2d_pow_thres = powf(distance_threshold, 2);
  75. dim3 blocks(col_blocks, col_blocks);
  76. dim3 threads(THREADS_PER_BLOCK_NMS);
  77. circleNMS_Kernel<<<blocks, threads, 0, stream>>>(
  78. thrust::raw_pointer_cast(boxes3d.data()), num_boxes3d, col_blocks, dist2d_pow_thres,
  79. thrust::raw_pointer_cast(mask.data()));
  80. return cudaGetLastError();
  81. }
  82. std::size_t circleNMS(
  83. thrust::device_vector<Box3D> & boxes3d, const float distance_threshold,
  84. thrust::device_vector<bool> & keep_mask, cudaStream_t stream)
  85. {
  86. const auto num_boxes3d = boxes3d.size();
  87. const auto col_blocks = divup(num_boxes3d, THREADS_PER_BLOCK_NMS);
  88. thrust::device_vector<std::uint64_t> mask_d(num_boxes3d * col_blocks);
  89. CHECK_CUDA_ERROR(
  90. circleNMS_launch(boxes3d, num_boxes3d, col_blocks, distance_threshold, mask_d, stream));
  91. // memcpy device to host
  92. thrust::host_vector<std::uint64_t> mask_h(mask_d.size());
  93. thrust::copy(mask_d.begin(), mask_d.end(), mask_h.begin());
  94. CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
  95. // generate keep_mask
  96. std::vector<std::uint64_t> remv_h(col_blocks);
  97. thrust::host_vector<bool> keep_mask_h(keep_mask.size());
  98. std::size_t num_to_keep = 0;
  99. for (std::size_t i = 0; i < num_boxes3d; i++) {
  100. auto nblock = i / THREADS_PER_BLOCK_NMS;
  101. auto inblock = i % THREADS_PER_BLOCK_NMS;
  102. if (!(remv_h[nblock] & (1ULL << inblock))) {
  103. keep_mask_h[i] = true;
  104. num_to_keep++;
  105. std::uint64_t * p = &mask_h[0] + i * col_blocks;
  106. for (std::size_t j = nblock; j < col_blocks; j++) {
  107. remv_h[j] |= p[j];
  108. }
  109. } else {
  110. keep_mask_h[i] = false;
  111. }
  112. }
  113. // memcpy host to device
  114. keep_mask = keep_mask_h;
  115. return num_to_keep;
  116. }
  117. } // namespace centerpoint