张板:如何查看Pytorch模型摘要?
我有以下网络。
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
class Net(nn.Module):
def __init__(self,input_shape, num_classes):
super(Net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(4,4)),
nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(4,4)),
)
x = self.conv(torch.rand(input_shape))
in_features = np.prod(x.shape)
self.classifier = nn.Sequential(
nn.Linear(in_features=in_features, out_features=num_classes),
)
def forward(self, x):
x = self.feature_extractor(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
net = Net(input_shape=(1,64,1292), num_classes=4)
print(net)
这打印了以下内容: -
Net(
(conv): Sequential(
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=(4, 4), stride=(4, 4), padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=(4, 4), stride=(4, 4), padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=320, out_features=4, bias=True)
)
)
但是,我正在尝试各种实验,我想跟踪张量板上的网络体系结构。我知道有一个函数writer.add_graph(模型,input_to_model)
,但它需要输入,或者至少应该知道其形状。
因此,我尝试了writer.add_text(“模型”,str(模型))
,但格式化在张板中拧紧。
我的问题是,有没有办法至少通过在张量板中使用打印功能可以看到我可以看到的方式?
I have the following network.
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
class Net(nn.Module):
def __init__(self,input_shape, num_classes):
super(Net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(4,4)),
nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(4,4)),
)
x = self.conv(torch.rand(input_shape))
in_features = np.prod(x.shape)
self.classifier = nn.Sequential(
nn.Linear(in_features=in_features, out_features=num_classes),
)
def forward(self, x):
x = self.feature_extractor(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
net = Net(input_shape=(1,64,1292), num_classes=4)
print(net)
This prints the following:-
Net(
(conv): Sequential(
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=(4, 4), stride=(4, 4), padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=(4, 4), stride=(4, 4), padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=320, out_features=4, bias=True)
)
)
However, I am trying various experiments and I want to keep track of network architecture on Tensorboard. I know there is a function writer.add_graph(model, input_to_model)
but it requires input, or at least its shape should be known.
So, I tried writer.add_text("model", str(model))
, but formatting is screwed up in tensorboard.
My question is, is there a way to at least visualize the way I can see by using print function in the tensorboard?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
我可以看到一切都正确,但是有一个格式的问题。 Tensorboard了解Markdown,因此您实际上可以用
\ n
用< br/>
和& nbsp;
>。这是一个详细的演练。假设您有以下模型: -
它打印以下以及是否可以在张量板中显示。
add_graph(模型,输入)
在summaryWriter
中有函数,但是您必须创建虚拟输入,在某些情况下,很难始终了解它们。取而代之的是以下: -上面在张板中产生以下文本: -
I can see everything is going right but there is just a formatting issue. Tensorboard understands markdown so you can actually replace
.
\n
with<br/>
andwith
Here is a detailed walkthrough. Suppose you have the following model:-
This prints the following and if can actually show it in the Tensorboard.
There is function in
add_graph(model, input)
inSummaryWriter
but you must create dummy input and in some cases it is difficult of to always know them. Instead do following:-Above produces following text in tensorboard:-