Browse Source

first commit centerpoint_paddle

liyupeng 7 months ago
parent
commit
6be9acaa3f

+ 247 - 0
src/detection/centerpoint_paddle/CMakeLists.txt

@@ -0,0 +1,247 @@
+cmake_minimum_required(VERSION 3.0)
+project(cpp_inference_demo CXX C)
+option(WITH_MKL        "Compile demo with MKL/OpenBlas support, default use MKL."       ON)
+option(WITH_GPU        "Compile demo with GPU/CPU, default use CPU."                    ON)
+option(USE_TENSORRT "Compile demo with TensorRT."   ON)
+option(CUSTOM_OPERATOR_FILES "List of file names for custom operators" "")
+
+execute_process(COMMAND ${CMAKE_C_COMPILER} -dumpfullversion -dumpversion
+                OUTPUT_VARIABLE GCC_VERSION)
+string(REGEX MATCHALL "[0-9]+" GCC_VERSION_COMPONENTS ${GCC_VERSION})
+list(GET GCC_VERSION_COMPONENTS 0 GCC_MAJOR)
+list(GET GCC_VERSION_COMPONENTS 1 GCC_MINOR)
+set(GCC_VERSION "${GCC_MAJOR}.${GCC_MINOR}")
+if (GCC_VERSION LESS "8.0")
+    set(CMAKE_CXX_FLAGS "-Wl,--no-as-needed")
+endif()
+
+set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
+#include(external/boost)
+
+if(WITH_GPU)
+  find_package(CUDA REQUIRED)
+  add_definitions("-DPADDLE_WITH_CUDA")
+endif()
+
+if(NOT WITH_STATIC_LIB)
+  add_definitions("-DPADDLE_WITH_SHARED_LIB")
+else()
+  # PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode.
+  # Set it to empty in static library mode to avoid compilation issues.
+  add_definitions("/DPD_INFER_DECL=")
+endif()
+
+macro(safe_set_static_flag)
+    foreach(flag_var
+        CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
+        CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
+      if(${flag_var} MATCHES "/MD")
+        string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
+      endif(${flag_var} MATCHES "/MD")
+    endforeach(flag_var)
+endmacro()
+
+if(NOT DEFINED PADDLE_LIB)
+  message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib")
+endif()
+if(NOT DEFINED DEMO_NAME)
+  message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name")
+endif()
+
+include_directories("${PADDLE_LIB}/")
+set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
+#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/include")
+include_directories("/home/nvidia/modularization/include")
+include_directories("/home/nvidia/modularization/src/include/msgtype")
+
+
+#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
+link_directories("${PADDLE_LIB}/paddle/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib")
+link_directories("/home/nvidia/modularization/bin")
+
+find_package(Protobuf REQUIRED)
+include_directories(Protobuf_INCLUD_DIRS)
+
+find_package(Qt5Core REQUIRED)
+
+find_package(PCL REQUIRED)
+include_directories(PCL_INCLUD_DIRS)
+
+find_package(Boost REQUIRED system)
+
+find_package(OpenCV REQUIRED)
+include_directories(OpenCV_INCLUD_DIRS)
+
+
+if (WIN32)
+  add_definitions("/DGOOGLE_GLOG_DLL_DECL=")
+  option(MSVC_STATIC_CRT "use static C Runtime library by default" ON)
+  if (MSVC_STATIC_CRT)
+    if (WITH_MKL)
+      set(FLAG_OPENMP "/openmp")
+    endif()
+    set(CMAKE_C_FLAGS_DEBUG   "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}")
+    set(CMAKE_C_FLAGS_RELEASE  "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}")
+    set(CMAKE_CXX_FLAGS_DEBUG  "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}")
+    set(CMAKE_CXX_FLAGS_RELEASE   "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}")
+    safe_set_static_flag()
+    if (WITH_STATIC_LIB)
+      add_definitions(-DSTATIC_LIB)
+    endif()
+  endif()
+else()
+  if(WITH_MKL)
+    set(FLAG_OPENMP "-fopenmp")
+  endif()
+  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 ${FLAG_OPENMP}")
+endif()
+
+if(WITH_GPU)
+  if(NOT WIN32)
+    set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library")
+  else()
+    if(CUDA_LIB STREQUAL "")
+      set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64")
+    endif()
+  endif(NOT WIN32)
+endif()
+
+if (USE_TENSORRT AND WITH_GPU)
+  set(TENSORRT_ROOT "" CACHE STRING "The root directory of TensorRT library")
+  if("${TENSORRT_ROOT}" STREQUAL "")
+      message(FATAL_ERROR "The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH ")
+  endif()
+  set(TENSORRT_INCLUDE_DIR ${TENSORRT_ROOT}/include)
+  set(TENSORRT_LIB_DIR ${TENSORRT_ROOT}/lib)
+endif()
+
+if (NOT WIN32)
+  if (USE_TENSORRT AND WITH_GPU)
+      include_directories("${TENSORRT_INCLUDE_DIR}")
+      link_directories("${TENSORRT_LIB_DIR}")
+  endif()
+endif(NOT WIN32)
+
+if(WITH_MKL)
+  set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
+  include_directories("${MATH_LIB_PATH}/include")
+  if(WIN32)
+    set(MATH_LIB ${MATH_LIB_PATH}/lib/mklml${CMAKE_STATIC_LIBRARY_SUFFIX}
+                 ${MATH_LIB_PATH}/lib/libiomp5md${CMAKE_STATIC_LIBRARY_SUFFIX})
+  else()
+    set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
+                 ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
+  endif()
+  set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
+  if(EXISTS ${MKLDNN_PATH})
+    include_directories("${MKLDNN_PATH}/include")
+    if(WIN32)
+      set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib)
+    else(WIN32)
+      set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
+    endif(WIN32)
+  endif()
+else()
+  set(OPENBLAS_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}openblas")
+  include_directories("${OPENBLAS_LIB_PATH}/include/openblas")
+  if(WIN32)
+    set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX})
+  else()
+    set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
+  endif()
+endif()
+
+if(WITH_STATIC_LIB)
+  set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
+else()
+  if(WIN32)
+    set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
+  else()
+    set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
+  endif()
+endif()
+
+
+if (NOT WIN32)
+  if (GCC_VERSION LESS "8.0")
+      set(EXTERNAL_LIB ${EXTERNAL_LIB} "-lssl -lcrypto -lz -lleveldb -lsnappy")
+  endif()
+  set(EXTERNAL_LIB ${EXTERNAL_LIB} "-lrt -ldl -lpthread")
+  set(DEPS ${DEPS}
+      ${MATH_LIB} ${MKLDNN_LIB}
+      glog gflags protobuf  xxhash
+      ${EXTERNAL_LIB})
+else()
+  set(DEPS ${DEPS}
+      ${MATH_LIB} ${MKLDNN_LIB}
+      glog gflags_static libprotobuf  xxhash ${EXTERNAL_LIB})
+  set(DEPS ${DEPS} shlwapi.lib)
+endif(NOT WIN32)
+
+if(WITH_GPU)
+  if(NOT WIN32)
+    if (USE_TENSORRT)
+      set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
+      set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
+    endif()
+    set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
+  else()
+    if(USE_TENSORRT)
+      set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_STATIC_LIBRARY_SUFFIX})
+      set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX})
+    endif()
+    set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} )
+    set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} )
+    set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} )
+  endif()
+endif()
+
+cuda_add_library(pd_infer_custom_op ${CUSTOM_OPERATOR_FILES} SHARED)
+add_executable(${DEMO_NAME} main.cc
+    /home/nvidia/modularization/src/include/msgtype/object.pb.cc
+    /home/nvidia/modularization/src/include/msgtype/objectarray.pb.cc)
+
+if (GCC_VERSION GREATER_EQUAL "8.0")
+    set(DEPS ${DEPS} libssl.a libcrypto.a libz.a libleveldb.a libsnappy.a)
+endif()
+set(DEPS ${DEPS} Boost::system pd_infer_custom_op)# libssl.a libcrypto.a libz.a libleveldb.a libsnappy.a)
+
+if(WIN32)
+  if(USE_TENSORRT)
+    add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+            COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}
+              ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
+            COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}
+              ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
+    )
+  endif()
+  if(WITH_MKL)
+    add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+          COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/mklml.dll ${CMAKE_BINARY_DIR}/Release
+          COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/Release
+          COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_PATH}/lib/mkldnn.dll  ${CMAKE_BINARY_DIR}/Release
+    )
+  else()
+    add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+          COMMAND ${CMAKE_COMMAND} -E copy ${OPENBLAS_LIB_PATH}/lib/openblas.dll ${CMAKE_BINARY_DIR}/Release
+    )
+  endif()
+  if(NOT WITH_STATIC_LIB)
+      add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+        COMMAND ${CMAKE_COMMAND} -E copy "${PADDLE_LIB}/paddle/lib/paddle_fluid.dll" ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
+      )
+  endif()
+endif()
+
+target_link_libraries(${DEMO_NAME} ${DEPS} ${PROTOBUF_LIBRARIES} Qt5::Core ${PCL_LIBRARIES} ${OpenCV_LIBRARIES}
+    /home/nvidia/modularization/bin/libmodulecomm.so /home/nvidia/centerpoint/cpp/build/libpd_infer_custom_op.so)

+ 46 - 0
src/detection/centerpoint_paddle/cmake/external/boost.cmake

@@ -0,0 +1,46 @@
+include(ExternalProject)
+
+set(BOOST_PROJECT       "extern_boost")
+# To release PaddlePaddle as a pip package, we have to follow the
+# manylinux1 standard, which features as old Linux kernels and
+# compilers as possible and recommends CentOS 5. Indeed, the earliest
+# CentOS version that works with NVIDIA CUDA is CentOS 6.  And a new
+# version of boost, say, 1.66.0, doesn't build on CentOS 6.  We
+# checked that the devtools package of CentOS 6 installs boost 1.41.0.
+# So we use 1.41.0 here.
+set(BOOST_VER           "1.55.0")
+set(BOOST_TAR "boost_1_55_0" CACHE STRING "" FORCE)
+set(BOOST_URL "http://paddlepaddledeps.bj.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
+
+MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}")
+
+set(BOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/boost)
+set(BOOST_DOWNLOAD_DIR  "${BOOST_SOURCES_DIR}/src/${BOOST_PROJECT}")
+
+set(BOOST_INCLUDE_DIR "${BOOST_DOWNLOAD_DIR}" CACHE PATH "boost include directory." FORCE)
+set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1)
+include_directories(${BOOST_INCLUDE_DIR})
+
+ExternalProject_Add(
+    ${BOOST_PROJECT}
+    ${EXTERNAL_PROJECT_LOG_ARGS}
+    DOWNLOAD_DIR          ${BOOST_DOWNLOAD_DIR}
+    URL      ${BOOST_URL}
+    DOWNLOAD_NO_PROGRESS  1
+    PREFIX                ${BOOST_SOURCES_DIR}
+    CONFIGURE_COMMAND     ""
+    BUILD_COMMAND         ""
+    INSTALL_COMMAND       ""
+    UPDATE_COMMAND        ""
+    )
+
+if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32)
+    set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c)
+    file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";")
+    add_library(boost STATIC ${dummyfile})
+else()
+    add_library(boost INTERFACE)
+endif()
+
+add_dependencies(boost ${BOOST_PROJECT})
+set(Boost_INCLUDE_DIR ${BOOST_INCLUDE_DIR})

+ 46 - 0
src/detection/centerpoint_paddle/compile.sh

@@ -0,0 +1,46 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+mkdir -p build
+cd build
+rm -rf *
+
+sudo chmod 777 build
+
+DEMO_NAME=centerpoint_paddle
+
+WITH_MKL=OFF
+WITH_GPU=ON
+USE_TENSORRT=OFF
+
+LIB_DIR=/home/nvidia/Paddle/paddle_inference_install_dir
+CUDNN_LIB=/usr/lib/aarch64-linux-gnu
+CUDA_LIB=/usr/local/cuda-11.4/targets/aarch64-linux/lib
+TENSORRT_ROOT=/home/nvidia/Paddle/tensorrt
+CUSTOM_OPERATOR_FILES="custom_ops/voxelize_op.cu;custom_ops/voxelize_op.cc;custom_ops/iou3d_nms_kernel.cu;custom_ops/postprocess.cc;custom_ops/postprocess.cu"
+
+
+cmake .. -DPADDLE_LIB=${LIB_DIR} \
+  -DWITH_MKL=${WITH_MKL} \
+  -DDEMO_NAME=${DEMO_NAME} \
+  -DWITH_GPU=${WITH_GPU} \
+  -DWITH_STATIC_LIB=OFF \
+  -DUSE_TENSORRT=${USE_TENSORRT} \
+  -DCUDNN_LIB=${CUDNN_LIB} \
+  -DCUDA_LIB=${CUDA_LIB} \
+  -DTENSORRT_ROOT=${TENSORRT_ROOT} \
+  -DCUSTOM_OPERATOR_FILES=${CUSTOM_OPERATOR_FILES}
+
+make -j
+

+ 352 - 0
src/detection/centerpoint_paddle/custom_ops/iou3d_nms_kernel.cu

@@ -0,0 +1,352 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
+Written by Shaoshuai Shi
+All Rights Reserved 2019-2020.
+*/
+
+#include <stdio.h>
+#define THREADS_PER_BLOCK 16
+#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
+
+// #define DEBUG
+const int THREADS_PER_BLOCK_NMS = sizeof(int64_t) * 8;
+const float EPS = 1e-8;
+struct Point {
+  float x, y;
+  __device__ Point() {}
+  __device__ Point(double _x, double _y) { x = _x, y = _y; }
+
+  __device__ void set(float _x, float _y) {
+    x = _x;
+    y = _y;
+  }
+
+  __device__ Point operator+(const Point &b) const {
+    return Point(x + b.x, y + b.y);
+  }
+
+  __device__ Point operator-(const Point &b) const {
+    return Point(x - b.x, y - b.y);
+  }
+};
+
+__device__ inline float cross(const Point &a, const Point &b) {
+  return a.x * b.y - a.y * b.x;
+}
+
+__device__ inline float cross(const Point &p1, const Point &p2,
+                              const Point &p0) {
+  return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
+}
+
+__device__ int check_rect_cross(const Point &p1, const Point &p2,
+                                const Point &q1, const Point &q2) {
+  int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&
+            min(q1.x, q2.x) <= max(p1.x, p2.x) &&
+            min(p1.y, p2.y) <= max(q1.y, q2.y) &&
+            min(q1.y, q2.y) <= max(p1.y, p2.y);
+  return ret;
+}
+
+__device__ inline int check_in_box2d(const float *box, const Point &p) {
+  // params: (7) [x, y, z, dx, dy, dz, heading]
+  const float MARGIN = 1e-2;
+
+  float center_x = box[0], center_y = box[1];
+  // rotate the point in the opposite direction of box
+  float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]);
+  float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
+  float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
+
+  return (fabs(rot_x) < box[3] / 2 + MARGIN &&
+          fabs(rot_y) < box[4] / 2 + MARGIN);
+}
+
+__device__ inline int intersection(const Point &p1, const Point &p0,
+                                   const Point &q1, const Point &q0,
+                                   Point *ans) {
+  // fast exclusion
+  if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
+
+  // check cross standing
+  float s1 = cross(q0, p1, p0);
+  float s2 = cross(p1, q1, p0);
+  float s3 = cross(p0, q1, q0);
+  float s4 = cross(q1, p1, q0);
+
+  if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0;
+
+  // calculate intersection of two lines
+  float s5 = cross(q1, p1, p0);
+  if (fabs(s5 - s1) > EPS) {
+    ans->x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
+    ans->y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
+
+  } else {
+    float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
+    float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
+    float D = a0 * b1 - a1 * b0;
+
+    ans->x = (b0 * c1 - b1 * c0) / D;
+    ans->y = (a1 * c0 - a0 * c1) / D;
+  }
+
+  return 1;
+}
+
+__device__ inline void rotate_around_center(const Point &center,
+                                            const float angle_cos,
+                                            const float angle_sin, Point *p) {
+  float new_x = (p->x - center.x) * angle_cos +
+                (p->y - center.y) * (-angle_sin) + center.x;
+  float new_y =
+      (p->x - center.x) * angle_sin + (p->y - center.y) * angle_cos + center.y;
+  p->set(new_x, new_y);
+}
+
+__device__ inline int point_cmp(const Point &a, const Point &b,
+                                const Point &center) {
+  return atan2(a.y - center.y, a.x - center.x) >
+         atan2(b.y - center.y, b.x - center.x);
+}
+
+__device__ inline float box_overlap(const float *box_a, const float *box_b) {
+  // params box_a: [x, y, z, dx, dy, dz, heading]
+  // params box_b: [x, y, z, dx, dy, dz, heading]
+
+  float a_angle = box_a[6], b_angle = box_b[6];
+  float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2,
+        a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
+  float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
+  float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
+  float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
+  float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;
+
+  Point center_a(box_a[0], box_a[1]);
+  Point center_b(box_b[0], box_b[1]);
+
+#ifdef DEBUG
+  printf(
+      "a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n",
+      a_x1, a_y1, a_x2, a_y2, a_angle, b_x1, b_y1, b_x2, b_y2, b_angle);
+  printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y,
+         center_b.x, center_b.y);
+#endif
+
+  Point box_a_corners[5];
+  box_a_corners[0].set(a_x1, a_y1);
+  box_a_corners[1].set(a_x2, a_y1);
+  box_a_corners[2].set(a_x2, a_y2);
+  box_a_corners[3].set(a_x1, a_y2);
+
+  Point box_b_corners[5];
+  box_b_corners[0].set(b_x1, b_y1);
+  box_b_corners[1].set(b_x2, b_y1);
+  box_b_corners[2].set(b_x2, b_y2);
+  box_b_corners[3].set(b_x1, b_y2);
+
+  // get oriented corners
+  float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
+  float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
+
+  for (int k = 0; k < 4; k++) {
+#ifdef DEBUG
+    printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k,
+           box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x,
+           box_b_corners[k].y);
+#endif
+    rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners + k);
+    rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners + k);
+#ifdef DEBUG
+    printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x,
+           box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y);
+#endif
+  }
+
+  box_a_corners[4] = box_a_corners[0];
+  box_b_corners[4] = box_b_corners[0];
+
+  // get intersection of lines
+  Point cross_points[16];
+  Point poly_center;
+  int cnt = 0, flag = 0;
+
+  poly_center.set(0, 0);
+  for (int i = 0; i < 4; i++) {
+    for (int j = 0; j < 4; j++) {
+      flag = intersection(box_a_corners[i + 1], box_a_corners[i],
+                          box_b_corners[j + 1], box_b_corners[j],
+                          cross_points + cnt);
+      if (flag) {
+        poly_center = poly_center + cross_points[cnt];
+        cnt++;
+#ifdef DEBUG
+        printf(
+            "Cross points (%.3f, %.3f): a(%.3f, %.3f)->(%.3f, %.3f), "
+            "b(%.3f, %.3f)->(%.3f, %.3f) \n",
+            cross_points[cnt - 1].x, cross_points[cnt - 1].y,
+            box_a_corners[i].x, box_a_corners[i].y, box_a_corners[i + 1].x,
+            box_a_corners[i + 1].y, box_b_corners[i].x, box_b_corners[i].y,
+            box_b_corners[i + 1].x, box_b_corners[i + 1].y);
+#endif
+      }
+    }
+  }
+
+  // check corners
+  for (int k = 0; k < 4; k++) {
+    if (check_in_box2d(box_a, box_b_corners[k])) {
+      poly_center = poly_center + box_b_corners[k];
+      cross_points[cnt] = box_b_corners[k];
+      cnt++;
+#ifdef DEBUG
+      printf("b corners in a: corner_b(%.3f, %.3f)", cross_points[cnt - 1].x,
+             cross_points[cnt - 1].y);
+#endif
+    }
+    if (check_in_box2d(box_b, box_a_corners[k])) {
+      poly_center = poly_center + box_a_corners[k];
+      cross_points[cnt] = box_a_corners[k];
+      cnt++;
+#ifdef DEBUG
+      printf("a corners in b: corner_a(%.3f, %.3f)", cross_points[cnt - 1].x,
+             cross_points[cnt - 1].y);
+#endif
+    }
+  }
+
+  poly_center.x /= cnt;
+  poly_center.y /= cnt;
+
+  // sort the points of polygon
+  Point temp;
+  for (int j = 0; j < cnt - 1; j++) {
+    for (int i = 0; i < cnt - j - 1; i++) {
+      if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {
+        temp = cross_points[i];
+        cross_points[i] = cross_points[i + 1];
+        cross_points[i + 1] = temp;
+      }
+    }
+  }
+
+#ifdef DEBUG
+  printf("cnt=%d\n", cnt);
+  for (int i = 0; i < cnt; i++) {
+    printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x,
+           cross_points[i].y);
+  }
+#endif
+
+  // get the overlap areas
+  float area = 0;
+  for (int k = 0; k < cnt - 1; k++) {
+    area += cross(cross_points[k] - cross_points[0],
+                  cross_points[k + 1] - cross_points[0]);
+  }
+
+  return fabs(area) / 2.0;
+}
+
+__device__ inline float iou_bev(const float *box_a, const float *box_b) {
+  // params box_a: [x, y, z, dx, dy, dz, heading]
+  // params box_b: [x, y, z, dx, dy, dz, heading]
+  float sa = box_a[3] * box_a[4];
+  float sb = box_b[3] * box_b[4];
+  float s_overlap = box_overlap(box_a, box_b);
+  return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
+}
+
+__global__ void nms_kernel(const int num_bboxes, const int num_bboxes_for_nms,
+                           const float nms_overlap_thresh,
+                           const int decode_bboxes_dims, const float *bboxes,
+                           const int *index, const int64_t *sorted_index,
+                           int64_t *mask) {
+  // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
+  // params: mask (N, N/THREADS_PER_BLOCK_NMS)
+
+  const int row_start = blockIdx.y;
+  const int col_start = blockIdx.x;
+
+  // if (row_start > col_start) return;
+
+  const int row_size =
+      fminf(num_bboxes_for_nms - row_start * THREADS_PER_BLOCK_NMS,
+            THREADS_PER_BLOCK_NMS);
+  const int col_size =
+      fminf(num_bboxes_for_nms - col_start * THREADS_PER_BLOCK_NMS,
+            THREADS_PER_BLOCK_NMS);
+
+  __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
+
+  if (threadIdx.x < col_size) {
+    int box_idx =
+        index[sorted_index[THREADS_PER_BLOCK_NMS * col_start + threadIdx.x]];
+    block_boxes[threadIdx.x * 7 + 0] = bboxes[box_idx * decode_bboxes_dims];
+    block_boxes[threadIdx.x * 7 + 1] = bboxes[box_idx * decode_bboxes_dims + 1];
+    block_boxes[threadIdx.x * 7 + 2] = bboxes[box_idx * decode_bboxes_dims + 2];
+    block_boxes[threadIdx.x * 7 + 3] = bboxes[box_idx * decode_bboxes_dims + 4];
+    block_boxes[threadIdx.x * 7 + 4] = bboxes[box_idx * decode_bboxes_dims + 3];
+    block_boxes[threadIdx.x * 7 + 5] = bboxes[box_idx * decode_bboxes_dims + 5];
+    block_boxes[threadIdx.x * 7 + 6] =
+        -bboxes[box_idx * decode_bboxes_dims + decode_bboxes_dims - 1] -
+        3.141592653589793 / 2;
+  }
+  __syncthreads();
+
+  if (threadIdx.x < row_size) {
+    const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
+    const int act_box_idx = index[sorted_index[cur_box_idx]];
+    float cur_box[7];
+    cur_box[0] = bboxes[act_box_idx * decode_bboxes_dims];
+    cur_box[1] = bboxes[act_box_idx * decode_bboxes_dims + 1];
+    cur_box[2] = bboxes[act_box_idx * decode_bboxes_dims + 2];
+    cur_box[3] = bboxes[act_box_idx * decode_bboxes_dims + 4];
+    cur_box[4] = bboxes[act_box_idx * decode_bboxes_dims + 3];
+    cur_box[5] = bboxes[act_box_idx * decode_bboxes_dims + 5];
+    cur_box[6] =
+        -bboxes[act_box_idx * decode_bboxes_dims + decode_bboxes_dims - 1] -
+        3.141592653589793 / 2;
+
+    int i = 0;
+    int64_t t = 0;
+    int start = 0;
+    if (row_start == col_start) {
+      start = threadIdx.x + 1;
+    }
+    for (i = start; i < col_size; i++) {
+      if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
+        t |= 1ULL << i;
+      }
+    }
+    const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS);
+    mask[cur_box_idx * col_blocks + col_start] = t;
+  }
+}
+
+void NmsLauncher(const cudaStream_t &stream, const float *bboxes,
+                 const int *index, const int64_t *sorted_index,
+                 const int num_bboxes, const int num_bboxes_for_nms,
+                 const float nms_overlap_thresh, const int decode_bboxes_dims,
+                 int64_t *mask) {
+  dim3 blocks(DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS),
+              DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS));
+  dim3 threads(THREADS_PER_BLOCK_NMS);
+  nms_kernel<<<blocks, threads, 0, stream>>>(
+      num_bboxes, num_bboxes_for_nms, nms_overlap_thresh, decode_bboxes_dims,
+      bboxes, index, sorted_index, mask);
+}

+ 105 - 0
src/detection/centerpoint_paddle/custom_ops/postprocess.cc

@@ -0,0 +1,105 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+
+#include "paddle/include/experimental/ext_all.h"
+
+std::vector<paddle::Tensor> postprocess_gpu(
+    const std::vector<paddle::Tensor> &hm,
+    const std::vector<paddle::Tensor> &reg,
+    const std::vector<paddle::Tensor> &height,
+    const std::vector<paddle::Tensor> &dim,
+    const std::vector<paddle::Tensor> &vel,
+    const std::vector<paddle::Tensor> &rot,
+    const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range,
+    const std::vector<float> &post_center_range,
+    const std::vector<int> &num_classes, const int down_ratio,
+    const float score_threshold, const float nms_iou_threshold,
+    const int nms_pre_max_size, const int nms_post_max_size,
+    const bool with_velocity);
+
+std::vector<paddle::Tensor> centerpoint_postprocess(
+    const std::vector<paddle::Tensor> &hm,
+    const std::vector<paddle::Tensor> &reg,
+    const std::vector<paddle::Tensor> &height,
+    const std::vector<paddle::Tensor> &dim,
+    const std::vector<paddle::Tensor> &vel,
+    const std::vector<paddle::Tensor> &rot,
+    const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range,
+    const std::vector<float> &post_center_range,
+    const std::vector<int> &num_classes, const int down_ratio,
+    const float score_threshold, const float nms_iou_threshold,
+    const int nms_pre_max_size, const int nms_post_max_size,
+    const bool with_velocity) {
+  if (hm[0].is_gpu()) {
+    return postprocess_gpu(hm, reg, height, dim, vel, rot, voxel_size,
+                           point_cloud_range, post_center_range, num_classes,
+                           down_ratio, score_threshold, nms_iou_threshold,
+                           nms_pre_max_size, nms_post_max_size, with_velocity);
+  } else {
+    PD_THROW(
+        "Unsupported device type for centerpoint postprocess "
+        "operator.");
+  }
+}
+
+std::vector<std::vector<int64_t>> PostProcessInferShape(
+    const std::vector<std::vector<int64_t>> &hm_shape,
+    const std::vector<std::vector<int64_t>> &reg_shape,
+    const std::vector<std::vector<int64_t>> &height_shape,
+    const std::vector<std::vector<int64_t>> &dim_shape,
+    const std::vector<std::vector<int64_t>> &vel_shape,
+    const std::vector<std::vector<int64_t>> &rot_shape,
+    const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range,
+    const std::vector<float> &post_center_range,
+    const std::vector<int> &num_classes, const int down_ratio,
+    const float score_threshold, const float nms_iou_threshold,
+    const int nms_pre_max_size, const int nms_post_max_size,
+    const bool with_velocity) {
+  if (with_velocity) {
+    return {{-1, 9}, {-1}, {-1}};
+  } else {
+    return {{-1, 7}, {-1}, {-1}};
+  }
+}
+
+std::vector<paddle::DataType> PostProcessInferDtype(
+    const std::vector<paddle::DataType> &hm_dtype,
+    const std::vector<paddle::DataType> &reg_dtype,
+    const std::vector<paddle::DataType> &height_dtype,
+    const std::vector<paddle::DataType> &dim_dtype,
+    const std::vector<paddle::DataType> &vel_dtype,
+    const std::vector<paddle::DataType> &rot_dtype) {
+  return {reg_dtype[0], hm_dtype[0], paddle::DataType::INT64};
+}
+
+PD_BUILD_OP(centerpoint_postprocess)
+    .Inputs({paddle::Vec("HM"), paddle::Vec("REG"), paddle::Vec("HEIGHT"),
+             paddle::Vec("DIM"), paddle::Vec("VEL"), paddle::Vec("ROT")})
+    .Outputs({"BBOXES", "SCORES", "LABELS"})
+    .SetKernelFn(PD_KERNEL(centerpoint_postprocess))
+    .Attrs({"voxel_size: std::vector<float>",
+            "point_cloud_range: std::vector<float>",
+            "post_center_range: std::vector<float>",
+            "num_classes: std::vector<int>", "down_ratio: int",
+            "score_threshold: float", "nms_iou_threshold: float",
+            "nms_pre_max_size: int", "nms_post_max_size: int",
+            "with_velocity: bool"})
+    .SetInferShapeFn(PD_INFER_SHAPE(PostProcessInferShape))
+    .SetInferDtypeFn(PD_INFER_DTYPE(PostProcessInferDtype));

+ 280 - 0
src/detection/centerpoint_paddle/custom_ops/postprocess.cu

@@ -0,0 +1,280 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/include/experimental/ext_all.h"
+
+#define CHECK_INPUT_CUDA(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
+
+#define CHECK_INPUT_BATCHSIZE(x) \
+  PD_CHECK(x.shape()[0] == 1, #x " batch size must be 1.")
+
+#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
+
+const int THREADS_PER_BLOCK_NMS = sizeof(int64_t) * 8;
+
+void NmsLauncher(const cudaStream_t &stream, const float *bboxes,
+                 const int *index, const int64_t *sorted_index,
+                 const int num_bboxes, const int num_bboxes_for_nms,
+                 const float nms_overlap_thresh, const int decode_bboxes_dims,
+                 int64_t *mask);
+
+__global__ void decode_kernel(
+    const float *score, const float *reg, const float *height, const float *dim,
+    const float *vel, const float *rot, const float score_threshold,
+    const int feat_w, const float down_ratio, const float voxel_size_x,
+    const float voxel_size_y, const float point_cloud_range_x_min,
+    const float point_cloud_range_y_min, const float post_center_range_x_min,
+    const float post_center_range_y_min, const float post_center_range_z_min,
+    const float post_center_range_x_max, const float post_center_range_y_max,
+    const float post_center_range_z_max, const int num_bboxes,
+    const bool with_velocity, const int decode_bboxes_dims, float *bboxes,
+    bool *mask, int *score_idx) {
+  int box_idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (box_idx == num_bboxes || box_idx > num_bboxes) {
+    return;
+  }
+  const int xs = box_idx % feat_w;
+  const int ys = box_idx / feat_w;
+
+  float x = reg[box_idx];
+  float y = reg[box_idx + num_bboxes];
+  float z = height[box_idx];
+
+  bboxes[box_idx * decode_bboxes_dims] =
+      (x + xs) * down_ratio * voxel_size_x + point_cloud_range_x_min;
+  bboxes[box_idx * decode_bboxes_dims + 1] =
+      (y + ys) * down_ratio * voxel_size_y + point_cloud_range_y_min;
+  bboxes[box_idx * decode_bboxes_dims + 2] = z;
+  bboxes[box_idx * decode_bboxes_dims + 3] = dim[box_idx];
+  bboxes[box_idx * decode_bboxes_dims + 4] = dim[box_idx + num_bboxes];
+  bboxes[box_idx * decode_bboxes_dims + 5] = dim[box_idx + 2 * num_bboxes];
+  if (with_velocity) {
+    bboxes[box_idx * decode_bboxes_dims + 6] = vel[box_idx];
+    bboxes[box_idx * decode_bboxes_dims + 7] = vel[box_idx + num_bboxes];
+    bboxes[box_idx * decode_bboxes_dims + 8] =
+        atan2f(rot[box_idx], rot[box_idx + num_bboxes]);
+  } else {
+    bboxes[box_idx * decode_bboxes_dims + 6] =
+        atan2f(rot[box_idx], rot[box_idx + num_bboxes]);
+  }
+
+  if (score[box_idx] > score_threshold && x <= post_center_range_x_max &&
+      y <= post_center_range_y_max && z <= post_center_range_z_max &&
+      x >= post_center_range_x_min && y >= post_center_range_y_min &&
+      z >= post_center_range_z_min) {
+    mask[box_idx] = true;
+  }
+
+  score_idx[box_idx] = box_idx;
+}
+
+void DecodeLauncher(
+    const cudaStream_t &stream, const float *score, const float *reg,
+    const float *height, const float *dim, const float *vel, const float *rot,
+    const float score_threshold, const int feat_w, const float down_ratio,
+    const float voxel_size_x, const float voxel_size_y,
+    const float point_cloud_range_x_min, const float point_cloud_range_y_min,
+    const float post_center_range_x_min, const float post_center_range_y_min,
+    const float post_center_range_z_min, const float post_center_range_x_max,
+    const float post_center_range_y_max, const float post_center_range_z_max,
+    const int num_bboxes, const bool with_velocity,
+    const int decode_bboxes_dims, float *bboxes, bool *mask, int *score_idx) {
+  dim3 blocks(DIVUP(num_bboxes, THREADS_PER_BLOCK_NMS));
+  dim3 threads(THREADS_PER_BLOCK_NMS);
+  decode_kernel<<<blocks, threads, 0, stream>>>(
+      score, reg, height, dim, vel, rot, score_threshold, feat_w, down_ratio,
+      voxel_size_x, voxel_size_y, point_cloud_range_x_min,
+      point_cloud_range_y_min, post_center_range_x_min, post_center_range_y_min,
+      post_center_range_z_min, post_center_range_x_max, post_center_range_y_max,
+      post_center_range_z_max, num_bboxes, with_velocity, decode_bboxes_dims,
+      bboxes, mask, score_idx);
+}
+
+std::vector<paddle::Tensor> postprocess_gpu(
+    const std::vector<paddle::Tensor> &hm,
+    const std::vector<paddle::Tensor> &reg,
+    const std::vector<paddle::Tensor> &height,
+    const std::vector<paddle::Tensor> &dim,
+    const std::vector<paddle::Tensor> &vel,
+    const std::vector<paddle::Tensor> &rot,
+    const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range,
+    const std::vector<float> &post_center_range,
+    const std::vector<int> &num_classes, const int down_ratio,
+    const float score_threshold, const float nms_iou_threshold,
+    const int nms_pre_max_size, const int nms_post_max_size,
+    const bool with_velocity) {
+  int num_tasks = hm.size();
+  int decode_bboxes_dims = 9;
+  if (!with_velocity) {
+    decode_bboxes_dims = 7;
+  }
+  float voxel_size_x = voxel_size[0];
+  float voxel_size_y = voxel_size[1];
+  float point_cloud_range_x_min = point_cloud_range[0];
+  float point_cloud_range_y_min = point_cloud_range[1];
+
+  float post_center_range_x_min = post_center_range[0];
+  float post_center_range_y_min = post_center_range[1];
+  float post_center_range_z_min = post_center_range[2];
+  float post_center_range_x_max = post_center_range[3];
+  float post_center_range_y_max = post_center_range[4];
+  float post_center_range_z_max = post_center_range[5];
+  std::vector<paddle::Tensor> scores;
+  std::vector<paddle::Tensor> labels;
+  std::vector<paddle::Tensor> bboxes;
+  for (int task_id = 0; task_id < num_tasks; ++task_id) {
+    CHECK_INPUT_BATCHSIZE(hm[0]);
+
+    int feat_h = hm[0].shape()[2];
+    int feat_w = hm[0].shape()[3];
+    int num_bboxes = feat_h * feat_w;
+
+    // score and label
+    auto sigmoid_hm_per_task = paddle::experimental::sigmoid(hm[task_id]);
+    auto label_per_task =
+        paddle::experimental::argmax(sigmoid_hm_per_task, 1, true, false, 3);
+    auto score_per_task =
+        paddle::experimental::max(sigmoid_hm_per_task, {1}, true);
+    // dim
+    auto exp_dim_per_task = paddle::experimental::exp(dim[task_id]);
+
+    // decode bboxed and get mask of bboxes for nms
+    const float *score_ptr = score_per_task.data<float>();
+    const float *reg_ptr = reg[task_id].data<float>();
+    const float *height_ptr = height[task_id].data<float>();
+    // const float* dim_ptr = dim[task_id].data<float>();
+    const float *exp_dim_per_task_ptr = exp_dim_per_task.data<float>();
+    const float *vel_ptr = vel[task_id].data<float>();
+    const float *rot_ptr = rot[task_id].data<float>();
+    auto decode_bboxes =
+        paddle::empty({num_bboxes, decode_bboxes_dims},
+                      paddle::DataType::FLOAT32, paddle::GPUPlace());
+    float *decode_bboxes_ptr = decode_bboxes.data<float>();
+    auto thresh_mask = paddle::full({num_bboxes}, 0, paddle::DataType::BOOL,
+                                    paddle::GPUPlace());
+    bool *thresh_mask_ptr = thresh_mask.data<bool>();
+    auto score_idx = paddle::empty({num_bboxes}, paddle::DataType::INT32,
+                                   paddle::GPUPlace());
+    int *score_idx_ptr = score_idx.data<int32_t>();
+
+    DecodeLauncher(score_per_task.stream(), score_ptr, reg_ptr, height_ptr,
+                   exp_dim_per_task_ptr, vel_ptr, rot_ptr, score_threshold,
+                   feat_w, down_ratio, voxel_size_x, voxel_size_y,
+                   point_cloud_range_x_min, point_cloud_range_y_min,
+                   post_center_range_x_min, post_center_range_y_min,
+                   post_center_range_z_min, post_center_range_x_max,
+                   post_center_range_y_max, post_center_range_z_max, num_bboxes,
+                   with_velocity, decode_bboxes_dims, decode_bboxes_ptr,
+                   thresh_mask_ptr, score_idx_ptr);
+
+    // select score by mask
+    auto selected_score_idx =
+        paddle::experimental::masked_select(score_idx, thresh_mask);
+    auto flattened_selected_score =
+        paddle::experimental::reshape(score_per_task, {num_bboxes});
+    auto selected_score = paddle::experimental::masked_select(
+        flattened_selected_score, thresh_mask);
+    int num_selected = selected_score.numel();
+    if (num_selected == 0 || num_selected < 0) {
+      auto fake_out_boxes =
+          paddle::full({1, decode_bboxes_dims}, 0., paddle::DataType::FLOAT32,
+                       paddle::GPUPlace());
+      auto fake_out_score =
+          paddle::full({1}, -1., paddle::DataType::FLOAT32, paddle::GPUPlace());
+      auto fake_out_label =
+          paddle::full({1}, 0, paddle::DataType::INT64, paddle::GPUPlace());
+      scores.push_back(fake_out_score);
+      labels.push_back(fake_out_label);
+      bboxes.push_back(fake_out_boxes);
+      continue;
+    }
+
+    // sort score by descending
+    auto sort_out = paddle::experimental::argsort(selected_score, 0, true);
+    auto sorted_index = std::get<1>(sort_out);
+    int num_bboxes_for_nms =
+        num_selected > nms_pre_max_size ? nms_pre_max_size : num_selected;
+
+    // nms
+    // in NmsLauncher, rot = - theta - pi / 2
+    const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS);
+    auto nms_mask = paddle::empty({num_bboxes_for_nms * col_blocks},
+                                  paddle::DataType::INT64, paddle::GPUPlace());
+    int64_t *nms_mask_data = nms_mask.data<int64_t>();
+
+    NmsLauncher(score_per_task.stream(), decode_bboxes.data<float>(),
+                selected_score_idx.data<int>(), sorted_index.data<int64_t>(),
+                num_selected, num_bboxes_for_nms, nms_iou_threshold,
+                decode_bboxes_dims, nms_mask_data);
+
+    const paddle::Tensor nms_mask_cpu_tensor =
+        nms_mask.copy_to(paddle::CPUPlace(), true);
+    const int64_t *nms_mask_cpu = nms_mask_cpu_tensor.data<int64_t>();
+
+    auto remv_cpu = paddle::full({col_blocks}, 0, paddle::DataType::INT64,
+                                 paddle::CPUPlace());
+    int64_t *remv_cpu_data = remv_cpu.data<int64_t>();
+    int num_to_keep = 0;
+    auto keep = paddle::empty({num_bboxes_for_nms}, paddle::DataType::INT32,
+                              paddle::CPUPlace());
+    int *keep_data = keep.data<int>();
+
+    for (int i = 0; i < num_bboxes_for_nms; i++) {
+      int nblock = i / THREADS_PER_BLOCK_NMS;
+      int inblock = i % THREADS_PER_BLOCK_NMS;
+
+      if (!(remv_cpu_data[nblock] & (1ULL << inblock))) {
+        keep_data[num_to_keep++] = i;
+        const int64_t *p = &nms_mask_cpu[0] + i * col_blocks;
+        for (int j = nblock; j < col_blocks; j++) {
+          remv_cpu_data[j] |= p[j];
+        }
+      }
+    }
+
+    int num_for_gather =
+        num_to_keep > nms_post_max_size ? nms_post_max_size : num_to_keep;
+    auto keep_gpu = paddle::empty({num_for_gather}, paddle::DataType::INT32,
+                                  paddle::GPUPlace());
+    int *keep_gpu_ptr = keep_gpu.data<int>();
+    cudaMemcpy(keep_gpu_ptr, keep_data, num_for_gather * sizeof(int),
+               cudaMemcpyHostToDevice);
+
+    auto gather_sorted_index =
+        paddle::experimental::gather(sorted_index, keep_gpu, 0);
+    auto gather_index = paddle::experimental::gather(selected_score_idx,
+                                                     gather_sorted_index, 0);
+
+    auto gather_score =
+        paddle::experimental::gather(selected_score, gather_sorted_index, 0);
+    auto flattened_label =
+        paddle::experimental::reshape(label_per_task, {num_bboxes});
+    auto gather_label =
+        paddle::experimental::gather(flattened_label, gather_index, 0);
+    auto gather_bbox =
+        paddle::experimental::gather(decode_bboxes, gather_index, 0);
+    auto start_label = paddle::full(
+        {1}, num_classes[task_id], paddle::DataType::INT64, paddle::GPUPlace());
+    auto added_label = paddle::experimental::add(gather_label, start_label);
+    scores.push_back(gather_score);
+    labels.push_back(added_label);
+    bboxes.push_back(gather_bbox);
+  }
+
+  auto out_scores = paddle::experimental::concat(scores, 0);
+  auto out_labels = paddle::experimental::concat(labels, 0);
+  auto out_bboxes = paddle::experimental::concat(bboxes, 0);
+  return {out_bboxes, out_scores, out_labels};
+}

+ 191 - 0
src/detection/centerpoint_paddle/custom_ops/voxelize_op.cc

@@ -0,0 +1,191 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "paddle/include/experimental/ext_all.h"
+
+template <typename T, typename T_int>
+bool hard_voxelize_cpu_kernel(
+    const T *points, const float point_cloud_range_x_min,
+    const float point_cloud_range_y_min, const float point_cloud_range_z_min,
+    const float voxel_size_x, const float voxel_size_y,
+    const float voxel_size_z, const int grid_size_x, const int grid_size_y,
+    const int grid_size_z, const int64_t num_points, const int num_point_dim,
+    const int max_num_points_in_voxel, const int max_voxels, T *voxels,
+    T_int *coords, T_int *num_points_per_voxel, T_int *grid_idx_to_voxel_idx,
+    T_int *num_voxels) {
+  std::fill(voxels,
+            voxels + max_voxels * max_num_points_in_voxel * num_point_dim,
+            static_cast<T>(0));
+
+  num_voxels[0] = 0;
+  int voxel_idx, grid_idx, curr_num_point;
+  int coord_x, coord_y, coord_z;
+  for (int point_idx = 0; point_idx < num_points; ++point_idx) {
+    coord_x = floor(
+        (points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
+        voxel_size_x);
+    coord_y = floor(
+        (points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
+        voxel_size_y);
+    coord_z = floor(
+        (points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
+        voxel_size_z);
+
+    if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
+      continue;
+    }
+    if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
+      continue;
+    }
+    if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
+      continue;
+    }
+
+    grid_idx =
+        coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
+    voxel_idx = grid_idx_to_voxel_idx[grid_idx];
+    if (voxel_idx == -1) {
+      voxel_idx = num_voxels[0];
+      if (num_voxels[0] == max_voxels || num_voxels[0] > max_voxels) {
+        continue;
+      }
+      num_voxels[0]++;
+      grid_idx_to_voxel_idx[grid_idx] = voxel_idx;
+      coords[voxel_idx * 3 + 0] = coord_z;
+      coords[voxel_idx * 3 + 1] = coord_y;
+      coords[voxel_idx * 3 + 2] = coord_x;
+    }
+    curr_num_point = num_points_per_voxel[voxel_idx];
+    if (curr_num_point < max_num_points_in_voxel) {
+      for (int j = 0; j < num_point_dim; ++j) {
+        voxels[voxel_idx * max_num_points_in_voxel * num_point_dim +
+               curr_num_point * num_point_dim + j] =
+            points[point_idx * num_point_dim + j];
+      }
+      num_points_per_voxel[voxel_idx] = curr_num_point + 1;
+    }
+  }
+  return true;
+}
+
+std::vector<paddle::Tensor> hard_voxelize_cpu(
+    const paddle::Tensor &points, const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range,
+    const int max_num_points_in_voxel, const int max_voxels) {
+  auto num_points = points.shape()[0];
+  auto num_point_dim = points.shape()[1];
+
+  const float voxel_size_x = voxel_size[0];
+  const float voxel_size_y = voxel_size[1];
+  const float voxel_size_z = voxel_size[2];
+  const float point_cloud_range_x_min = point_cloud_range[0];
+  const float point_cloud_range_y_min = point_cloud_range[1];
+  const float point_cloud_range_z_min = point_cloud_range[2];
+  int grid_size_x = static_cast<int>(
+      round((point_cloud_range[3] - point_cloud_range[0]) / voxel_size_x));
+  int grid_size_y = static_cast<int>(
+      round((point_cloud_range[4] - point_cloud_range[1]) / voxel_size_y));
+  int grid_size_z = static_cast<int>(
+      round((point_cloud_range[5] - point_cloud_range[2]) / voxel_size_z));
+
+  auto voxels =
+      paddle::empty({max_voxels, max_num_points_in_voxel, num_point_dim},
+                    paddle::DataType::FLOAT32, paddle::CPUPlace());
+
+  auto coords = paddle::full({max_voxels, 3}, 0, paddle::DataType::INT32,
+                             paddle::CPUPlace());
+  auto *coords_data = coords.data<int>();
+
+  auto num_points_per_voxel = paddle::full(
+      {max_voxels}, 0, paddle::DataType::INT32, paddle::CPUPlace());
+  auto *num_points_per_voxel_data = num_points_per_voxel.data<int>();
+  std::fill(num_points_per_voxel_data,
+            num_points_per_voxel_data + num_points_per_voxel.size(),
+            static_cast<int>(0));
+
+  auto num_voxels =
+      paddle::full({1}, 0, paddle::DataType::INT32, paddle::CPUPlace());
+  auto *num_voxels_data = num_voxels.data<int>();
+
+  auto grid_idx_to_voxel_idx =
+      paddle::full({grid_size_z, grid_size_y, grid_size_x}, -1,
+                   paddle::DataType::INT32, paddle::CPUPlace());
+  auto *grid_idx_to_voxel_idx_data = grid_idx_to_voxel_idx.data<int>();
+
+  PD_DISPATCH_FLOATING_TYPES(
+      points.type(), "hard_voxelize_cpu_kernel", ([&] {
+        hard_voxelize_cpu_kernel<data_t, int>(
+            points.data<data_t>(), point_cloud_range_x_min,
+            point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
+            voxel_size_y, voxel_size_z, grid_size_x, grid_size_y, grid_size_z,
+            num_points, num_point_dim, max_num_points_in_voxel, max_voxels,
+            voxels.data<data_t>(), coords_data, num_points_per_voxel_data,
+            grid_idx_to_voxel_idx_data, num_voxels_data);
+      }));
+
+  return {voxels, coords, num_points_per_voxel, num_voxels};
+}
+
+#ifdef PADDLE_WITH_CUDA
+std::vector<paddle::Tensor> hard_voxelize_cuda(
+    const paddle::Tensor &points, const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range, int max_num_points_in_voxel,
+    int max_voxels);
+#endif
+
+std::vector<paddle::Tensor> hard_voxelize(
+    const paddle::Tensor &points, const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range,
+    const int max_num_points_in_voxel, const int max_voxels) {
+  if (points.is_cpu()) {
+    return hard_voxelize_cpu(points, voxel_size, point_cloud_range,
+                             max_num_points_in_voxel, max_voxels);
+#ifdef PADDLE_WITH_CUDA
+  } else if (points.is_gpu() || points.is_gpu_pinned()) {
+    return hard_voxelize_cuda(points, voxel_size, point_cloud_range,
+                              max_num_points_in_voxel, max_voxels);
+#endif
+  } else {
+    PD_THROW(
+        "Unsupported device type for hard_voxelize "
+        "operator.");
+  }
+}
+
+std::vector<std::vector<int64_t>> HardInferShape(
+    std::vector<int64_t> points_shape, const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range,
+    const int &max_num_points_in_voxel, const int &max_voxels) {
+  return {{max_voxels, max_num_points_in_voxel, points_shape[1]},
+          {max_voxels, 3},
+          {max_voxels},
+          {1}};
+}
+
+std::vector<paddle::DataType> HardInferDtype(paddle::DataType points_dtype) {
+  return {points_dtype, paddle::DataType::INT32, paddle::DataType::INT32,
+          paddle::DataType::INT32};
+}
+
+PD_BUILD_OP(hard_voxelize)
+    .Inputs({"POINTS"})
+    .Outputs({"VOXELS", "COORS", "NUM_POINTS_PER_VOXEL", "num_voxels"})
+    .SetKernelFn(PD_KERNEL(hard_voxelize))
+    .Attrs({"voxel_size: std::vector<float>",
+            "point_cloud_range: std::vector<float>",
+            "max_num_points_in_voxel: int", "max_voxels: int"})
+    .SetInferShapeFn(PD_INFER_SHAPE(HardInferShape))
+    .SetInferDtypeFn(PD_INFER_DTYPE(HardInferDtype));

+ 345 - 0
src/detection/centerpoint_paddle/custom_ops/voxelize_op.cu

@@ -0,0 +1,345 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/include/experimental/ext_all.h"
+
+#define CHECK_INPUT_CUDA(x) \
+  PD_CHECK(x.is_gpu() || x.is_gpu_pinned(), #x " must be a GPU Tensor.")
+
+#define CUDA_KERNEL_LOOP(i, n)                                  \
+  for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+       i += blockDim.x * gridDim.x)
+
+template <typename T, typename T_int>
+__global__ void init_num_point_grid(
+    const T *points, const float point_cloud_range_x_min,
+    const float point_cloud_range_y_min, const float point_cloud_range_z_min,
+    const float voxel_size_x, const float voxel_size_y,
+    const float voxel_size_z, const int grid_size_x, const int grid_size_y,
+    const int grid_size_z, const int64_t num_points, const int num_point_dim,
+    T_int *num_points_in_grid, int *points_valid) {
+  int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (point_idx > num_points || point_idx == num_points) {
+    return;
+  }
+  int coord_x =
+      floor((points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
+            voxel_size_x);
+  int coord_y =
+      floor((points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
+            voxel_size_y);
+  int coord_z =
+      floor((points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
+            voxel_size_z);
+
+  if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
+    return;
+  }
+  if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
+    return;
+  }
+  if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
+    return;
+  }
+
+  int grid_idx =
+      coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
+  num_points_in_grid[grid_idx] = 0;
+  points_valid[grid_idx] = num_points;
+}
+
+template <typename T, typename T_int>
+__global__ void map_point_to_grid_kernel(
+    const T *points, const float point_cloud_range_x_min,
+    const float point_cloud_range_y_min, const float point_cloud_range_z_min,
+    const float voxel_size_x, const float voxel_size_y,
+    const float voxel_size_z, const int grid_size_x, const int grid_size_y,
+    const int grid_size_z, const int64_t num_points, const int num_point_dim,
+    const int max_num_points_in_voxel, T_int *points_to_grid_idx,
+    T_int *points_to_num_idx, T_int *num_points_in_grid, int *points_valid) {
+  int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (point_idx > num_points || point_idx == num_points) {
+    return;
+  }
+  int coord_x =
+      floor((points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
+            voxel_size_x);
+  int coord_y =
+      floor((points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
+            voxel_size_y);
+  int coord_z =
+      floor((points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
+            voxel_size_z);
+
+  if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
+    return;
+  }
+  if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
+    return;
+  }
+  if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
+    return;
+  }
+
+  int grid_idx =
+      coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
+  T_int num = atomicAdd(num_points_in_grid + grid_idx, 1);
+  if (num < max_num_points_in_voxel) {
+    points_to_num_idx[point_idx] = num;
+    points_to_grid_idx[point_idx] = grid_idx;
+    atomicMin(points_valid + grid_idx, static_cast<int>(point_idx));
+  }
+}
+
+template <typename T_int>
+__global__ void update_points_flag(const int *points_valid,
+                                   const T_int *points_to_grid_idx,
+                                   const int num_points, int *points_flag) {
+  int tid = threadIdx.x + blockIdx.x * blockDim.x;
+  for (int i = tid; i < num_points; i += gridDim.x * blockDim.x) {
+    T_int grid_idx = points_to_grid_idx[i];
+    if (grid_idx >= 0) {
+      int id = points_valid[grid_idx];
+      if (id != num_points && id == i) {
+        points_flag[i] = 1;
+      }
+    }
+  }
+}
+
+template <typename T_int>
+__global__ void get_voxel_idx_kernel(const int *points_flag,
+                                     const T_int *points_to_grid_idx,
+                                     const int *points_flag_prefix_sum,
+                                     const int num_points, const int max_voxels,
+                                     T_int *num_voxels,
+                                     T_int *grid_idx_to_voxel_idx) {
+  int tid = threadIdx.x + blockIdx.x * blockDim.x;
+  for (int i = tid; i < num_points; i += gridDim.x * blockDim.x) {
+    if (points_flag[i] == 1) {
+      T_int grid_idx = points_to_grid_idx[i];
+      int num = points_flag_prefix_sum[i];
+      if (num < max_voxels) {
+        grid_idx_to_voxel_idx[grid_idx] = num;
+      }
+    }
+    if (i == num_points - 1) {
+      int num = points_flag_prefix_sum[i] + points_flag[i];
+      if (num < max_voxels) {
+        num_voxels[0] = num;
+      } else {
+        num_voxels[0] = max_voxels;
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void init_voxels_kernel(const int64_t num, T *voxels) {
+  int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (idx > num || idx == num) {
+    return;
+  }
+  voxels[idx] = static_cast<T>(0);
+}
+
+template <typename T, typename T_int>
+__global__ void assign_voxels_kernel(
+    const T *points, const T_int *points_to_grid_idx,
+    const T_int *points_to_num_idx, const T_int *grid_idx_to_voxel_idx,
+    const int64_t num_points, const int num_point_dim,
+    const int max_num_points_in_voxel, T *voxels) {
+  int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (point_idx > num_points || point_idx == num_points) {
+    return;
+  }
+  T_int grid_idx = points_to_grid_idx[point_idx];
+  T_int num_idx = points_to_num_idx[point_idx];
+  if (grid_idx > -1 && num_idx > -1) {
+    T_int voxel_idx = grid_idx_to_voxel_idx[grid_idx];
+    if (voxel_idx > -1) {
+      for (int64_t i = 0; i < num_point_dim; ++i) {
+        voxels[voxel_idx * max_num_points_in_voxel * num_point_dim +
+               num_idx * num_point_dim + i] =
+            points[point_idx * num_point_dim + i];
+      }
+    }
+  }
+}
+
+template <typename T, typename T_int>
+__global__ void assign_coords_kernel(const T_int *grid_idx_to_voxel_idx,
+                                     const T_int *num_points_in_grid,
+                                     const int num_grids, const int grid_size_x,
+                                     const int grid_size_y,
+                                     const int grid_size_z,
+                                     const int max_num_points_in_voxel,
+                                     T *coords, T *num_points_per_voxel) {
+  int64_t grid_idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (grid_idx > num_grids || grid_idx == num_grids) {
+    return;
+  }
+  T_int voxel_idx = grid_idx_to_voxel_idx[grid_idx];
+  if (voxel_idx > -1) {
+    T_int coord_z = grid_idx / grid_size_x / grid_size_y;
+    T_int coord_y =
+        (grid_idx - coord_z * grid_size_x * grid_size_y) / grid_size_x;
+    T_int coord_x =
+        grid_idx - coord_z * grid_size_x * grid_size_y - coord_y * grid_size_x;
+    coords[voxel_idx * 3 + 0] = coord_z;
+    coords[voxel_idx * 3 + 1] = coord_y;
+    coords[voxel_idx * 3 + 2] = coord_x;
+    num_points_per_voxel[voxel_idx] =
+        min(num_points_in_grid[grid_idx], max_num_points_in_voxel);
+  }
+}
+
+std::vector<paddle::Tensor> hard_voxelize_cuda(
+    const paddle::Tensor &points, const std::vector<float> &voxel_size,
+    const std::vector<float> &point_cloud_range, int max_num_points_in_voxel,
+    int max_voxels) {
+  // check device
+  CHECK_INPUT_CUDA(points);
+
+  int64_t num_points = points.shape()[0];
+  int64_t num_point_dim = points.shape()[1];
+
+  const float voxel_size_x = voxel_size[0];
+  const float voxel_size_y = voxel_size[1];
+  const float voxel_size_z = voxel_size[2];
+  const float point_cloud_range_x_min = point_cloud_range[0];
+  const float point_cloud_range_y_min = point_cloud_range[1];
+  const float point_cloud_range_z_min = point_cloud_range[2];
+  int grid_size_x = static_cast<int>(
+      round((point_cloud_range[3] - point_cloud_range[0]) / voxel_size_x));
+  int grid_size_y = static_cast<int>(
+      round((point_cloud_range[4] - point_cloud_range[1]) / voxel_size_y));
+  int grid_size_z = static_cast<int>(
+      round((point_cloud_range[5] - point_cloud_range[2]) / voxel_size_z));
+  int num_grids = grid_size_x * grid_size_y * grid_size_z;
+
+  auto voxels =
+      paddle::empty({max_voxels, max_num_points_in_voxel, num_point_dim},
+                    paddle::DataType::FLOAT32, paddle::GPUPlace());
+
+  auto coords = paddle::full({max_voxels, 3}, 0, paddle::DataType::INT32,
+                             paddle::GPUPlace());
+  auto *coords_data = coords.data<int>();
+
+  auto num_points_per_voxel = paddle::full(
+      {max_voxels}, 0, paddle::DataType::INT32, paddle::GPUPlace());
+  auto *num_points_per_voxel_data = num_points_per_voxel.data<int>();
+
+  auto points_to_grid_idx = paddle::full(
+      {num_points}, -1, paddle::DataType::INT32, paddle::GPUPlace());
+  auto *points_to_grid_idx_data = points_to_grid_idx.data<int>();
+
+  auto points_to_num_idx = paddle::full(
+      {num_points}, -1, paddle::DataType::INT32, paddle::GPUPlace());
+  auto *points_to_num_idx_data = points_to_num_idx.data<int>();
+
+  auto num_points_in_grid =
+      paddle::empty({grid_size_z, grid_size_y, grid_size_x},
+                    paddle::DataType::INT32, paddle::GPUPlace());
+  auto *num_points_in_grid_data = num_points_in_grid.data<int>();
+
+  auto grid_idx_to_voxel_idx =
+      paddle::full({grid_size_z, grid_size_y, grid_size_x}, -1,
+                   paddle::DataType::INT32, paddle::GPUPlace());
+  auto *grid_idx_to_voxel_idx_data = grid_idx_to_voxel_idx.data<int>();
+
+  auto num_voxels =
+      paddle::full({1}, 0, paddle::DataType::INT32, paddle::GPUPlace());
+  auto *num_voxels_data = num_voxels.data<int>();
+
+  auto points_valid =
+      paddle::empty({grid_size_z, grid_size_y, grid_size_x},
+                    paddle::DataType::INT32, paddle::GPUPlace());
+  int *points_valid_data = points_valid.data<int>();
+  auto points_flag = paddle::full({num_points}, 0, paddle::DataType::INT32,
+                                  paddle::GPUPlace());
+
+  // 1. Find the grid index for each point, compute the
+  // number of points in each grid
+  int64_t threads = 512;
+  int64_t blocks = (num_points + threads - 1) / threads;
+
+  PD_DISPATCH_FLOATING_TYPES(
+      points.type(), "init_num_point_grid", ([&] {
+        init_num_point_grid<data_t, int>
+            <<<blocks, threads, 0, points.stream()>>>(
+                points.data<data_t>(), point_cloud_range_x_min,
+                point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
+                voxel_size_y, voxel_size_z, grid_size_x, grid_size_y,
+                grid_size_z, num_points, num_point_dim, num_points_in_grid_data,
+                points_valid_data);
+      }));
+
+  PD_DISPATCH_FLOATING_TYPES(
+      points.type(), "map_point_to_grid_kernel", ([&] {
+        map_point_to_grid_kernel<data_t, int>
+            <<<blocks, threads, 0, points.stream()>>>(
+                points.data<data_t>(), point_cloud_range_x_min,
+                point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
+                voxel_size_y, voxel_size_z, grid_size_x, grid_size_y,
+                grid_size_z, num_points, num_point_dim, max_num_points_in_voxel,
+                points_to_grid_idx_data, points_to_num_idx_data,
+                num_points_in_grid_data, points_valid_data);
+      }));
+
+  // 2. Find the number of non-zero voxels
+  int *points_flag_data = points_flag.data<int>();
+
+  threads = 512;
+  blocks = (num_points + threads - 1) / threads;
+  update_points_flag<int><<<blocks, threads, 0, points.stream()>>>(
+      points_valid_data, points_to_grid_idx_data, num_points, points_flag_data);
+
+  auto points_flag_prefix_sum =
+      paddle::experimental::cumsum(points_flag, 0, false, true, false);
+  int *points_flag_prefix_sum_data = points_flag_prefix_sum.data<int>();
+  get_voxel_idx_kernel<int><<<blocks, threads, 0, points.stream()>>>(
+      points_flag_data, points_to_grid_idx_data, points_flag_prefix_sum_data,
+      num_points, max_voxels, num_voxels_data, grid_idx_to_voxel_idx_data);
+
+  // 3. Store points to voxels coords and num_points_per_voxel
+  int64_t num = max_voxels * max_num_points_in_voxel * num_point_dim;
+  threads = 512;
+  blocks = (num + threads - 1) / threads;
+  PD_DISPATCH_FLOATING_TYPES(points.type(), "init_voxels_kernel", ([&] {
+                               init_voxels_kernel<data_t>
+                                   <<<blocks, threads, 0, points.stream()>>>(
+                                       num, voxels.data<data_t>());
+                             }));
+
+  threads = 512;
+  blocks = (num_points + threads - 1) / threads;
+  PD_DISPATCH_FLOATING_TYPES(
+      points.type(), "assign_voxels_kernel", ([&] {
+        assign_voxels_kernel<data_t, int>
+            <<<blocks, threads, 0, points.stream()>>>(
+                points.data<data_t>(), points_to_grid_idx_data,
+                points_to_num_idx_data, grid_idx_to_voxel_idx_data, num_points,
+                num_point_dim, max_num_points_in_voxel, voxels.data<data_t>());
+      }));
+
+  // 4. Store coords, num_points_per_voxel
+  blocks = (num_grids + threads - 1) / threads;
+  assign_coords_kernel<int><<<blocks, threads, 0, points.stream()>>>(
+      grid_idx_to_voxel_idx_data, num_points_in_grid_data, num_grids,
+      grid_size_x, grid_size_y, grid_size_z, max_num_points_in_voxel,
+      coords_data, num_points_per_voxel_data);
+
+  return {voxels, coords, num_points_per_voxel, num_voxels};
+}

+ 1 - 0
src/detection/centerpoint_paddle/demo.sh

@@ -0,0 +1 @@
+./build/main --model_file /home/nvidia/centerpoint/cpp/model_0.4/centerpoint.pdmodel --params_file /home/nvidia/centerpoint/cpp/model_0.4/centerpoint.pdiparams --lidar_file /home/nvidia/centerpoint/cpp/bindata/000103.bin --num_point_dim 4

+ 1 - 0
src/detection/centerpoint_paddle/demo_trt.sh

@@ -0,0 +1 @@
+./build/main --model_file /home/nvidia/centerpoint/cpp/model_0.5/centerpoint.pdmodel --params_file /home/nvidia/centerpoint/cpp/model_0.5/centerpoint.pdiparams --lidar_file /home/nvidia/centerpoint/cpp/bindata/000103.bin --num_point_dim 4 --use_trt 1 --dynamic_shape_file /home/nvidia/centerpoint/cpp/model_0.5/shape_info.txt --trt_precision 1 --trt_use_static 1 --trt_static_dir /home/nvidia/centerpoint/cpp/model_0.5 

+ 796 - 0
src/detection/centerpoint_paddle/main.cc

@@ -0,0 +1,796 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+
+#include <chrono>
+#include <cmath>
+#include <fstream>
+#include <iostream>
+#include <numeric>
+#include <string>
+#include <dlfcn.h>
+
+#include "paddle/include/paddle_inference_api.h"
+
+
+#include "modulecomm.h"
+#include "objectarray.pb.h"
+#include <thread>
+#include <pcl/point_cloud.h>
+#include <pcl/point_types.h>
+#include <pcl/io/io.h>
+#include <pcl/io/pcd_io.h>
+#include <chrono>
+#include <opencv2/opencv.hpp>
+
+////////////////////用于nms////////////////////
+#define M_PI       3.14159265358979323846
+#include<opencv2/opencv.hpp>
+#define rad2Angle(rad) ((rad) * 180.0 / M_PI)
+
+
+float score_threshold = 0.4;
+float smallinbig_threshold = 0.8;
+float distance_threshold = 0.2;
+float second_nms_threshold = 0.5;
+
+std::string config_file = "./yaml/centerpoint_paddle.yaml";
+
+using paddle_infer::Config;
+using paddle_infer::CreatePredictor;
+using paddle_infer::Predictor;
+
+void * gpa;
+std::string lidarname = "lidar_pc";
+void * gpdetect;
+//std::string detectname = "lidar_track";
+std::string detectname = "lidar_pointpillar";
+
+int gnothavedatatime = 0;
+std::thread * gpthread;
+
+std::shared_ptr<paddle_infer::Predictor> predictor;
+
+static const char *labels[] = {"smallMot","bigMot","nonMot","pedestrian","trafficCone"};
+
+
+DEFINE_string(model_file, "./model/centerpoint.pdmodel", "Path of a inference model");
+DEFINE_string(params_file, "./model/centerpoint.pdiparams", "Path of a inference params");
+DEFINE_string(pd_infer_custom_op,"./lib/libpd_infer_custom_op.so","Path of the libpd_infer_custom_op.so");
+DEFINE_string(lidar_file, "", "Path of a lidar file to be predicted");
+DEFINE_int32(num_point_dim, 4, "Dimension of a point in the lidar file");
+DEFINE_int32(with_timelag, 0,
+             "Whether timelag is the 5-th dimension of each point feature, "
+             "like: x, y, z, intensive, timelag");
+DEFINE_int32(gpu_id, 0, "GPU card id");
+DEFINE_int32(use_trt, 0,
+             "Whether to use tensorrt to accelerate when using gpu");
+DEFINE_int32(trt_precision, 0,
+             "Precision type of tensorrt, 0: kFloat32, 1: kHalf");
+DEFINE_int32(
+    trt_use_static, 0,
+    "Whether to load the tensorrt graph optimization from a disk path");
+DEFINE_string(trt_static_dir, "",
+              "Path of a tensorrt graph optimization directory");
+DEFINE_int32(collect_shape_info, 0,
+             "Whether to collect dynamic shape before using tensorrt");
+DEFINE_string(dynamic_shape_file, "",
+              "Path of a dynamic shape file for tensorrt");
+
+
+/////////////nms算法去除重叠交叉的框/////////////
+struct Box{
+    float x;
+    float y;
+    float z;
+    float l;
+    float h;
+    float w;
+    float theta;
+
+    float score;
+    int cls;
+    bool isDrop; // for nms
+};
+typedef struct {
+    cv::RotatedRect box;
+    Box detection;
+    int label;
+    float score;
+}BBOX3D;
+
+
+//将检测结果转为RotatedRect以便于nms计算
+bool GetRotatedRect(const std::vector<float> &box3d_lidar,
+                        const std::vector<int64_t> &label_preds,
+                        const std::vector<float> &scores,
+                        std::vector<BBOX3D> &results)
+{
+    int num_bbox3d = scores.size();
+    if(num_bbox3d>0)
+    {
+        for(int box_idx=0;box_idx<num_bbox3d;box_idx++)
+        {
+
+            if(scores[box_idx] < score_threshold)
+            {
+                //std::cout<<" the score < threshold"<<std::endl;
+                continue;
+            }
+
+            Box parse_bbox;
+            parse_bbox.x = box3d_lidar[box_idx * 7 + 0];
+            parse_bbox.y = box3d_lidar[box_idx * 7 + 1];
+            parse_bbox.z = box3d_lidar[box_idx * 7 + 2];
+            parse_bbox.l = box3d_lidar[box_idx * 7 + 3];
+            parse_bbox.w = box3d_lidar[box_idx * 7 + 4];
+            parse_bbox.h = box3d_lidar[box_idx * 7 + 5];
+            parse_bbox.theta = box3d_lidar[box_idx * 7 + 6];
+            parse_bbox.score = scores[box_idx];
+            parse_bbox.cls = label_preds[box_idx];
+
+            BBOX3D result;
+            result.box = cv::RotatedRect(cv::Point2f(parse_bbox.x,parse_bbox.y),
+                                         cv::Size2f(parse_bbox.l,parse_bbox.w),
+                                         rad2Angle(parse_bbox.theta));
+
+            result.detection = parse_bbox;
+            result.label = parse_bbox.cls;
+            result.score = parse_bbox.score;
+            results.push_back(result);
+        }
+        return true;
+    }
+    else{
+        std::cout<<"The out_detections size == 0 "<<std::endl;
+        return false;
+    }
+}
+
+
+bool sort_score(BBOX3D box1,BBOX3D box2)
+{
+    return (box1.score > box2.score);
+}
+
+//计算两个旋转矩形的IOU
+float calcIOU(cv::RotatedRect rect1, cv::RotatedRect rect2)
+{
+    float areaRect1 = rect1.size.width * rect1.size.height;
+    float areaRect2 = rect2.size.width * rect2.size.height;
+    std::vector<cv::Point2f> vertices;
+    int intersectionType = cv::rotatedRectangleIntersection(rect1, rect2, vertices);
+    if (vertices.size()==0)
+        return 0.0;
+    else{
+        std::vector<cv::Point2f> order_pts;
+        cv::convexHull(cv::Mat(vertices), order_pts, true);
+        double area = cv::contourArea(order_pts);
+        float inner = (float) (area / (areaRect1 + areaRect2 - area + 0.0001));
+        //排除小框完全在大框里面的case
+        float areaMin = (areaRect1 < areaRect2)?areaRect1:areaRect2;
+        float innerMin = (float)(area / (areaMin + 0.0001));
+        if(innerMin > smallinbig_threshold)
+            inner = innerMin;
+        return inner;
+    }
+}
+//计算两个点的欧式距离
+float calcdistance(cv::Point2f center1, cv::Point2f center2)
+{
+    float distance = sqrt((center1.x-center2.x)*(center1.x-center2.x)+
+                          (center1.y-center2.y)*(center1.y-center2.y));
+    return distance;
+}
+
+
+//nms
+void nms(std::vector<BBOX3D> &vec_boxs,float threshold,std::vector<BBOX3D> &results)
+{
+    std::sort(vec_boxs.begin(),vec_boxs.end(),sort_score);
+    while(vec_boxs.size() > 0)
+    {
+        results.push_back(vec_boxs[0]);
+        vec_boxs.erase(vec_boxs.begin());
+        for (auto it = vec_boxs.begin(); it != vec_boxs.end();)
+        {
+            float iou_value =calcIOU(results.back().box,(*it).box);
+            float distance_value = calcdistance(results.back().box.center,(*it).box.center);
+            if ((iou_value > threshold) || (distance_value<distance_threshold))
+                it = vec_boxs.erase(it);
+            else it++;
+        }
+
+//        std::cout<<"results: "<<results.back().detection.at(0)<<" "<<results.back().detection.at(1)<<
+//                   " "<<results.back().detection.at(2)<<std::endl;
+
+    }
+}
+/////////////nms算法去除重叠交叉的框/////////////
+
+//change lidarpoints to waymo type,x-->forward,y-->left
+void PclXYZITToArray(
+        const pcl::PointCloud<pcl::PointXYZI>::Ptr& in_pcl_pc_ptr,
+        float* out_points_array, const float normalizing_factor) {
+    for (size_t i = 0; i < in_pcl_pc_ptr->size(); ++i) {
+        pcl::PointXYZI point = in_pcl_pc_ptr->at(i);
+        out_points_array[i * 4 + 0] = point.x;
+        out_points_array[i * 4 + 1] = point.y;
+        out_points_array[i * 4 + 2] = point.z;
+        out_points_array[i * 4 + 3] = static_cast<float>(point.intensity / normalizing_factor);
+        //out_points_array[i * 5 + 4] = 0;
+
+        //std::cout<<"the intensity = "<< out_points_array[i * 5 + 3]<< std::endl;
+    }
+}
+
+bool preprocess_new(const pcl::PointCloud<pcl::PointXYZI>::Ptr &pc_ptr, const int num_point_dim,
+                std::vector<int> *points_shape,
+                std::vector<float> *points_data)
+{
+  int num_points;
+  std::shared_ptr<float> points_array_ptr = std::shared_ptr<float>(new float[pc_ptr->size() * num_point_dim]);
+  PclXYZITToArray(pc_ptr, points_array_ptr.get(), 255.0);
+  num_points = pc_ptr->width;
+
+  if(num_points < 0)
+      return false;
+
+  float *points = points_array_ptr.get();
+  points_data->assign(points, points + num_points * num_point_dim);
+  points_shape->push_back(num_points);
+  points_shape->push_back(num_point_dim);
+
+  //free(points);
+  return true;
+}
+
+bool read_point(const std::string &file_path, const int num_point_dim,
+                void **buffer, int *num_points) {
+  std::ifstream file_in(file_path, std::ios::in | std::ios::binary);
+  if (num_point_dim < 4) {
+    LOG(ERROR) << "Point dimension must not be less than 4, but received "
+               << "num_point_dim is " << num_point_dim << ".\n";
+  }
+
+  if (!file_in) {
+    LOG(ERROR) << "Failed to read file: " << file_path << "\n";
+    return false;
+  }
+
+  std::streampos file_size;
+  file_in.seekg(0, std::ios::end);
+  file_size = file_in.tellg();
+  file_in.seekg(0, std::ios::beg);
+
+  *buffer = malloc(file_size);
+  if (*buffer == nullptr) {
+    LOG(ERROR) << "Failed to malloc memory of size: " << file_size << "\n";
+    return false;
+  }
+  file_in.read(reinterpret_cast<char *>(*buffer), file_size);
+  file_in.close();
+
+  if (file_size / sizeof(float) % num_point_dim != 0) {
+    LOG(ERROR) << "Loaded file size (" << file_size
+               << ") is not evenly divisible by num_point_dim ("
+               << num_point_dim << ")\n";
+    return false;
+  }
+  *num_points = file_size / sizeof(float) / num_point_dim;
+  return true;
+}
+
+
+bool insert_time_to_points(const int num_points, const int num_point_dim,
+                           float *points) {
+  for (int i = 0; i < num_points; ++i) {
+    *(points + i * num_point_dim + 4) = 0.;
+  }
+  return true;
+}
+
+bool preprocess(const std::string &file_path, const int num_point_dim,
+                const int with_timelag, std::vector<int> *points_shape,
+                std::vector<float> *points_data) {
+  void *buffer = nullptr;
+  int num_points;
+  if (!read_point(file_path, num_point_dim, &buffer, &num_points)) {
+    return false;
+  }
+  float *points = static_cast<float *>(buffer);
+
+  if (!with_timelag && num_point_dim == 5 || num_point_dim > 5) {
+    // the origin points dim is [x, y, z, intensity, ring_index],
+    // but we need [x, y, z, intensity] and the sweep time index should be
+    // inserted into points
+    // so these two steps will be done in function insert_time_to_points
+    insert_time_to_points(num_points, num_point_dim, points);
+  }
+
+  points_data->assign(points, points + num_points * num_point_dim);
+  points_shape->push_back(num_points);
+  points_shape->push_back(num_point_dim);
+
+  free(points);
+  return true;
+}
+
+std::shared_ptr<paddle_infer::Predictor> create_predictor(
+    const std::string &model_path, const std::string &params_path,
+    const int gpu_id, const int use_trt, const int trt_precision,
+    const int trt_use_static, const std::string trt_static_dir,
+    const int collect_shape_info, const std::string dynamic_shape_file) {
+  paddle::AnalysisConfig config;
+  config.EnableUseGpu(1000, gpu_id);
+  config.SetModel(model_path, params_path);
+  if (use_trt) {
+    paddle::AnalysisConfig::Precision precision;
+    if (trt_precision == 0) {
+      precision = paddle_infer::PrecisionType::kFloat32;
+    } else if (trt_precision == 1) {
+      precision = paddle_infer::PrecisionType::kHalf;
+    } else {
+      LOG(ERROR) << "Tensorrt type can only support 0 or 1, but received is"
+                 << trt_precision << "\n";
+      return nullptr;
+    }
+    config.EnableTensorRtEngine(1 << 30, 1, 3, precision, trt_use_static,
+                                false);
+
+    if (dynamic_shape_file == "") {
+      LOG(ERROR) << "dynamic_shape_file should be set, but received is "
+                 << dynamic_shape_file << "\n";
+      return nullptr;
+    }
+    if (collect_shape_info) {
+      config.CollectShapeRangeInfo(dynamic_shape_file);
+    } else {
+      config.EnableTunedTensorRtDynamicShape(dynamic_shape_file, true);
+    }
+
+    if (trt_use_static) {
+      if (trt_static_dir == "") {
+        LOG(ERROR) << "trt_static_dir should be set, but received is "
+                   << trt_static_dir << "\n";
+        return nullptr;
+      }
+      config.SetOptimCacheDir(trt_static_dir);
+    }
+  }
+  config.SwitchIrOptim(true);
+  return paddle_infer::CreatePredictor(config);
+}
+
+void run(Predictor *predictor, const std::vector<int> &points_shape,
+         const std::vector<float> &points_data, std::vector<float> *box3d_lidar,
+         std::vector<int64_t> *label_preds, std::vector<float> *scores) {
+  auto input_names = predictor->GetInputNames();
+  for (const auto &tensor_name : input_names) {
+    auto in_tensor = predictor->GetInputHandle(tensor_name);
+    if (tensor_name == "data") {
+      in_tensor->Reshape(points_shape);
+      in_tensor->CopyFromCpu(points_data.data());
+    }
+  }
+
+  CHECK(predictor->Run());
+
+  auto output_names = predictor->GetOutputNames();
+  for (size_t i = 0; i != output_names.size(); i++) {
+    auto output = predictor->GetOutputHandle(output_names[i]);
+    std::vector<int> output_shape = output->shape();
+    int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
+                                  std::multiplies<int>());
+    if (i == 0) {
+      box3d_lidar->resize(out_num);
+      output->CopyToCpu(box3d_lidar->data());
+    } else if (i == 1) {
+      label_preds->resize(out_num);
+      output->CopyToCpu(label_preds->data());
+    } else if (i == 2) {
+      scores->resize(out_num);
+      output->CopyToCpu(scores->data());
+    }
+  }
+}
+
+bool parse_result(const std::vector<float> &box3d_lidar,
+                  const std::vector<int64_t> &label_preds,
+                  const std::vector<float> &scores) {
+  int num_bbox3d = scores.size();
+  int bbox3d_dims = box3d_lidar.size() / num_bbox3d;
+  for (size_t box_idx = 0; box_idx != num_bbox3d; ++box_idx) {
+    // filter fake results:  score = -1
+    if (scores[box_idx] < 0) {
+      continue;
+    }
+    LOG(INFO) << "Score: " << scores[box_idx]
+              << " Label: " << label_preds[box_idx] << " ";
+    if (bbox3d_dims == 9) {
+      LOG(INFO) << "Box (x_c, y_c, z_c, w, l, h, vec_x, vec_y, -rot): "
+                << box3d_lidar[box_idx * 9 + 0] << " "
+                << box3d_lidar[box_idx * 9 + 1] << " "
+                << box3d_lidar[box_idx * 9 + 2] << " "
+                << box3d_lidar[box_idx * 9 + 3] << " "
+                << box3d_lidar[box_idx * 9 + 4] << " "
+                << box3d_lidar[box_idx * 9 + 5] << " "
+                << box3d_lidar[box_idx * 9 + 6] << " "
+                << box3d_lidar[box_idx * 9 + 7] << " "
+                << box3d_lidar[box_idx * 9 + 8] << "\n";
+    } else if (bbox3d_dims == 7) {
+      LOG(INFO) << "Box (x_c, y_c, z_c, w, l, h, -rot): "
+                << box3d_lidar[box_idx * 7 + 0] << " "
+                << box3d_lidar[box_idx * 7 + 1] << " "
+                << box3d_lidar[box_idx * 7 + 2] << " "
+                << box3d_lidar[box_idx * 7 + 3] << " "
+                << box3d_lidar[box_idx * 7 + 4] << " "
+                << box3d_lidar[box_idx * 7 + 5] << " "
+                << box3d_lidar[box_idx * 7 + 6] << "\n";
+    }
+  }
+
+  return true;
+}
+
+void GetLidarObj(const std::vector<float> &box3d_lidar,
+                     const std::vector<int64_t> &label_preds,
+                     const std::vector<float> &scores,
+                     iv::lidar::objectarray & lidarobjvec)
+{
+    int num_bbox3d = scores.size();
+    int bbox3d_dims = box3d_lidar.size() / num_bbox3d;
+    for (size_t box_idx = 0; box_idx != num_bbox3d; ++box_idx)
+    {
+      // filter fake results:  score = -1
+      if (scores[box_idx] < 0) {
+        continue;
+      }
+
+      iv::lidar::lidarobject lidarobj;
+      //std::cout<<" The scores = "<<result.score<<std::endl;
+
+      lidarobj.set_tyaw( -box3d_lidar[box_idx * 7 + 6]);
+
+      //std::cout<<" The theta = "<<result.theta<<std::endl;
+      //std::cout<<" The xyz = "<<result.cls<<" "<<result.w<<" "<<result.l<<" "<<result.h<<std::endl;
+
+      std::cout<<"obstacle id is: "<<box_idx<<std::endl;
+      std::cout<<"obstacle score is: "<<scores[box_idx]<<std::endl;
+      std::cout<<"(x,y,z,dx,dy,dz,yaw,class)=("<<box3d_lidar[box_idx * 7 + 0]<<","
+              <<box3d_lidar[box_idx * 7 + 1]<<","
+              <<box3d_lidar[box_idx * 7 + 2]<<","
+              <<box3d_lidar[box_idx * 7 + 3]<<","
+             <<box3d_lidar[box_idx * 7 + 4]<<","
+            <<box3d_lidar[box_idx * 7 + 5]<<","
+             <<box3d_lidar[box_idx * 7 + 6]<<","
+            <<label_preds[box_idx]<<")"<<std::endl;
+
+      //givlog->verbose("obstacle id is: %d",idx);
+      //givlog->verbose("(x,y,z,dx,dy,dz,yaw,class)=(%f,%f,%f,%f,%f,%f,%f,%s)",-result.y,result.x,result.z,
+                      //result.l,result.w,result.h,-result.theta,labels[result.cls]);
+
+
+      iv::lidar::PointXYZ centroid;
+      iv::lidar::PointXYZ * _centroid;
+      centroid.set_x(box3d_lidar[box_idx * 7 + 0]);
+      centroid.set_y(box3d_lidar[box_idx * 7 + 1]);
+      centroid.set_z(box3d_lidar[box_idx * 7 + 2]);
+      _centroid = lidarobj.mutable_centroid();
+      _centroid->CopyFrom(centroid);
+
+      iv::lidar::PointXYZ min_point;
+      iv::lidar::PointXYZ * _min_point;
+      min_point.set_x(0);
+      min_point.set_y(0);
+      min_point.set_z(0);
+      _min_point = lidarobj.mutable_min_point();
+      _min_point->CopyFrom(min_point);
+
+      iv::lidar::PointXYZ max_point;
+      iv::lidar::PointXYZ * _max_point;
+      max_point.set_x(0);
+      max_point.set_y(0);
+      max_point.set_z(0);
+      _max_point = lidarobj.mutable_max_point();
+      _max_point->CopyFrom(max_point);
+
+      iv::lidar::PointXYZ position;
+      iv::lidar::PointXYZ * _position;
+      position.set_x(box3d_lidar[box_idx * 7 + 0]);
+      position.set_y(box3d_lidar[box_idx * 7 + 1]);
+      position.set_z(box3d_lidar[box_idx * 7 + 2]);
+      _position = lidarobj.mutable_position();
+      _position->CopyFrom(position);
+
+      lidarobj.set_mntype(label_preds[box_idx]);
+
+      lidarobj.set_score(scores[box_idx]);
+      lidarobj.add_type_probs(scores[box_idx]);
+
+      iv::lidar::PointXYZI point_cloud;
+      iv::lidar::PointXYZI * _point_cloud;
+
+      point_cloud.set_x(box3d_lidar[box_idx * 7 + 0]);
+      point_cloud.set_y(box3d_lidar[box_idx * 7 + 1]);
+      point_cloud.set_z(box3d_lidar[box_idx * 7 + 2]);
+
+      point_cloud.set_i(0);
+
+      _point_cloud = lidarobj.add_cloud();
+      _point_cloud->CopyFrom(point_cloud);
+
+      iv::lidar::Dimension ld;
+      iv::lidar::Dimension * pld;
+
+
+      ld.set_x(box3d_lidar[box_idx * 7 + 3]);// w
+      ld.set_y(box3d_lidar[box_idx * 7 + 4]);// l
+      ld.set_z(box3d_lidar[box_idx * 7 + 5]);// h
+
+
+      pld = lidarobj.mutable_dimensions();
+      pld->CopyFrom(ld);
+      iv::lidar::lidarobject * po = lidarobjvec.add_obj();
+      po->CopyFrom(lidarobj);
+
+    }
+
+
+}
+
+void GetLidarObj_nms(std::vector<BBOX3D> &predResult,iv::lidar::objectarray & lidarobjvec)
+{
+    //    givlog->verbose("OBJ","object size is %d",obj_size);
+    for(size_t idx = 0; idx < predResult.size(); idx++)
+    {
+        iv::lidar::lidarobject lidarobj;
+        Box result = predResult[idx].detection;
+        lidarobj.set_tyaw(-result.theta);
+
+        std::cout<<"obstacle id is: "<<idx<<std::endl;
+        std::cout<<"obstacle score is: "<<result.score<<std::endl;
+        std::cout<<"obstacle class is: "<<labels[result.cls]<<std::endl;
+        std::cout<<"(x,y,z,dx,dy,dz,yaw)=("<<result.x<<","<<result.y<<","<<result.z<<","
+                <<result.l<<","<<result.w<<","<<result.h<<","<<-result.theta<<")"<<std::endl;
+
+//        givlog->verbose("obstacle id is: %d",idx);
+//        givlog->verbose("(x,y,z,dx,dy,dz,yaw,classs)=(%f,%f,%f,%f,%f,%f,%f,%s)",result.x,result.y,result.z,
+//                        result.l,result.w,result.h,-result.theta,labels[result.cls]);
+
+
+        iv::lidar::PointXYZ centroid;
+        iv::lidar::PointXYZ * _centroid;
+        centroid.set_x(result.x);
+        centroid.set_y(result.y);
+        centroid.set_z(result.z);
+        _centroid = lidarobj.mutable_centroid();
+        _centroid->CopyFrom(centroid);
+
+        iv::lidar::PointXYZ min_point;
+        iv::lidar::PointXYZ * _min_point;
+        min_point.set_x(0);
+        min_point.set_y(0);
+        min_point.set_z(0);
+        _min_point = lidarobj.mutable_min_point();
+        _min_point->CopyFrom(min_point);
+
+        iv::lidar::PointXYZ max_point;
+        iv::lidar::PointXYZ * _max_point;
+        max_point.set_x(0);
+        max_point.set_y(0);
+        max_point.set_z(0);
+        _max_point = lidarobj.mutable_max_point();
+        _max_point->CopyFrom(max_point);
+
+        iv::lidar::PointXYZ position;
+        iv::lidar::PointXYZ * _position;
+        position.set_x(result.x);
+        position.set_y(result.y);
+        position.set_z(result.z);
+        _position = lidarobj.mutable_position();
+        _position->CopyFrom(position);
+
+        lidarobj.set_mntype(result.cls);
+
+        lidarobj.set_score(result.score);
+        lidarobj.add_type_probs(result.score);
+
+        iv::lidar::PointXYZI point_cloud;
+        iv::lidar::PointXYZI * _point_cloud;
+
+        point_cloud.set_x(result.x);
+        point_cloud.set_y(result.y);
+        point_cloud.set_z(result.z);
+
+        point_cloud.set_i(0);
+
+        _point_cloud = lidarobj.add_cloud();
+        _point_cloud->CopyFrom(point_cloud);
+
+        iv::lidar::Dimension ld;
+        iv::lidar::Dimension * pld;
+
+
+        ld.set_x(result.l);// w
+        ld.set_y(result.w);// l
+        ld.set_z(result.h);// h
+
+        pld = lidarobj.mutable_dimensions();
+        pld->CopyFrom(ld);
+        iv::lidar::lidarobject * po = lidarobjvec.add_obj();
+        po->CopyFrom(lidarobj);
+    }
+
+}
+
+void ListenPointCloud(const char *strdata,const unsigned int nSize,const unsigned int index,const QDateTime * dt,const char * strmemname)
+{
+    //    std::cout<<" is  ok  ------------  "<<std::endl;
+
+    std::cout<<"ListenPointCloud is  ok  ------------  "<<std::endl;
+
+    if(nSize <=16)return;
+    unsigned int * pHeadSize = (unsigned int *)strdata;
+    if(*pHeadSize > nSize)
+    {
+        //givlog->verbose("ListenPointCloud data is small headsize = %d, data size is %d", *pHeadSize, nSize);
+        std::cout<<"ListenPointCloud data is small headsize ="<<*pHeadSize<<"  data size is"<<nSize<<std::endl;
+    }
+
+    gnothavedatatime = 0;
+    QTime xTime;
+    xTime.start();
+
+    pcl::PointCloud<pcl::PointXYZI>::Ptr point_cloud(
+                new pcl::PointCloud<pcl::PointXYZI>());
+    int nNameSize;
+    nNameSize = *pHeadSize - 4-4-8;
+    char * strName = new char[nNameSize+1];strName[nNameSize] = 0;
+    std::shared_ptr<char> str_ptr;
+    str_ptr.reset(strName);
+    memcpy(strName,(char *)((char *)strdata +4),nNameSize);
+    point_cloud->header.frame_id = strName;
+    memcpy(&point_cloud->header.seq,(char *)strdata+4+nNameSize,4);
+    memcpy(&point_cloud->header.stamp,(char *)strdata+4+nNameSize+4,8);
+    int nPCount = (nSize - *pHeadSize)/sizeof(pcl::PointXYZI);
+    int i;
+    pcl::PointXYZI * p;
+    p = (pcl::PointXYZI *)((char *)strdata + *pHeadSize);
+    for(i=0;i<nPCount;i++)
+    {
+        pcl::PointXYZI xp;
+        memcpy(&xp,p,sizeof(pcl::PointXYZI));
+        xp.z = xp.z;
+        point_cloud->push_back(xp);
+        p++;
+    }
+
+
+    std::chrono::time_point<std::chrono::system_clock> start,end;
+
+    start = std::chrono::system_clock::now();  //时间函数
+    std::vector<int> points_shape;
+    std::vector<float> points_data;
+
+    if (!preprocess_new(point_cloud, 4,&points_shape, &points_data)) {
+        LOG(ERROR) << "Failed to preprocess!\n";
+    }
+
+    std::vector<float> box3d_lidar;
+    std::vector<int64_t> label_preds;
+    std::vector<float> scores;
+
+    run(predictor.get(), points_shape, points_data, &box3d_lidar, &label_preds,
+        &scores);
+
+    //parse_result(box3d_lidar, label_preds, scores);
+
+    end = std::chrono::system_clock::now();  //时间函数
+    std::cout <<"centerpoint infer time: "<<
+                std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()<< "ms" << std::endl;
+
+    /////////////////增加nms算法进一步去除重叠交叉的框//////////////////
+    std::vector<BBOX3D>results_rect;
+    GetRotatedRect(box3d_lidar, label_preds, scores,results_rect);
+    //std::cout<<"results_rect size: "<<results_rect.size()<<std::endl;
+
+    std::vector<BBOX3D>results_bbox;
+    nms(results_rect,second_nms_threshold,results_bbox);
+    //std::cout<<"results_bbox size: "<<results_bbox.size()<<std::endl;
+    //std::cout<<"obj size is "<<predResult.size()<<std::endl;
+    /////////////////增加nms算法进一步去除重叠交叉的框//////////////////
+
+    iv::lidar::objectarray lidarobjvec;
+    //GetLidarObj(box3d_lidar, label_preds, scores,lidarobjvec);
+    GetLidarObj_nms(results_bbox,lidarobjvec);
+
+
+    double timex = point_cloud->header.stamp;
+    timex = timex/1000.0;
+    lidarobjvec.set_timestamp(point_cloud->header.stamp);
+
+    int ntlen;
+    std::string out = lidarobjvec.SerializeAsString();
+    //   char * strout = lidarobjtostr(lidarobjvec,ntlen);
+    iv::modulecomm::ModuleSendMsg(gpdetect,out.data(),out.length());
+
+    end = std::chrono::system_clock::now();  //时间函数
+    std::cout <<"total process time: "<<
+                std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()<< "ms" << std::endl;
+
+    std::cout<<"ListenPointCloud is  end  ------------  "<<std::endl;
+}
+
+int main(int argc, char *argv[]) {
+
+
+//    google::ParseCommandLineFlags(&argc, &argv, true);
+//    if (FLAGS_model_file == "" || FLAGS_params_file == "" ||
+//            FLAGS_lidar_file == "") {
+//        LOG(INFO) << "Missing required parameter"
+//                  << "\n";
+//        LOG(INFO) << "Usage: " << std::string(argv[0])
+//                << " --model_file ${MODEL_FILE} "
+//                << "--params_file ${PARAMS_FILE} "
+//                << "--lidar_file ${LIDAR_FILE}"
+//                << "\n";
+//        return -1;
+//    }
+
+    cv::FileStorage config(config_file, cv::FileStorage::READ);
+    bool config_isOpened = config.isOpened();
+    //const char* onnx_path_;
+    if(config_isOpened)
+    {
+        FLAGS_model_file = std::string(config["model_file"]);
+        FLAGS_params_file = std::string(config["params_file"]);
+        FLAGS_pd_infer_custom_op = std::string(config["pd_infer_custom_op"]);
+        score_threshold = float(config["score_threshold"]);
+        smallinbig_threshold = float(config["smallinbig_threshold"]);
+        distance_threshold = float(config["distance_threshold"]);
+        second_nms_threshold = float(config["second_nms_threshold"]);
+    }
+
+    ///////////////////////////////add code
+    void *handle = dlopen(FLAGS_pd_infer_custom_op.c_str(),RTLD_NOW);
+    if(!handle)
+    {
+        fprintf(stderr,"%s\n",dlerror());
+        exit(EXIT_FAILURE);
+    }
+    ///////////////////////////////add code
+
+
+    predictor = create_predictor(
+                FLAGS_model_file, FLAGS_params_file, FLAGS_gpu_id, FLAGS_use_trt,
+                FLAGS_trt_precision, FLAGS_trt_use_static, FLAGS_trt_static_dir,
+                FLAGS_collect_shape_info, FLAGS_dynamic_shape_file);
+    if (predictor == nullptr) {
+        return 0;
+    }
+
+
+    gpa = iv::modulecomm::RegisterRecv(&lidarname[0],ListenPointCloud);
+    gpdetect = iv::modulecomm::RegisterSend(&detectname[0], 10000000,1);
+    int num = 0;
+    while(1)
+    {
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
+    }
+
+    return 0;
+}

+ 9 - 0
src/detection/centerpoint_paddle/yaml/centerpoint_paddle.yaml

@@ -0,0 +1,9 @@
+%YAML:1.0
+---
+model_file: ./model/centerpoint.pdmodel
+params_file: ./model/centerpoint.pdiparams
+pd_infer_custom_op: ./lib/libpd_infer_custom_op.so
+score_threshold: 0.4
+smallinbig_threshold: 0.8
+distance_threshold: 0.2
+second_nms_threshold: 0.5