今天使用hiddenlayer测试了下retinanet网络的可视化。
首先,安装hiddlayer,直接pip pip install git+https://github.com/waleedka/hiddenlayer.git
而后在终端加载模型并显示:python
import model, torch import hiddenlayer as hl retinanet = model.resnet18(num_classes=100, pretrained=True).cuda() x = torch.rand((1, 3, 224, 224)).cuda().float() ann = torch.tensor([[[20.0, 30.0, 53.2, 33.3, 32.0]]]).cuda().float() hl.build_graph(retinanet, [x, ann]) hl.save('/home/willer/model.pdf')
模型太复杂了,放在这里了。
昨天晚上对比着模型结构的pdf和代码又看了下,发现仍是颇有用的,起码对网络的数据流动的认识更加清晰了。git