TensorRT ---- 使用自定义layer扩展 TensorRT
TensorRT支持众多layer,并且功能还在进行持续扩展;然而,可能存在支持的layer无法满足模型特定需求的情况。此时,可通过实现自定义层(通常称为插件)来扩展TensorRT。
TensorRT包含可加载至应用程序的标准插件。如需查看开源插件列表,请参阅GitHub:TensorRT插件。
要在应用程序中使用标准TensorRT插件,必须加载libnvinfer_plugin.so(Windows系统下为nvinfer_plugin.dll)库,并通过在应用程序代码中调用initLibNvInferPlugins函数完成所有插件的注册。有关这些插件的更多信息,请参阅NvInferPlugin.h文件。
如果这些插件还是无法满足需求,则可以自行编写并添加新的插件。
1.使用C++ API添加自定义layer
为了正确的识别到自定义插件,需要4步:
继承TensorRT的插件基类,目前唯一推荐的基类是IPluginV3
实现一个与自定义插件绑定的creator类,这个类要继承自TensorRT的插件creator基类,当前唯一推荐的基类是IPluginCreatorV3One
在TensorRT的插件注册表中注册插件创建者类的实例
通过直接使用TensorRT的网络API或使用TensorRT ONNX解析器API加载ONNX模型,将插件类的实例添加到TensorRT网络中
以下部分将详细探讨上面提到的每一个步骤
1.1 实现插件类
您可以通过从TensorRT的插件基类之一派生来实现自定义layer。从TensorRT 10.0开始,唯一推荐的插件接口是IPluginV3,因为其他接口已被弃用。因此,本节主要描述使用IPluginV3的插件实现。请参阅将V2插件迁移到IPluginV3一节,了解如何将实现V2插件接口的插件迁移到IPluginV3。
IPluginV3是一组功能接口的包装器,这些接口定义了三种功能:核心功能、构建和运行时
核心能力:指插件生命周期的构建和运行时阶段共有的插件属性和行为。
构建能力:指插件必须为TensorRT构建器展示的插件属性和行为。
运行时能力:指插件运行时必须表现出的插件属性和行为,以便在TensorRT构建阶段进行自动调优或在TensorRT运行时阶段的推理执行。
IPluginV3OneCore(C++,Python)、IPluginV3OneBuild(C++,Python)和IPluginV3OneRuntime(C++,Python)是IPluginV3插件必须实现的基类,以分别显示核心、构建和运行时功能。如果需要I/O别名,可以使用IPluginV3OneBuildV2(C++、Python)作为构建功能的基类,其是包含IPluginV3OneBuild功能的超集
1.2 实现插件创建者类
要在网络中使用插件,您必须首先在TensorRT的PluginRegistry(C++、Python)中注册它。与其直接注册插件,不如为插件注册一个工厂类的实例,该实例派生自IPluginCreatorInterface(C++、Python)的子类。插件创建者类还提供有关插件的其他信息:名称、版本和插件字段参数。
IPluginCreatorV3One是IPluginV3的工厂类。IPluginCreatorV3One::createPlugin(),其签名如下
C++
IPluginV3* createPlugin(AsciiChar const *name, PluginFieldCollection const *fc, TensorRTPhase phase)Python
create_plugin(self: trt.IPluginCreatorV3, name: str, field_collection: trt.PluginFieldCollection, phase: trt.TensorRTPhase) -> trt.IPluginV3IPluginCreatorV3One::createPlugin()可以在TensorRT的构建阶段或TensorRT的运行时阶段调用,以创建插件实例,这是通过TensorRTPhase(C++,Python)类型的phase参数进行通信的
返回的IPluginV3对象在每个阶段都必须具有有效的核心功能。
在构建阶段,返回的IPluginV3对象必须同时具有构建和运行时功能。
在运行时阶段,返回的IPluginV3对象必须具有运行时功能。构建功能不是必需的,会被忽略
1.3 在插件注册表中注册插件创建者
有两种方法可以在注册表中注册插件创建者:
通过调用register_TENSORRT_PLUGIN进行静态注册。REGISTER_TENSORRT_PLUGIN始终在默认命名空间(“”)下注册创建者。
通过创建类似于initLibNvInferPlugins的入口点并在插件注册表上调用registerCreator来动态注册。这比静态注册更可取,因为它允许插件在唯一的命名空间下注册。这确保了在构建过程中,不同插件库之间不会发生名称冲突。
在序列化过程中,TensorRT引擎在内部存储所有插件的插件名称、插件版本和命名空间(如果存在),以及IPluginV3OneRuntime::getFieldsToSerialize()返回的PluginFieldCollection中的任何插件字段。在反序列化过程中,TensorRT从插件注册表中查找具有相同插件名称、版本和命名空间的插件创建者,并在其上调用IPluginCreatorV3One::createPlugin()——序列化的PluginFieldCollection作为fc参数传回。
1.4 向TensorRT网络添加插件实例
您可以使用addPluginV3()将插件添加到TensorRT网络中,该方法使用给定的插件创建网络层。
例如,您可以按如下方式向网络添加插件层:
// Look up the plugin in the registry
// Cast to appropriate child class of IPluginCreatorInterface
auto creator = static_cast<IPluginCreatorV3One*>(getPluginRegistry()->getCreator(pluginName, pluginVersion, pluginNamespace));
PluginFieldCollection const* pluginFC = creator->getFieldNames();
// Populate the field parameters for the plugin layer
PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);
// Create the plugin object using the layerName and the plugin metadata for use by the TensorRT builder
IPluginV3 *pluginObj = creator->createPlugin(layerName, pluginData, TensorRTPhase::kBUILD);
// Add the plugin to the TensorRT network
auto layer = network.addPluginV3(inputs.data(), int(inputs.size()), shapeInputs.data(), int(shapeInputs.size()), pluginObj);
//… (build rest of the network and serialize engine)
// Delete the plugin object
delete pluginObj;d
// … (free allocated pluginData)前面描述的createPlugin方法在堆上创建一个新的插件对象并返回一个指针。如前所示,确保删除pluginObj以避免内存泄漏。
删除引擎后,引擎会销毁在构建过程中创建的插件对象的任何克隆。您有责任确保您创建的插件对象在添加到网络后被释放。
注意:
不要序列化所有插件参数,只序列化运行时正常运行所需的参数。构建时参数可以省略。
如果您是auto用户,则必须调用getSafePluginRegistry()而不是getPluginRegistry()。您还必须使用宏REGISTER_SAFE_TENSORRT_PLUGIN,而不是REGISTER_TENSORRT_PLUGIN。
1.5 示例:使用C++添加具有动态形状的自定义layer
考虑实现一个自定义layer实现类似填充的操作,其中输入batch中的每个图像都必须被重塑为32 x 32。输入张量X的形状为(B,C,H,W),输出张量Y的形状为(B, C, 32, 32)。为了实现这一点,可以使用IPluginV3接口编写TensorRT插件;称其为PadPlugin。
由于IPluginV3插件必须具备多种功能,每种功能都由单独的接口定义,因此您可以使用组合或多重继承的原则来实现插件。然而,对于大多数用例来说,多重继承方法更容易,特别是在需要将构建和运行时功能耦合到单个类中的情况下。
使用多重继承,PadPlugin可以实现如下
class PadPlugin : public IPluginV3, public IPluginV3OneCore, public IPluginV3OneBuild, public IPluginV3OneRuntime
{
...override inherited virtual methods.
};IPluginV3::getCapabilityInterface的重写必须返回指向各个功能接口的指针。对于每个PluginCapabilityType,必须通过相应的能力接口进行强制转换,以消除编译器的歧义
IPluginCapability* PadPlugin::getCapabilityInterface(PluginCapabilityType type) noexcept override
{
// All plugin interface methods are noexcept and care should be
// taken not to throw exceptions across the API boundary. It is
// recommended to catch any exceptions and return a value that
// appropriately represents the error status.
try
{
if (type == PluginCapabilityType::kBUILD)
{
return static_cast<IPluginV3OneBuild*>(this);
}
if (type == PluginCapabilityType::kRUNTIME)
{
return static_cast<IPluginV3OneRuntime*>(this);
}
ASSERT(type == PluginCapabilityType::kCORE);
return static_cast<IPluginV3OneCore*>(this);
}
catch(...)
{
// log error
}
return nullptr;
}在这个特定的例子中,重要的方法是:
INetworkDefinition::addPluginV3IPluginV3OneBuild::getNbOutputsIPluginV3OneBuild::getOutputDataTypesIPluginV3OneBuild::getOutputShapesIPluginV3OneBuild::supportsFormatCombinationIPluginV3OneBuild::configurePluginIPluginV3OneRuntime::onShapeChangeIPluginV3OneRuntime::enqueue
INetworkDefinition::addPluginV3 (C++, Python) 能够将插件添加到网络
std::vector<ITensor*> inputs{X};
auto pluginLayer = network->addPluginV3(inputs.data(), inputs.size(), nullptr, 0, *plugin);您可以通过覆盖IPluginV3OneBuild::getNbOutputs来传达有一个插件输出
int32_t PadPlugin::getNbOutputs() const noexcept override
{
return 1;
}输出将具有与输入相同的数据类型,可以在覆盖PluginV3OneBuild::getOutputDataTypes进行表示
int32_t PadPlugin::getOutputDataTypes(
DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override
{
outputTypes[0] = inputTypes[0];
return 0;
}getOutputShapes的重写根据输入维度返回输出维度的符号表达式,但数据相关的输出形状除外,这将在后面的示例中介绍:使用C++添加具有数据相关和形状输入相关形状的自定义层。在当前示例中,输出的前二维将分别等于输入的前二维,后二维将是常数,每个常数等于32。传递给getOutputShapes的IExprBuilder可用于定义常量符号表达式
int32_t PadPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept
{
outputs[0].nbDims = 4;
// first two output dims are equal to the first two input dims
outputs[0].d[0] = inputs[0].d[0];
outputs[0].d[1] = inputs[0].d[1];
// The last two output dims are equal to 32
outputs[0].d[2] = exprBuilder.constant(32);
outputs[0].d[3] = exprBuilder.constant(32);
return 0;
}TensorRT使用supportsFormatCombination来询问插件是否接受给定位置pos的连接的给定类型和格式组合,以及较小索引连接的给定格式/类型。该接口将输入/输出统一索引为连接,从第一个输入的0开始,然后按顺序对其余输入进行索引,然后对输出进行编号。在该示例中,输入为连接0,输出为连接1
为了简单起见,该示例仅支持线性格式和FP32类型
bool PadPlugin::supportsFormatCombination(
int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override
{
assert(0 <= pos && pos < 2);
return inOut[pos].desc.format == PluginFormat::kLINEAR && inOut[pos].desc.type == DataType::kFLOAT;
}TensorRT调用两种方法,允许插件在入队()之前做出任何配置选择,无论是在自动调优期间(在引擎构建阶段)还是在执行引擎时(在运行时阶段)。
IPluginV3OneBuild::configurePlugin:当插件准备进行分析(自动调优)时调用,但不适用于任何特定的输入大小。DynamicPluginTensorDesc的min、max和opt值对应于张量形状及其自动调整形状的边界。desc.dims字段对应于网络创建时指定的插件的维度,包括动态维度的任何通配符(-1)。
IPluginV3OneRuntime::onShapeChange:在排队()之前的构建阶段和运行阶段调用,以传达后续排队()的输入和输出形状。输出PluginTensorDesc将包含通配符(-1),用于通过getOutputShapes()指定的任何数据相关维度
这个插件不需要configurePlugin和onShapeChange来做任何事情,所以它们是空操作
int32_t PadPlugin::configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override
{
return 0;
}
int32_t PadPlugin::onShapeChange(PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override
{
return 0;
}最后,重写必须的PadPlugin::enqueue。由于形状是动态的,enqueue会收到一个PluginTensorDesc,描述每个输入和输出的维度、类型和格式。
int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) noexcept override
{
// populate outputs and return status code
}