为什么视觉变形金刚键和查询linar层不合并为一个矩阵

发布于 2025-01-26 16:52:52 字数 470 浏览 2 评论 0原文

我研究一些视觉变压器代码(例如vit-pytorch) 并在注意模块中发现:

#x is input
key=nn.Linear(...,bias=False)(x)
query=nn.Linear(...,bias=False)(x)
similar_matrix=torch.matmul(query,key.transpose(...))

由于可以将线性视为矩阵,所以我

key=K^T @ x
query=Q^T @ x
similar_matrix = query^T @ key = x^T @ (Q @ K^T) @ x
(K,Q means learnable matrix, @ means matmul, ^T means transpose)

在这里得到q @ k^t,我认为它们可以合并到矩阵中,以减少参数的量和计算,

为什么不这样做呢?是因为训练效果不好吗?

I study some vision transformers code (e.g. vit-pytorch)
and found in attention module:

#x is input
key=nn.Linear(...,bias=False)(x)
query=nn.Linear(...,bias=False)(x)
similar_matrix=torch.matmul(query,key.transpose(...))

because Linear can be considered as a matrix, I get

key=K^T @ x
query=Q^T @ x
similar_matrix = query^T @ key = x^T @ (Q @ K^T) @ x
(K,Q means learnable matrix, @ means matmul, ^T means transpose)

here Q @ K^T , I think they can be combined into a matrix in order to reduce the amount of parameters and calculation

why not do this? is it because the training effect is not good?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

口干舌燥 2025-02-02 16:52:52

让我们清除一些事情。

由于bias = false这确实是正确的想法。这意味着只有权重,但偏差将是none。快速让我们检查一下

import torch
import torch.nn as nn

m = nn.Linear(3, 4,bias=False)
n = nn.Linear(3,4,bias=False)

if hasattr(m, 'weight'):
    print(m.weight) # prints the weight
else:
    print("doesn't exist")
if hasattr(m, 'bias'):
    print(m.bias) # None
else:
    print("doesn't exist")

,让我们运行此片段,它遵循您的想法,

x = torch.randint(1,5,(2,3)).float()
m = nn.Linear(3, 4,bias=False)
n = nn.Linear(3,4,bias=False)

K = m.weight
Q = n.weight

k = m(x) # [email protected]
q = n(x) # [email protected]
print([email protected])
print( [email protected] @  K @x.T) 


#output (it would be same in your case as per the initialized input)
tensor([[-4.6655, -5.2234],
        [-6.8665, -7.6535]], grad_fn=<MmBackward0>)
tensor([[-4.6655, -5.2234],
        [-6.8665, -7.6535]], grad_fn=<MmBackward0>)

因此这两个术语都是等效的。 (确认的)。
但是,参数的数量将与其他方法完全相同。由于我们使用相同的重量矩阵和东西。

哪个更快?

timeit("m(x) @  n(x).T", globals())
timeit("[email protected] @  K @x.T", globals())

 #Output
1000 loops, best of 3: 50.1 usec per loop
1000 loops, best of 3: 31.1 usec per loop

因此,第一个花费的时间比第二个时间多。如果我们增加了大小,那么

x = torch.randint(1,5,(2000,3000)).float()
m = nn.Linear(3000, 4000,bias=False)
n = nn.Linear(3000,4000,bias=False)

K = m.weight
Q = n.weight
timeit("m(x) @  n(x).T", globals())
timeit("[email protected] @  K @x.T", globals())
# output
1 loops, best of 3: 2.93 sec per loop
1 loops, best of 3: 2.73 sec per loop

在这里也可以看到第二种情况下的性能显着提高。因此,理想情况下,人们可以像您猜到的那样使用第二个绩效提高。

但是主要的事情是,为什么他们不遵循的原因可以归因于在使用bias的未来更改的尝试中,并且在这种情况下,第一个代码将更加可读性。并符合现有标准。在这种小情况下,它可能是合适的,但对于较大的型号,这可能很丑陋。为了更好的可读性和未来的范围,我相信他们选择了第一个。

Let's clear few things up.

Since bias=False it is indeed the correct idea. That means there would be only the weight but the bias would be None. Quickly let's check it out

import torch
import torch.nn as nn

m = nn.Linear(3, 4,bias=False)
n = nn.Linear(3,4,bias=False)

if hasattr(m, 'weight'):
    print(m.weight) # prints the weight
else:
    print("doesn't exist")
if hasattr(m, 'bias'):
    print(m.bias) # None
else:
    print("doesn't exist")

Now let's run this snippet, it is following your idea,

x = torch.randint(1,5,(2,3)).float()
m = nn.Linear(3, 4,bias=False)
n = nn.Linear(3,4,bias=False)

K = m.weight
Q = n.weight

k = m(x) # [email protected]
q = n(x) # [email protected]
print([email protected])
print( [email protected] @  K @x.T) 


#output (it would be same in your case as per the initialized input)
tensor([[-4.6655, -5.2234],
        [-6.8665, -7.6535]], grad_fn=<MmBackward0>)
tensor([[-4.6655, -5.2234],
        [-6.8665, -7.6535]], grad_fn=<MmBackward0>)

So both terms are equivalent. (confirmed).
But the number of parameters would be exactly the same as other other approach. Since we use the same weight matrices and stuff.

Which is faster?

timeit("m(x) @  n(x).T", globals())
timeit("[email protected] @  K @x.T", globals())

 #Output
1000 loops, best of 3: 50.1 usec per loop
1000 loops, best of 3: 31.1 usec per loop

So the first one takes slightly more time than the second one. If we increase the size then also

x = torch.randint(1,5,(2000,3000)).float()
m = nn.Linear(3000, 4000,bias=False)
n = nn.Linear(3000,4000,bias=False)

K = m.weight
Q = n.weight
timeit("m(x) @  n(x).T", globals())
timeit("[email protected] @  K @x.T", globals())
# output
1 loops, best of 3: 2.93 sec per loop
1 loops, best of 3: 2.73 sec per loop

So here as well we see a significant performance increase in the second case. So ideally one could use second one to gain some performance increase like you guessed.

But the major thing, why they didn't follow can be attributed to the attempt of keeping the code open towards future changes where the bias is being used and in that case the first code would be much more readable and conforming to the existing standards. In this small case it might be suitable but for larger models this can be quite ugly. For better readability and future scopes I believe they selected the first one.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文