生成器 `max_length` 的 query() 成功

发布于 2025-01-11 20:59:56 字数 836 浏览 0 评论 0原文

目标:在 Hugging Face Transformers 生成器查询中设置 min_lengthmax_length

我已传递 50, 200 作为这些参数。然而,我的输出长度要高得多......

没有运行时故障。

from transformers import pipeline, set_seed
generator = pipeline('text-generation', model='gpt2')
set_seed(42)

def query(payload, multiple, min_char_len, max_char_len):
    print(min_char_len, max_char_len)
    list_dict = generator(payload, min_length=min_char_len, max_length=max_char_len, num_return_sequences=multiple)
    test = [d['generated_text'].split(payload)[1].strip() for d in list_dict]
    for t in test: print(len(t))
    return test

query('example', 1, 50, 200)

输出:

50 200
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
1015

Goal: set min_length and max_length in Hugging Face Transformers generator query.

I've passed 50, 200 as these parameters. Yet, the length of my outputs are much higher...

There's no runtime failure.

from transformers import pipeline, set_seed
generator = pipeline('text-generation', model='gpt2')
set_seed(42)

def query(payload, multiple, min_char_len, max_char_len):
    print(min_char_len, max_char_len)
    list_dict = generator(payload, min_length=min_char_len, max_length=max_char_len, num_return_sequences=multiple)
    test = [d['generated_text'].split(payload)[1].strip() for d in list_dict]
    for t in test: print(len(t))
    return test

query('example', 1, 50, 200)

Output:

50 200
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
1015

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

吃兔兔 2025-01-18 20:59:56

说明:

正如 Narsil 在 Hugging Face

Explanation:

As explained by Narsil on Hugging Face ???? Transformers Git Issue response

Models, don't ingest the text one character at a time, but one token
at a time. There are different algorithms to achieve this but
basically "My name is Nicolas" gets transformers into ["my", " name",
" is", " nic", "olas"] for instance, and each of those tokens have a
number.

So when you are generating tokens, they can contain themselves 1 or
more characters (usually several and almost any common word for
instance). That's why you are seeing 1015 instead of your expected 200
(the tokens here have an average of 5 chars)

Solution:

As I resolved...

Rename min_char_len, max_char_len to min_tokens, max_tokens and
simply reduce their values by a ~1/4 or 1/5.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文