centerpoint_config.hpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. // Copyright 2021 TIER IV, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef LIDAR_CENTERPOINT__CENTERPOINT_CONFIG_HPP_
  15. #define LIDAR_CENTERPOINT__CENTERPOINT_CONFIG_HPP_
  16. #include <cstddef>
  17. #include <vector>
  18. namespace centerpoint
  19. {
  20. class CenterPointConfig
  21. {
  22. public:
  23. explicit CenterPointConfig(
  24. const std::size_t class_size, const float point_feature_size, const std::size_t max_voxel_size,
  25. const std::vector<double> & point_cloud_range, const std::vector<double> & voxel_size,
  26. const std::size_t downsample_factor, const std::size_t encoder_in_feature_size,
  27. const float score_threshold, const float circle_nms_dist_threshold)
  28. {
  29. class_size_ = class_size;
  30. point_feature_size_ = point_feature_size;
  31. max_voxel_size_ = max_voxel_size;
  32. if (point_cloud_range.size() == 6) {
  33. range_min_x_ = static_cast<float>(point_cloud_range[0]);
  34. range_min_y_ = static_cast<float>(point_cloud_range[1]);
  35. range_min_z_ = static_cast<float>(point_cloud_range[2]);
  36. range_max_x_ = static_cast<float>(point_cloud_range[3]);
  37. range_max_y_ = static_cast<float>(point_cloud_range[4]);
  38. range_max_z_ = static_cast<float>(point_cloud_range[5]);
  39. }
  40. if (voxel_size.size() == 3) {
  41. voxel_size_x_ = static_cast<float>(voxel_size[0]);
  42. voxel_size_y_ = static_cast<float>(voxel_size[1]);
  43. voxel_size_z_ = static_cast<float>(voxel_size[2]);
  44. }
  45. downsample_factor_ = downsample_factor;
  46. encoder_in_feature_size_ = encoder_in_feature_size;
  47. if (score_threshold > 0 && score_threshold < 1) {
  48. score_threshold_ = score_threshold;
  49. }
  50. if (circle_nms_dist_threshold > 0) {
  51. circle_nms_dist_threshold_ = circle_nms_dist_threshold;
  52. }
  53. grid_size_x_ = static_cast<std::size_t>((range_max_x_ - range_min_x_) / voxel_size_x_);
  54. grid_size_y_ = static_cast<std::size_t>((range_max_y_ - range_min_y_) / voxel_size_y_);
  55. grid_size_z_ = static_cast<std::size_t>((range_max_z_ - range_min_z_) / voxel_size_z_);
  56. offset_x_ = range_min_x_ + voxel_size_x_ / 2;
  57. offset_y_ = range_min_y_ + voxel_size_y_ / 2;
  58. offset_z_ = range_min_z_ + voxel_size_z_ / 2;
  59. down_grid_size_x_ = grid_size_x_ / downsample_factor_;
  60. down_grid_size_y_ = grid_size_y_ / downsample_factor_;
  61. };
  62. // input params
  63. std::size_t class_size_{3};
  64. const std::size_t point_dim_size_{3}; // x, y and z
  65. std::size_t point_feature_size_{4}; // x, y, z and timelag
  66. std::size_t max_point_in_voxel_size_{32};
  67. std::size_t max_voxel_size_{40000};
  68. float range_min_x_{-89.6f};
  69. float range_min_y_{-89.6f};
  70. float range_min_z_{-3.0f};
  71. float range_max_x_{89.6f};
  72. float range_max_y_{89.6f};
  73. float range_max_z_{5.0f};
  74. float voxel_size_x_{0.32f};
  75. float voxel_size_y_{0.32f};
  76. float voxel_size_z_{8.0f};
  77. // network params
  78. const std::size_t batch_size_{1};
  79. std::size_t downsample_factor_{2};
  80. std::size_t encoder_in_feature_size_{9};
  81. const std::size_t encoder_out_feature_size_{32};
  82. const std::size_t head_out_size_{6};
  83. const std::size_t head_out_offset_size_{2};
  84. const std::size_t head_out_z_size_{1};
  85. const std::size_t head_out_dim_size_{3};
  86. const std::size_t head_out_rot_size_{2};
  87. const std::size_t head_out_vel_size_{2};
  88. // post-process params
  89. float score_threshold_{0.4f};
  90. float circle_nms_dist_threshold_{1.5f};
  91. // calculated params
  92. std::size_t grid_size_x_ = (range_max_x_ - range_min_x_) / voxel_size_x_;
  93. std::size_t grid_size_y_ = (range_max_y_ - range_min_y_) / voxel_size_y_;
  94. std::size_t grid_size_z_ = (range_max_z_ - range_min_z_) / voxel_size_z_;
  95. float offset_x_ = range_min_x_ + voxel_size_x_ / 2;
  96. float offset_y_ = range_min_y_ + voxel_size_y_ / 2;
  97. float offset_z_ = range_min_z_ + voxel_size_z_ / 2;
  98. std::size_t down_grid_size_x_ = grid_size_x_ / downsample_factor_;
  99. std::size_t down_grid_size_y_ = grid_size_y_ / downsample_factor_;
  100. };
  101. } // namespace centerpoint
  102. #endif // LIDAR_CENTERPOINT__CENTERPOINT_CONFIG_HPP_