Use **kwargs in generate_chat_prompt

This commit is contained in:
oobabooga
2023-04-05 21:38:49 -03:00
parent cf239c1232
commit 97e8ea219b
2 changed files with 10 additions and 5 deletions

View File

@@ -18,7 +18,12 @@ from modules.text_generation import (encode, generate_reply,
get_max_prompt_length)
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn="", impersonate=False, also_return_rows=False):
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"]
@@ -112,11 +117,11 @@ def chatbot_wrapper(text, max_new_tokens, generation_params, seed, name1, name2,
visible_text = text
text = apply_extensions(text, "input")
is_instruct = mode == 'instruct'
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn)
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs)
else:
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn)
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs)
# Yield *Is typing...*
if not regenerate: