一文搞懂 | Pytorch维度转换操作:view,reshape,permute,flatten函数详解

阿旭 阿旭算法与机器学习 2024年09月06日 15:59

公众号

小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、深度学习相关分享研究。欢迎共同学习交流!

------------

AI

1.人脸识别与管理系统2.车牌识别与管理系统
3.手势识别系统4.人脸面部活体检测系统
5.YOLOv8自动标注6.人脸表情识别系统
7.行人跌倒检测系统8.PCB板缺陷检测系统
9.安全帽检测系统
10.生活垃圾分类检测
11.火焰烟雾检测系统
12.路面坑洞检测系统
13.钢材表面缺陷检测
14.102种犬类检测系统
15.面部口罩检测系统16.西红柿成熟度检测
17.血细胞检测计数
18.舰船分类检测系统
19.吸烟行为检测
20.水稻害虫检测识别
21.车辆行人检测计数
22.小麦害虫检测识别
23.玉米害虫检测识别
24.200种鸟类检测识别
25.交通标志检测识别
26.苹果病害识别
27.肺炎诊断系统
28.100种中草药识别
29.102种花卉识别
30.100种蝴蝶识别
31.车辆行人追踪系统
32.水稻病害识别
33.车牌检测识别系统
34.草莓病害检测分割
35.复杂环境船舶检测
36.裂缝检测分析系统
37.田间杂草检测系统
38.葡萄病害识别
39.路面坑洞检测分割
40.遥感地面物体检测
41.无人机视角检测
42.木薯病害识别预防
43.野火烟雾检测
44.脑肿瘤检测
45.玉米病害检测
46.橙子病害识别
47.车辆追踪计数
48.行人追踪计数
49.反光衣检测预警
50.人员闯入报警
51.高密度人脸检测52.肾结石检测
53.水果检测识别54.蔬菜检测识别
55.水果质量检测56.非机动车头盔检测
57.螺栓螺母检测58.焊缝缺陷检测
59.金属品瑕疵检测60.链条缺陷检测
61.条形码检测识别62.交通信号灯检测

------------

引言

在深度学习网络构建与计算过程中,我们经常会使用到张量维度之间的各种转换,用于不同操作。Pytorch中常见的维度转换函数有view,reshape,permute,flatten。本文将详细介绍这几个函数的作用与使用方式,并给出了具体的代码示例,希望能够帮助大家。

常见的维度有四维:比如(batch, channel, height, width);三维:比如(b,n,c);二维:比如(h,w)。下面介绍如何使用上述函数进行维度之间的转换。

view函数

作用

tensor.view() 可以用来调整张量的形状,这对于在网络层之间传递数据或者在处理图像数据时非常有用。需要注意的是,新的形状必须与原始张量的元素数量一致。

参数

size (tuple of ints) – 新的大小应该与原张量元素数量相匹配。可以指定一个尺寸为 -1 的维度来自动计算合适的大小。

代码示例

将计算机视觉中的常见四维张量(Batch, Channel, Height, Width)转为三维(Batch,N,Channel)形式。

import torch
# view使用示例
x = torch.randn(16,3,64,64# B, C, H, W
print(x.shape) #torch.Size([16,3,64,64])
B, C, H, W = x.size()

# 转为BNC
x = x.view(B, -1, C)
# 或者 x = x.view(B, H*W, C)
print(x.shape) #torch.Size([16, 4096, 3])

torch.randn() 是 PyTorch 中的一个函数,用于生成一个填充了从标准正态分布(均值为 0,方差为 1)中随机抽取的数字的张量。

permute函数

作用permute() 函数用于改变张量的维度顺序。它接受一个新的维度顺序作为参数,并返回一个新的张量,其维度顺序按照给定的顺序排列。参数说明参数:一个元组,表示新的维度顺序。例如,对于一个形状为 (10, 3, 32, 32) 的张量,permute(0, 2, 3, 1) 表示新的维度顺序为 (10, 32, 32, 3)。其中0,1,2,3分别表示4个维度(10, 3, 32, 32)的索引。

代码示例

依然将计算机视觉中的常见四维张量(Batch, Channel, Height, Width)转为三维(Batch,N,Channel)形式。

import torch
# permute使用示例:permute转换唯独顺序
x = torch.randn(16,3,64,64# B, C, H, W
print(x.shape) #torch.Size([16,3,64,64])

# 16,3,64,64的维度索引分别为0,1,2,3
dim_change = x.permute(0,2,3,1# 转为 B,H,W,C
# 然后将中间两个通道索引为[1,2]展平
out = dim_change.flatten(start_dim=1,end_dim=2)
print(out.shape) #torch.Size([16, 4096, 3])

flatten() 方法用于展平张量的一个或多个维度。它可以接受两个可选参数:start_dim:从哪个维度开始展平,默认为 0。 

end_dim:到哪个维度结束展平,默认为 -1,表示直到最后一个维度。 

此处的作用是将第二个和第三个维度进行展平。start_dim=1 表示从第二个维度(即 64)开始展平。end_dim=2 表示到第三个维度(即 64)结束展平。展平后的结果为 (16, 4096, 3),其中 4096= 64 * 64。 

过这些步骤,你可以将原始张量从 (16,3,64,64) 转换为 (16, 4096, 3)。

Reshape函数

torch.reshape() 可以改变张量的形状,而不改变张量中的数据。与view函数的作用类似。注意事项:新旧形状的元素总数必须相同。

import torch

# 创建一个简单的张量
x = torch.randn(43)
print("Original tensor:")
print(x)

# 使用 torch.reshape() 来改变张量的形状
# 将 (4, 3) 的张量转换成 (2, 6) 的张量
reshaped_x = torch.reshape(x, (26))
print("\nReshaped tensor:")
print(reshaped_x)

# 如果不确定某个维度的大小,可以使用 -1 让 PyTorch 自动计算
# 这里将 (4, 3) 转换为 (12,) 的一维张量
flat_x = torch.reshape(x, (-1))
print("\nFlattened tensor:")
print(flat_x)

# 更复杂的形状变换
# 将 (4, 3) 转换为 (3, 4) 的张量
complex_reshaped_x = torch.reshape(x, (34))
print("\nComplex reshaped tensor:")
print(complex_reshaped_x)

flatten函数

torch.flatten 是 PyTorch 库中的一个函数,用于将一个多维张量转换为一维张量或降低其维度。

torch.flatten参数说明

input: 这是要被展平的张量。这是必需的参数。 

start_dim (可选): 指定从哪个维度开始展平。默认值为 0,这意味着展平将从第一个维度(通常是批量大小)开始。如果你希望保留前几个维度并只展平后续的维度,你可以设置这个参数。 

end_dim (可选): 指定展平到哪个维度结束。默认值为 -1,这表示展平将一直持续到最后一个维度。如果只想展平中间的一部分维度,可以设置这个参数来指定结束维度。

当 start_dim 和 end_dim 都没有被显式地指定时,torch.flatten 将会展平除了第一个维度之外的所有维度,通常第一个维度是批量大小,会被保留以便于批次处理。

代码示例

举个例子,假设你有一个形状为 [batch_size, channels, height, width] 的四维张量,如果你想将其展平为 [batch_size, channels * height * width] 的二维张量,你可以直接调用 torch.flatten 而不需要额外的参数。但是,如果你想保留通道维度,并展平高度和宽度维度,你可以设置 start_dim=1 和 end_dim=2。

import torch

# 创建一个形状为 [8, 3, 64, 64] 的随机张量
x = torch.randn(836464)

# 展平除了第一个维度外的所有维度
y = torch.flatten(x)
print(y.shape)  # 输出: torch.Size([8, 12288])

# 只展平第二和第三个维度[也就是最后两个维度],0,1,2,3
z = torch.flatten(x, 12)
print(z.shape)  # 输出: torch.Size([8, 3, 4096])

关注下方公众号:【阿旭算法与机器学习】,发送【开源】可获取更多学习资源

图片

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~ 

end



python




图片




MoileSAM