yololayer.h 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #ifndef _YOLO_LAYER_H
  2. #define _YOLO_LAYER_H
  3. #include <assert.h>
  4. #include <cmath>
  5. #include <string.h>
  6. #include <cublas_v2.h>
  7. #include "NvInfer.h"
  8. #include "Utils.h"
  9. #include <iostream>
  10. #include "NvInferPlugin.h"
  11. struct YoloKernel
  12. {
  13. int width;
  14. int height;
  15. int everyYoloAnchors;
  16. float anchors[10]; // 一组yolo输出层中 anchors的数据个数 等于 3*2, 可以设置的更大一点,这个无所谓
  17. };
  18. struct alignas(float) Detection{
  19. //x y w h
  20. float bbox[4];
  21. float det_confidence;
  22. float class_id;
  23. float class_confidence;
  24. };
  25. namespace nvinfer1
  26. {
  27. class YoloLayerPlugin: public IPluginV2IOExt
  28. {
  29. public:
  30. YoloLayerPlugin(const PluginFieldCollection& fc);
  31. YoloLayerPlugin(const void* data, size_t length);
  32. ~YoloLayerPlugin();
  33. int getNbOutputs() const override
  34. {
  35. return 1;
  36. }
  37. Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
  38. int initialize() override;
  39. virtual void terminate() override {};
  40. virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
  41. virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
  42. virtual size_t getSerializationSize() const override;
  43. virtual void serialize(void* buffer) const override;
  44. bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
  45. return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
  46. }
  47. const char* getPluginType() const override;
  48. const char* getPluginVersion() const override;
  49. void destroy() override;
  50. IPluginV2IOExt* clone() const override;
  51. void setPluginNamespace(const char* pluginNamespace) override;
  52. const char* getPluginNamespace() const override;
  53. DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
  54. bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
  55. bool canBroadcastInputAcrossBatch(int inputIndex) const override;
  56. void attachToContext(
  57. cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
  58. void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
  59. void detachFromContext() override;
  60. private:
  61. void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
  62. int mClassCount; // 检测的目标的类别,从cfg文件获取,在cfg 设置
  63. int mInput_w; // 图像输入的尺寸,从cfg获取
  64. int mInput_h; // 由于umsample层的原因,宽度和高度要想等,TODO 调整
  65. int mNumYoloLayers; // yolo输出层的数量,从cfg获取,无需设置
  66. std::vector<YoloKernel> mYoloKernel;
  67. float mIgnore_thresh = 0.4; // 置信度阈值,可以调整
  68. int max_output_box = 1000; // 最大输出数量
  69. int mThreadCount = 256; // cuda 内核函数,每一block中线程数量
  70. const char* mPluginNamespace; // 该插件名称
  71. };
  72. // 继承与IPluginCreator,重写虚函数
  73. class YoloPluginCreator : public IPluginCreator
  74. {
  75. public:
  76. YoloPluginCreator();
  77. ~YoloPluginCreator() override = default;
  78. const char* getPluginName() const override;
  79. const char* getPluginVersion() const override;
  80. const PluginFieldCollection* getFieldNames() override;
  81. // 生成插件,这个是在 build network时调用
  82. IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
  83. // 反序列化,在读取保存的trt模型engine时调用,负责解析插件
  84. IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
  85. void setPluginNamespace(const char* libNamespace) override{
  86. mNamespace = libNamespace;
  87. }
  88. const char* getPluginNamespace() const override{
  89. return mNamespace.c_str();
  90. }
  91. private:
  92. std::string mNamespace;
  93. };
  94. };
  95. #endif