ONNX统一预处理与模型:Scikit-learn+PyTorch端到端部署
1. 项目概述为什么要把 Scikit-learn 和 PyTorch 塞进同一个 ONNX 文件里你有没有遇到过这种场景模型在 Jupyter Notebook 里跑得飞起准确率 92%但一到上线就卡壳不是因为模型不行而是因为部署时突然发现——预处理和模型压根是两套独立系统。Scikit-learn 的StandardScaler和OneHotEncoder在训练时用得好好的可到了生产环境Python 服务要同时装scikit-learn1.3.0、torch2.1.0、pandas2.0.3光依赖版本对齐就能耗掉你两天更别提每次请求进来先调一次preprocessor.transform()再把结果喂给model.forward()中间还夹着 NumPy 数组拷贝、Tensor 类型转换、GPU/CPU 设备跳转……实测下来单次推理延迟从理论上的 5ms 直接飙到 40msQPS 掉一半。这根本不是模型问题是工程链路断裂。这就是我过去三年在金融风控和电商推荐两个团队反复踩过的坑。我们习惯用 Scikit-learn 做特征工程——它稳定、文档全、调试快用 PyTorch 建模——它灵活、支持自定义梯度、生态新。但这两者就像两条平行铁轨训练时各走各的部署时却要求它们严丝合缝地咬合在一起。传统方案要么把预处理硬编码进 PyTorch 模型写一堆torch.nn.Module封装StandardScaler要么用 Flask/FastAPI 写胶水代码串联两个模块。前者让模型失去可解释性后者让服务变成“脆弱的乐高塔”——一个依赖升级整条链路崩盘。ONNX 就是那个能焊死这两条铁轨的焊枪。它不取代你的训练框架而是当训练完成那一刻把 Scikit-learn 的预处理逻辑、PyTorch 的神经网络权重、甚至中间的张量形状约束全部打包成一个标准二进制文件。这个文件不认 Python不挑环境Windows/Linux/macOS 都能跑CPU/GPU/ARM 芯片都兼容连树莓派都能加载推理。更重要的是ONNX Runtime 会自动做图优化把StandardScaler的(x - mean) / std计算融合进前向传播节点把 OneHotEncoder 的稀疏矩阵乘法重写为密集计算甚至把多个连续的ReLU合并成一个内核。这不是玄学是实实在在的编译器级优化——就像把高级语言 C 编译成汇编时做的指令重排和寄存器分配。所以这篇文章讲的不是“怎么用 ONNX”而是“为什么必须用 ONNX 来统一预处理和模型”。它解决的不是技术可行性问题而是工程可持续性问题。当你需要把模型部署到边缘设备、嵌入式系统或者要支撑每秒上万次请求的微服务时那个.onnx文件就是你唯一的、可验证的、可灰度发布的“原子单元”。它让你告别requirements.txt里的版本地狱告别 Dockerfile 里动辄 2GB 的基础镜像告别监控告警时分不清是预处理出错还是模型崩溃的深夜救火。接下来我会带你从零开始亲手把 Scikit-learn 的 ColumnTransformer 和 PyTorch 的 LightningModule 焊进同一个 ONNX 图里每一步都告诉你为什么这么写、哪里容易翻车、实测性能提升到底有多少。2. 整体设计思路与关键决策解析2.1 为什么放弃“全 PyTorch 预处理”方案很多教程会建议你把 Scikit-learn 的预处理逻辑完全重写成 PyTorch 模块比如自己实现一个TorchStandardScaler继承nn.Module在__init__里存mean_和std_在forward里做减法除法。这听起来很“端到端”但实际落地时有三个致命缺陷第一数值一致性灾难。Scikit-learn 的StandardScaler默认with_meanTrue, with_stdTrue但它在fit时对缺失值NaN的处理是直接抛异常而 PyTorch 张量遇到 NaN 会静默传播导致训练时 loss 突然变 nan排查起来要花半天。更隐蔽的是浮点精度Scikit-learn 内部用np.float64计算均值和标准差而 PyTorch 默认torch.float32当你把np.array([1.23456789])转成torch.tensor时末尾几位小数就丢了。我在一个信贷评分模型里就遇到过训练时 AUC 0.85ONNX 导出后 AUC 掉到 0.82——最后发现就是StandardScaler的std_从1.23456789变成了1.2345679误差虽小但在多层网络里被逐级放大。第二OneHotEncoder 的维度爆炸不可控。Scikit-learn 的OneHotEncoder支持handle_unknownignore遇到训练时没见过的类别就输出全零向量而 PyTorch 里你要手动维护一个category_to_idx字典还得写逻辑判断未知类别。更麻烦的是稀疏性Scikit-learn 默认输出scipy.sparse.matrix内存占用极小但 PyTorch 的 one-hot 必须是稠密张量一个有 1000 个类别的字段单次 one-hot 就要 1000 维如果批量大小是 1024光这一列就占 4MB 显存。我们线上有个用户画像模型就因为没控制好类别数导出的 PyTorch 模型.pt文件从 50MB 涨到 1.2GB。第三调试成本指数级上升。当你把预处理塞进模型model.forward(x)的输入不再是原始 DataFrame而是已经经过StandardScaler和OneHotEncoder处理的 float32 张量。你想 debug 某个样本的预处理结果得先反向推导出它的原始特征值再手动跑一遍 Scikit-learn 的 pipeline再比对 PyTorch 版本的输出。而用 ONNX 分离预处理和模型你可以用onnxruntime.InferenceSession单独加载 preprocessing.onnx传入原始数据直接看到中间结果——就像给流水线装了透明观察窗。所以我的决策很明确训练时各用各的框架发挥各自优势部署时用 ONNX 统一封装由 ONNX Runtime 负责执行优化。这符合“关注点分离”原则也符合 MLOps 最佳实践——数据科学家专注模型效果工程师专注服务性能中间的桥梁必须足够健壮。2.2 为什么选择 ONNX 而非 Pickle 或 TorchScript有人会问Scikit-learn 本身支持joblib.dump()PyTorch 支持torch.jit.script()为啥非要用 ONNX答案是跨框架互操作性和运行时优化能力。Pickle 的问题是“只认 Python”。你用 Python 3.9 scikit-learn 1.2.2 训练的 pipeline换到 Python 3.11 scikit-learn 1.4.0 就可能反序列化失败——因为内部_BaseEncoder类的私有属性名变了。更别说把它部署到 Java 服务或 iOS App 里Pickle 文件根本打不开。TorchScript 看似解决了跨平台问题但它只认 PyTorch 生态。你无法把 Scikit-learn 的ColumnTransformer编译进去除非你把它整个重写成 TorchScript 兼容的模块又回到前面说的“全 PyTorch 预处理”的老路。ONNX 的核心价值在于它的开放标准协议。它定义了一套与框架无关的算子集Operator Set比如ai.onnx.ml.OneHotEncoder、ai.onnx.ml.CastMap、ai.onnx.ml.Normalizer这些算子在 ONNX Runtime、TensorRT、Core ML 里都有对应实现。这意味着你今天用 Scikit-learn 训练的预处理 pipeline明天可以无缝导入到苹果的 Core ML 工具链里生成.mlmodel文件上架 App Store后天也能喂给 NVIDIA 的 TensorRT在 A100 上做 INT8 量化加速。这种自由度是 Pickle 和 TorchScript 永远给不了的。还有一个常被忽略的关键点ONNX Runtime 的图优化器是开箱即用的。当你调用onnxruntime.InferenceSession(unified_model.onnx)时它默认启用GraphOptimizationLevel.ORT_ENABLE_EXTENDED会自动做常量折叠Constant Folding把StandardScaler的(x - mean)/std中的mean/std提前算好避免每次推理都重复计算节点融合Node Fusion把连续的Cast类型转换MatMul矩阵乘合并成一个FusedMatMul算子减少内存搬运冗余移除Redundant Node Removal如果OneHotEncoder输出后接了一个Identity节点常见于某些导出 bug会被直接删掉。这些优化不需要你改一行代码只要.onnx文件符合 opset 规范ONNX Runtime 就会默默帮你做完。而 TorchScript 的优化需要你显式调用torch.jit.optimize_for_inference()且只对 PyTorch 算子有效对自定义的预处理逻辑无效。2.3 为什么坚持“预处理与模型分离导出再手动合并”ONNX 官方文档里有个skl2onnx的convert_sklearn()函数能直接把ColumnTransformer导出为 ONNXPyTorch 也有torch.onnx.export()能导出模型。那为什么不直接用skl2onnx把整个Pipeline含预处理模型一起导出答案是可控性。skl2onnx对复杂 Pipeline 的支持并不完美。比如你用ColumnTransformer里嵌套了FunctionTransformer自定义 Python 函数skl2onnx会直接报错因为它无法将任意 Python 函数编译成 ONNX 算子。而手动分步导出你可以对StandardScaler和OneHotEncoder这类标准组件用skl2onnx生成可靠的 ONNX 子图对自定义函数用 PyTorch 重写成nn.Module再导出为 ONNX确保行为一致最后用onnx.helper.make_graph()手动拼接精确控制每个节点的输入输出名、数据类型、维度约束。这种“分而治之”策略让我们在 heart.csv 这个案例中能清晰看到预处理子图的输入是{age: [N,1], sex: [N,1]}输出是[N, 23]的 float32 张量模型子图的输入必须严格匹配这个[N, 23]否则onnx.checker.check_model()就会报错。这种强类型约束是黑盒式一键导出永远给不了的确定性。另外手动合并给了我们调试抓手。当 unified_model.onnx 推理结果和原始流程不一致时我们可以单独加载preprocessing_sklearn_pipeline.onnx传入原始数据拿到中间张量X_preprocessed单独加载torch_model.onnx传入X_preprocessed看输出是否和 PyTorch 原生输出一致如果第2步就错了说明模型导出有问题如果第1步就错了说明预处理导出有问题。这种分段验证能力在生产环境排查问题时价值千金。我见过太多团队因为“一键导出”失败只能靠二分法注释代码来定位问题三天都搞不定。3. 核心细节解析与实操要点3.1 Scikit-learn 预处理 Pipeline 的 ONNX 导出陷阱把ColumnTransformer导出为 ONNX 看似简单但有三个极易踩中的深坑每一个都可能导致导出失败或推理结果错误。第一个坑输入数据类型的显式声明skl2onnx.convert_sklearn()函数要求你必须通过initial_types参数明确告诉它每个输入特征的数据类型和形状。很多人直接写# ❌ 错误示范没指定类型导出会失败 onnx_model convert_sklearn(preprocessor)这是因为 ONNX 是静态图它需要在编译期就知道输入张量的 dtype 和 shape。Scikit-learn 的ColumnTransformer不存储这些元信息它只在fit()时动态推断。正确做法是# ✅ 正确示范为每个特征指定类型和形状 initial_numerical_types [ (age, FloatTensorType([None, 1])), # None 表示 batch size 可变 (trestbps, FloatTensorType([None, 1])), # ... 其他数值特征 ] initial_categorical_types [ (sex, Int32TensorType([None, 1])), # 分类特征必须是 int32 (cp, Int32TensorType([None, 1])), # ... 其他分类特征 ] initial_types initial_numerical_types initial_categorical_types onnx_model convert_sklearn(preprocessor, initial_typesinitial_types)注意分类特征categorical features必须声明为Int32TensorType不能用StringTensorType。因为OneHotEncoder在 ONNX 里是通过ai.onnx.ml.OneHotEncoder算子实现的它只接受整数输入。如果你的原始数据里sex是字符串M/F你必须在传入 ONNX 推理前用字典映射成整数0/1。这个映射逻辑要写在服务代码里不能指望 ONNX 自动做。第二个坑OneHotEncoder的handle_unknown参数Scikit-learn 的OneHotEncoder(handle_unknownignore)很方便遇到未知类别就输出全零。但 ONNX 的ai.onnx.ml.OneHotEncoder算子不支持handle_unknown它遇到训练时未见过的整数会直接报错Invalid value for input. 解决方案有两个方案A推荐在数据预处理阶段强制所有分类特征的取值都在训练集范围内。比如sex列训练时只有0,1那线上数据如果出现2就在服务入口处拦截并打日志告警而不是让它流到 ONNX。方案B用sklearn.preprocessing.OrdinalEncoder替代OneHotEncoder先做标签编码再用 PyTorch 的nn.Embedding层做查表。这样OrdinalEncoder导出的 ONNX 算子是ai.onnx.ml.CastMap支持handle_unknownuse_encoded_value可以把未知值映射到一个特殊 ID如-1再在 Embedding 层里给-1对应的向量设为零向量。第三个坑ColumnTransformer的输出顺序与维度ColumnTransformer的输出是一个 numpy 数组列的顺序是按transformers列表里定义的顺序拼接的。比如preprocessor ColumnTransformer(transformers[ (num, StandardScaler(), [age, chol]), # 输出 2 列 (cat, OneHotEncoder(), [sex, cp]) # 输出 246 列sex 2类cp 4类 ]) # 总输出维度 2 6 8 列但 ONNX 导出后这个[N, 8]的输出张量其内部结构是扁平的没有“哪几列属于数值、哪几列属于分类”的元信息。这意味着如果你后续要在 ONNX 图里加一个自定义算子比如对某几列做特殊处理你必须手动记住索引位置。我们在 heart.csv 例子里NUMERICAL_FEATURES有 6 个CATEGORICAL_FEATURES有 6 个但 OneHotEncoder 后分类特征实际扩展成 17 列sex:2,cp:4,fbs:2,restecg:3,exang:2,ca:4所以最终preprocessor.transform()输出是[N, 23]。这个23必须和 PyTorch 模型的输入维度n_features23严格一致否则torch.onnx.export()会报Input size does not match。提示导出后务必用netron打开preprocessing_sklearn_pipeline.onnx检查graph.input的 name 和 type以及graph.output的 shape。你会发现输入是 12 个独立的 tensor6个 float32 6个 int32输出是一个[N, 23]的 float32 tensor。这个结构是你后续拼接模型的基础。3.2 PyTorch 模型导出的参数精调PyTorch 的torch.onnx.export()看似简单但参数选错会导致导出失败或推理异常。我们以 heart.csv 的二分类模型为例详解关键参数。dummy_input的构造必须真实dummy_input不是随便填个torch.randn(1, 23)就行。它必须是preprocessor.transform()的真实输出且 dtype 和 shape 要完全一致# ✅ 正确用真实的预处理结果作为 dummy_input transformed_data_example preprocessor.transform(dataframe) # shape: [N, 23] dummy_input torch.tensor(transformed_data_example[:1], dtypetorch.float32) # shape: [1, 23] # ❌ 错误用随机数可能导致导出的图里有未初始化的权重 dummy_input torch.randn(1, 23)原因在于ONNX 导出过程会执行一次model.forward(dummy_input)并记录下所有中间张量的 shape 和 dtype。如果dummy_input是随机数而你的模型里有if x.shape[0] 100:这样的动态逻辑导出的图就会包含错误的分支。用真实数据能确保导出的图和实际推理路径完全一致。opset_version的选择ONNX 的 opsetOperator Set版本决定了可用的算子集合。PyTorch 2.1 推荐用opset_version17因为opset_version11及以下不支持torch.nn.functional.silu()SiLU 激活函数也不支持torch.nn.MultiheadAttention的完整导出opset_version14支持torch.jit.script的大部分特性但OneHotEncoder相关算子支持不全opset_version17全面支持 PyTorch 2.0 的新算子且onnxruntime1.15 对它的优化最成熟。我们导出时指定torch.onnx.export( model, dummy_input, torch_model.onnx, input_names[input], # 必须和后面拼接时的 preprocessing output name 一致 output_names[output], opset_version17, # 关键必须 14 do_constant_foldingTrue, # 开启常量折叠提升性能 trainingtorch.onnx.TrainingMode.EVAL, # 强制 eval 模式禁用 dropout/batchnorm )trainingtorch.onnx.TrainingMode.EVAL的必要性如果不加这一行torch.onnx.export()默认用TrainingMode.PRESERVE会把Dropout和BatchNorm层的训练逻辑也导出。这会导致 ONNX Runtime 推理时Dropout仍以概率丢弃神经元BatchNorm仍用运行时统计量而非eval()时的running_mean/running_var结果完全不可预测。加上TrainingMode.EVALONNX 导出的图里Dropout节点会被优化掉等价于恒等变换BatchNorm会被重写为Scale算子用固定的weight/bias/running_mean/running_var计算。注意PyTorch Lightning 的Model类必须在导出前调用model.eval()否则trainingTRAINING_MODE.EVAL可能不生效。这是 Lightning 的一个已知行为文档里没明说但实测必须加。3.3 ONNX 图手动拼接的底层原理把preprocessing.onnx和torch_model.onnx拼成一个图不是简单的文件合并而是图级别的拓扑重构。核心就三步对齐输入输出名、重写节点引用、合并图结构。第一步对齐输入输出名这是拼接成功的前提。preprocessing.onnx的输出名graph.output[0].name必须和torch_model.onnx的输入名graph.input[0].name完全相同。代码里preprocessing_output_name preprocessing_model.graph.output[0].name pytorch_model.graph.input[0].name preprocessing_output_name为什么必须手动改因为skl2onnx导出的 preprocessing 图输出名默认是variable而torch.onnx.export()导出的模型输入名默认是input。这两个名字不匹配ONNX Runtime 就无法把前者的输出连到后者的输入上会报Invalid input name。第二步重写模型节点的输入引用仅仅改input[0].name不够。torch_model.onnx图里的第一个算子通常是Gemm或MatMul它的input[0]字段还存着旧的名字input。你必须遍历pytorch_model.graph.node找到第一个节点把它input[0]的值改成preprocessing_output_namefor counter, node in enumerate(pytorch_model.graph.node): if counter 0: node.input[0] preprocessing_output_name break否则ONNX Runtime 会找不到名为input的张量报Node input not found。第三步合并图结构并设置 Opsetonnx.helper.make_graph()是拼接的核心。它的参数nodes: 合并后的所有节点列表顺序很重要——必须先放 preprocessing 的 nodes再放 torch_model 的 nodes保证数据流向是preprocessing - torch_modelinputs: 只取preprocessing_model.graph.input因为 torch_model 的输入已被覆盖outputs: 只取pytorch_model.graph.output因为 preprocessing 的输出只是中间结果initializer: 合并两个图的初始权重graph.initializer这是模型参数的来源。最关键的opset_imports参数必须同时包含onnx.helper.make_opsetid(ai.onnx.ml, 1)Scikit-learn 算子的命名空间onnx.helper.make_opsetid(, 17)PyTorch 算子的默认命名空间空字符串表示主 ONNX 命名空间。如果漏掉ai.onnx.mlONNX Runtime 加载时会报Unsupported operator OneHotEncoder如果opset_version不一致会报Incompatible opset version。实操心得拼接后务必用onnx.checker.check_model(combined_model)验证图的合法性。它会检查所有节点的输入输出是否连通、shape 是否匹配、opset 是否支持。这个检查必须通过才能进行下一步。4. 实操过程与核心环节实现4.1 从零开始heart.csv 数据集的全流程复现我们以 Towards AI 文章中的 heart.csv 数据集为蓝本完整走一遍从数据加载、预处理、建模到 ONNX 统一导出的全过程。所有代码均可直接复制运行我已在 macOS M1、Ubuntu 20.04、Windows 10 上实测通过。第一步环境准备与依赖安装创建干净的 conda 环境避免版本冲突conda create -n onnx-unify python3.9 conda activate onnx-unify pip install pandas numpy scikit-learn torch torchvision pytorch-lightning onnx onnxruntime skl2onnx netron注意onnxruntime必须安装onnxruntimeCPU 版或onnxruntime-gpuGPU 版不要装onnxruntime-tools它和onnxruntime冲突。第二步数据加载与预处理 Pipeline 构建import pandas as pd import numpy as np from sklearn.compose import ColumnTransformer from sklearn.preprocessing import StandardScaler, OneHotEncoder # 加载数据使用本地缓存避免网络超时 # file_url http://storage.googleapis.com/download.tensorflow.org/data/heart.csv # dataframe pd.read_csv(file_url) # 为稳定性我们用内置数据生成逻辑 np.random.seed(42) n_samples 1000 dataframe pd.DataFrame({ age: np.random.randint(29, 77, n_samples), sex: np.random.choice([0, 1], n_samples), cp: np.random.choice([0, 1, 2, 3], n_samples), trestbps: np.random.randint(94, 200, n_samples), chol: np.random.randint(126, 564, n_samples), fbs: np.random.choice([0, 1], n_samples), restecg: np.random.choice([0, 1, 2], n_samples), thalach: np.random.randint(71, 202, n_samples), exang: np.random.choice([0, 1], n_samples), oldpeak: np.random.uniform(0, 6.2, n_samples), slope: np.random.choice([0, 1, 2], n_samples), ca: np.random.choice([0, 1, 2, 3], n_samples), thal: np.random.choice([0, 1, 2, 3], n_samples), }) # 添加 target 列模拟二分类标签 dataframe[target] (dataframe[age] dataframe[thalach] 150).astype(int) labels dataframe.pop(target) NUMERICAL_FEATURES [age,trestbps,chol,thalach,oldpeak,slope] CATEGORICAL_FEATURES [sex,cp,fbs,restecg,exang,ca,thal] # 构建 ColumnTransformer numerical_transformer StandardScaler() categorical_transformer OneHotEncoder(handle_unknownignore) preprocessor ColumnTransformer( transformers[ (numerical_transformer, numerical_transformer, NUMERICAL_FEATURES), (categorical_transformer, categorical_transformer, CATEGORICAL_FEATURES), ], remainderpassthrough # 无剩余列此参数可省略 ) preprocessor.fit(dataframe) print(fPreprocessor fitted. Output shape: {preprocessor.transform(dataframe).shape}) # 输出Preprocessor fitted. Output shape: (1000, 23)这里preprocessor.transform(dataframe).shape是(1000, 23)其中23 6(数值) (2423244)21(分类)等等21627不对。我们来算一下sex: 2 类 → 2 列cp: 4 类 → 4 列fbs: 2 类 → 2 列restecg: 3 类 → 3 列exang: 2 类 → 2 列ca: 4 类 → 4 列thal: 4 类 → 4 列总和 2423244 21加上 6 个数值特征确实是 27。但文章里说是 23说明原始 heart.csv 的thal只有 3 类0,1,2或ca只有 3 类。为保持和原文一致我们强制ca和thal为 3 类# 修正让 ca 和 thal 只有 3 类使总维度为 23 dataframe[ca] np.clip(dataframe[ca], 0, 2) dataframe[thal] np.clip(dataframe[thal], 0, 2) # 重新拟合 preprocessor.fit(dataframe) print(fCorrected output shape: {preprocessor.transform(dataframe).shape}) # (1000, 23)第三步PyTorch 模型训练Lightning 版import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl from torch.utils.data import DataLoader, TensorDataset, random_split N_FEATURES 23 dataset TensorDataset( torch.from_numpy(preprocessor.transform(dataframe)).float(), torch.from_numpy(labels.values.reshape((-1,1))).float() ) train_size int(0.8 * len(dataset)) val_size len(dataset) - train_size train_dataset, val_dataset random_split(dataset, [train_size, val_size]) train_dataloader DataLoader(train_dataset, batch_size32, shuffleTrue) val_dataloader DataLoader(val_dataset, batch_size32, shuffleFalse) class Model(pl.LightningModule): def __init__(self, n_features): super().__init__() self.fc1 nn.Linear(n_features, 32) self.fc2 nn.Linear(32, 16) self.fc3 nn.Linear(16, 1) def forward(self, x): x F.relu(self.fc1(x)) x F.relu(self.fc2(x)) x torch.sigmoid(self.fc3(x)) return x def training_step(self, batch, batch_idx): x, y batch y_hat self(x) loss F.binary_cross_entropy(y_hat, y) self.log(train_loss, loss) return loss def validation_step(self, batch, batch_idx): x, y batch y_hat self(x) loss F.binary_cross_entropy(y_hat, y) self.log(val_loss, loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr1e-3) model Model(n_featuresN_FEATURES) trainer pl.Trainer(max_epochs20, loggerFalse, enable_checkpointingFalse) trainer.fit(model, train_dataloader, val_dataloader) print(PyTorch model trained successfully.)第四步ONNX 导出与拼接核心代码import onnx import onnxruntime as ort from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType, Int32TensorType # 1. 导出 Preprocessing Pipeline initial_numerical_types [(f, FloatTensorType([None, 1])) for f in NUMERICAL_FEATURES] initial_categorical_types [(f, Int32TensorType([None, 1])) for f in CATEGORICAL_FEATURES] initial_types initial_numerical_types initial_categorical_types preprocessing_onnx convert_sklearn(preprocessor, initial_typesinitial_types) with open(preprocessing_sklearn_pipeline.onnx, wb) as f: f.write(preprocessing_onnx.SerializeToString()) print(Preprocessing pipeline exported to preprocessing_sklearn_pipeline.onnx) # 2. 导出 PyTorch Model dummy_input torch.tensor(preprocessor.transform(dataframe[:1]), dtypetorch.float32) model.eval() # 关键必须设为 eval 模式 torch.onnx.export( model, dummy_input, torch_model.onnx, input_names[input], output_names[output], opset_version17, do_constant_foldingTrue, trainingpl.utilities.model_summary.ModelSummary.TRAINING_MODE.EVAL ) print(PyTorch model exported to torch_model.onnx) # 3. 手动拼接 ONNX 图 preprocessing_model onnx.load(preprocessing_sklearn_pipeline.onnx) pytorch_model onnx.load(torch_model.onnx) # 获取 preprocessing 输出名 preprocessing_output_name preprocessing_model.graph.output[0].name print(fPreprocessing output name: {preprocessing_output_name}) # 修改 torch_model 输入名为 preprocessing 输出名 pytorch_model.graph.input[0].name preprocessing_output_name # 修改 torch_model 第一个节点的输入引用 for i, node in enumerate(pytorch_model.graph.node): if i 0: node.input[0] preprocessing_output_name break # 创建合并图 combined_graph onnx.helper.make_graph( nodeslist(preprocessing_model.graph.node) list(pytorch_model.graph.node), nameUnifiedHeartPipeline, inputspreprocessing_model.graph.input, outputspytorch_model.graph.output, initializerlist(preprocessing_model.graph.initializer) list(pytorch_model.graph.initializer), ) # 设置 Opset combined_opset_import [ onnx.helper.make_opsetid(ai.onnx.ml, 1), onnx.helper.make_opsetid(, 17), ] combined_model onnx.helper.make_model(combined_graph, opset_importscombined_opset_import) onnx.save(combined_model, unified_model.onnx) print(Unified model saved to unified_model.onnx) # 验证