// Copyright 2022 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Modified from
// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu

/*
3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/

#include "lidar_centerpoint/postprocess/circle_nms_kernel.hpp"

#include <lidar_centerpoint/cuda_utils.hpp>
#include <lidar_centerpoint/utils.hpp>

#include <thrust/host_vector.h>

namespace
{
const std::size_t THREADS_PER_BLOCK_NMS = 16;
}  // namespace

namespace centerpoint
{

__device__ inline float dist2dPow(const Box3D * a, const Box3D * b)
{
  return powf(a->x - b->x, 2) + powf(a->y - b->y, 2);
}

__global__ void circleNMS_Kernel(
  const Box3D * boxes, const std::size_t num_boxes3d, const std::size_t col_blocks,
  const float dist2d_pow_threshold, std::uint64_t * mask)
{
  // params: boxes (N,)
  // params: mask (N, divup(N/THREADS_PER_BLOCK_NMS))

  const auto row_start = blockIdx.y;
  const auto col_start = blockIdx.x;

  if (row_start > col_start) return;

  const std::size_t row_size =
    fminf(num_boxes3d - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);
  const std::size_t col_size =
    fminf(num_boxes3d - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);

  __shared__ Box3D block_boxes[THREADS_PER_BLOCK_NMS];

  if (threadIdx.x < col_size) {
    block_boxes[threadIdx.x] = boxes[THREADS_PER_BLOCK_NMS * col_start + threadIdx.x];
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const std::size_t cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
    const Box3D * cur_box = boxes + cur_box_idx;

    std::uint64_t t = 0;
    std::size_t start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
    }
    for (std::size_t i = start; i < col_size; i++) {
      if (dist2dPow(cur_box, block_boxes + i) < dist2d_pow_threshold) {
        t |= 1ULL << i;
      }
    }
    mask[cur_box_idx * col_blocks + col_start] = t;
  }
}

cudaError_t circleNMS_launch(
  const thrust::device_vector<Box3D> & boxes3d, const std::size_t num_boxes3d,
  std::size_t col_blocks, const float distance_threshold,
  thrust::device_vector<std::uint64_t> & mask, cudaStream_t stream)
{
  const float dist2d_pow_thres = powf(distance_threshold, 2);

  dim3 blocks(col_blocks, col_blocks);
  dim3 threads(THREADS_PER_BLOCK_NMS);
  circleNMS_Kernel<<<blocks, threads, 0, stream>>>(
    thrust::raw_pointer_cast(boxes3d.data()), num_boxes3d, col_blocks, dist2d_pow_thres,
    thrust::raw_pointer_cast(mask.data()));

  return cudaGetLastError();
}

std::size_t circleNMS(
  thrust::device_vector<Box3D> & boxes3d, const float distance_threshold,
  thrust::device_vector<bool> & keep_mask, cudaStream_t stream)
{
  const auto num_boxes3d = boxes3d.size();
  const auto col_blocks = divup(num_boxes3d, THREADS_PER_BLOCK_NMS);
  thrust::device_vector<std::uint64_t> mask_d(num_boxes3d * col_blocks);

  CHECK_CUDA_ERROR(
    circleNMS_launch(boxes3d, num_boxes3d, col_blocks, distance_threshold, mask_d, stream));

  // memcpy device to host
  thrust::host_vector<std::uint64_t> mask_h(mask_d.size());
  thrust::copy(mask_d.begin(), mask_d.end(), mask_h.begin());
  CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

  // generate keep_mask
  std::vector<std::uint64_t> remv_h(col_blocks);
  thrust::host_vector<bool> keep_mask_h(keep_mask.size());
  std::size_t num_to_keep = 0;
  for (std::size_t i = 0; i < num_boxes3d; i++) {
    auto nblock = i / THREADS_PER_BLOCK_NMS;
    auto inblock = i % THREADS_PER_BLOCK_NMS;

    if (!(remv_h[nblock] & (1ULL << inblock))) {
      keep_mask_h[i] = true;
      num_to_keep++;
      std::uint64_t * p = &mask_h[0] + i * col_blocks;
      for (std::size_t j = nblock; j < col_blocks; j++) {
        remv_h[j] |= p[j];
      }
    } else {
      keep_mask_h[i] = false;
    }
  }

  // memcpy host to device
  keep_mask = keep_mask_h;

  return num_to_keep;
}

}  // namespace centerpoint