可以在pytorch中学习标量重量,并保证标量的总和为1

发布于 2025-01-22 20:12:25 字数 1549 浏览 2 评论 0原文

我有这样的代码:

class MyModule(nn.Module):
    
    def __init__(self, channel, reduction=16, n_segment=8):
        super(MyModule, self).__init__()
        self.channel = channel
        self.reduction = reduction
        self.n_segment = n_segment
        
        self.conv1 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        self.conv3 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        #whatever

        # learnable weight
        self.W_1 = nn.Parameter(torch.randn(1), requires_grad=True)
        self.W_2 = nn.Parameter(torch.randn(1), requires_grad=True)
        self.W_3 = nn.Parameter(torch.randn(1), requires_grad=True)

    def forward(self, x):
        
        # whatever
        
        ## branch1                
        bottleneck_1 = self.conv1(x)
        
        ## branch2
        bottleneck_2 = self.conv2(x)
        
        ## branch3                
        bottleneck_3 = self.conv3(x)
        
        ## summation
        output = self.avg_pool(self.W_1*bottleneck_1 + 
                          self.W_2*bottleneck_2 + 
                          self.W_3*bottleneck_3) 
        
        return output

如您所见,3个可学习的标量(w_1w_2w_3)用于加权目的。但是,这种方法不能保证这些标量的总和是1。如何使我的可学习标量的总和等于pytorch中的1?谢谢

I have code like this:

class MyModule(nn.Module):
    
    def __init__(self, channel, reduction=16, n_segment=8):
        super(MyModule, self).__init__()
        self.channel = channel
        self.reduction = reduction
        self.n_segment = n_segment
        
        self.conv1 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        self.conv3 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        #whatever

        # learnable weight
        self.W_1 = nn.Parameter(torch.randn(1), requires_grad=True)
        self.W_2 = nn.Parameter(torch.randn(1), requires_grad=True)
        self.W_3 = nn.Parameter(torch.randn(1), requires_grad=True)

    def forward(self, x):
        
        # whatever
        
        ## branch1                
        bottleneck_1 = self.conv1(x)
        
        ## branch2
        bottleneck_2 = self.conv2(x)
        
        ## branch3                
        bottleneck_3 = self.conv3(x)
        
        ## summation
        output = self.avg_pool(self.W_1*bottleneck_1 + 
                          self.W_2*bottleneck_2 + 
                          self.W_3*bottleneck_3) 
        
        return output

As you see, 3 learnable scalars (W_1, W_2, and W_3) are used for weighting purpose. But, this approach will not guarantee that the sum of those scalars is 1. How to make the summation of my learnable scalars equals to 1 in Pytorch? Thanks

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

原谅我要高飞 2025-01-29 20:12:25

保持简单:

    ## summation
    WSum = self.W_1 + self.W_2 + self.W_3
    output = self.avg_pool( self.W_1/WSum *bottleneck_1 + 
                            self.W_2/WSum *bottleneck_2 + 
                            self.W_3/WSum *bottleneck_3)

另外,人们可以使用分发性法:

    output = self.avg_pool(self.W_1*bottleneck_1 + 
                      self.W_2*bottleneck_2 + 
                      self.W_3*bottleneck_3) /WSum

Keep it simple:

    ## summation
    WSum = self.W_1 + self.W_2 + self.W_3
    output = self.avg_pool( self.W_1/WSum *bottleneck_1 + 
                            self.W_2/WSum *bottleneck_2 + 
                            self.W_3/WSum *bottleneck_3)

Also, one can use distributivity law:

    output = self.avg_pool(self.W_1*bottleneck_1 + 
                      self.W_2*bottleneck_2 + 
                      self.W_3*bottleneck_3) /WSum
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文