注册网站好的平台百度关键词搜索优化
 PyTorch学习之 torch.squeeze 函数  
 
一、功能
torch.squeeze 的主要作用是从给定的张量 input 中移除所有尺寸为1的维度。
二、基本语法
torch.squeeze(input, dim=None)
 
三、参数说明
input(Tensor): 输入的张量。dim(int, 可选): 指定要移除的尺寸为1的维度- 如果未指定,函数将移除所有尺寸为1的维度。
 - 如果指定的维度不为1,则 
torch.squeeze不会对该维度进行操作 - 如果所有维度都不为1且未指定 
dim参数,则返回的张量与输入张量相同 
四、返回值
- 返回一个新的张量,移除了指定的尺寸为1的维度。
 - ⚠️如果没有可以移除的维度,则返回与输入相同的张量。
 
五、示例
以下是一些使用 torch.squeeze 的示例,以帮助更好地理解其用法。
示例 1: 移除所有尺寸为1的维度
import torch# 创建一个张量,其形状为 (1, 3, 1, 5)
x = torch.randn(1, 3, 1, 5)
print("原始张量形状:", x.shape)# 使用 torch.squeeze 移除所有尺寸为1的维度
y = x.squeeze()
print("移除后张量形状:", y.shape)
 
输出:
原始张量形状: torch.Size([1, 3, 1, 5])
移除后张量形状: torch.Size([3, 5])
 
示例 2: 移除指定维度(该维度尺寸为1)
import torch# 创建一个张量,其形状为 (1, 3, 1, 5)
x = torch.randn(1, 3, 1, 5)
print("原始张量形状:", x.shape)# 指定维度移除,尝试移除第0维
y = x.squeeze(0)
print("移除第0维后的张量形状:", y.shape)# 尝试移除第2维
z = x.squeeze(2)
print("移除第2维后的张量形状:", z.shape)
 
输出:
原始张量形状: torch.Size([1, 3, 1, 5])
移除第0维后的张量形状: torch.Size([3, 1, 5])
移除第2维后的张量形状: torch.Size([1, 3, 5])
