from torchsummary import summary
from vggnet import VGGNet
myNet = VGGNet()
summary(myNet, (3, 28, 28))
# print(myNet)
重点
说明我们应该把模型放到cuda上面去运行
主要修改位置
myNet = VGGNet() =>myNet = VGGNet().cuda()
整体
from torchsummary import summary
from vggnet import VGGNet
myNet = VGGNet().cuda()
summary(myNet, (3, 28, 28))
# print(myNet)
summary(myNet, (3, 28, 28))这里的(3, 28, 28)里面的3是深度,一定要和定义的一样
比如:
这里我们输入的深度是1
所以应该是summary(myNet, (1, 28, 28))