Pytorch查看模型结构
控制台输出深度学习模型各层结构与特征图尺寸。
准备
对于2D模型和部分3D模型可以使用torchsummary:
1 | pip install torchsummary |
对于3D模型则推荐使用torchinfo:
1 | pip install torchinfo |
torchinfo算是对torchsummary的完全重写版,更为稳定,实际使用中推荐优先选择前者。
使用方法
模型定义
假设我们有一个简单的AlexNet(GPT完成)
1 | import torch |
print(model)
这是最简单直接的查看模型的方式,在定义模型:
1 | model = AlexNet2D(num_class=10) |
之后,直接使用print方法将模型的结构打印出来,但是知会打印模型init
的部分:
1 | print(model) |
效果:
1 | AlexNet2D( |
torchsummary
torchsummary.summary
中传入实例化的模型和输入参数尺寸。
1 | import torchsummary |
Output:
1 | ---------------------------------------------------------------- |
torchinfo
使用方法和torchsummary
几乎一直,只是需要手动指定batch_size
的尺寸。但是输出的结构更为清晰。
1 | summary(model, input_size=(1, 3, 224, 224), device="cuda") |
效果
1 | ========================================================================================== |
补充torchinfo
- 支持RNNs、LSTM等其他递归模型
- 查看指定层数的尺寸