如何为文本Pytorch变压器编写一代函数?
def generate(model: nn.Module, src_text:str):
src=BeatleSet.encode(src_text.lower()) # encodes seed text
SOS=BeatleSet.textTokDict['<sos>'] ; EOS=BeatleSet.textTokDict['<eos>'] # obtain eos and sos tokens
model.eval(); entry=[SOS]+src
y_input=torch.tensor([entry], dtype=torch.long, device=device) # convert entry to tensor
num_tokens=len(BeatleSet)
for i in range(50):
src_mask=generate_square_subsequent_mask(y_input.size(0)).to(device) #create a mask of size 1,1 (???)
pred=model(y_input, src_mask) # passed through forward method
next_item = pred.topk(1)[1].view(-1)[-1].item() # selecting the most probable next token (I think)
next_item = torch.tensor([[next_item]], device=device)
y_input=torch.cat((y_input, next_item), dim=1) # added to inputs to be run again
if next_item.view(-1).item() == EOS:
break
return " ".join(BeatleSet.decode(y_input.view(-1).tolist()))
print(generate(model, "Oh yeah I"))
为了记录,我正在关注该信件的架构。这应该与教程中使用的Wikidata集可重现。请告知,我一直在敲打这个问题。
Following this pytorch tutorial, I'm able to create and train a transformer model on a custom dataset. The problem is, I've scoured the web and have found no clear answers... How do I use this model to generate text? I took a stab at it, by encoding my SOS and seed text and passing it through the model's forward method... But this only produces repeating garbage. The src_mask doesn't appear to be the right size or functioning at all.
def generate(model: nn.Module, src_text:str):
src=BeatleSet.encode(src_text.lower()) # encodes seed text
SOS=BeatleSet.textTokDict['<sos>'] ; EOS=BeatleSet.textTokDict['<eos>'] # obtain eos and sos tokens
model.eval(); entry=[SOS]+src
y_input=torch.tensor([entry], dtype=torch.long, device=device) # convert entry to tensor
num_tokens=len(BeatleSet)
for i in range(50):
src_mask=generate_square_subsequent_mask(y_input.size(0)).to(device) #create a mask of size 1,1 (???)
pred=model(y_input, src_mask) # passed through forward method
next_item = pred.topk(1)[1].view(-1)[-1].item() # selecting the most probable next token (I think)
next_item = torch.tensor([[next_item]], device=device)
y_input=torch.cat((y_input, next_item), dim=1) # added to inputs to be run again
if next_item.view(-1).item() == EOS:
break
return " ".join(BeatleSet.decode(y_input.view(-1).tolist()))
print(generate(model, "Oh yeah I"))
For the record, I'm following the architecture to the letter. This should be reproducible with the wikidata set that is used in the tutorial. Please advise, I've been banging my head on this one.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论