postprocess_kernel.cu 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright 2022 TIER IV, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "lidar_centerpoint/postprocess/circle_nms_kernel.hpp"
  15. #include <lidar_centerpoint/postprocess/postprocess_kernel.hpp>
  16. #include <thrust/count.h>
  17. #include <thrust/sort.h>
  18. namespace
  19. {
  20. const std::size_t THREADS_PER_BLOCK = 32;
  21. } // namespace
  22. namespace centerpoint
  23. {
  24. struct is_score_greater
  25. {
  26. is_score_greater(float t) : t_(t) {}
  27. __device__ bool operator()(const Box3D & b) { return b.score > t_; }
  28. private:
  29. float t_{0.0};
  30. };
  31. struct is_kept
  32. {
  33. __device__ bool operator()(const bool keep) { return keep; }
  34. };
  35. struct score_greater
  36. {
  37. __device__ bool operator()(const Box3D & lb, const Box3D & rb) { return lb.score > rb.score; }
  38. };
  39. __device__ inline float sigmoid(float x) { return 1.0f / expf(-x); }
  40. __global__ void generateBoxes3D_kernel(
  41. const float * out_heatmap, const float * out_offset, const float * out_z, const float * out_dim,
  42. const float * out_rot, const float * out_vel, const float voxel_size_x, const float voxel_size_y,
  43. const float range_min_x, const float range_min_y, const std::size_t down_grid_size_x,
  44. const std::size_t down_grid_size_y, const std::size_t downsample_factor, const int class_size,
  45. Box3D * det_boxes3d)
  46. {
  47. // generate boxes3d from the outputs of the network.
  48. // shape of out_*: (N, DOWN_GRID_SIZE_Y, DOWN_GRID_SIZE_X)
  49. // heatmap: N = class_size, offset: N = 2, z: N = 1, dim: N = 3, rot: N = 2, vel: N = 2
  50. const auto yi = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
  51. const auto xi = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
  52. const auto idx = down_grid_size_x * yi + xi;
  53. const auto down_grid_size = down_grid_size_y * down_grid_size_x;
  54. if (yi >= down_grid_size_y || xi >= down_grid_size_x) {
  55. return;
  56. }
  57. int label = -1;
  58. float max_score = -1;
  59. for (int ci = 0; ci < class_size; ci++) {
  60. float score = sigmoid(out_heatmap[down_grid_size * ci + idx]);
  61. if (score > max_score) {
  62. label = ci;
  63. max_score = score;
  64. }
  65. }
  66. const float offset_x = out_offset[down_grid_size * 0 + idx];
  67. const float offset_y = out_offset[down_grid_size * 1 + idx];
  68. const float x = voxel_size_x * downsample_factor * (xi + offset_x) + range_min_x;
  69. const float y = voxel_size_y * downsample_factor * (yi + offset_y) + range_min_y;
  70. const float z = out_z[idx];
  71. const float w = out_dim[down_grid_size * 0 + idx];
  72. const float l = out_dim[down_grid_size * 1 + idx];
  73. const float h = out_dim[down_grid_size * 2 + idx];
  74. const float yaw_sin = out_rot[down_grid_size * 0 + idx];
  75. const float yaw_cos = out_rot[down_grid_size * 1 + idx];
  76. const float vel_x = out_vel[down_grid_size * 0 + idx];
  77. const float vel_y = out_vel[down_grid_size * 1 + idx];
  78. det_boxes3d[idx].label = label;
  79. det_boxes3d[idx].score = max_score;
  80. det_boxes3d[idx].x = x;
  81. det_boxes3d[idx].y = y;
  82. det_boxes3d[idx].z = z;
  83. det_boxes3d[idx].length = expf(l);
  84. det_boxes3d[idx].width = expf(w);
  85. det_boxes3d[idx].height = expf(h);
  86. det_boxes3d[idx].yaw = atan2f(yaw_sin, yaw_cos);
  87. det_boxes3d[idx].vel_x = vel_x;
  88. det_boxes3d[idx].vel_y = vel_y;
  89. }
  90. PostProcessCUDA::PostProcessCUDA(const CenterPointConfig & config) : config_(config)
  91. {
  92. const auto num_raw_boxes3d = config.down_grid_size_y_ * config.down_grid_size_x_;
  93. boxes3d_d_ = thrust::device_vector<Box3D>(num_raw_boxes3d);
  94. }
  95. cudaError_t PostProcessCUDA::generateDetectedBoxes3D_launch(
  96. const float * out_heatmap, const float * out_offset, const float * out_z, const float * out_dim,
  97. const float * out_rot, const float * out_vel, std::vector<Box3D> & det_boxes3d,
  98. cudaStream_t stream)
  99. {
  100. dim3 blocks(
  101. divup(config_.down_grid_size_y_, THREADS_PER_BLOCK),
  102. divup(config_.down_grid_size_x_, THREADS_PER_BLOCK));
  103. dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
  104. generateBoxes3D_kernel<<<blocks, threads, 0, stream>>>(
  105. out_heatmap, out_offset, out_z, out_dim, out_rot, out_vel, config_.voxel_size_x_,
  106. config_.voxel_size_y_, config_.range_min_x_, config_.range_min_y_, config_.down_grid_size_x_,
  107. config_.down_grid_size_y_, config_.downsample_factor_, config_.class_size_,
  108. thrust::raw_pointer_cast(boxes3d_d_.data()));
  109. // suppress by socre
  110. const auto num_det_boxes3d = thrust::count_if(
  111. thrust::device, boxes3d_d_.begin(), boxes3d_d_.end(),
  112. is_score_greater(config_.score_threshold_));
  113. if (num_det_boxes3d == 0) {
  114. return cudaGetLastError();
  115. }
  116. thrust::device_vector<Box3D> det_boxes3d_d(num_det_boxes3d);
  117. thrust::copy_if(
  118. thrust::device, boxes3d_d_.begin(), boxes3d_d_.end(), det_boxes3d_d.begin(),
  119. is_score_greater(config_.score_threshold_));
  120. // sort by score
  121. thrust::sort(det_boxes3d_d.begin(), det_boxes3d_d.end(), score_greater());
  122. // supress by NMS
  123. thrust::device_vector<bool> final_keep_mask_d(num_det_boxes3d);
  124. const auto num_final_det_boxes3d =
  125. circleNMS(det_boxes3d_d, config_.circle_nms_dist_threshold_, final_keep_mask_d, stream);
  126. thrust::device_vector<Box3D> final_det_boxes3d_d(num_final_det_boxes3d);
  127. thrust::copy_if(
  128. thrust::device, det_boxes3d_d.begin(), det_boxes3d_d.end(), final_keep_mask_d.begin(),
  129. final_det_boxes3d_d.begin(), is_kept());
  130. // memcpy device to host
  131. det_boxes3d.resize(num_final_det_boxes3d);
  132. thrust::copy(final_det_boxes3d_d.begin(), final_det_boxes3d_d.end(), det_boxes3d.begin());
  133. return cudaGetLastError();
  134. }
  135. } // namespace centerpoint