使用Mode.generate()从变形金刚 - typeError:forward()获得意外关键字参数' return_dict&return_dict'
我试图通过Transformers库中的FineTuned GPT2HeadWithValueModel进行推断。我正在使用generation_utils.py的model.generate()方法,
我正在使用此函数来调用generate()方法:
def top_p_sampling(text, model, tokenizer):
encoding = tokenizer(text, return_tensors="pt")['input_ids']
output_tensor = model.generate(
encoding,
do_sample=True,
max_length=max_len,
top_k=50,
top_p= .92,
temperature= .9,
early_stopping=False)
return tokenizer.decode(output_tensor[0], skip_special_tokens=True).strip()
但是当我尝试时:
text = "this is an example of input text"
comp = top_p_sampling(text, model_name, tokenizer_name)
我会收到以下错误:
TypeError: forward() got an unexpected keyword argument 'return_dict'
Full Trackback:
TypeError Traceback (most recent call last)
<ipython-input-24-cc7c3f8aa367> in <module>()
1 text = "this is an example of input text"
----> 2 comp = top_p_sampling(text, model_name, tokenizer_name)
4 frames
<ipython-input-23-a5241487f309> in top_p_sampling(text, model, tokenizer)
9 temperature=temp,
10 early_stopping=False,
---> 11 return_dict=False)
12
13 return tokenizer.decode(output_tensor[0], skip_special_tokens=True).strip()
/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, **model_kwargs)
938 output_scores=output_scores,
939 return_dict_in_generate=return_dict_in_generate,
--> 940 **model_kwargs,
941 )
942
/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py in sample(self, input_ids, logits_processor, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, **model_kwargs)
1383 return_dict=True,
1384 output_attentions=output_attentions,
-> 1385 output_hidden_states=output_hidden_states,
1386 )
1387 next_token_logits = outputs.logits[:, -1, :]
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() got an unexpected keyword argument 'return_dict'
我有点像新秀,所以我希望有人可以指出我在做什么错。多谢
I am trying to perform inference with a finetuned GPT2HeadWithValueModel from the Transformers library. I'm using the model.generate() method from generation_utils.py
I am using this function to call the generate() method:
def top_p_sampling(text, model, tokenizer):
encoding = tokenizer(text, return_tensors="pt")['input_ids']
output_tensor = model.generate(
encoding,
do_sample=True,
max_length=max_len,
top_k=50,
top_p= .92,
temperature= .9,
early_stopping=False)
return tokenizer.decode(output_tensor[0], skip_special_tokens=True).strip()
But when i try:
text = "this is an example of input text"
comp = top_p_sampling(text, model_name, tokenizer_name)
I get the following error:
TypeError: forward() got an unexpected keyword argument 'return_dict'
Full traceback:
TypeError Traceback (most recent call last)
<ipython-input-24-cc7c3f8aa367> in <module>()
1 text = "this is an example of input text"
----> 2 comp = top_p_sampling(text, model_name, tokenizer_name)
4 frames
<ipython-input-23-a5241487f309> in top_p_sampling(text, model, tokenizer)
9 temperature=temp,
10 early_stopping=False,
---> 11 return_dict=False)
12
13 return tokenizer.decode(output_tensor[0], skip_special_tokens=True).strip()
/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, **model_kwargs)
938 output_scores=output_scores,
939 return_dict_in_generate=return_dict_in_generate,
--> 940 **model_kwargs,
941 )
942
/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py in sample(self, input_ids, logits_processor, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, **model_kwargs)
1383 return_dict=True,
1384 output_attentions=output_attentions,
-> 1385 output_hidden_states=output_hidden_states,
1386 )
1387 next_token_logits = outputs.logits[:, -1, :]
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() got an unexpected keyword argument 'return_dict'
I'm a bit of a rookie, so I hope someone can point out what I'm doing wrong. Thanks a lot
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论