From 9c3a585915abbbdb97540f290e622d5ec5c9d480 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 00:31:58 -0300 Subject: [PATCH] Create new API --- modules/api.py | 38 ++++++++++++++++++++++++++++++++++++++ server.py | 4 +++- 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 modules/api.py diff --git a/modules/api.py b/modules/api.py new file mode 100644 index 0000000..26249fd --- /dev/null +++ b/modules/api.py @@ -0,0 +1,38 @@ +import json + +import gradio as gr + +from modules import shared +from modules.text_generation import generate_reply + + +def generate_reply_wrapper(string): + generate_params = { + 'do_sample': True, + 'temperature': 1, + 'top_p': 1, + 'typical_p': 1, + 'repetition_penalty': 1, + 'encoder_repetition_penalty': 1, + 'top_k': 50, + 'num_beams': 1, + 'penalty_alpha': 0, + 'min_length': 0, + 'length_penalty': 1, + 'no_repeat_ngram_size': 0, + 'early_stopping': False, + } + params = json.loads(string) + for k in params[1]: + generate_params[k] = params[1][k] + for i in generate_reply(params[0], generate_params): + yield i + +def create_apis(): + t1 = gr.Textbox(visible=False) + t2 = gr.Textbox(visible=False) + dummy = gr.Button(visible=False) + + input_params = [t1] + output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']] + dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen') diff --git a/server.py b/server.py index 14ff577..0f66860 100644 --- a/server.py +++ b/server.py @@ -15,7 +15,7 @@ import gradio as gr from PIL import Image import modules.extensions as extensions_module -from modules import chat, shared, training, ui +from modules import chat, shared, training, ui, api from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt @@ -538,6 +538,8 @@ def create_interface(): else: shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state']) + api.create_apis() + # Authentication auth = None if shared.args.gradio_auth_path is not None: