postprocess.h 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. /******************************************************************************
  2. * Copyright 2020 The Apollo Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. *****************************************************************************/
  16. /*
  17. * Copyright 2018-2019 Autoware Foundation. All rights reserved.
  18. *
  19. * Licensed under the Apache License, Version 2.0 (the "License");
  20. * you may not use this file except in compliance with the License.
  21. * You may obtain a copy of the License at
  22. *
  23. * http://www.apache.org/licenses/LICENSE-2.0
  24. *
  25. * Unless required by applicable law or agreed to in writing, software
  26. * distributed under the License is distributed on an "AS IS" BASIS,
  27. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. * See the License for the specific language governing permissions and
  29. * limitations under the License.
  30. */
  31. /**
  32. * @author Kosuke Murakami
  33. * @date 2019/02/26
  34. */
  35. /**
  36. * @author Yan haixu
  37. * Contact: just github.com/hova88
  38. * @date 2021/04/30
  39. */
  40. #pragma once
  41. #include <memory>
  42. #include <vector>
  43. #include "nms.h"
  44. class PostprocessCuda {
  45. private:
  46. // initializer list
  47. const int num_threads_;
  48. const float float_min_;
  49. const float float_max_;
  50. const int num_class_;
  51. const int num_anchor_per_cls_;
  52. const float score_threshold_;
  53. const float nms_overlap_threshold_;
  54. const int nms_pre_maxsize_;
  55. const int nms_post_maxsize_;
  56. const int num_box_corners_;
  57. const int num_input_box_feature_;
  58. const int num_output_box_feature_;
  59. const std::vector<std::vector<int>> multihead_label_mapping_;
  60. // end initializer list
  61. std::unique_ptr<NmsCuda> nms_cuda_ptr_;
  62. public:
  63. /**
  64. * @brief Constructor
  65. * @param[in] num_threads Number of threads when launching cuda kernel
  66. * @param[in] float_min The lowest float value
  67. * @param[in] float_max The maximum float value
  68. * @param[in] num_class Number of classes
  69. * @param[in] num_anchor_per_cls Number anchor per category
  70. * @param[in] multihead_label_mapping
  71. * @param[in] score_threshold Score threshold for filtering output
  72. * @param[in] nms_overlap_threshold IOU threshold for NMS
  73. * @param[in] nms_pre_maxsize Maximum number of boxes into NMS
  74. * @param[in] nms_post_maxsize Maximum number of boxes after NMS
  75. * @param[in] num_box_corners Number of box's corner
  76. * @param[in] num_output_box_feature Number of output box's feature
  77. * @details Captital variables never change after the compile, non-capital
  78. * variables could be changed through rosparam
  79. */
  80. PostprocessCuda(const int num_threads,
  81. const float float_min, const float float_max,
  82. const int num_class, const int num_anchor_per_cls,
  83. const std::vector<std::vector<int>> multihead_label_mapping,
  84. const float score_threshold,
  85. const float nms_overlap_threshold,
  86. const int nms_pre_maxsize,
  87. const int nms_post_maxsize,
  88. const int num_box_corners,
  89. const int num_input_box_feature,
  90. const int num_output_box_feature);
  91. ~PostprocessCuda(){}
  92. /**
  93. * @brief Postprocessing for the network output
  94. * @param[in] rpn_box_output Box predictions from the network output
  95. * @param[in] rpn_cls_output Class predictions from the network output
  96. * @param[in] rpn_dir_output Direction predictions from the network output
  97. * @param[in] dev_filtered_box Filtered box predictions
  98. * @param[in] dev_filtered_score Filtered score predictions
  99. * @param[in] dev_filter_count The number of filtered output
  100. * @param[out] out_detection Output bounding boxes
  101. * @param[out] out_label Output labels of objects
  102. * @details dev_* represents device memory allocated variables
  103. */
  104. void DoPostprocessCuda(
  105. float* cls_pred_0,
  106. float* cls_pred_12,
  107. float* cls_pred_34,
  108. float* cls_pred_5,
  109. float* cls_pred_67,
  110. float* cls_pred_89,
  111. const float* box_preds,
  112. float* dev_filtered_box,
  113. float* dev_filtered_score,
  114. int* dev_filter_count,
  115. std::vector<float>& out_detection, std::vector<int>& out_label , std::vector<float>& out_score);
  116. };