postprocess.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. /******************************************************************************
  2. * Copyright 2020 The Apollo Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. *****************************************************************************/
  16. /*
  17. * Copyright 2018-2019 Autoware Foundation. All rights reserved.
  18. *
  19. * Licensed under the Apache License, Version 2.0 (the "License");
  20. * you may not use this file except in compliance with the License.
  21. * You may obtain a copy of the License at
  22. *
  23. * http://www.apache.org/licenses/LICENSE-2.0
  24. *
  25. * Unless required by applicable law or agreed to in writing, software
  26. * distributed under the License is distributed on an "AS IS" BASIS,
  27. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. * See the License for the specific language governing permissions and
  29. * limitations under the License.
  30. */
  31. /**
  32. * @author Kosuke Murakami
  33. * @date 2019/02/26
  34. */
  35. /**
  36. * @author Yan haixu
  37. * Contact: just github.com/hova88
  38. * @date 2021/04/30
  39. */
  40. // headers in CUDA
  41. #include <thrust/sort.h>
  42. // headers in local files
  43. #include "common.h"
  44. #include "postprocess.h"
  45. #include <stdio.h>
  46. // sigmoid_filter_warp
  47. __device__ void box_decode_warp(int head_offset, const float* box_pred,
  48. int tid , int num_anchors_per_head , int counter, float* filtered_box)
  49. {
  50. filtered_box[blockIdx.z * num_anchors_per_head * 7 + counter * 7 + 0] = box_pred[ head_offset + tid * 9 + 0];
  51. filtered_box[blockIdx.z * num_anchors_per_head * 7 + counter * 7 + 1] = box_pred[ head_offset + tid * 9 + 1];
  52. filtered_box[blockIdx.z * num_anchors_per_head * 7 + counter * 7 + 2] = box_pred[ head_offset + tid * 9 + 2];
  53. filtered_box[blockIdx.z * num_anchors_per_head * 7 + counter * 7 + 3] = box_pred[ head_offset + tid * 9 + 3];
  54. filtered_box[blockIdx.z * num_anchors_per_head * 7 + counter * 7 + 4] = box_pred[ head_offset + tid * 9 + 4];
  55. filtered_box[blockIdx.z * num_anchors_per_head * 7 + counter * 7 + 5] = box_pred[ head_offset + tid * 9 + 5];
  56. filtered_box[blockIdx.z * num_anchors_per_head * 7 + counter * 7 + 6] = box_pred[ head_offset + tid * 9 + 6];
  57. }
  58. __global__ void sigmoid_filter_kernel(
  59. float* cls_pred_0,
  60. float* cls_pred_12,
  61. float* cls_pred_34,
  62. float* cls_pred_5,
  63. float* cls_pred_67,
  64. float* cls_pred_89,
  65. const float* box_pred_0,
  66. const float* box_pred_1,
  67. const float* box_pred_2,
  68. const float* box_pred_3,
  69. const float* box_pred_4,
  70. const float* box_pred_5,
  71. const float* box_pred_6,
  72. const float* box_pred_7,
  73. const float* box_pred_8,
  74. const float* box_pred_9,
  75. float* filtered_box,
  76. float* filtered_score,
  77. int* filter_count,
  78. const float score_threshold) {
  79. // cls_pred_34
  80. // 32768*2 , 2
  81. int num_anchors_per_head = gridDim.x * gridDim.y * blockDim.x;
  82. // 16 * 4 * 512 = 32768
  83. extern __shared__ float cls_score[];
  84. cls_score[threadIdx.x + blockDim.x] = -1.0f;
  85. int tid = blockIdx.x * gridDim.y * blockDim.x + blockIdx.y * blockDim.x + threadIdx.x;
  86. if ( blockIdx.z == 0) cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_0[ tid ]));
  87. if ( blockIdx.z == 1) {
  88. cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_12[ tid * 2 ]));
  89. cls_score[ threadIdx.x + blockDim.x] = 1 / (1 + expf(-cls_pred_12[ (num_anchors_per_head + tid) * 2]));}
  90. if ( blockIdx.z == 2) {
  91. cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_12[ tid * 2 + 1]));
  92. cls_score[ threadIdx.x + blockDim.x] = 1 / (1 + expf(-cls_pred_12[ (num_anchors_per_head + tid) * 2 + 1]));}
  93. if ( blockIdx.z == 3) {
  94. cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_34[ tid * 2 ]));
  95. cls_score[ threadIdx.x + blockDim.x] = 1 / (1 + expf(-cls_pred_34[ (num_anchors_per_head + tid) * 2]));}
  96. if ( blockIdx.z == 4) {
  97. cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_34[ tid * 2 + 1 ]));
  98. cls_score[ threadIdx.x + blockDim.x] = 1 / (1 + expf(-cls_pred_34[ (num_anchors_per_head + tid) * 2 + 1]));}
  99. if ( blockIdx.z == 5) cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_5[ tid ]));
  100. if ( blockIdx.z == 6) {
  101. cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_67[ tid * 2 ]));
  102. cls_score[ threadIdx.x + blockDim.x] = 1 / (1 + expf(-cls_pred_67[ (num_anchors_per_head + tid) * 2]));}
  103. if ( blockIdx.z == 7) {
  104. cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_67[ tid * 2 + 1 ]));
  105. cls_score[ threadIdx.x + blockDim.x] = 1 / (1 + expf(-cls_pred_67[ (num_anchors_per_head + tid) * 2 + 1]));}
  106. if ( blockIdx.z == 8) {
  107. cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_89[ tid * 2 ]));
  108. cls_score[ threadIdx.x + blockDim.x] = 1 / (1 + expf(-cls_pred_89[ (num_anchors_per_head + tid) * 2]));}
  109. if ( blockIdx.z == 9) {
  110. cls_score[ threadIdx.x ] = 1 / (1 + expf(-cls_pred_89[ tid * 2 + 1 ]));
  111. cls_score[ threadIdx.x + blockDim.x] = 1 / (1 + expf(-cls_pred_89[ (num_anchors_per_head + tid) * 2 + 1]));}
  112. __syncthreads();
  113. if( cls_score[ threadIdx.x ] > score_threshold)
  114. {
  115. int counter = atomicAdd(&filter_count[blockIdx.z], 1);
  116. if ( blockIdx.z == 0) {
  117. box_decode_warp(0 ,box_pred_0 , tid , num_anchors_per_head , counter , filtered_box);
  118. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  119. }else
  120. if ( blockIdx.z == 1) {
  121. box_decode_warp(0 ,box_pred_1 , tid , num_anchors_per_head , counter , filtered_box);
  122. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  123. }else
  124. if ( blockIdx.z == 2) {
  125. box_decode_warp(0 ,box_pred_1 , tid , num_anchors_per_head , counter , filtered_box);
  126. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  127. }else
  128. if ( blockIdx.z == 3) {
  129. box_decode_warp(0 ,box_pred_3 , tid , num_anchors_per_head , counter , filtered_box);
  130. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  131. }else
  132. if (blockIdx.z == 4) {
  133. box_decode_warp(0 ,box_pred_3 , tid , num_anchors_per_head , counter , filtered_box);
  134. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  135. }else
  136. if ( blockIdx.z == 5) {
  137. box_decode_warp(0 ,box_pred_5 , tid , num_anchors_per_head , counter , filtered_box);
  138. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  139. }else
  140. if ( blockIdx.z == 6) {
  141. box_decode_warp(0 ,box_pred_6 , tid , num_anchors_per_head , counter , filtered_box);
  142. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  143. }else
  144. if ( blockIdx.z == 7) {
  145. box_decode_warp(0 ,box_pred_6 , tid , num_anchors_per_head , counter , filtered_box);
  146. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  147. }else
  148. if ( blockIdx.z == 8) {
  149. box_decode_warp(0 ,box_pred_8 , tid , num_anchors_per_head , counter , filtered_box);
  150. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  151. }else
  152. if ( blockIdx.z == 9) {
  153. box_decode_warp(0 ,box_pred_8 , tid , num_anchors_per_head , counter , filtered_box);
  154. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  155. }
  156. }
  157. __syncthreads();
  158. if( cls_score[ threadIdx.x + blockDim.x ] > score_threshold) {
  159. int counter = atomicAdd(&filter_count[blockIdx.z], 1);
  160. // printf("counter : %d \n" , counter);
  161. if (blockIdx.z == 1) {
  162. box_decode_warp(0 ,box_pred_2 , tid , num_anchors_per_head , counter , filtered_box);
  163. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  164. }else
  165. if (blockIdx.z == 2) {
  166. box_decode_warp(0 ,box_pred_2 , tid , num_anchors_per_head , counter , filtered_box);
  167. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  168. }else
  169. if (blockIdx.z == 3) {
  170. box_decode_warp(0 ,box_pred_4 , tid , num_anchors_per_head , counter , filtered_box);
  171. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  172. }else
  173. if (blockIdx.z == 4) {
  174. box_decode_warp(0 ,box_pred_4 , tid , num_anchors_per_head , counter , filtered_box);
  175. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  176. }else
  177. if (blockIdx.z == 6) {
  178. box_decode_warp(0 ,box_pred_7 , tid , num_anchors_per_head , counter , filtered_box);
  179. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  180. }else
  181. if (blockIdx.z == 7) {
  182. box_decode_warp(0 ,box_pred_7 , tid , num_anchors_per_head , counter , filtered_box);
  183. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  184. }else
  185. if (blockIdx.z == 8) {
  186. box_decode_warp(0 ,box_pred_9 , tid , num_anchors_per_head , counter , filtered_box);
  187. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  188. }else
  189. if (blockIdx.z == 9) {
  190. box_decode_warp(0 ,box_pred_9 , tid , num_anchors_per_head , counter , filtered_box);
  191. filtered_score[blockIdx.z * num_anchors_per_head + counter] = cls_score[ threadIdx.x ];
  192. }
  193. }
  194. }
  195. __global__ void sort_boxes_by_indexes_kernel(float* filtered_box, float* filtered_scores, int* indexes, int filter_count,
  196. float* sorted_filtered_boxes, float* sorted_filtered_scores,
  197. const int num_output_box_feature)
  198. {
  199. int tid = threadIdx.x + blockIdx.x * blockDim.x;
  200. if(tid < filter_count) {
  201. int sort_index = indexes[tid];
  202. sorted_filtered_boxes[tid * num_output_box_feature + 0] = filtered_box[sort_index * num_output_box_feature + 0];
  203. sorted_filtered_boxes[tid * num_output_box_feature + 1] = filtered_box[sort_index * num_output_box_feature + 1];
  204. sorted_filtered_boxes[tid * num_output_box_feature + 2] = filtered_box[sort_index * num_output_box_feature + 2];
  205. sorted_filtered_boxes[tid * num_output_box_feature + 3] = filtered_box[sort_index * num_output_box_feature + 3];
  206. sorted_filtered_boxes[tid * num_output_box_feature + 4] = filtered_box[sort_index * num_output_box_feature + 4];
  207. sorted_filtered_boxes[tid * num_output_box_feature + 5] = filtered_box[sort_index * num_output_box_feature + 5];
  208. sorted_filtered_boxes[tid * num_output_box_feature + 6] = filtered_box[sort_index * num_output_box_feature + 6];
  209. // sorted_filtered_dir[tid] = filtered_dir[sort_index];
  210. sorted_filtered_scores[tid] = filtered_scores[sort_index];
  211. }
  212. }
  213. PostprocessCuda::PostprocessCuda(const int num_threads, const float float_min, const float float_max,
  214. const int num_class,const int num_anchor_per_cls,
  215. const std::vector<std::vector<int>> multihead_label_mapping,
  216. const float score_threshold, const float nms_overlap_threshold,
  217. const int nms_pre_maxsize, const int nms_post_maxsize,
  218. const int num_box_corners,
  219. const int num_input_box_feature,
  220. const int num_output_box_feature)
  221. : num_threads_(num_threads),
  222. float_min_(float_min),
  223. float_max_(float_max),
  224. num_class_(num_class),
  225. num_anchor_per_cls_(num_anchor_per_cls),
  226. multihead_label_mapping_(multihead_label_mapping),
  227. score_threshold_(score_threshold),
  228. nms_overlap_threshold_(nms_overlap_threshold),
  229. nms_pre_maxsize_(nms_pre_maxsize),
  230. nms_post_maxsize_(nms_post_maxsize),
  231. num_box_corners_(num_box_corners),
  232. num_input_box_feature_(num_input_box_feature),
  233. num_output_box_feature_(num_output_box_feature) {
  234. nms_cuda_ptr_.reset(
  235. new NmsCuda(num_threads_, num_box_corners_, nms_overlap_threshold_));
  236. }
  237. void PostprocessCuda::DoPostprocessCuda(
  238. float* cls_pred_0,
  239. float* cls_pred_12,
  240. float* cls_pred_34,
  241. float* cls_pred_5,
  242. float* cls_pred_67,
  243. float* cls_pred_89,
  244. const float* box_preds,
  245. float* dev_filtered_box,
  246. float* dev_filtered_score,
  247. int* dev_filter_count,
  248. std::vector<float>& out_detection, std::vector<int>& out_label , std::vector<float>& out_score) {
  249. // 在此之前,先进行rpn_box_output的concat.
  250. // 128x128 的feature map, cls_pred 的shape为(32768,1),(32768,1),(32768,1),(65536,2),(32768,1)
  251. dim3 gridsize(16, 4 , 10); //16 * 4 * 512 = 32768 代表一个head的anchors
  252. sigmoid_filter_kernel<<< gridsize, 512 , 512 * 2 * sizeof(float)>>>(
  253. cls_pred_0,
  254. cls_pred_12,
  255. cls_pred_34,
  256. cls_pred_5,
  257. cls_pred_67,
  258. cls_pred_89,
  259. &box_preds[0 * 32768 * 9],
  260. &box_preds[1 * 32768 * 9],
  261. &box_preds[2 * 32768 * 9],
  262. &box_preds[3 * 32768 * 9],
  263. &box_preds[4 * 32768 * 9],
  264. &box_preds[5 * 32768 * 9],
  265. &box_preds[6 * 32768 * 9],
  266. &box_preds[7 * 32768 * 9],
  267. &box_preds[8 * 32768 * 9],
  268. &box_preds[9 * 32768 * 9],
  269. dev_filtered_box,
  270. dev_filtered_score,
  271. dev_filter_count,
  272. score_threshold_);
  273. cudaDeviceSynchronize();
  274. int host_filter_count[num_class_] = {0};
  275. GPU_CHECK(cudaMemcpy(host_filter_count, dev_filter_count, num_class_ * sizeof(int), cudaMemcpyDeviceToHost));
  276. for (int i = 0; i < num_class_; ++ i) {
  277. if(host_filter_count[i] <= 0) continue;
  278. int* dev_indexes;
  279. float* dev_sorted_filtered_box;
  280. float* dev_sorted_filtered_scores;
  281. GPU_CHECK(cudaMalloc((void**)&dev_indexes, host_filter_count[i] * sizeof(int)));
  282. GPU_CHECK(cudaMalloc((void**)&dev_sorted_filtered_box, host_filter_count[i] * num_output_box_feature_ * sizeof(float)));
  283. GPU_CHECK(cudaMalloc((void**)&dev_sorted_filtered_scores, host_filter_count[i]*sizeof(float)));
  284. // GPU_CHECK(cudaMalloc((void**)&dev_sorted_box_for_nms, NUM_BOX_CORNERS_*host_filter_count[i]*sizeof(float)));
  285. thrust::sequence(thrust::device, dev_indexes, dev_indexes + host_filter_count[i]);
  286. thrust::sort_by_key(thrust::device,
  287. &dev_filtered_score[i * num_anchor_per_cls_],
  288. &dev_filtered_score[i * num_anchor_per_cls_ + host_filter_count[i]],
  289. dev_indexes,
  290. thrust::greater<float>());
  291. const int num_blocks = DIVUP(host_filter_count[i], num_threads_);
  292. sort_boxes_by_indexes_kernel<<<num_blocks, num_threads_>>>(
  293. &dev_filtered_box[i * num_anchor_per_cls_ * num_output_box_feature_],
  294. &dev_filtered_score[i * num_anchor_per_cls_],
  295. dev_indexes,
  296. host_filter_count[i],
  297. dev_sorted_filtered_box,
  298. dev_sorted_filtered_scores,
  299. num_output_box_feature_);
  300. int num_box_for_nms = min(nms_pre_maxsize_, host_filter_count[i]);
  301. long* keep_inds = new long[num_box_for_nms]; // index of kept box
  302. memset(keep_inds, 0, num_box_for_nms * sizeof(int));
  303. int num_out = 0;
  304. nms_cuda_ptr_->DoNmsCuda(num_box_for_nms, dev_sorted_filtered_box, keep_inds, &num_out);
  305. num_out = min(num_out, nms_post_maxsize_);
  306. float* host_filtered_box = new float[host_filter_count[i] * num_output_box_feature_]();
  307. float* host_filtered_scores = new float[host_filter_count[i]]();
  308. cudaMemcpy(host_filtered_box, dev_sorted_filtered_box, host_filter_count[i] * num_output_box_feature_ * sizeof(float), cudaMemcpyDeviceToHost);
  309. cudaMemcpy(host_filtered_scores, dev_sorted_filtered_scores, host_filter_count[i] * sizeof(float), cudaMemcpyDeviceToHost);
  310. for (int j = 0; j < num_out; ++j) {
  311. out_detection.emplace_back(host_filtered_box[keep_inds[j] * num_output_box_feature_ + 0]);
  312. out_detection.emplace_back(host_filtered_box[keep_inds[j] * num_output_box_feature_ + 1]);
  313. out_detection.emplace_back(host_filtered_box[keep_inds[j] * num_output_box_feature_ + 2]);
  314. out_detection.emplace_back(host_filtered_box[keep_inds[j] * num_output_box_feature_ + 3]);
  315. out_detection.emplace_back(host_filtered_box[keep_inds[j] * num_output_box_feature_ + 4]);
  316. out_detection.emplace_back(host_filtered_box[keep_inds[j] * num_output_box_feature_ + 5]);
  317. out_detection.emplace_back(host_filtered_box[keep_inds[j] * num_output_box_feature_ + 6]);
  318. out_score.emplace_back(host_filtered_scores[keep_inds[j]]);
  319. out_label.emplace_back(i);
  320. }
  321. delete[] keep_inds;
  322. delete[] host_filtered_scores;
  323. delete[] host_filtered_box;
  324. GPU_CHECK(cudaFree(dev_indexes));
  325. GPU_CHECK(cudaFree(dev_sorted_filtered_box));
  326. }
  327. }