在 PyTorch 中实现 1D 自注意力

发布于 2025-01-15 19:33:58 字数 2128 浏览 1 评论 0 原文

我正在尝试使用 PyTorch 实现下面的一维自注意力块:

< img src="https://i.sstatic.net/RQc9H.png" alt="在此处输入图像描述">

在以下论文。您可以在下面找到我的(临时)尝试:

import torch.nn as nn
import torch

#INPUT shape ((B), CH, H, W)


class Self_Attention1D(nn.Module):
    
    def __init__(self, in_channels=1, out_channels=3):
        
        super().__init__()
        
        self.pointwise_conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
        
        self.pointwise_conv2 = nn.Conv1d(in_channels=out_channels, out_channels=in_channels, kernel_size=(1,1))
        
        self.phi = MLP(in_size = out_channels, out_size=32)
        
        self.psi = MLP(in_size = out_channels, out_size=32)
                
        self.gamma = MLP(in_size=32, out_size=out_channels)
                
    def forward(self, x):
                
        x = self.pointwise_conv1(x)
        
        phi = self.phi(x.transpose(1,3))
        
        psi = self.psi(x.transpose(1,3))
        
        delta = phi-psi
        
        gamma = self.gamma(delta).transpose(3,1)
        
        out = self.pointwise_conv2(torch.mul(gamma,x))
        
        return out



class MLP(nn.Module):
    
    def __init__(self, in_size, out_size):
        
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        
        self.layers = nn.Sequential(
            
            nn.Linear(in_size, 64),
            nn.ReLU(),
            nn.Linear(64,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,out_size))
    
    def forward(self, x):
        
        out = self.layers(x)
        
        return out

我完全不确定这是否正确,因为我的实现中的操作是全局发生的,而如图所示,我们应该计算之间的一些操作每个条目及其邻居一次一个。我最初很想实例化一个 for 循环来迭代计算每个条目的神经网络 delta,phi,psi,但我觉得这不是正确的方法。

如果这是微不足道的,我很抱歉,但我在 PyTorch 方面仍然没有丰富的经验。

I'm trying to implement the 1D self-attention block below using PyTorch:

enter image description here

proposed in the following paper. Below you can find my (provisional) attempt:

import torch.nn as nn
import torch

#INPUT shape ((B), CH, H, W)


class Self_Attention1D(nn.Module):
    
    def __init__(self, in_channels=1, out_channels=3):
        
        super().__init__()
        
        self.pointwise_conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
        
        self.pointwise_conv2 = nn.Conv1d(in_channels=out_channels, out_channels=in_channels, kernel_size=(1,1))
        
        self.phi = MLP(in_size = out_channels, out_size=32)
        
        self.psi = MLP(in_size = out_channels, out_size=32)
                
        self.gamma = MLP(in_size=32, out_size=out_channels)
                
    def forward(self, x):
                
        x = self.pointwise_conv1(x)
        
        phi = self.phi(x.transpose(1,3))
        
        psi = self.psi(x.transpose(1,3))
        
        delta = phi-psi
        
        gamma = self.gamma(delta).transpose(3,1)
        
        out = self.pointwise_conv2(torch.mul(gamma,x))
        
        return out



class MLP(nn.Module):
    
    def __init__(self, in_size, out_size):
        
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        
        self.layers = nn.Sequential(
            
            nn.Linear(in_size, 64),
            nn.ReLU(),
            nn.Linear(64,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,out_size))
    
    def forward(self, x):
        
        out = self.layers(x)
        
        return out

I'm not sure at all that this is correct, as the operations in my implementation are happening globally while as displayed in the image we should compute some operation between each entry and its neighbours one at a time. I was initially tempted to instantiate a for loop to iteratively compute the neural networks delta,phi,psi for each entry, but I felt that it wasn't the right way to do that.

Apologies if this is trivial but I still don't have a huge experience in PyTorch.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文