ONNX内置预处理

之前我们研究了如何使用 Kotlin 预处理 ONNX 模型的图像输入。了解此过程很有用,因为这些原则适用于拟希望使用的任何模型。另一方面,必须编写用于输入处理的样板代码可能很繁琐 - 这也意味着有更多代码可能有错误并需要测试。

ONNXRuntime-Extensions 项目通过包含用于模型输入的常见预处理和后处理的自定义运算符来简化此过程。这意味着拟可以构建、测试和部署包含必要输入预处理和输出后处理的模型,使其更容易融入 Android 等应用程序项目。

超分辨率示例演示了如何向模型添加预处理和后处理,并仅用几行代码运行它。

1、向模型添加预处理

无需用 Kotlin 编写输入图像大小调整和其他操作,可以使用较新版本的 ONNX 运行时(例如 ORT 1.14/opset 18)或 onnxruntime_extensions 库中提供的运算符将预处理和后处理步骤添加到模型中。

超分辨率示例很好地展示了此功能,因为输入和输出都是图像,因此在模型中包含处理步骤可大大减少特定于平台的代码。

将预处理和后处理添加到超分辨率模型的步骤位于示例文档的准备模型部分。可以在 superresolution_e2e.py 中查看创建更新模型的 Python 脚本。

按照说明运行脚本时,它会生成两个 ONNX 格式的模型:

  • 基本的 pytorch_superresolution.onnx 模型
  • 包含额外处理的pytorch_superresolution_with_pre_and_post_proceessing.onnx

第二个模型(包括处理说明)可以在 Android 应用程序中调用,代码行数比我们讨论的上一个示例要少。

2、比较模型

在深入研究 Android 示例应用程序之前,可以使用在线工具 Netron.app 观察这两个模型之间的差异。Netron 有助于可视化各种神经网络、深度学习和机器学习模型。

第一个模型 - 需要对输入和输出进行本机预处理和后处理 - 显示以下内容:

图 1:转换为 ONNX 并在 Netron 中查看的超分辨率模型

这是一个相对简单的模型,现在理解图表并不重要,只需将其与下面显示的包含预处理和后处理操作的模型进行比较:

图 2:具有输入和输出处理的超分辨率模型,转换为 ONNX 并在 Netron 中查看

格式化和调整图像字节大小的附加操作在原始模型之前和之后链接。再次深入研究图表的细节并不重要,它包含在这里是为了说明附加操作已打包到模型中,以便在任何受支持的平台上以更少的本机代码更轻松地使用。

3、与 Android 集成

可以托管此模型的 Android 示例可在 GitHub 上找到。 MainActivity.kt 中的初始化步骤类似于图像分类器示例,但增加了 sessionOptions 对象,其中 ONNX 运行时扩展添加到会话中。如果没有扩展,具有额外处理的模型可能会缺少运行所需的操作。

此代码片段突出显示了创建 ONNX 运行时环境类的关键行:

  var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
  // fun onCreate
  val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions()
  sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())
  ortSession = ortEnv.createSession(readModel(), sessionOptions) // the model is in raw resources
  // fun performSuperResolution
  var superResPerformer = SuperResPerformer()
  var result = superResPerformer.upscale(readInputImage(), ortEnv, ortSession)
  // result.outputBitmap contains the output image!

superResPerformer.upscal 函数如下所示。

SuperResPerformer.kt 中运行模型的代码比图像分类器示例少很多,因为不需要本机处理(例如不再需要的 ImageUtil.kt 辅助类)。图像字节可用于创建张量,并将图像字节作为结果发出。 ortSession.run 是关键函数,它接受输入张量并返回生成的放大图像:

fun upscale(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result {
    var result = Result()
    // Step 1: convert image into byte array (raw image bytes)
    val rawImageBytes = inputStream.readBytes()
    // Step 2: get the shape of the byte array and make ort tensor
    val shape = longArrayOf(rawImageBytes.size.toLong())
    val inputTensor = OnnxTensor.createTensor(
        ortEnv,
        ByteBuffer.wrap(rawImageBytes),
        shape,
        OnnxJavaType.UINT8
    )
    inputTensor.use {
        // Step 3: call ort inferenceSession run
        val output = ortSession.run(Collections.singletonMap("image", inputTensor))
        // Step 4: output analysis
        output.use {
            val rawOutput = (output?.get(0)?.value) as ByteArray
            val outputImageBitmap =
                byteArrayToBitmap(rawOutput)
            // Step 5: set output result
            result.outputBitmap = outputImageBitmap
        }
    }
    return result
}

Android 超分辨率演示的屏幕截图如下:

图 3:Android 演示“超分辨率”,可提高图像的分辨率

这个特定的升级模型相对简单,但打包模型使其更易于使用的概念可以应用于具有预处理和后处理要求的其他模型。


原文链接:Built-in model pre-processing with ONNX

BimAnt翻译整理,转载请标明出处