用张量流遮盖多螺纹遮罩

发布于 2025-02-01 23:22:24 字数 902 浏览 2 评论 0原文

我一直在尝试为我的多键层的查询和钥匙的靶向组合做一个自定义掩码,但无法找到使用此层掩蔽的方法。

这是一个带虚拟数据集(批次1)的示例:

key     = tf.ones([1, 32 , 128])
mask    = tf.concat([
    tf.concat([tf.zeros([16 , 16]) , tf.zeros([16 , 16]) ] , 0) ,
    tf.concat([tf.zeros([16 , 16]) , tf.ones([16 , 16])  ] , 0) ] , 1)
mask    = mask[tf.newaxis, tf.newaxis, : , : ]


# key shape  -> ( 1 , 32 , 128 )
# mask shape -> ( 1 , 1,  32 , 32 )

当我打印bask [0] [0] [0] .numpy()我获取:

​头部,自我注意力):

mha_layer =  tf.keras.layers.MultiHeadAttention( num_heads=1, key_dim=128 )
attention_output, attention_scores = mha_layer(  key , key , attention_mask=mask  ,  return_attention_scores=True)

我得到了注意的注意分数(activation_scores [0] [0] .numpy()):

​黄色为0.06,绿色为0.03,

我期望由于掩盖遮罩,绿色蓝色部分将为0.0。

我是否使用掩盖错误?还是不可能掩盖整个查询/密钥?

我希望我的问题有意义

I have been trying to make a custom mask for targetted combinations of queries and keys for my MultiHeadAttention layer but can not figure out the way to use this layer masking.

Here is an example with a dummy dataset (batch size 1) :

key     = tf.ones([1, 32 , 128])
mask    = tf.concat([
    tf.concat([tf.zeros([16 , 16]) , tf.zeros([16 , 16]) ] , 0) ,
    tf.concat([tf.zeros([16 , 16]) , tf.ones([16 , 16])  ] , 0) ] , 1)
mask    = mask[tf.newaxis, tf.newaxis, : , : ]


# key shape  -> ( 1 , 32 , 128 )
# mask shape -> ( 1 , 1,  32 , 32 )

when I print mask[0][0].numpy() I get :

masking

Now using the foolowing layer ( 1 head , self-attention ) :

mha_layer =  tf.keras.layers.MultiHeadAttention( num_heads=1, key_dim=128 )
attention_output, attention_scores = mha_layer(  key , key , attention_mask=mask  ,  return_attention_scores=True)

I get the folowing attention scores (attention_scores[0][0].numpy()) :

attention scores

Here the dark-violet color stands for 0.0 , yellow for 0.06 and green-blue for 0.03

I would expect to have expected the green-blue part to be 0.0s because of the masking.

Am I using the masking wrong ? or it is not possible to mask entire queries/keys ?

I hope my question makes sense ???? and that it is not too obvious.
Thank you in advance, if you can help :)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文