如何从Pytorch的张量中弹出元素?
我想从Pytorch的张量中删除/流行元素,类似于Python的Pop Operation。在以下代码中,如果满足条件,它将从数组,当前和下一个元素中删除两个元素。 我有一个相应的pytorch张量。这意味着,如果数组的长度为10
,我的张量last_hidden_state size (1,10,768)
。在获得元素的平均值last_hidded_state [:,index-1,:],last_hidded_state [:,index,:]和last_hidden_state [:index+1,index+1,:]
我想删除> last_hidden_state [:,index,:]和last_hidden_state [:,index+1,:]
就像弹出数组中的当前和下一个元素一样。这意味着我应该得到大小(1,8,768)
的张量/代码>。 我做错了什么?我是Pytorch Tensors的新手,谢谢
def function_merge (prev_el, curr_el, next_el,index, array):
if(curr_el.startswith('##') and next_el.startswith('##')):
array[index-1] = prev_el + curr_el + next_el
array.pop(index) #remove current element
array.pop(index) #remove next element
last_hidden_state[:,index-1,:] = torch.add(last_hidden_state[:,index-1,:],last_hidden_state[:,index,:])
last_hidden_state[:,index-1,:] = torch.add(last_hidden_state[:,index-1,:],last_hidden_state[:,index+1,:])
last_hidden_state[:,index-1,:] = torch.mean(last_hidden_state[:,index-1,:])
last_hidden_state = torch.cat((last_hidden_state[:,:index,:],last_hidden_state[:,index+2:,:]), axis=1)
return array, last_hidden_state
I want to drop/pop elements from a tensor in Pytorch, something similar to pop operation in python. In the following code , if the condition is met, it removes two elements from the array, current and the next.
I have a corresponding pytorch tensor. That means if the length of array is 10
, I have the tensor last_hidden_state of size (1,10,768)
. After taking the mean of elements last_hidden_state[:,index-1,:], last_hidden_state[:,index,:] and last_hidden_state[:,index+1,:]
I want to drop last_hidden_state[:,index,:] and last_hidden_state[:,index+1,:]
Just like popping the current and next element from array. That means I should get a tensor of size (1,8,768)
but with this code sometimes it returns (1,7,768)
or (1,6,768)
.
What is it that I am doing wrong? I am new to Pytorch tensors, thank you
def function_merge (prev_el, curr_el, next_el,index, array):
if(curr_el.startswith('##') and next_el.startswith('##')):
array[index-1] = prev_el + curr_el + next_el
array.pop(index) #remove current element
array.pop(index) #remove next element
last_hidden_state[:,index-1,:] = torch.add(last_hidden_state[:,index-1,:],last_hidden_state[:,index,:])
last_hidden_state[:,index-1,:] = torch.add(last_hidden_state[:,index-1,:],last_hidden_state[:,index+1,:])
last_hidden_state[:,index-1,:] = torch.mean(last_hidden_state[:,index-1,:])
last_hidden_state = torch.cat((last_hidden_state[:,:index,:],last_hidden_state[:,index+2:,:]), axis=1)
return array, last_hidden_state
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论