Browse Source

add detection_trafficlight_classify.

yuchuli 2 years ago
parent
commit
d0d13eca42

+ 2 - 0
src/detection/detection_lidar_centerpoint/README.md

@@ -0,0 +1,2 @@
+from autoware.
+test fail in agx orin.

+ 3 - 3
src/detection/detection_lidar_centerpoint/main.cpp

@@ -24,10 +24,10 @@ void init()
     std::vector<double> yaw_norm_thresholds ;
     const std::string densification_world_frame_id = "map";
     const int densification_num_past_frames = 1;
-    const std::string trt_precision = "fp16";
-    const std::string encoder_onnx_path = "/home/nvidia/models/pts_voxel_encoder_centerpoint.onnx";//this->declare_parameter<std::string>("encoder_onnx_path");
+    const std::string trt_precision = "fp32";
+    const std::string encoder_onnx_path = "/home/nvidia/models/pts_voxel_encoder_centerpoint_tiny.onnx";//this->declare_parameter<std::string>("encoder_onnx_path");
     const std::string encoder_engine_path ="/home/nvidia/models/pts_voxel_encoder_centerpoint.eng";//this->declare_parameter<std::string>("encoder_engine_path");
-    const std::string head_onnx_path = "/home/nvidia/models/pts_backbone_neck_head_centerpoint.onnx";//this->declare_parameter<std::string>("head_onnx_path");
+    const std::string head_onnx_path = "/home/nvidia/models/pts_backbone_neck_head_centerpoint_tiny.onnx";//this->declare_parameter<std::string>("head_onnx_path");
     const std::string head_engine_path ="/home/nvidia/models/pts_backbone_neck_head_centerpoint.eng" ;//this->declare_parameter<std::string>("head_engine_path");
     const std::size_t point_feature_size =4;
     const std::size_t max_voxel_size =40000;

+ 50 - 0
src/detection/detection_trafficlight_classify/detection_trafficlight_classify.pro

@@ -0,0 +1,50 @@
+QT -= gui
+
+CONFIG += c++11 #console
+CONFIG -= app_bundle
+
+# The following define makes your compiler emit warnings if you use
+# any Qt feature that has been marked deprecated (the exact warnings
+# depend on your compiler). Please consult the documentation of the
+# deprecated API in order to know how to port your code away from it.
+DEFINES += QT_DEPRECATED_WARNINGS
+
+# You can also make your code fail to compile if it uses deprecated APIs.
+# In order to do so, uncomment the following line.
+# You can also select to disable deprecated APIs only up to a certain version of Qt.
+#DEFINES += QT_DISABLE_DEPRECATED_BEFORE=0x060000    # disables all the APIs deprecated before Qt 6.0.0
+
+SOURCES += \
+        main.cpp \
+        src/cnn_classifier.cpp \
+        utils/trt_common.cpp
+
+# Default rules for deployment.
+qnx: target.path = /tmp/$${TARGET}/bin
+else: unix:!android: target.path = /opt/$${TARGET}/bin
+!isEmpty(target.path): INSTALLS += target
+
+HEADERS += \
+    include/traffic_light_classifier/classifier_interface.hpp \
+    include/traffic_light_classifier/cnn_classifier.hpp \
+    utils/trt_common.hpp
+
+
+INCLUDEPATH += $$PWD/include
+INCLUDEPATH += $$PWD/utils
+
+INCLUDEPATH += /usr/include/opencv4
+
+
+INCLUDEPATH += /usr/local/cuda-11.4/targets/aarch64-linux/include
+
+INCLUDEPATH += /usr/local/cuda-10.2/targets/aarch64-linux/include
+
+LIBS += -L/usr/local/cuda-11.4/targets/aarch64-linux/lib  # -lcublas
+
+LIBS += -L/usr/local/cuda-10.2/targets/aarch64-linux/lib  # -lcublas
+
+LIBS += -lnvinfer -lcudnn  -lcudart -lnvparsers -lnvcaffe_parser -lnvinfer_plugin -lnvonnxparser -lstdc++fs
+
+
+unix:LIBS += -lopencv_highgui -lopencv_core -lopencv_imgproc -lopencv_imgcodecs -lopencv_video -lopencv_videoio -lpthread  #-lopencv_shape

+ 38 - 0
src/detection/detection_trafficlight_classify/include/traffic_light_classifier/classifier_interface.hpp

@@ -0,0 +1,38 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TRAFFIC_LIGHT_CLASSIFIER__CLASSIFIER_INTERFACE_HPP_
+#define TRAFFIC_LIGHT_CLASSIFIER__CLASSIFIER_INTERFACE_HPP_
+
+#include <opencv2/core/core.hpp>
+#include <opencv2/highgui/highgui.hpp>
+
+//#include <autoware_auto_perception_msgs/msg/traffic_signal.hpp>
+
+#include <vector>
+
+namespace traffic_light
+{
+class ClassifierInterface
+{
+public:
+    virtual bool getTrafficSignal(
+      const cv::Mat & input_image) = 0;
+//  virtual bool getTrafficSignal(
+//    const cv::Mat & input_image,
+//    autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal) = 0;
+};
+}  // namespace traffic_light
+
+#endif  // TRAFFIC_LIGHT_CLASSIFIER__CLASSIFIER_INTERFACE_HPP_

+ 121 - 0
src/detection/detection_trafficlight_classify/include/traffic_light_classifier/cnn_classifier.hpp

@@ -0,0 +1,121 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TRAFFIC_LIGHT_CLASSIFIER__CNN_CLASSIFIER_HPP_
+#define TRAFFIC_LIGHT_CLASSIFIER__CNN_CLASSIFIER_HPP_
+
+#include "traffic_light_classifier/classifier_interface.hpp"
+
+//#include <image_transport/image_transport.hpp>
+#include <opencv2/core/core.hpp>
+#include <opencv2/highgui/highgui.hpp>
+//#include <rclcpp/rclcpp.hpp>
+#include <trt_common.hpp>
+
+//#include <autoware_auto_perception_msgs/msg/traffic_light.hpp>
+
+//#include <cv_bridge/cv_bridge.h>
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace traffic_light
+{
+class CNNClassifier : public ClassifierInterface
+{
+public:
+  explicit CNNClassifier();
+ //   explicit CNNClassifier(rclcpp::Node * node_ptr);
+  virtual ~CNNClassifier() = default;
+
+
+    bool getTrafficSignal(
+      const cv::Mat & input_image) override;
+//  bool getTrafficSignal(
+//    const cv::Mat & input_image,
+//    autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal) override;
+
+private:
+  void preProcess(cv::Mat & image, std::vector<float> & tensor, bool normalize = true);
+
+  bool postProcess(
+          std::vector<float> & output_data_host);
+//  bool postProcess(
+//    std::vector<float> & output_data_host,
+//    autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal, bool apply_softmax = false);
+  bool readLabelfile(std::string filepath, std::vector<std::string> & labels);
+  bool isColorLabel(const std::string label);
+  void calcSoftmax(std::vector<float> & data, std::vector<float> & probs, int num_output);
+  std::vector<size_t> argsort(std::vector<float> & tensor, int num_output);
+//  void outputDebugImage(
+//    cv::Mat & debug_image,
+//    const autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal);
+
+private:
+//  std::map<int, std::string> state2label_{
+//    // color
+//    {autoware_auto_perception_msgs::msg::TrafficLight::RED, "red"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::AMBER, "yellow"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::GREEN, "green"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::WHITE, "white"},
+//    // shape
+//    {autoware_auto_perception_msgs::msg::TrafficLight::CIRCLE, "circle"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::LEFT_ARROW, "left"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::RIGHT_ARROW, "right"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::UP_ARROW, "straight"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::DOWN_ARROW, "down"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::DOWN_LEFT_ARROW, "down_left"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::DOWN_RIGHT_ARROW, "down_right"},
+//    {autoware_auto_perception_msgs::msg::TrafficLight::CROSS, "cross"},
+//    // other
+//    {autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN, "unknown"},
+//  };
+
+//  std::map<std::string, int> label2state_{
+//    // color
+//    {"red", autoware_auto_perception_msgs::msg::TrafficLight::RED},
+//    {"yellow", autoware_auto_perception_msgs::msg::TrafficLight::AMBER},
+//    {"green", autoware_auto_perception_msgs::msg::TrafficLight::GREEN},
+//    {"white", autoware_auto_perception_msgs::msg::TrafficLight::WHITE},
+//    // shape
+//    {"circle", autoware_auto_perception_msgs::msg::TrafficLight::CIRCLE},
+//    {"left", autoware_auto_perception_msgs::msg::TrafficLight::LEFT_ARROW},
+//    {"right", autoware_auto_perception_msgs::msg::TrafficLight::RIGHT_ARROW},
+//    {"straight", autoware_auto_perception_msgs::msg::TrafficLight::UP_ARROW},
+//    {"down", autoware_auto_perception_msgs::msg::TrafficLight::DOWN_ARROW},
+//    {"down_left", autoware_auto_perception_msgs::msg::TrafficLight::DOWN_LEFT_ARROW},
+//    {"down_right", autoware_auto_perception_msgs::msg::TrafficLight::DOWN_RIGHT_ARROW},
+//    {"cross", autoware_auto_perception_msgs::msg::TrafficLight::CROSS},
+//    // other
+//    {"unknown", autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN},
+//  };
+
+//  rclcpp::Node * node_ptr_;
+
+  std::shared_ptr<Tn::TrtCommon> trt_;
+ // image_transport::Publisher image_pub_;
+  std::vector<std::string> labels_;
+  std::vector<float> mean_{0.242, 0.193, 0.201};
+  std::vector<float> std_{1.0, 1.0, 1.0};
+  int input_c_;
+  int input_h_;
+  int input_w_;
+  bool apply_softmax_;
+};
+
+}  // namespace traffic_light
+
+#endif  // TRAFFIC_LIGHT_CLASSIFIER__CNN_CLASSIFIER_HPP_

+ 94 - 0
src/detection/detection_trafficlight_classify/include/traffic_light_classifier/color_classifier.hpp

@@ -0,0 +1,94 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TRAFFIC_LIGHT_CLASSIFIER__COLOR_CLASSIFIER_HPP_
+#define TRAFFIC_LIGHT_CLASSIFIER__COLOR_CLASSIFIER_HPP_
+
+#include "traffic_light_classifier/classifier_interface.hpp"
+
+#include <image_transport/image_transport.hpp>
+#include <opencv2/core/core.hpp>
+#include <opencv2/highgui/highgui.hpp>
+#include <rclcpp/rclcpp.hpp>
+
+#include <autoware_auto_perception_msgs/msg/traffic_light.hpp>
+
+#include <cv_bridge/cv_bridge.h>
+
+#include <vector>
+
+namespace traffic_light
+{
+struct HSVConfig
+{
+  int green_min_h;
+  int green_min_s;
+  int green_min_v;
+  int green_max_h;
+  int green_max_s;
+  int green_max_v;
+  int yellow_min_h;
+  int yellow_min_s;
+  int yellow_min_v;
+  int yellow_max_h;
+  int yellow_max_s;
+  int yellow_max_v;
+  int red_min_h;
+  int red_min_s;
+  int red_min_v;
+  int red_max_h;
+  int red_max_s;
+  int red_max_v;
+};
+
+class ColorClassifier : public ClassifierInterface
+{
+public:
+  explicit ColorClassifier(rclcpp::Node * node_ptr);
+  virtual ~ColorClassifier() = default;
+
+  bool getTrafficSignal(
+    const cv::Mat & input_image,
+    autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal) override;
+
+private:
+  bool filterHSV(
+    const cv::Mat & input_image, cv::Mat & green_image, cv::Mat & yellow_image,
+    cv::Mat & red_image);
+  rcl_interfaces::msg::SetParametersResult parametersCallback(
+    const std::vector<rclcpp::Parameter> & parameters);
+
+private:
+  enum HSV {
+    Hue = 0,
+    Sat = 1,
+    Val = 2,
+  };
+  image_transport::Publisher image_pub_;
+
+  rclcpp::node_interfaces::OnSetParametersCallbackHandle::SharedPtr set_param_res_;
+  rclcpp::Node * node_ptr_;
+
+  HSVConfig hsv_config_;
+  cv::Scalar min_hsv_green_;
+  cv::Scalar max_hsv_green_;
+  cv::Scalar min_hsv_yellow_;
+  cv::Scalar max_hsv_yellow_;
+  cv::Scalar min_hsv_red_;
+  cv::Scalar max_hsv_red_;
+};
+
+}  // namespace traffic_light
+
+#endif  // TRAFFIC_LIGHT_CLASSIFIER__COLOR_CLASSIFIER_HPP_

+ 90 - 0
src/detection/detection_trafficlight_classify/include/traffic_light_classifier/nodelet.hpp

@@ -0,0 +1,90 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TRAFFIC_LIGHT_CLASSIFIER__NODELET_HPP_
+#define TRAFFIC_LIGHT_CLASSIFIER__NODELET_HPP_
+
+#include "traffic_light_classifier/classifier_interface.hpp"
+
+#include <image_transport/image_transport.hpp>
+#include <image_transport/subscriber_filter.hpp>
+#include <rclcpp/rclcpp.hpp>
+
+#include <autoware_auto_perception_msgs/msg/traffic_light.hpp>
+#include <autoware_auto_perception_msgs/msg/traffic_light_roi_array.hpp>
+#include <autoware_auto_perception_msgs/msg/traffic_signal.hpp>
+#include <autoware_auto_perception_msgs/msg/traffic_signal_array.hpp>
+#include <sensor_msgs/image_encodings.hpp>
+#include <sensor_msgs/msg/image.hpp>
+#include <std_msgs/msg/header.hpp>
+
+#include <cv_bridge/cv_bridge.h>
+#include <message_filters/subscriber.h>
+#include <message_filters/sync_policies/approximate_time.h>
+#include <message_filters/synchronizer.h>
+#include <message_filters/time_synchronizer.h>
+
+#include <memory>
+#include <mutex>
+
+#if ENABLE_GPU
+#include "traffic_light_classifier/cnn_classifier.hpp"
+#endif
+
+#include "traffic_light_classifier/color_classifier.hpp"
+
+#include <opencv2/core/core.hpp>
+#include <opencv2/highgui/highgui.hpp>
+
+namespace traffic_light
+{
+class TrafficLightClassifierNodelet : public rclcpp::Node
+{
+public:
+  explicit TrafficLightClassifierNodelet(const rclcpp::NodeOptions & options);
+  void imageRoiCallback(
+    const sensor_msgs::msg::Image::ConstSharedPtr & input_image_msg,
+    const autoware_auto_perception_msgs::msg::TrafficLightRoiArray::ConstSharedPtr &
+      input_rois_msg);
+
+  enum ClassifierType {
+    HSVFilter = 0,
+    CNN = 1,
+  };
+
+private:
+  void connectCb();
+
+  rclcpp::TimerBase::SharedPtr timer_;
+  image_transport::SubscriberFilter image_sub_;
+  message_filters::Subscriber<autoware_auto_perception_msgs::msg::TrafficLightRoiArray> roi_sub_;
+  typedef message_filters::sync_policies::ExactTime<
+    sensor_msgs::msg::Image, autoware_auto_perception_msgs::msg::TrafficLightRoiArray>
+    SyncPolicy;
+  typedef message_filters::Synchronizer<SyncPolicy> Sync;
+  std::shared_ptr<Sync> sync_;
+  typedef message_filters::sync_policies::ApproximateTime<
+    sensor_msgs::msg::Image, autoware_auto_perception_msgs::msg::TrafficLightRoiArray>
+    ApproximateSyncPolicy;
+  typedef message_filters::Synchronizer<ApproximateSyncPolicy> ApproximateSync;
+  std::shared_ptr<ApproximateSync> approximate_sync_;
+  bool is_approximate_sync_;
+  rclcpp::Publisher<autoware_auto_perception_msgs::msg::TrafficSignalArray>::SharedPtr
+    traffic_signal_array_pub_;
+  std::shared_ptr<ClassifierInterface> classifier_ptr_;
+};
+
+}  // namespace traffic_light
+
+#endif  // TRAFFIC_LIGHT_CLASSIFIER__NODELET_HPP_

+ 28 - 0
src/detection/detection_trafficlight_classify/main.cpp

@@ -0,0 +1,28 @@
+#include <QCoreApplication>
+
+#include "traffic_light_classifier/cnn_classifier.hpp"
+
+using namespace traffic_light;
+
+using namespace cv;
+
+int main(int argc, char *argv[])
+{
+    QCoreApplication a(argc, argv);
+
+    CNNClassifier * pclass;
+    pclass =  new CNNClassifier();
+
+    std::cout<<" load. "<<std::endl;
+
+
+
+    Mat mat = imread("/home/nvidia/tra1.jpg", IMREAD_ANYCOLOR);
+
+    int time1 = std::chrono::system_clock::now().time_since_epoch().count()/1000000;
+    pclass->getTrafficSignal(mat);
+    int time2 = std::chrono::system_clock::now().time_since_epoch().count()/1000000;
+    std::cout<<" infer use: "<<(time2-time1)<<std::endl;
+
+    return a.exec();
+}

+ 377 - 0
src/detection/detection_trafficlight_classify/src/cnn_classifier.cpp

@@ -0,0 +1,377 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "traffic_light_classifier/cnn_classifier.hpp"
+
+#include <opencv2/core.hpp>
+
+//#include <ament_index_cpp/get_package_share_directory.hpp>
+
+#include "opencv2/imgproc/imgproc.hpp"
+
+#include <boost/algorithm/string/classification.hpp>
+#include <boost/algorithm/string/split.hpp>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace traffic_light
+{
+//CNNClassifier::CNNClassifier(rclcpp::Node * node_ptr) : node_ptr_(node_ptr)
+//{
+//  image_pub_ = image_transport::create_publisher(
+//    node_ptr_, "~/output/debug/image", rclcpp::QoS{1}.get_rmw_qos_profile());
+
+//  std::string precision;
+//  std::string label_file_path;
+//  std::string model_file_path;
+//  precision = node_ptr_->declare_parameter("precision", "fp16");
+//  label_file_path = node_ptr_->declare_parameter("label_file_path", "labels.txt");
+//  model_file_path = node_ptr_->declare_parameter("model_file_path", "model.onnx");
+//  input_c_ = node_ptr_->declare_parameter("input_c", 3);
+//  input_h_ = node_ptr_->declare_parameter("input_h", 224);
+//  input_w_ = node_ptr_->declare_parameter("input_w", 224);
+//  auto input_name = node_ptr_->declare_parameter("input_name", "input_0");
+//  auto output_name = node_ptr_->declare_parameter("output_name", "output_0");
+//  apply_softmax_ = node_ptr_->declare_parameter("apply_softmax", true);
+
+//  readLabelfile(label_file_path, labels_);
+
+//  trt_ = std::make_shared<Tn::TrtCommon>(model_file_path, precision, input_name, output_name);
+//  trt_->setup();
+
+//  if (node_ptr_->declare_parameter("build_only", false)) {
+//    RCLCPP_INFO(node_ptr_->get_logger(), "TensorRT engine is built and shutdown node.");
+//    rclcpp::shutdown();
+//  }
+//}
+
+CNNClassifier::CNNClassifier()
+{
+
+  std::string precision;
+  std::string label_file_path;
+  std::string model_file_path;
+  precision = "fp16";//node_ptr_->declare_parameter("precision", "fp16");
+  label_file_path = "labels.txt";//node_ptr_->declare_parameter("label_file_path", "labels.txt");
+  model_file_path = "model.onnx";//node_ptr_->declare_parameter("model_file_path", "model.onnx");
+  input_c_ = 3;//node_ptr_->declare_parameter("input_c", 3);
+  input_h_ = 224; //node_ptr_->declare_parameter("input_h", 224);
+  input_w_ = 224; //node_ptr_->declare_parameter("input_w", 224);
+  auto input_name = "input_0";//node_ptr_->declare_parameter("input_name", "input_0");
+  auto output_name = "output_0";//node_ptr_->declare_parameter("output_name", "output_0");
+  apply_softmax_ = true;//node_ptr_->declare_parameter("apply_softmax", true);
+
+  readLabelfile(label_file_path, labels_);
+
+  trt_ = std::make_shared<Tn::TrtCommon>(model_file_path, precision, input_name, output_name);
+  trt_->setup();
+
+//  if (node_ptr_->declare_parameter("build_only", false)) {
+//    RCLCPP_INFO(node_ptr_->get_logger(), "TensorRT engine is built and shutdown node.");
+//    rclcpp::shutdown();
+//  }
+}
+
+
+bool CNNClassifier::getTrafficSignal(
+  const cv::Mat & input_image)
+{
+  if (!trt_->isInitialized()) {
+      std::cout<<"failed to init tensorrt"<<std::endl;
+  //  RCLCPP_WARN(node_ptr_->get_logger(), "failed to init tensorrt");
+    return false;
+  }
+
+  int num_input = trt_->getNumInput();
+  int num_output = trt_->getNumOutput();
+
+  std::vector<float> input_data_host(num_input);
+
+  cv::Mat image = input_image.clone();
+  preProcess(image, input_data_host, true);
+
+  auto input_data_device = Tn::make_unique<float[]>(num_input);
+  cudaMemcpy(
+    input_data_device.get(), input_data_host.data(), num_input * sizeof(float),
+    cudaMemcpyHostToDevice);
+
+  auto output_data_device = Tn::make_unique<float[]>(num_output);
+
+  // do inference
+  std::vector<void *> bindings = {input_data_device.get(), output_data_device.get()};
+
+  trt_->context_->executeV2(bindings.data());
+
+  std::vector<float> output_data_host(num_output);
+  cudaMemcpy(
+    output_data_host.data(), output_data_device.get(), num_output * sizeof(float),
+    cudaMemcpyDeviceToHost);
+
+  int a = 1;
+  a++;
+
+  postProcess(output_data_host);
+//  postProcess(output_data_host, traffic_signal, apply_softmax_);
+
+  /* debug */
+//  if (0 < image_pub_.getNumSubscribers()) {
+//    cv::Mat debug_image = input_image.clone();
+//    outputDebugImage(debug_image, traffic_signal);
+//  }
+
+  return true;
+}
+
+
+bool CNNClassifier::postProcess(
+        std::vector<float> & output_tensor)
+{
+
+    std::vector<float> probs;
+    int num_output = trt_->getNumOutput();
+
+        calcSoftmax(output_tensor, probs, num_output);
+
+    std::vector<size_t> sorted_indices = argsort(output_tensor, num_output);
+
+    std::cout<<" lable: "<<labels_[sorted_indices[0]].c_str()<<" score: "<<probs[sorted_indices[0]] * 100<<std::endl;
+    // ROS_INFO("label: %s, score: %.2f\%",
+    //          labels_[sorted_indices[0]].c_str(),
+    //          probs[sorted_indices[0]] * 100);
+
+
+    return true;
+}
+
+//bool CNNClassifier::getTrafficSignal(
+//  const cv::Mat & input_image, autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal)
+//{
+//  if (!trt_->isInitialized()) {
+//    RCLCPP_WARN(node_ptr_->get_logger(), "failed to init tensorrt");
+//    return false;
+//  }
+
+//  int num_input = trt_->getNumInput();
+//  int num_output = trt_->getNumOutput();
+
+//  std::vector<float> input_data_host(num_input);
+
+//  cv::Mat image = input_image.clone();
+//  preProcess(image, input_data_host, true);
+
+//  auto input_data_device = Tn::make_unique<float[]>(num_input);
+//  cudaMemcpy(
+//    input_data_device.get(), input_data_host.data(), num_input * sizeof(float),
+//    cudaMemcpyHostToDevice);
+
+//  auto output_data_device = Tn::make_unique<float[]>(num_output);
+
+//  // do inference
+//  std::vector<void *> bindings = {input_data_device.get(), output_data_device.get()};
+
+//  trt_->context_->executeV2(bindings.data());
+
+//  std::vector<float> output_data_host(num_output);
+//  cudaMemcpy(
+//    output_data_host.data(), output_data_device.get(), num_output * sizeof(float),
+//    cudaMemcpyDeviceToHost);
+
+//  postProcess(output_data_host, traffic_signal, apply_softmax_);
+
+//  /* debug */
+//  if (0 < image_pub_.getNumSubscribers()) {
+//    cv::Mat debug_image = input_image.clone();
+//    outputDebugImage(debug_image, traffic_signal);
+//  }
+
+//  return true;
+//}
+
+//void CNNClassifier::outputDebugImage(
+//  cv::Mat & debug_image, const autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal)
+//{
+//  float probability;
+//  std::string label;
+//  for (std::size_t i = 0; i < traffic_signal.lights.size(); i++) {
+//    auto light = traffic_signal.lights.at(i);
+//    const auto light_label = state2label_[light.color] + "-" + state2label_[light.shape];
+//    label += light_label;
+//    // all lamp confidence are the same
+//    probability = light.confidence;
+//    if (i < traffic_signal.lights.size() - 1) {
+//      label += ",";
+//    }
+//  }
+
+//  const int expand_w = 200;
+//  const int expand_h =
+//    std::max(static_cast<int>((expand_w * debug_image.rows) / debug_image.cols), 1);
+
+//  cv::resize(debug_image, debug_image, cv::Size(expand_w, expand_h));
+//  cv::Mat text_img(cv::Size(expand_w, 50), CV_8UC3, cv::Scalar(0, 0, 0));
+//  std::string text = label + " " + std::to_string(probability);
+//  cv::putText(
+//    text_img, text, cv::Point(5, 25), cv::FONT_HERSHEY_COMPLEX, 0.5, cv::Scalar(0, 255, 0), 1);
+//  cv::vconcat(debug_image, text_img, debug_image);
+
+//  const auto debug_image_msg =
+//    cv_bridge::CvImage(std_msgs::msg::Header(), "rgb8", debug_image).toImageMsg();
+//  image_pub_.publish(debug_image_msg);
+//}
+
+void CNNClassifier::preProcess(cv::Mat & image, std::vector<float> & input_tensor, bool normalize)
+{
+  /* normalize */
+  /* ((channel[0] / 255) - mean[0]) / std[0] */
+
+  // cv::cvtColor(image, image, cv::COLOR_BGR2RGB, 3);
+  cv::resize(image, image, cv::Size(input_w_, input_h_));
+
+  const size_t strides_cv[3] = {
+    static_cast<size_t>(input_w_ * input_c_), static_cast<size_t>(input_c_), 1};
+  const size_t strides[3] = {
+    static_cast<size_t>(input_h_ * input_w_), static_cast<size_t>(input_w_), 1};
+
+  for (int i = 0; i < input_h_; i++) {
+    for (int j = 0; j < input_w_; j++) {
+      for (int k = 0; k < input_c_; k++) {
+        const size_t offset_cv = i * strides_cv[0] + j * strides_cv[1] + k * strides_cv[2];
+        const size_t offset = k * strides[0] + i * strides[1] + j * strides[2];
+        if (normalize) {
+          input_tensor[offset] =
+            ((static_cast<float>(image.data[offset_cv]) / 255) - mean_[k]) / std_[k];
+        } else {
+          input_tensor[offset] = static_cast<float>(image.data[offset_cv]);
+        }
+      }
+    }
+  }
+}
+
+//bool CNNClassifier::postProcess(
+//  std::vector<float> & output_tensor,
+//  autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal, bool apply_softmax)
+//{
+//  std::vector<float> probs;
+//  int num_output = trt_->getNumOutput();
+//  if (apply_softmax) {
+//    calcSoftmax(output_tensor, probs, num_output);
+//  }
+//  std::vector<size_t> sorted_indices = argsort(output_tensor, num_output);
+
+//  // ROS_INFO("label: %s, score: %.2f\%",
+//  //          labels_[sorted_indices[0]].c_str(),
+//  //          probs[sorted_indices[0]] * 100);
+
+//  size_t max_indice = sorted_indices.front();
+//  std::string match_label = labels_[max_indice];
+//  float probability = apply_softmax ? probs[max_indice] : output_tensor[max_indice];
+
+//  // label names are assumed to be comma-separated to represent each lamp
+//  // e.g.
+//  // match_label: "red","red-cross","right"
+//  // split_label: ["red","red-cross","right"]
+//  // if shape doesn't have color suffix, set GREEN to color state.
+//  // if color doesn't have shape suffix, set CIRCLE to shape state.
+//  std::vector<std::string> split_label;
+//  boost::algorithm::split(split_label, match_label, boost::is_any_of(","));
+//  for (auto label : split_label) {
+//    if (label2state_.find(label) == label2state_.end()) {
+//      RCLCPP_DEBUG(
+//        node_ptr_->get_logger(), "cnn_classifier does not have a key [%s]", label.c_str());
+//      continue;
+//    }
+//    autoware_auto_perception_msgs::msg::TrafficLight light;
+//    if (label.find("-") != std::string::npos) {
+//      // found "-" delimiter in label string
+//      std::vector<std::string> color_and_shape;
+//      boost::algorithm::split(color_and_shape, label, boost::is_any_of("-"));
+//      light.color = label2state_[color_and_shape.at(0)];
+//      light.shape = label2state_[color_and_shape.at(1)];
+//    } else {
+//      if (label == state2label_[autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN]) {
+//        light.color = autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN;
+//        light.shape = autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN;
+//      } else if (isColorLabel(label)) {
+//        light.color = label2state_[label];
+//        light.shape = autoware_auto_perception_msgs::msg::TrafficLight::CIRCLE;
+//      } else {
+//        light.color = autoware_auto_perception_msgs::msg::TrafficLight::GREEN;
+//        light.shape = label2state_[label];
+//      }
+//    }
+//    light.confidence = probability;
+//    traffic_signal.lights.push_back(light);
+//  }
+
+//  return true;
+//}
+
+bool CNNClassifier::readLabelfile(std::string filepath, std::vector<std::string> & labels)
+{
+  std::ifstream labelsFile(filepath);
+  if (!labelsFile.is_open()) {
+      std::cout<<"Could not open label file"<<std::endl;
+ //   RCLCPP_ERROR(node_ptr_->get_logger(), "Could not open label file. [%s]", filepath.c_str());
+    return false;
+  }
+  std::string label;
+  while (getline(labelsFile, label)) {
+    labels.push_back(label);
+  }
+  return true;
+}
+
+void CNNClassifier::calcSoftmax(
+  std::vector<float> & data, std::vector<float> & probs, int num_output)
+{
+  float exp_sum = 0.0;
+  for (int i = 0; i < num_output; ++i) {
+    exp_sum += exp(data[i]);
+  }
+
+  for (int i = 0; i < num_output; ++i) {
+    probs.push_back(exp(data[i]) / exp_sum);
+  }
+}
+
+std::vector<size_t> CNNClassifier::argsort(std::vector<float> & tensor, int num_output)
+{
+  std::vector<size_t> indices(num_output);
+  for (int i = 0; i < num_output; i++) {
+    indices[i] = i;
+  }
+  std::sort(indices.begin(), indices.begin() + num_output, [tensor](size_t idx1, size_t idx2) {
+    return tensor[idx1] > tensor[idx2];
+  });
+
+  return indices;
+}
+
+bool CNNClassifier::isColorLabel(const std::string label)
+{
+//  using autoware_auto_perception_msgs::msg::TrafficSignal;
+//  if (
+//    label == state2label_[autoware_auto_perception_msgs::msg::TrafficLight::GREEN] ||
+//    label == state2label_[autoware_auto_perception_msgs::msg::TrafficLight::AMBER] ||
+//    label == state2label_[autoware_auto_perception_msgs::msg::TrafficLight::RED] ||
+//    label == state2label_[autoware_auto_perception_msgs::msg::TrafficLight::WHITE]) {
+//    return true;
+//  }
+  return false;
+}
+
+}  // namespace traffic_light

+ 248 - 0
src/detection/detection_trafficlight_classify/src/color_classifier.cpp

@@ -0,0 +1,248 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "traffic_light_classifier/color_classifier.hpp"
+
+#include <opencv2/imgproc/imgproc_c.h>
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+namespace traffic_light
+{
+ColorClassifier::ColorClassifier(rclcpp::Node * node_ptr) : node_ptr_(node_ptr)
+{
+  using std::placeholders::_1;
+  image_pub_ = image_transport::create_publisher(
+    node_ptr_, "~/debug/image", rclcpp::QoS{1}.get_rmw_qos_profile());
+
+  hsv_config_.green_min_h = node_ptr_->declare_parameter("green_min_h", 50);
+  hsv_config_.green_min_s = node_ptr_->declare_parameter("green_min_s", 100);
+  hsv_config_.green_min_v = node_ptr_->declare_parameter("green_min_v", 150);
+  hsv_config_.green_max_h = node_ptr_->declare_parameter("green_max_h", 120);
+  hsv_config_.green_max_s = node_ptr_->declare_parameter("green_max_s", 200);
+  hsv_config_.green_max_v = node_ptr_->declare_parameter("green_max_v", 255);
+  hsv_config_.yellow_min_h = node_ptr_->declare_parameter("yellow_min_h", 0);
+  hsv_config_.yellow_min_s = node_ptr_->declare_parameter("yellow_min_s", 80);
+  hsv_config_.yellow_min_v = node_ptr_->declare_parameter("yellow_min_v", 150);
+  hsv_config_.yellow_max_h = node_ptr_->declare_parameter("yellow_max_h", 50);
+  hsv_config_.yellow_max_s = node_ptr_->declare_parameter("yellow_max_s", 200);
+  hsv_config_.yellow_max_v = node_ptr_->declare_parameter("yellow_max_v", 255);
+  hsv_config_.red_min_h = node_ptr_->declare_parameter("red_min_h", 160);
+  hsv_config_.red_min_s = node_ptr_->declare_parameter("red_min_s", 100);
+  hsv_config_.red_min_v = node_ptr_->declare_parameter("red_min_v", 150);
+  hsv_config_.red_max_h = node_ptr_->declare_parameter("red_max_h", 180);
+  hsv_config_.red_max_s = node_ptr_->declare_parameter("red_max_s", 255);
+  hsv_config_.red_max_v = node_ptr_->declare_parameter("red_max_v", 255);
+
+  // set parameter callback
+  set_param_res_ = node_ptr_->add_on_set_parameters_callback(
+    std::bind(&ColorClassifier::parametersCallback, this, _1));
+}
+
+bool ColorClassifier::getTrafficSignal(
+  const cv::Mat & input_image, autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal)
+{
+  cv::Mat green_image;
+  cv::Mat yellow_image;
+  cv::Mat red_image;
+  filterHSV(input_image, green_image, yellow_image, red_image);
+  // binarize
+  cv::Mat green_bin_image;
+  cv::Mat yellow_bin_image;
+  cv::Mat red_bin_image;
+  const int bin_threshold = 127;
+  cv::threshold(green_image, green_bin_image, bin_threshold, 255, cv::THRESH_BINARY);
+  cv::threshold(yellow_image, yellow_bin_image, bin_threshold, 255, cv::THRESH_BINARY);
+  cv::threshold(red_image, red_bin_image, bin_threshold, 255, cv::THRESH_BINARY);
+  // filter noise
+  cv::Mat green_filtered_bin_image;
+  cv::Mat yellow_filtered_bin_image;
+  cv::Mat red_filtered_bin_image;
+  cv::Mat element4 = (cv::Mat_<uchar>(3, 3) << 0, 1, 0, 1, 1, 1, 0, 1, 0);
+  cv::erode(green_bin_image, green_filtered_bin_image, element4, cv::Point(-1, -1), 1);
+  cv::erode(yellow_bin_image, yellow_filtered_bin_image, element4, cv::Point(-1, -1), 1);
+  cv::erode(red_bin_image, red_filtered_bin_image, element4, cv::Point(-1, -1), 1);
+  cv::dilate(green_filtered_bin_image, green_filtered_bin_image, cv::Mat(), cv::Point(-1, -1), 1);
+  cv::dilate(yellow_filtered_bin_image, yellow_filtered_bin_image, cv::Mat(), cv::Point(-1, -1), 1);
+  cv::dilate(red_filtered_bin_image, red_filtered_bin_image, cv::Mat(), cv::Point(-1, -1), 1);
+
+  /* debug */
+#if 1
+  if (0 < image_pub_.getNumSubscribers()) {
+    cv::Mat debug_raw_image;
+    cv::Mat debug_green_image;
+    cv::Mat debug_yellow_image;
+    cv::Mat debug_red_image;
+    cv::hconcat(input_image, input_image, debug_raw_image);
+    cv::hconcat(debug_raw_image, input_image, debug_raw_image);
+    cv::hconcat(green_image, green_bin_image, debug_green_image);
+    cv::hconcat(debug_green_image, green_filtered_bin_image, debug_green_image);
+    cv::hconcat(yellow_image, yellow_bin_image, debug_yellow_image);
+    cv::hconcat(debug_yellow_image, yellow_filtered_bin_image, debug_yellow_image);
+    cv::hconcat(red_image, red_bin_image, debug_red_image);
+    cv::hconcat(debug_red_image, red_filtered_bin_image, debug_red_image);
+
+    cv::Mat debug_image;
+    cv::vconcat(debug_green_image, debug_yellow_image, debug_image);
+    cv::vconcat(debug_image, debug_red_image, debug_image);
+    cv::cvtColor(debug_image, debug_image, cv::COLOR_GRAY2RGB);
+    cv::vconcat(debug_raw_image, debug_image, debug_image);
+    const int width = input_image.cols;
+    const int height = input_image.rows;
+    cv::line(
+      debug_image, cv::Point(0, 0), cv::Point(debug_image.cols, 0), cv::Scalar(255, 255, 255), 1,
+      CV_AA, 0);
+    cv::line(
+      debug_image, cv::Point(0, height), cv::Point(debug_image.cols, height),
+      cv::Scalar(255, 255, 255), 1, CV_AA, 0);
+    cv::line(
+      debug_image, cv::Point(0, height * 2), cv::Point(debug_image.cols, height * 2),
+      cv::Scalar(255, 255, 255), 1, CV_AA, 0);
+    cv::line(
+      debug_image, cv::Point(0, height * 3), cv::Point(debug_image.cols, height * 3),
+      cv::Scalar(255, 255, 255), 1, CV_AA, 0);
+
+    cv::line(
+      debug_image, cv::Point(0, 0), cv::Point(0, debug_image.rows), cv::Scalar(255, 255, 255), 1,
+      CV_AA, 0);
+    cv::line(
+      debug_image, cv::Point(width, 0), cv::Point(width, debug_image.rows),
+      cv::Scalar(255, 255, 255), 1, CV_AA, 0);
+    cv::line(
+      debug_image, cv::Point(width * 2, 0), cv::Point(width * 2, debug_image.rows),
+      cv::Scalar(255, 255, 255), 1, CV_AA, 0);
+    cv::line(
+      debug_image, cv::Point(width * 3, 0), cv::Point(width * 3, debug_image.rows),
+      cv::Scalar(255, 255, 255), 1, CV_AA, 0);
+
+    cv::putText(
+      debug_image, "green", cv::Point(0, height * 1.5), cv::FONT_HERSHEY_SIMPLEX, 1.0,
+      cv::Scalar(255, 255, 255), 1, CV_AA);
+    cv::putText(
+      debug_image, "yellow", cv::Point(0, height * 2.5), cv::FONT_HERSHEY_SIMPLEX, 1.0,
+      cv::Scalar(255, 255, 255), 1, CV_AA);
+    cv::putText(
+      debug_image, "red", cv::Point(0, height * 3.5), cv::FONT_HERSHEY_SIMPLEX, 1.0,
+      cv::Scalar(255, 255, 255), 1, CV_AA);
+    const auto debug_image_msg =
+      cv_bridge::CvImage(std_msgs::msg::Header(), "bgr8", debug_image).toImageMsg();
+    image_pub_.publish(debug_image_msg);
+  }
+#endif
+  /* --- */
+
+  const int green_pixel_num = cv::countNonZero(green_filtered_bin_image);
+  const int yellow_pixel_num = cv::countNonZero(yellow_filtered_bin_image);
+  const int red_pixel_num = cv::countNonZero(red_filtered_bin_image);
+  const double green_ratio =
+    static_cast<double>(green_pixel_num) /
+    static_cast<double>(green_filtered_bin_image.rows * green_filtered_bin_image.cols);
+  const double yellow_ratio =
+    static_cast<double>(yellow_pixel_num) /
+    static_cast<double>(yellow_filtered_bin_image.rows * yellow_filtered_bin_image.cols);
+  const double red_ratio =
+    static_cast<double>(red_pixel_num) /
+    static_cast<double>(red_filtered_bin_image.rows * red_filtered_bin_image.cols);
+
+  if (yellow_ratio < green_ratio && red_ratio < green_ratio) {
+    autoware_auto_perception_msgs::msg::TrafficLight light;
+    light.color = autoware_auto_perception_msgs::msg::TrafficLight::GREEN;
+    light.confidence = std::min(1.0, static_cast<double>(green_pixel_num) / (20.0 * 20.0));
+    traffic_signal.lights.push_back(light);
+  } else if (green_ratio < yellow_ratio && red_ratio < yellow_ratio) {
+    autoware_auto_perception_msgs::msg::TrafficLight light;
+    light.color = autoware_auto_perception_msgs::msg::TrafficLight::AMBER;
+    light.confidence = std::min(1.0, static_cast<double>(yellow_pixel_num) / (20.0 * 20.0));
+    traffic_signal.lights.push_back(light);
+  } else if (green_ratio < red_ratio && yellow_ratio < red_ratio) {
+    autoware_auto_perception_msgs::msg::TrafficLight light;
+    light.color = ::autoware_auto_perception_msgs::msg::TrafficLight::RED;
+    light.confidence = std::min(1.0, static_cast<double>(red_pixel_num) / (20.0 * 20.0));
+    traffic_signal.lights.push_back(light);
+  } else {
+    autoware_auto_perception_msgs::msg::TrafficLight light;
+    light.color = ::autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN;
+    light.confidence = 0.0;
+    traffic_signal.lights.push_back(light);
+  }
+  return true;
+}
+
+bool ColorClassifier::filterHSV(
+  const cv::Mat & input_image, cv::Mat & green_image, cv::Mat & yellow_image, cv::Mat & red_image)
+{
+  cv::Mat hsv_image;
+  cv::cvtColor(input_image, hsv_image, cv::COLOR_BGR2HSV);
+  try {
+    cv::inRange(hsv_image, min_hsv_green_, max_hsv_green_, green_image);
+    cv::inRange(hsv_image, min_hsv_yellow_, max_hsv_yellow_, yellow_image);
+    cv::inRange(hsv_image, min_hsv_red_, max_hsv_red_, red_image);
+  } catch (cv::Exception & e) {
+    RCLCPP_ERROR(node_ptr_->get_logger(), "failed to filter image by hsv value : %s", e.what());
+    return false;
+  }
+  return true;
+}
+rcl_interfaces::msg::SetParametersResult ColorClassifier::parametersCallback(
+  const std::vector<rclcpp::Parameter> & parameters)
+{
+  auto update_param = [&](const std::string & name, int & v) {
+    auto it = std::find_if(
+      parameters.cbegin(), parameters.cend(),
+      [&name](const rclcpp::Parameter & parameter) { return parameter.get_name() == name; });
+    if (it != parameters.cend()) {
+      v = it->as_int();
+      return true;
+    }
+    return false;
+  };
+
+  update_param("green_min_h", hsv_config_.green_min_h);
+  update_param("green_min_s", hsv_config_.green_min_s);
+  update_param("green_min_v", hsv_config_.green_min_v);
+  update_param("green_max_h", hsv_config_.green_max_h);
+  update_param("green_max_s", hsv_config_.green_max_s);
+  update_param("green_max_v", hsv_config_.green_max_v);
+  update_param("yellow_min_h", hsv_config_.yellow_min_h);
+  update_param("yellow_min_s", hsv_config_.yellow_min_s);
+  update_param("yellow_min_v", hsv_config_.yellow_min_v);
+  update_param("yellow_max_h", hsv_config_.yellow_max_h);
+  update_param("yellow_max_s", hsv_config_.yellow_max_s);
+  update_param("yellow_max_v", hsv_config_.yellow_max_v);
+  update_param("red_min_h", hsv_config_.red_min_h);
+  update_param("red_min_s", hsv_config_.red_min_s);
+  update_param("red_min_v", hsv_config_.red_min_v);
+  update_param("red_max_h", hsv_config_.red_max_h);
+  update_param("red_max_s", hsv_config_.red_max_s);
+  update_param("red_max_v", hsv_config_.red_max_v);
+
+  min_hsv_green_ =
+    cv::Scalar(hsv_config_.green_min_h, hsv_config_.green_min_s, hsv_config_.green_min_v);
+  max_hsv_green_ =
+    cv::Scalar(hsv_config_.green_max_h, hsv_config_.green_max_s, hsv_config_.green_max_v);
+  min_hsv_yellow_ =
+    cv::Scalar(hsv_config_.yellow_min_h, hsv_config_.yellow_min_s, hsv_config_.yellow_min_v);
+  max_hsv_yellow_ =
+    cv::Scalar(hsv_config_.yellow_max_h, hsv_config_.yellow_max_s, hsv_config_.yellow_max_v);
+  min_hsv_red_ = cv::Scalar(hsv_config_.red_min_h, hsv_config_.red_min_s, hsv_config_.red_min_v);
+  max_hsv_red_ = cv::Scalar(hsv_config_.red_max_h, hsv_config_.red_max_s, hsv_config_.red_max_v);
+
+  rcl_interfaces::msg::SetParametersResult result;
+  result.successful = true;
+  result.reason = "success";
+  return result;
+}
+
+}  // namespace traffic_light

+ 116 - 0
src/detection/detection_trafficlight_classify/src/nodelet.cpp

@@ -0,0 +1,116 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "traffic_light_classifier/nodelet.hpp"
+
+#include <iostream>
+#include <memory>
+#include <utility>
+#include <vector>
+
+namespace traffic_light
+{
+TrafficLightClassifierNodelet::TrafficLightClassifierNodelet(const rclcpp::NodeOptions & options)
+: Node("traffic_light_classifier_node", options)
+{
+  using std::placeholders::_1;
+  using std::placeholders::_2;
+  is_approximate_sync_ = this->declare_parameter("approximate_sync", false);
+  if (is_approximate_sync_) {
+    approximate_sync_.reset(new ApproximateSync(ApproximateSyncPolicy(10), image_sub_, roi_sub_));
+    approximate_sync_->registerCallback(
+      std::bind(&TrafficLightClassifierNodelet::imageRoiCallback, this, _1, _2));
+  } else {
+    sync_.reset(new Sync(SyncPolicy(10), image_sub_, roi_sub_));
+    sync_->registerCallback(
+      std::bind(&TrafficLightClassifierNodelet::imageRoiCallback, this, _1, _2));
+  }
+
+  traffic_signal_array_pub_ =
+    this->create_publisher<autoware_auto_perception_msgs::msg::TrafficSignalArray>(
+      "~/output/traffic_signals", rclcpp::QoS{1});
+
+  using std::chrono_literals::operator""ms;
+  timer_ = rclcpp::create_timer(
+    this, get_clock(), 100ms, std::bind(&TrafficLightClassifierNodelet::connectCb, this));
+
+  int classifier_type = this->declare_parameter(
+    "classifier_type", static_cast<int>(TrafficLightClassifierNodelet::ClassifierType::HSVFilter));
+  if (classifier_type == TrafficLightClassifierNodelet::ClassifierType::HSVFilter) {
+    classifier_ptr_ = std::make_shared<ColorClassifier>(this);
+  } else if (classifier_type == TrafficLightClassifierNodelet::ClassifierType::CNN) {
+#if ENABLE_GPU
+    classifier_ptr_ = std::make_shared<CNNClassifier>(this);
+#else
+    RCLCPP_ERROR(
+      this->get_logger(), "please install CUDA, CUDNN and TensorRT to use cnn classifier");
+#endif
+  }
+}
+
+void TrafficLightClassifierNodelet::connectCb()
+{
+  // set callbacks only when there are subscribers to this node
+  if (
+    traffic_signal_array_pub_->get_subscription_count() == 0 &&
+    traffic_signal_array_pub_->get_intra_process_subscription_count() == 0) {
+    image_sub_.unsubscribe();
+    roi_sub_.unsubscribe();
+  } else if (!image_sub_.getSubscriber()) {
+    image_sub_.subscribe(this, "~/input/image", "raw", rmw_qos_profile_sensor_data);
+    roi_sub_.subscribe(this, "~/input/rois", rclcpp::QoS{1}.get_rmw_qos_profile());
+  }
+}
+
+void TrafficLightClassifierNodelet::imageRoiCallback(
+  const sensor_msgs::msg::Image::ConstSharedPtr & input_image_msg,
+  const autoware_auto_perception_msgs::msg::TrafficLightRoiArray::ConstSharedPtr & input_rois_msg)
+{
+  if (classifier_ptr_.use_count() == 0) {
+    return;
+  }
+
+  cv_bridge::CvImagePtr cv_ptr;
+  try {
+    cv_ptr = cv_bridge::toCvCopy(input_image_msg, sensor_msgs::image_encodings::RGB8);
+  } catch (cv_bridge::Exception & e) {
+    RCLCPP_ERROR(
+      this->get_logger(), "Could not convert from '%s' to 'rgb8'.",
+      input_image_msg->encoding.c_str());
+  }
+
+  autoware_auto_perception_msgs::msg::TrafficSignalArray output_msg;
+
+  for (size_t i = 0; i < input_rois_msg->rois.size(); ++i) {
+    const sensor_msgs::msg::RegionOfInterest & roi = input_rois_msg->rois.at(i).roi;
+    cv::Mat clipped_image(
+      cv_ptr->image, cv::Rect(roi.x_offset, roi.y_offset, roi.width, roi.height));
+
+    autoware_auto_perception_msgs::msg::TrafficSignal traffic_signal;
+    traffic_signal.map_primitive_id = input_rois_msg->rois.at(i).id;
+    if (!classifier_ptr_->getTrafficSignal(clipped_image, traffic_signal)) {
+      RCLCPP_ERROR(this->get_logger(), "failed classify image, abort callback");
+      return;
+    }
+    output_msg.signals.push_back(traffic_signal);
+  }
+
+  output_msg.header = input_image_msg->header;
+  traffic_signal_array_pub_->publish(output_msg);
+}
+
+}  // namespace traffic_light
+
+#include <rclcpp_components/register_node_macro.hpp>
+
+RCLCPP_COMPONENTS_REGISTER_NODE(traffic_light::TrafficLightClassifierNodelet)

+ 164 - 0
src/detection/detection_trafficlight_classify/src/single_image_debug_inference_node.cpp

@@ -0,0 +1,164 @@
+// Copyright 2023 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <rclcpp/rclcpp.hpp>
+
+#if ENABLE_GPU
+#include <traffic_light_classifier/cnn_classifier.hpp>
+#endif
+
+#include <traffic_light_classifier/color_classifier.hpp>
+#include <traffic_light_classifier/nodelet.hpp>
+
+#include <memory>
+#include <string>
+
+namespace
+{
+std::string toString(const uint8_t state)
+{
+  if (state == autoware_auto_perception_msgs::msg::TrafficLight::RED) {
+    return "red";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::AMBER) {
+    return "yellow";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::GREEN) {
+    return "green";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::WHITE) {
+    return "white";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::CIRCLE) {
+    return "circle";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::LEFT_ARROW) {
+    return "left";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::RIGHT_ARROW) {
+    return "right";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::UP_ARROW) {
+    return "straight";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_ARROW) {
+    return "down";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_LEFT_ARROW) {
+    return "down_left";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_RIGHT_ARROW) {
+    return "down_right";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::CROSS) {
+    return "cross";
+  } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN) {
+    return "unknown";
+  } else {
+    return "";
+  }
+}
+}  // namespace
+
+namespace traffic_light
+{
+class SingleImageDebugInferenceNode : public rclcpp::Node
+{
+public:
+  explicit SingleImageDebugInferenceNode(const rclcpp::NodeOptions & node_options)
+  : Node("single_image_debug_inference", node_options)
+  {
+    const auto image_path = declare_parameter("image_path", "");
+
+    int classifier_type = this->declare_parameter(
+      "classifier_type",
+      static_cast<int>(TrafficLightClassifierNodelet::ClassifierType::HSVFilter));
+    if (classifier_type == TrafficLightClassifierNodelet::ClassifierType::HSVFilter) {
+      classifier_ptr_ = std::make_unique<ColorClassifier>(this);
+    } else if (classifier_type == TrafficLightClassifierNodelet::ClassifierType::CNN) {
+#if ENABLE_GPU
+      classifier_ptr_ = std::make_unique<CNNClassifier>(this);
+#else
+      RCLCPP_ERROR(get_logger(), "please install CUDA, CUDNN and TensorRT to use cnn classifier");
+#endif
+    }
+
+    image_ = cv::imread(image_path);
+    if (image_.empty()) {
+      RCLCPP_ERROR(get_logger(), "image is empty");
+      return;
+    }
+    cv::namedWindow("inference image", cv::WINDOW_NORMAL);
+    cv::setMouseCallback("inference image", SingleImageDebugInferenceNode::onMouse, this);
+
+    cv::imshow("inference image", image_);
+
+    // loop until q character is pressed
+    while (cv::waitKey(0) != 113) {
+    }
+    cv::destroyAllWindows();
+    rclcpp::shutdown();
+  }
+
+private:
+  static void onMouse(int event, int x, int y, int flags, void * param)
+  {
+    SingleImageDebugInferenceNode * node = static_cast<SingleImageDebugInferenceNode *>(param);
+    if (node) {
+      node->inferWithCrop(event, x, y, flags);
+    }
+  }
+
+  void inferWithCrop(int action, int x, int y, [[maybe_unused]] int flags)
+  {
+    if (action == cv::EVENT_LBUTTONDOWN) {
+      top_left_corner_ = cv::Point(x, y);
+    } else if (action == cv::EVENT_LBUTTONUP) {
+      bottom_right_corner_ = cv::Point(x, y);
+      cv::Mat tmp = image_.clone();
+      cv::Mat crop = image_(cv::Rect{top_left_corner_, bottom_right_corner_}).clone();
+      if (crop.empty()) {
+        RCLCPP_ERROR(get_logger(), "crop image is empty");
+        return;
+      }
+      cv::cvtColor(crop, crop, cv::COLOR_BGR2RGB);
+      autoware_auto_perception_msgs::msg::TrafficSignal traffic_signal;
+      if (!classifier_ptr_->getTrafficSignal(crop, traffic_signal)) {
+        RCLCPP_ERROR(get_logger(), "failed to classify image");
+        return;
+      }
+      cv::Scalar color;
+      cv::Scalar text_color;
+      for (const auto & light : traffic_signal.lights) {
+        auto color_str = toString(light.color);
+        auto shape_str = toString(light.shape);
+        auto confidence_str = std::to_string(light.confidence);
+        if (shape_str == "circle") {
+          if (color_str == "red") {
+            color = cv::Scalar(0, 0, 255);
+          } else if (color_str == "green") {
+            color = cv::Scalar(0, 255, 0);
+          } else if (color_str == "yellow") {
+            color = cv::Scalar(0, 255, 255);
+          } else if (color_str == "white") {
+            color = cv::Scalar(0, 0, 0);
+          } else {
+            color = cv::Scalar(255, 255, 255);
+          }
+        }
+        RCLCPP_INFO_STREAM(get_logger(), color_str << " " << shape_str << " " << confidence_str);
+      }
+      cv::rectangle(tmp, top_left_corner_, bottom_right_corner_, color, 2, 8);
+      cv::imshow("inference image", tmp);
+    }
+  }
+
+  cv::Point top_left_corner_;
+  cv::Point bottom_right_corner_;
+  cv::Mat image_;
+  std::unique_ptr<ClassifierInterface> classifier_ptr_;
+};
+}  // namespace traffic_light
+
+#include "rclcpp_components/register_node_macro.hpp"
+RCLCPP_COMPONENTS_REGISTER_NODE(traffic_light::SingleImageDebugInferenceNode)

+ 167 - 0
src/detection/detection_trafficlight_classify/utils/trt_common.cpp

@@ -0,0 +1,167 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <trt_common.hpp>
+
+#if (defined(_MSC_VER) or (defined(__GNUC__) and (7 <= __GNUC_MAJOR__)))
+#include <filesystem>
+namespace fs = ::std::filesystem;
+#else
+#include <experimental/filesystem>
+namespace fs = ::std::experimental::filesystem;
+#endif
+
+#include <functional>
+#include <string>
+
+namespace Tn
+{
+void check_error(const ::cudaError_t e, decltype(__FILE__) f, decltype(__LINE__) n)
+{
+  if (e != ::cudaSuccess) {
+    std::stringstream s;
+    s << ::cudaGetErrorName(e) << " (" << e << ")@" << f << "#L" << n << ": "
+      << ::cudaGetErrorString(e);
+    throw std::runtime_error{s.str()};
+  }
+}
+
+TrtCommon::TrtCommon(
+  std::string model_path, std::string precision, std::string input_name, std::string output_name)
+: model_file_path_(model_path),
+  precision_(precision),
+  input_name_(input_name),
+  output_name_(output_name),
+  is_initialized_(false)
+{
+  runtime_ = UniquePtr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger_));
+}
+
+void TrtCommon::setup()
+{
+  const fs::path path(model_file_path_);
+  std::string extension = path.extension().string();
+
+  if (fs::exists(path)) {
+    if (extension == ".engine") {
+      loadEngine(model_file_path_);
+    } else if (extension == ".onnx") {
+      fs::path cache_engine_path{model_file_path_};
+      cache_engine_path.replace_extension("engine");
+      if (fs::exists(cache_engine_path)) {
+        loadEngine(cache_engine_path.string());
+      } else {
+        logger_.log(nvinfer1::ILogger::Severity::kINFO, "start build engine");
+        buildEngineFromOnnx(model_file_path_, cache_engine_path.string());
+        logger_.log(nvinfer1::ILogger::Severity::kINFO, "end build engine");
+      }
+    } else {
+      is_initialized_ = false;
+      return;
+    }
+  } else {
+    is_initialized_ = false;
+    return;
+  }
+
+  context_ = UniquePtr<nvinfer1::IExecutionContext>(engine_->createExecutionContext());
+
+#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
+  input_dims_ = engine_->getTensorShape(input_name_.c_str());
+  output_dims_ = engine_->getTensorShape(output_name_.c_str());
+#else
+  // Deprecated since 8.5
+  input_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(input_name_.c_str()));
+  output_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(output_name_.c_str()));
+#endif
+
+  is_initialized_ = true;
+}
+
+bool TrtCommon::loadEngine(std::string engine_file_path)
+{
+  std::ifstream engine_file(engine_file_path);
+  std::stringstream engine_buffer;
+  engine_buffer << engine_file.rdbuf();
+  std::string engine_str = engine_buffer.str();
+  engine_ = UniquePtr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(
+    reinterpret_cast<const void *>(engine_str.data()), engine_str.size()));
+  return true;
+}
+
+bool TrtCommon::buildEngineFromOnnx(std::string onnx_file_path, std::string output_engine_file_path)
+{
+  auto builder = UniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(logger_));
+  const auto explicitBatch =
+    1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
+  auto network = UniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
+  auto config = UniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
+
+  auto parser = UniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, logger_));
+  if (!parser->parseFromFile(
+        onnx_file_path.c_str(), static_cast<int>(nvinfer1::ILogger::Severity::kERROR))) {
+    return false;
+  }
+
+#if (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 8400
+  config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 16 << 20);
+#else
+  config->setMaxWorkspaceSize(16 << 20);
+#endif
+
+  if (precision_ == "fp16") {
+    config->setFlag(nvinfer1::BuilderFlag::kFP16);
+  } else if (precision_ == "int8") {
+    config->setFlag(nvinfer1::BuilderFlag::kINT8);
+  } else {
+    return false;
+  }
+
+  auto plan = UniquePtr<nvinfer1::IHostMemory>(builder->buildSerializedNetwork(*network, *config));
+  if (!plan) {
+    return false;
+  }
+  engine_ =
+    UniquePtr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(plan->data(), plan->size()));
+  if (!engine_) {
+    return false;
+  }
+
+  // save engine
+  std::ofstream file;
+  file.open(output_engine_file_path, std::ios::binary | std::ios::out);
+  if (!file.is_open()) {
+    return false;
+  }
+  file.write((const char *)plan->data(), plan->size());
+  file.close();
+
+  return true;
+}
+
+bool TrtCommon::isInitialized() { return is_initialized_; }
+
+int TrtCommon::getNumInput()
+{
+  return std::accumulate(
+    input_dims_.d, input_dims_.d + input_dims_.nbDims, 1, std::multiplies<int>());
+}
+
+int TrtCommon::getNumOutput()
+{
+  return std::accumulate(
+    output_dims_.d, output_dims_.d + output_dims_.nbDims, 1, std::multiplies<int>());
+}
+
+}  // namespace Tn

+ 142 - 0
src/detection/detection_trafficlight_classify/utils/trt_common.hpp

@@ -0,0 +1,142 @@
+// Copyright 2020 Tier IV, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef PERCEPTION__TRAFFIC_LIGHT_CLASSIFIER__UTILS__TRT_COMMON_HPP_
+#define PERCEPTION__TRAFFIC_LIGHT_CLASSIFIER__UTILS__TRT_COMMON_HPP_
+
+#include <opencv2/core/core.hpp>
+#include <opencv2/highgui/highgui.hpp>
+
+#include <./cudnn.h>
+#include <NvInfer.h>
+#include <NvOnnxParser.h>
+#include <stdio.h>
+
+#include <algorithm>
+#include <chrono>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <numeric>
+#include <sstream>
+#include <string>
+
+#define CHECK_CUDA_ERROR(e) (Tn::check_error(e, __FILE__, __LINE__))
+
+namespace Tn
+{
+class Logger : public nvinfer1::ILogger
+{
+public:
+  Logger() : Logger(Severity::kINFO) {}
+
+  explicit Logger(Severity severity) : reportableSeverity(severity) {}
+
+  void log(Severity severity, const char * msg) noexcept override
+  {
+    // suppress messages with severity enum value greater than the reportable
+    if (severity > reportableSeverity) {
+      return;
+    }
+
+    switch (severity) {
+      case Severity::kINTERNAL_ERROR:
+        std::cerr << "[TRT_COMMON][INTERNAL_ERROR]: ";
+        break;
+      case Severity::kERROR:
+        std::cerr << "[TRT_COMMON][ERROR]: ";
+        break;
+      case Severity::kWARNING:
+        std::cerr << "[TRT_COMMON][WARNING]: ";
+        break;
+      case Severity::kINFO:
+        std::cerr << "[TRT_COMMON][INFO]: ";
+        break;
+      default:
+        std::cerr << "[TRT_COMMON][UNKNOWN]: ";
+        break;
+    }
+    std::cerr << msg << std::endl;
+  }
+
+  Severity reportableSeverity{Severity::kWARNING};
+};
+
+void check_error(const ::cudaError_t e, decltype(__FILE__) f, decltype(__LINE__) n);
+
+struct InferDeleter
+{
+  void operator()(void * p) const { ::cudaFree(p); }
+};
+
+template <typename T>
+using UniquePtr = std::unique_ptr<T, InferDeleter>;
+
+// auto array = Tn::make_unique<float[]>(n);
+// ::cudaMemcpy(array.get(), src_array, sizeof(float)*n, ::cudaMemcpyHostToDevice);
+template <typename T>
+typename std::enable_if<std::is_array<T>::value, Tn::UniquePtr<T>>::type make_unique(
+  const std::size_t n)
+{
+  using U = typename std::remove_extent<T>::type;
+  U * p;
+  ::cudaMalloc(reinterpret_cast<void **>(&p), sizeof(U) * n);
+  return Tn::UniquePtr<T>{p};
+}
+
+// auto value = Tn::make_unique<my_class>();
+// ::cudaMemcpy(value.get(), src_value, sizeof(my_class), ::cudaMemcpyHostToDevice);
+template <typename T>
+Tn::UniquePtr<T> make_unique()
+{
+  T * p;
+  ::cudaMalloc(reinterpret_cast<void **>(&p), sizeof(T));
+  return Tn::UniquePtr<T>{p};
+}
+
+class TrtCommon
+{
+public:
+  TrtCommon(
+    std::string model_path, std::string precision, std::string input_name, std::string output_name);
+  ~TrtCommon() {}
+
+  bool loadEngine(std::string engine_file_path);
+  bool buildEngineFromOnnx(std::string onnx_file_path, std::string output_engine_file_path);
+  void setup();
+
+  bool isInitialized();
+  int getNumInput();
+  int getNumOutput();
+
+  UniquePtr<nvinfer1::IExecutionContext> context_;
+
+private:
+  Logger logger_;
+  std::string model_file_path_;
+  UniquePtr<nvinfer1::IRuntime> runtime_;
+  UniquePtr<nvinfer1::ICudaEngine> engine_;
+
+  nvinfer1::Dims input_dims_;
+  nvinfer1::Dims output_dims_;
+  std::string cache_dir_;
+  std::string precision_;
+  std::string input_name_;
+  std::string output_name_;
+  bool is_initialized_;
+};
+
+}  // namespace Tn
+
+#endif  // PERCEPTION__TRAFFIC_LIGHT_CLASSIFIER__UTILS__TRT_COMMON_HPP_