From 849a54ef2dfd0c64342c2f1243aded9c337b1163 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 22:53:21 -0300 Subject: [PATCH] Remove variables --- modules/text_generation.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 24f2b69..ce912ee 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -103,12 +103,8 @@ def stop_everything_event(): shared.stop_everything = True def generate_reply(question, generate_params, eos_token=None, stopping_strings=[]): - max_new_tokens = generate_params['max_new_tokens'] - seed = generate_params['seed'] - print(generate_params) - print('---------------') clear_torch_cache() - set_manual_seed(seed) + set_manual_seed(generate_params['seed']) shared.stop_everything = False updated_params = {} t0 = time.time() @@ -155,7 +151,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[ print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})") return - input_ids = encode(question, max_new_tokens) + input_ids = encode(question, generate_params['max_new_tokens']) original_input_ids = input_ids output = input_ids[0] @@ -168,7 +164,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[ t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings] stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) - updated_params["max_new_tokens"] = max_new_tokens + updated_params["max_new_tokens"] = generate_params['max_new_tokens'] if not shared.args.flexgen: updated_params["eos_token_id"] = eos_token_ids updated_params["stopping_criteria"] = stopping_criteria_list @@ -244,7 +240,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[ # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' else: - for i in range(max_new_tokens//8+1): + for i in range(generate_params['max_new_tokens']//8+1): clear_torch_cache() with torch.no_grad(): output = shared.model.generate(**updated_params)[0]