当前位置: 首页 > news >正文

有网站有安全狗进不去了海南省住房和城乡建设官方网站

有网站有安全狗进不去了,海南省住房和城乡建设官方网站,wordpress果酱主题,网站个人中心模板本文尝试将pytorch搭建的ViT模型转为onnx模型。 首先将博主上一篇文章中搭建的模型ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程-CSDN博客转存为.pth torch.save(model, my_vit_model.pth) 然…

本文尝试将pytorch搭建的ViT模型转为onnx模型。

首先将博主上一篇文章中搭建的模型ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程-CSDN博客转存为.pth

torch.save(model, 'my_vit_model.pth')

然后新建一个py文件,要新建py文件的原因是,博主上一篇文章的main.py文件引用了很多torch相关的库,如果还是在main.py文件中运行转onnx的代码,回报错circle import 重复循环引用的错误,所以姑且将.pth作为一个中转。

新建一个py文件,写入

import importlib
torch = importlib.import_module('torch')model = torch.load("my_vit_model.pth")model.cpu()
# 创建一个随机的输入张量
dummy_input = torch.randn(1, 3, 16, 16)
torch.onnx.export(model, dummy_input, 'model.onnx', opset_version=18)

引入importlib,通过它来引用torch也是为了解决循环引用的问题。

这时运行这段代码,会报错onnx 不支持aten::unflatten运算。这里有两种解决方法,一种是将自己pytorch模型中的unflatten运算全部换成onnx支持的reshape函数(参见文章:https://www.cnblogs.com/antelx/p/17564039.html)

还有一种方法是,修改onnx库中的 symbolic_opset18.py 文件(/home/.local/lib/python3.8/site-packages/torch/onnx),改为如下形式

"""This file exports ONNX ops for opset 18.Note [ONNX Operators that are added/updated in opset 18]~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
New operators:CenterCropPadCol2ImMishOptionalGetElementOptionalHasElementPadResizeScatterElementsScatterND
"""import functools
from typing import Sequenceimport torch
import torch._C._onnx as _C_onnx
from torch.onnx import (_constants,_type_utils,errors,symbolic_helper,symbolic_opset11 as opset11,symbolic_opset9 as opset9,utils,
)
from torch.onnx._internal import _beartype, jit_utils, registrationfrom torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import _beartype, registration# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py__all__ = ["col2im"]_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)@_onnx_symbolic("aten::col2im")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
@_beartype.beartype
def col2im(g,input: _C.Value,output_size: _C.Value,kernel_size: _C.Value,dilation: Sequence[int],padding: Sequence[int],stride: Sequence[int],
):# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]adjusted_padding = []for pad in padding:for _ in range(2):adjusted_padding.append(pad)num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]if not adjusted_padding:adjusted_padding = [0, 0] * num_dimensional_axisif not dilation:dilation = [1] * num_dimensional_axisif not stride:stride = [1] * num_dimensional_axisreturn g.op("Col2Im",input,output_size,kernel_size,dilations_i=dilation,pads_i=adjusted_padding,strides_i=stride,)@_onnx_symbolic("aten::unflatten")
def unflatten(g:jit_utils.GraphContext, input, dim, unflattened_size):input_dim = symbolic_helper._get_tensor_rank(input)if input_dim is None:return symbolic_helper._unimplemented("dim","ONNX and PyTorch use different strategies to split the input. ""Input rank must be known at export time.",)# dim could be negativeinput_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64))dim = g.op("Add", input_dim, dim)dim = g.op("Mod", dim, input_dim)input_size = g.op("Shape", input)head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))head_end_idx = g.op("Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)))head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx)dim_plus_one = g.op("Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)))tail_start_idx = g.op("Reshape",dim_plus_one,g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)),)tail_end_idx = g.op("Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64))tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx)final_shape = g.op("Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0)return symbolic_helper._reshape_helper(g, input, final_shape)

这里这样做是相当于自己在onnx库中注册aten::unflatten运算。

再新建一个py文件,写入

import onnxruntime as rt
import numpy as np# 加载模型
sess = rt.InferenceSession("model.onnx")# 获取输入和输出名称
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name# 创建输入数据
input_data = np.random.rand(1, 3, 16, 16).astype(np.float32)# 运行模型
pred_onnx = sess.run([output_name], {input_name: input_data})# 打印预测结果
print(pred_onnx)

就可以运行onnx模型了。

http://www.yayakq.cn/news/809083/

相关文章:

  • 做教育网站需要规划哪些内容申请百度账号注册
  • 搭建企业网站宽带多大合肥网站开发cnfg
  • 目录做排名 网站电商型网站是否是趋势
  • 什么网站可以做推广的百度搜索排名购买
  • 怎样在文章后做网站链接学校ui设计培训
  • php做网站怎么健免费网站吗
  • 农村网站建设补助开一家网络公司做网站前景如何
  • 广州魔站建站网架加工费多少钱一吨
  • 网站代理维护百度网盘网页版
  • 自建商城网站用什么技术好做网站的公司都缴什么税金
  • 怎么在阿里云建网站深圳宝安大型网站建设
  • 网站开发笔试题网站建设做的快
  • 高校里做网站的工作广告设计毕业设计
  • 自适应网站优点缺点门户网站怎样做
  • 2017网站趋势移动端优化
  • 男男做暧暧视频网站网站开发设计选题背景
  • 昆山网站建设电话饮食网站开发需求
  • 怀仁有做网站的公司吗中国机械加工网19易5下2拉i
  • 阿里巴巴国际站买家入口seo实战密码第四版
  • 海淀网站设计广州玩的地方有哪些地方
  • 陕西省交通建设集团公司西商分公司网站网站开发部
  • 图书馆网站建设费用重庆网站建设哪里有
  • 无锡网站推广优化网络管理系统包括哪五大功能
  • 郑州住房和城乡建设厅网站机械设计软件solidworks
  • 搭建公司网站多少钱免费商城版网站制作
  • 网站开发合同变更.我爱你 域名网站
  • 网站开发需要什么软件有哪些东莞百度seo排名
  • 谢岗镇网站仿做wordpress 适配 手机
  • 办网站租服务器招标网公告
  • 局机关门户网站建设情况汇报影视网站开发背景