centerpoint_trt.cpp 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. // Copyright 2021 TIER IV, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "lidar_centerpoint/centerpoint_trt.hpp"
  15. #include <lidar_centerpoint/centerpoint_config.hpp>
  16. #include <lidar_centerpoint/network/scatter_kernel.hpp>
  17. #include <lidar_centerpoint/preprocess/preprocess_kernel.hpp>
  18. //#include <tier4_autoware_utils/math/constants.hpp>
  19. #include <iostream>
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. namespace centerpoint
  24. {
  25. CenterPointTRT::CenterPointTRT(
  26. const NetworkParam & encoder_param, const NetworkParam & head_param,
  27. const DensificationParam & densification_param, const CenterPointConfig & config)
  28. : config_(config)
  29. {
  30. vg_ptr_ = std::make_unique<VoxelGenerator>(densification_param, config_);
  31. post_proc_ptr_ = std::make_unique<PostProcessCUDA>(config_);
  32. // encoder
  33. encoder_trt_ptr_ = std::make_unique<VoxelEncoderTRT>(config_, verbose_);
  34. encoder_trt_ptr_->init(
  35. encoder_param.onnx_path(), encoder_param.engine_path(), encoder_param.trt_precision());
  36. encoder_trt_ptr_->context_->setBindingDimensions(
  37. 0,
  38. nvinfer1::Dims3(
  39. config_.max_voxel_size_, config_.max_point_in_voxel_size_, config_.encoder_in_feature_size_));
  40. // head
  41. std::vector<std::size_t> out_channel_sizes = {
  42. config_.class_size_, config_.head_out_offset_size_, config_.head_out_z_size_,
  43. config_.head_out_dim_size_, config_.head_out_rot_size_, config_.head_out_vel_size_};
  44. head_trt_ptr_ = std::make_unique<HeadTRT>(out_channel_sizes, config_, verbose_);
  45. head_trt_ptr_->init(head_param.onnx_path(), head_param.engine_path(), head_param.trt_precision());
  46. head_trt_ptr_->context_->setBindingDimensions(
  47. 0, nvinfer1::Dims4(
  48. config_.batch_size_, config_.encoder_out_feature_size_, config_.grid_size_y_,
  49. config_.grid_size_x_));
  50. initPtr();
  51. cudaStreamCreate(&stream_);
  52. }
  53. CenterPointTRT::~CenterPointTRT()
  54. {
  55. if (stream_) {
  56. cudaStreamSynchronize(stream_);
  57. cudaStreamDestroy(stream_);
  58. }
  59. }
  60. void CenterPointTRT::initPtr()
  61. {
  62. const auto voxels_size =
  63. config_.max_voxel_size_ * config_.max_point_in_voxel_size_ * config_.point_feature_size_;
  64. const auto coordinates_size = config_.max_voxel_size_ * config_.point_dim_size_;
  65. encoder_in_feature_size_ =
  66. config_.max_voxel_size_ * config_.max_point_in_voxel_size_ * config_.encoder_in_feature_size_;
  67. const auto pillar_features_size = config_.max_voxel_size_ * config_.encoder_out_feature_size_;
  68. spatial_features_size_ =
  69. config_.grid_size_x_ * config_.grid_size_y_ * config_.encoder_out_feature_size_;
  70. const auto grid_xy_size = config_.down_grid_size_x_ * config_.down_grid_size_y_;
  71. // host
  72. voxels_.resize(voxels_size);
  73. coordinates_.resize(coordinates_size);
  74. num_points_per_voxel_.resize(config_.max_voxel_size_);
  75. // device
  76. voxels_d_ = cuda::make_unique<float[]>(voxels_size);
  77. coordinates_d_ = cuda::make_unique<int[]>(coordinates_size);
  78. num_points_per_voxel_d_ = cuda::make_unique<float[]>(config_.max_voxel_size_);
  79. encoder_in_features_d_ = cuda::make_unique<float[]>(encoder_in_feature_size_);
  80. pillar_features_d_ = cuda::make_unique<float[]>(pillar_features_size);
  81. spatial_features_d_ = cuda::make_unique<float[]>(spatial_features_size_);
  82. head_out_heatmap_d_ = cuda::make_unique<float[]>(grid_xy_size * config_.class_size_);
  83. head_out_offset_d_ = cuda::make_unique<float[]>(grid_xy_size * config_.head_out_offset_size_);
  84. head_out_z_d_ = cuda::make_unique<float[]>(grid_xy_size * config_.head_out_z_size_);
  85. head_out_dim_d_ = cuda::make_unique<float[]>(grid_xy_size * config_.head_out_dim_size_);
  86. head_out_rot_d_ = cuda::make_unique<float[]>(grid_xy_size * config_.head_out_rot_size_);
  87. head_out_vel_d_ = cuda::make_unique<float[]>(grid_xy_size * config_.head_out_vel_size_);
  88. }
  89. bool CenterPointTRT::detect(
  90. const sensor_msgs::msg::PointCloud2 & input_pointcloud_msg, const tf2_ros::Buffer & tf_buffer,
  91. std::vector<Box3D> & det_boxes3d)
  92. {
  93. std::fill(voxels_.begin(), voxels_.end(), 0);
  94. std::fill(coordinates_.begin(), coordinates_.end(), -1);
  95. std::fill(num_points_per_voxel_.begin(), num_points_per_voxel_.end(), 0);
  96. CHECK_CUDA_ERROR(cudaMemsetAsync(
  97. encoder_in_features_d_.get(), 0, encoder_in_feature_size_ * sizeof(float), stream_));
  98. CHECK_CUDA_ERROR(
  99. cudaMemsetAsync(spatial_features_d_.get(), 0, spatial_features_size_ * sizeof(float), stream_));
  100. if (!preprocess(input_pointcloud_msg, tf_buffer)) {
  101. RCLCPP_WARN_STREAM(
  102. rclcpp::get_logger("lidar_centerpoint"), "Fail to preprocess and skip to detect.");
  103. return false;
  104. }
  105. inference();
  106. postProcess(det_boxes3d);
  107. return true;
  108. }
  109. bool CenterPointTRT::preprocess(
  110. const sensor_msgs::msg::PointCloud2 & input_pointcloud_msg, const tf2_ros::Buffer & tf_buffer)
  111. {
  112. bool is_success = vg_ptr_->enqueuePointCloud(input_pointcloud_msg, tf_buffer);
  113. if (!is_success) {
  114. return false;
  115. }
  116. num_voxels_ = vg_ptr_->pointsToVoxels(voxels_, coordinates_, num_points_per_voxel_);
  117. if (num_voxels_ == 0) {
  118. return false;
  119. }
  120. const auto voxels_size =
  121. num_voxels_ * config_.max_point_in_voxel_size_ * config_.point_feature_size_;
  122. const auto coordinates_size = num_voxels_ * config_.point_dim_size_;
  123. // memcpy from host to device (not copy empty voxels)
  124. CHECK_CUDA_ERROR(cudaMemcpyAsync(
  125. voxels_d_.get(), voxels_.data(), voxels_size * sizeof(float), cudaMemcpyHostToDevice));
  126. CHECK_CUDA_ERROR(cudaMemcpyAsync(
  127. coordinates_d_.get(), coordinates_.data(), coordinates_size * sizeof(int),
  128. cudaMemcpyHostToDevice));
  129. CHECK_CUDA_ERROR(cudaMemcpyAsync(
  130. num_points_per_voxel_d_.get(), num_points_per_voxel_.data(), num_voxels_ * sizeof(float),
  131. cudaMemcpyHostToDevice));
  132. CHECK_CUDA_ERROR(cudaStreamSynchronize(stream_));
  133. CHECK_CUDA_ERROR(generateFeatures_launch(
  134. voxels_d_.get(), num_points_per_voxel_d_.get(), coordinates_d_.get(), num_voxels_,
  135. config_.max_voxel_size_, config_.voxel_size_x_, config_.voxel_size_y_, config_.voxel_size_z_,
  136. config_.range_min_x_, config_.range_min_y_, config_.range_min_z_, encoder_in_features_d_.get(),
  137. stream_));
  138. return true;
  139. }
  140. void CenterPointTRT::inference()
  141. {
  142. if (!encoder_trt_ptr_->context_ || !head_trt_ptr_->context_) {
  143. throw std::runtime_error("Failed to create tensorrt context.");
  144. }
  145. // pillar encoder network
  146. std::vector<void *> encoder_buffers{encoder_in_features_d_.get(), pillar_features_d_.get()};
  147. encoder_trt_ptr_->context_->enqueueV2(encoder_buffers.data(), stream_, nullptr);
  148. // scatter
  149. CHECK_CUDA_ERROR(scatterFeatures_launch(
  150. pillar_features_d_.get(), coordinates_d_.get(), num_voxels_, config_.max_voxel_size_,
  151. config_.encoder_out_feature_size_, config_.grid_size_x_, config_.grid_size_y_,
  152. spatial_features_d_.get(), stream_));
  153. // head network
  154. std::vector<void *> head_buffers = {spatial_features_d_.get(), head_out_heatmap_d_.get(),
  155. head_out_offset_d_.get(), head_out_z_d_.get(),
  156. head_out_dim_d_.get(), head_out_rot_d_.get(),
  157. head_out_vel_d_.get()};
  158. head_trt_ptr_->context_->enqueueV2(head_buffers.data(), stream_, nullptr);
  159. }
  160. void CenterPointTRT::postProcess(std::vector<Box3D> & det_boxes3d)
  161. {
  162. CHECK_CUDA_ERROR(post_proc_ptr_->generateDetectedBoxes3D_launch(
  163. head_out_heatmap_d_.get(), head_out_offset_d_.get(), head_out_z_d_.get(), head_out_dim_d_.get(),
  164. head_out_rot_d_.get(), head_out_vel_d_.get(), det_boxes3d, stream_));
  165. if (det_boxes3d.size() == 0) {
  166. RCLCPP_WARN_STREAM(rclcpp::get_logger("lidar_centerpoint"), "No detected boxes.");
  167. }
  168. }
  169. } // namespace centerpoint