pytorch softmax的使用

torch.nn.functional.softmax(input, dim=None)
在这里插入图片描述html

import torch
import torch.nn as nn
m = nn.Softmax(dim=1) #注意是沿着那个维度计算
input = torch.randn(2,2)
print("input:")
print(input)
output = m(input)
print(output)

#注意区分如下结果,这是两个不一样size的tensor
input1=torch.randn(2,1)
print("input1:")
print(input1)
print(m(input1))
input2=torch.randn(1,2)
print("input2:")
print(input2)
print(m(input2))

在这里插入图片描述

相关文章
相关标签/搜索