如何为文本Pytorch变压器编写一代函数?

发布于 2025-01-24 23:53:23 字数 1368 浏览 2 评论 0原文

遵循此 pytorch教程自定义数据集。问题是,我已经搜寻了网络,没有找到明确的答案...如何使用此模型生成文本?我通过编码我的SOS和种子文本并将其通过模型的前进方法进行了刺伤……但是这仅会产生重复的垃圾。 SRC_MASK似乎根本不是正确的大小或功能。

def generate(model: nn.Module, src_text:str):
    src=BeatleSet.encode(src_text.lower()) # encodes seed text
    SOS=BeatleSet.textTokDict['<sos>'] ; EOS=BeatleSet.textTokDict['<eos>'] # obtain eos and sos tokens
    model.eval(); entry=[SOS]+src
    y_input=torch.tensor([entry], dtype=torch.long, device=device) # convert entry to tensor
    num_tokens=len(BeatleSet)
    for i in range(50):
        src_mask=generate_square_subsequent_mask(y_input.size(0)).to(device) #create a mask of size 1,1 (???)
        pred=model(y_input, src_mask) # passed through forward method
        next_item = pred.topk(1)[1].view(-1)[-1].item() # selecting the most probable next token (I think)
        next_item = torch.tensor([[next_item]], device=device)
        y_input=torch.cat((y_input, next_item), dim=1) # added to inputs to be run again
        if next_item.view(-1).item() == EOS:
            break
    return " ".join(BeatleSet.decode(y_input.view(-1).tolist()))
    
print(generate(model, "Oh yeah I"))

为了记录,我正在关注该信件的架构。这应该与教程中使用的Wikidata集可重现。请告知,我一直在敲打这个问题。

Following this pytorch tutorial, I'm able to create and train a transformer model on a custom dataset. The problem is, I've scoured the web and have found no clear answers... How do I use this model to generate text? I took a stab at it, by encoding my SOS and seed text and passing it through the model's forward method... But this only produces repeating garbage. The src_mask doesn't appear to be the right size or functioning at all.

def generate(model: nn.Module, src_text:str):
    src=BeatleSet.encode(src_text.lower()) # encodes seed text
    SOS=BeatleSet.textTokDict['<sos>'] ; EOS=BeatleSet.textTokDict['<eos>'] # obtain eos and sos tokens
    model.eval(); entry=[SOS]+src
    y_input=torch.tensor([entry], dtype=torch.long, device=device) # convert entry to tensor
    num_tokens=len(BeatleSet)
    for i in range(50):
        src_mask=generate_square_subsequent_mask(y_input.size(0)).to(device) #create a mask of size 1,1 (???)
        pred=model(y_input, src_mask) # passed through forward method
        next_item = pred.topk(1)[1].view(-1)[-1].item() # selecting the most probable next token (I think)
        next_item = torch.tensor([[next_item]], device=device)
        y_input=torch.cat((y_input, next_item), dim=1) # added to inputs to be run again
        if next_item.view(-1).item() == EOS:
            break
    return " ".join(BeatleSet.decode(y_input.view(-1).tolist()))
    
print(generate(model, "Oh yeah I"))

For the record, I'm following the architecture to the letter. This should be reproducible with the wikidata set that is used in the tutorial. Please advise, I've been banging my head on this one.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文