pytorch维度变换

import torchui

import numpy as np图片

 

#维度变换1:view容易形成数据存储方式丢失.import

a = torch.rand(4,1,28,28)扩展

print(a.shape,a.view(4,28,28))#4,28,28 4张图片,把每张图片都合并在一块儿,即784,经常使用于全链接层;channel

print(a.view(4,28*28).shape)#torch.Size([4, 784])numpy

print(a.view(4*28,28))#把全部通道全部行都放在第一个维度,即channel和行通道合并在一块儿im

print(a.view(4*1,28,28))#数据

 

#维度展开:unsqueeze,注意能插入范围是[-5,4)这里4表明整个维度,5表明维度加1,好比0表明第一个位置前插入,1表明第二个位置前插入,3表明第三个位置前插入view

b = a.unsqueeze(0)#torch.Size([1, 4, 1, 28, 28])vi

print(b.shape)

c = a.unsqueeze(-1)#torch.Size([4, 1, 28, 28, 1])

print(c.shape)

"""

-5 -4 -3 -2 -1

0 1 2 3 4

4 1 28 28

 

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 28, 28, 1])

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 28, 1, 28])

torch.Size([4, 1, 28, 28, 1])

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 28, 1, 28])

torch.Size([4, 1, 28, 28, 1])

尽可能不使用负数

"""

for i in range(-5,5):

d = a.unsqueeze(i)

print(d.shape)

 

b = torch.rand(32)

f = torch.rand(4,32,14,14)

b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)#torch.Size([1, 32, 1, 1])

print(b.shape)


 

#维度删减squeeze

"""

torch.Size([32, 1, 1])

torch.Size([1, 32, 1, 1])

torch.Size([1, 32, 1])

torch.Size([1, 32, 1])

torch.Size([32, 1, 1])

torch.Size([1, 32, 1, 1])

torch.Size([1, 32, 1])

torch.Size([1, 32, 1])

"""

c = b.squeeze()#torch.Size([32])

print(c.shape)

for i in range(-4,4):

print(b.squeeze(i).shape)

 

#维度扩展,即把shape改变 expand改变理解方式,不增长数据,repeat增长数据;注意repeat须要拷贝数据,因此速度慢.

b = torch.rand(1,32,1,1)

a = torch.rand(4,32,14,14)

c = b.expand(4,32,14,14)

print(b.shape,a.shape,c.shape)#torch.Size([1, 32, 1, 1]) torch.Size([4, 32, 14, 14]) torch.Size([4, 32, 14, 14])

 

d = b.repeat(4,32,1,1)#这里4,32,1,1表明数据被拷贝次数;

print(d.shape)#torch.Size([4, 1024, 1, 1])这不是咱们想要结果,正确以下:

d = b.repeat(4,1,1,1)

print(d.shape)#torch.Size([4, 32, 1, 1])

 

#矩阵转置

a = torch.randn(4,3)

print(a,a.t())#t只用于二维度

 

a = torch.rand(4,3,32,32)

#b = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)#数据不连续,错误

#print(a.shape,c.shape)

b=a.transpose(1,3).contiguous().view(4,3*32*32 ).view(4,3,32,32)

c=a.transpose(1,3).contiguous().view(4,3*32*32 ).view(4,32,32,3).transpose(1,3)

print(b.shape,c.shape)#torch.Size([4, 3, 32, 32]) torch.Size([4, 32, 32, 3])

print(torch.all(torch.eq(a,b)),torch.all(torch.eq(a,c)))#tensor(0, dtype=torch.uint8) tensor(1, dtype=torch.uint8) 判断数据内容是否一致

 

d = a.permute(0,2,3,1)

print(d.shape)#torch.Size([4, 32, 32, 3]) 0,2,3,1表明存放维度数

相关文章
相关标签/搜索