scatter_kernel.cu 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. // Copyright 2022 TIER IV, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "lidar_centerpoint/network/scatter_kernel.hpp"
  15. #include <lidar_centerpoint/utils.hpp>
  16. namespace
  17. {
  18. const std::size_t THREADS_PER_BLOCK = 32;
  19. } // namespace
  20. namespace centerpoint
  21. {
  22. __global__ void scatterFeatures_kernel(
  23. const float * pillar_features, const int * coords, const std::size_t num_pillars,
  24. const std::size_t pillar_feature_size, const std::size_t grid_size_x,
  25. const std::size_t grid_size_y, float * scattered_features)
  26. {
  27. // pillar_features: shape of (max_num_pillars, pillar_feature_size)
  28. // coords: shape of (max_num_pillars, 3)
  29. // scattered_features: shape of (num_pillars, grid_size_y, grid_size_x)
  30. const auto pillar_i = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
  31. const auto feature_i = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
  32. if (pillar_i >= num_pillars || feature_i >= pillar_feature_size) {
  33. return;
  34. }
  35. const int3 coord = ((int3 *)coords)[pillar_i]; // zyx
  36. if (coord.x < 0) {
  37. return;
  38. }
  39. const auto feature = pillar_features[pillar_feature_size * pillar_i + feature_i];
  40. scattered_features[grid_size_y * grid_size_x * feature_i + grid_size_x * coord.y + coord.z] =
  41. feature;
  42. }
  43. cudaError_t scatterFeatures_launch(
  44. const float * pillar_features, const int * coords, const std::size_t num_pillars,
  45. const std::size_t max_voxel_size, const std::size_t encoder_out_feature_size,
  46. const std::size_t grid_size_x, const std::size_t grid_size_y, float * scattered_features,
  47. cudaStream_t stream)
  48. {
  49. dim3 blocks(
  50. divup(max_voxel_size, THREADS_PER_BLOCK), divup(encoder_out_feature_size, THREADS_PER_BLOCK));
  51. dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
  52. scatterFeatures_kernel<<<blocks, threads, 0, stream>>>(
  53. pillar_features, coords, num_pillars, encoder_out_feature_size, grid_size_x, grid_size_y,
  54. scattered_features);
  55. return cudaGetLastError();
  56. }
  57. } // namespace centerpoint