avatar

晓安哥

A text-focused Halo theme

  • 首页
  • 高性能计算
  • 关于
主页 TensorRT ---- 使用自定义layer扩展 TensorRT
文章

TensorRT ---- 使用自定义layer扩展 TensorRT

发表于 2025-11-29 更新于 2025-11- 29
作者 Administrator
28~36 分钟 阅读

TensorRT支持众多layer,并且功能还在进行持续扩展;然而,可能存在支持的layer无法满足模型特定需求的情况。此时,可通过实现自定义层(通常称为插件)来扩展TensorRT。

TensorRT包含可加载至应用程序的标准插件。如需查看开源插件列表,请参阅GitHub:TensorRT插件。

要在应用程序中使用标准TensorRT插件,必须加载libnvinfer_plugin.so(Windows系统下为nvinfer_plugin.dll)库,并通过在应用程序代码中调用initLibNvInferPlugins函数完成所有插件的注册。有关这些插件的更多信息,请参阅NvInferPlugin.h文件。

如果这些插件还是无法满足需求,则可以自行编写并添加新的插件。

1.使用C++ API添加自定义layer

为了正确的识别到自定义插件,需要4步:

  1. 继承TensorRT的插件基类,目前唯一推荐的基类是IPluginV3

  2. 实现一个与自定义插件绑定的creator类,这个类要继承自TensorRT的插件creator基类,当前唯一推荐的基类是IPluginCreatorV3One

  3. 在TensorRT的插件注册表中注册插件创建者类的实例

  4. 通过直接使用TensorRT的网络API或使用TensorRT ONNX解析器API加载ONNX模型,将插件类的实例添加到TensorRT网络中

以下部分将详细探讨上面提到的每一个步骤

1.1 实现插件类

您可以通过从TensorRT的插件基类之一派生来实现自定义layer。从TensorRT 10.0开始,唯一推荐的插件接口是IPluginV3,因为其他接口已被弃用。因此,本节主要描述使用IPluginV3的插件实现。请参阅将V2插件迁移到IPluginV3一节,了解如何将实现V2插件接口的插件迁移到IPluginV3。

IPluginV3是一组功能接口的包装器,这些接口定义了三种功能:核心功能、构建和运行时

  • 核心能力:指插件生命周期的构建和运行时阶段共有的插件属性和行为。

  • 构建能力:指插件必须为TensorRT构建器展示的插件属性和行为。

  • 运行时能力:指插件运行时必须表现出的插件属性和行为,以便在TensorRT构建阶段进行自动调优或在TensorRT运行时阶段的推理执行。

IPluginV3OneCore(C++,Python)、IPluginV3OneBuild(C++,Python)和IPluginV3OneRuntime(C++,Python)是IPluginV3插件必须实现的基类,以分别显示核心、构建和运行时功能。如果需要I/O别名,可以使用IPluginV3OneBuildV2(C++、Python)作为构建功能的基类,其是包含IPluginV3OneBuild功能的超集

1.2 实现插件创建者类

要在网络中使用插件,您必须首先在TensorRT的PluginRegistry(C++、Python)中注册它。与其直接注册插件,不如为插件注册一个工厂类的实例,该实例派生自IPluginCreatorInterface(C++、Python)的子类。插件创建者类还提供有关插件的其他信息:名称、版本和插件字段参数。

IPluginCreatorV3One是IPluginV3的工厂类。IPluginCreatorV3One::createPlugin(),其签名如下

  • C++

IPluginV3* createPlugin(AsciiChar const *name, PluginFieldCollection const *fc, TensorRTPhase phase)
  • Python

create_plugin(self: trt.IPluginCreatorV3, name: str, field_collection: trt.PluginFieldCollection, phase: trt.TensorRTPhase) -> trt.IPluginV3

IPluginCreatorV3One::createPlugin()可以在TensorRT的构建阶段或TensorRT的运行时阶段调用,以创建插件实例,这是通过TensorRTPhase(C++,Python)类型的phase参数进行通信的

  • 返回的IPluginV3对象在每个阶段都必须具有有效的核心功能。

  • 在构建阶段,返回的IPluginV3对象必须同时具有构建和运行时功能。

  • 在运行时阶段,返回的IPluginV3对象必须具有运行时功能。构建功能不是必需的,会被忽略

1.3 在插件注册表中注册插件创建者

有两种方法可以在注册表中注册插件创建者:

  1. 通过调用register_TENSORRT_PLUGIN进行静态注册。REGISTER_TENSORRT_PLUGIN始终在默认命名空间(“”)下注册创建者。

  2. 通过创建类似于initLibNvInferPlugins的入口点并在插件注册表上调用registerCreator来动态注册。这比静态注册更可取,因为它允许插件在唯一的命名空间下注册。这确保了在构建过程中,不同插件库之间不会发生名称冲突。

在序列化过程中,TensorRT引擎在内部存储所有插件的插件名称、插件版本和命名空间(如果存在),以及IPluginV3OneRuntime::getFieldsToSerialize()返回的PluginFieldCollection中的任何插件字段。在反序列化过程中,TensorRT从插件注册表中查找具有相同插件名称、版本和命名空间的插件创建者,并在其上调用IPluginCreatorV3One::createPlugin()——序列化的PluginFieldCollection作为fc参数传回。

1.4 向TensorRT网络添加插件实例

您可以使用addPluginV3()将插件添加到TensorRT网络中,该方法使用给定的插件创建网络层。

例如,您可以按如下方式向网络添加插件层:

// Look up the plugin in the registry
// Cast to appropriate child class of IPluginCreatorInterface
auto creator = static_cast<IPluginCreatorV3One*>(getPluginRegistry()->getCreator(pluginName, pluginVersion, pluginNamespace));
PluginFieldCollection const* pluginFC = creator->getFieldNames();

// Populate the field parameters for the plugin layer
PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);

// Create the plugin object using the layerName and the plugin metadata for use by the TensorRT builder
IPluginV3 *pluginObj = creator->createPlugin(layerName, pluginData, TensorRTPhase::kBUILD);

// Add the plugin to the TensorRT network
auto layer = network.addPluginV3(inputs.data(), int(inputs.size()),  shapeInputs.data(), int(shapeInputs.size()), pluginObj);
//… (build rest of the network and serialize engine)

// Delete the plugin object
delete pluginObj;d
// … (free allocated pluginData)

前面描述的createPlugin方法在堆上创建一个新的插件对象并返回一个指针。如前所示,确保删除pluginObj以避免内存泄漏。

删除引擎后,引擎会销毁在构建过程中创建的插件对象的任何克隆。您有责任确保您创建的插件对象在添加到网络后被释放。

注意:

不要序列化所有插件参数,只序列化运行时正常运行所需的参数。构建时参数可以省略。

如果您是auto用户,则必须调用getSafePluginRegistry()而不是getPluginRegistry()。您还必须使用宏REGISTER_SAFE_TENSORRT_PLUGIN,而不是REGISTER_TENSORRT_PLUGIN。

1.5 示例:使用C++添加具有动态形状的自定义layer

考虑实现一个自定义layer实现类似填充的操作,其中输入batch中的每个图像都必须被重塑为32 x 32。输入张量X的形状为(B,C,H,W),输出张量Y的形状为(B, C, 32, 32)。为了实现这一点,可以使用IPluginV3接口编写TensorRT插件;称其为PadPlugin。

由于IPluginV3插件必须具备多种功能,每种功能都由单独的接口定义,因此您可以使用组合或多重继承的原则来实现插件。然而,对于大多数用例来说,多重继承方法更容易,特别是在需要将构建和运行时功能耦合到单个类中的情况下。

使用多重继承,PadPlugin可以实现如下

class PadPlugin : public IPluginV3, public IPluginV3OneCore, public IPluginV3OneBuild, public IPluginV3OneRuntime
{
    ...override inherited virtual methods.
};

IPluginV3::getCapabilityInterface的重写必须返回指向各个功能接口的指针。对于每个PluginCapabilityType,必须通过相应的能力接口进行强制转换,以消除编译器的歧义

IPluginCapability* PadPlugin::getCapabilityInterface(PluginCapabilityType type) noexcept override
{
    // All plugin interface methods are noexcept and care should be
    // taken not to throw exceptions across the API boundary. It is
    // recommended to catch any exceptions and return a value that
    // appropriately represents the error status.
    try
    {
        if (type == PluginCapabilityType::kBUILD)
        {
            return static_cast<IPluginV3OneBuild*>(this);
        }
        if (type == PluginCapabilityType::kRUNTIME)
        {
            return static_cast<IPluginV3OneRuntime*>(this);
        }
        ASSERT(type == PluginCapabilityType::kCORE);
        return static_cast<IPluginV3OneCore*>(this);
    }
    catch(...)
    {
        // log error
    }
    return nullptr;

}

在这个特定的例子中,重要的方法是:

  • INetworkDefinition::addPluginV3

  • IPluginV3OneBuild::getNbOutputs

  • IPluginV3OneBuild::getOutputDataTypes

  • IPluginV3OneBuild::getOutputShapes

  • IPluginV3OneBuild::supportsFormatCombination

  • IPluginV3OneBuild::configurePlugin

  • IPluginV3OneRuntime::onShapeChange

  • IPluginV3OneRuntime::enqueue

INetworkDefinition::addPluginV3 (C++, Python) 能够将插件添加到网络

std::vector<ITensor*> inputs{X};

auto pluginLayer = network->addPluginV3(inputs.data(), inputs.size(), nullptr, 0, *plugin);

您可以通过覆盖IPluginV3OneBuild::getNbOutputs来传达有一个插件输出

int32_t PadPlugin::getNbOutputs() const noexcept override
{
    return 1;
}

输出将具有与输入相同的数据类型,可以在覆盖PluginV3OneBuild::getOutputDataTypes进行表示


int32_t PadPlugin::getOutputDataTypes(
        DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override
{
    outputTypes[0] = inputTypes[0];
    return 0;
}

getOutputShapes的重写根据输入维度返回输出维度的符号表达式,但数据相关的输出形状除外,这将在后面的示例中介绍:使用C++添加具有数据相关和形状输入相关形状的自定义层。在当前示例中,输出的前二维将分别等于输入的前二维,后二维将是常数,每个常数等于32。传递给getOutputShapes的IExprBuilder可用于定义常量符号表达式

int32_t PadPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept
{
    outputs[0].nbDims = 4;
    // first two output dims are equal to the first two input dims
    outputs[0].d[0] = inputs[0].d[0];
    outputs[0].d[1] = inputs[0].d[1];
    // The last two output dims are equal to 32
    outputs[0].d[2] = exprBuilder.constant(32);
    outputs[0].d[3] = exprBuilder.constant(32);
    return 0;
}

TensorRT使用supportsFormatCombination来询问插件是否接受给定位置pos的连接的给定类型和格式组合,以及较小索引连接的给定格式/类型。该接口将输入/输出统一索引为连接,从第一个输入的0开始,然后按顺序对其余输入进行索引,然后对输出进行编号。在该示例中,输入为连接0,输出为连接1

为了简单起见,该示例仅支持线性格式和FP32类型

bool PadPlugin::supportsFormatCombination(
        int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override
{
    assert(0 <= pos && pos < 2);
    return inOut[pos].desc.format == PluginFormat::kLINEAR && inOut[pos].desc.type == DataType::kFLOAT;
}

TensorRT调用两种方法,允许插件在入队()之前做出任何配置选择,无论是在自动调优期间(在引擎构建阶段)还是在执行引擎时(在运行时阶段)。

IPluginV3OneBuild::configurePlugin:当插件准备进行分析(自动调优)时调用,但不适用于任何特定的输入大小。DynamicPluginTensorDesc的min、max和opt值对应于张量形状及其自动调整形状的边界。desc.dims字段对应于网络创建时指定的插件的维度,包括动态维度的任何通配符(-1)。

IPluginV3OneRuntime::onShapeChange:在排队()之前的构建阶段和运行阶段调用,以传达后续排队()的输入和输出形状。输出PluginTensorDesc将包含通配符(-1),用于通过getOutputShapes()指定的任何数据相关维度

这个插件不需要configurePlugin和onShapeChange来做任何事情,所以它们是空操作

int32_t PadPlugin::configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override
{
    return 0;
}

int32_t PadPlugin::onShapeChange(PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override
{
    return 0;
}

最后,重写必须的PadPlugin::enqueue。由于形状是动态的,enqueue会收到一个PluginTensorDesc,描述每个输入和输出的维度、类型和格式。

int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs,
        void* const* outputs, void* workspace, cudaStream_t stream) noexcept override
{
    // populate outputs and return status code
}

许可协议:  CC BY 4.0
分享

相关文章

下一篇

TensorRT 中的量化

上一篇

TensorRT------性能优化

最近更新

  • pytorch compile ------ backend详解
  • cuda编程 --------- warp级规约操作 __shfl_xor_sync
  • TensorRT ---- Myelin
  • TensorRT------性能优化
  • TensorRT ---- 使用自定义layer扩展 TensorRT

热门标签

Halo gpu hpc

目录

©2026 晓安哥. 保留部分权利。

使用 Halo 主题 Chirpy