如何从Pytorch的张量中弹出元素?

发布于 2025-01-20 21:17:39 字数 1296 浏览 4 评论 0原文

我想从Pytorch的张量中删除/流行元素,类似于Python的Pop Operation。在以下代码中,如果满足条件,它将从数组,当前和下一个元素中删除两个元素。 我有一个相应的pytorch张量。这意味着,如果数组的长度为10,我的张量last_hidden_​​state size (1,10,768)。在获得元素的平均值last_hidded_state [:,index-1,:],last_hidded_state [:,index,:]和last_hidden_​​state [:index+1,index+1,:]我想删除> last_hidden_​​state [:,index,:]和last_hidden_​​state [:,index+1,:]就像弹出数组中的当前和下一个元素一样。这意味着我应该得到大小(1,8,768)的张量/代码>。 我做错了什么?我是Pytorch Tensors的新手,谢谢

def function_merge (prev_el, curr_el, next_el,index, array):
    if(curr_el.startswith('##') and next_el.startswith('##')):
         array[index-1] =  prev_el + curr_el + next_el
         array.pop(index) #remove current element
         array.pop(index) #remove next element

         last_hidden_state[:,index-1,:] = torch.add(last_hidden_state[:,index-1,:],last_hidden_state[:,index,:])
         last_hidden_state[:,index-1,:] = torch.add(last_hidden_state[:,index-1,:],last_hidden_state[:,index+1,:])
         last_hidden_state[:,index-1,:] = torch.mean(last_hidden_state[:,index-1,:])
            
         last_hidden_state = torch.cat((last_hidden_state[:,:index,:],last_hidden_state[:,index+2:,:]), axis=1)
         return array, last_hidden_state
               

I want to drop/pop elements from a tensor in Pytorch, something similar to pop operation in python. In the following code , if the condition is met, it removes two elements from the array, current and the next.
I have a corresponding pytorch tensor. That means if the length of array is 10, I have the tensor last_hidden_state of size (1,10,768). After taking the mean of elements last_hidden_state[:,index-1,:], last_hidden_state[:,index,:] and last_hidden_state[:,index+1,:] I want to drop last_hidden_state[:,index,:] and last_hidden_state[:,index+1,:] Just like popping the current and next element from array. That means I should get a tensor of size (1,8,768) but with this code sometimes it returns (1,7,768) or (1,6,768).
What is it that I am doing wrong? I am new to Pytorch tensors, thank you

def function_merge (prev_el, curr_el, next_el,index, array):
    if(curr_el.startswith('##') and next_el.startswith('##')):
         array[index-1] =  prev_el + curr_el + next_el
         array.pop(index) #remove current element
         array.pop(index) #remove next element

         last_hidden_state[:,index-1,:] = torch.add(last_hidden_state[:,index-1,:],last_hidden_state[:,index,:])
         last_hidden_state[:,index-1,:] = torch.add(last_hidden_state[:,index-1,:],last_hidden_state[:,index+1,:])
         last_hidden_state[:,index-1,:] = torch.mean(last_hidden_state[:,index-1,:])
            
         last_hidden_state = torch.cat((last_hidden_state[:,:index,:],last_hidden_state[:,index+2:,:]), axis=1)
         return array, last_hidden_state
               

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文